xref: /aosp_15_r20/external/pytorch/test/inductor/test_aot_inductor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import copy
3import itertools
4import os
5import sys
6import tempfile
7import types
8import unittest
9from typing import Dict, Tuple
10from unittest import skip
11
12import torch
13import torch._export
14import torch._inductor
15import torch._inductor.config
16import torch.nn as nn
17from torch._dynamo.testing import rand_strided, same
18from torch._dynamo.utils import counters
19from torch._inductor import config
20from torch._inductor.exc import CppWrapperCodeGenError
21from torch._inductor.runtime.runtime_utils import cache_dir
22from torch._inductor.test_case import TestCase
23from torch._inductor.utils import run_and_get_cpp_code
24from torch.export import Dim, export
25from torch.testing import FileCheck
26from torch.testing._internal import common_utils
27from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater
28from torch.testing._internal.common_quantization import (
29    skip_if_no_torchvision,
30    skipIfNoFBGEMM,
31)
32from torch.testing._internal.common_utils import (
33    DeterministicGuard,
34    find_library_location,
35    IS_CI,
36    IS_FBCODE,
37    IS_MACOS,
38    IS_SANDCASTLE,
39    IS_WINDOWS,
40    skipIfRocm,
41    TEST_WITH_ROCM,
42)
43from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
44from torch.utils import _pytree as pytree
45
46
47if HAS_CUDA:
48    import triton
49
50    from torch.testing._internal.triton_utils import (
51        add_kernel,
52        add_kernel_2d_autotuned,
53        add_kernel_autotuned,
54        add_kernel_autotuned_weird_param_order,
55        add_kernel_with_optional_param,
56        add_kernel_with_scaling,
57        mul2_inplace_kernel,
58    )
59
60if IS_WINDOWS and IS_CI:
61    sys.stderr.write(
62        "Windows CI does not have necessary dependencies for test_torchinductor yet\n"
63    )
64    if __name__ == "__main__":
65        sys.exit(0)
66    raise unittest.SkipTest("requires sympy/functorch/filelock")
67
68try:
69    try:
70        from .test_aot_inductor_utils import AOTIRunnerUtil
71        from .test_control_flow import (
72            CondModels,
73            prepend_counters,
74            prepend_predicates,
75            WhileLoopModels,
76        )
77        from .test_torchinductor import copy_tests, requires_multigpu, TestFailure
78    except ImportError:
79        from test_aot_inductor_utils import AOTIRunnerUtil
80        from test_control_flow import (
81            CondModels,
82            prepend_counters,
83            prepend_predicates,
84            WhileLoopModels,
85        )
86        from test_torchinductor import copy_tests, requires_multigpu, TestFailure
87except (unittest.SkipTest, ImportError) as e:
88    if __name__ == "__main__":
89        sys.exit(0)
90    raise
91
92
93def check_model(
94    self: TestCase,
95    model,
96    example_inputs,
97    options=None,
98    dynamic_shapes=None,
99    disable_constraint_solver=False,
100    atol=None,
101    rtol=None,
102):
103    with torch.no_grad(), config.patch(
104        {
105            "abi_compatible": self.abi_compatible,
106            "allow_stack_allocation": self.allow_stack_allocation,
107            "use_minimal_arrayref_interface": self.use_minimal_arrayref_interface,
108        }
109    ):
110        torch.manual_seed(0)
111        if not isinstance(model, types.FunctionType):
112            model = model.to(self.device)
113        ref_model = copy.deepcopy(model)
114        ref_inputs = copy.deepcopy(example_inputs)
115        expected = ref_model(*ref_inputs)
116
117        torch.manual_seed(0)
118        actual = AOTIRunnerUtil.run(
119            self.device,
120            model,
121            example_inputs,
122            options,
123            dynamic_shapes,
124            disable_constraint_solver,
125        )
126
127    self.assertEqual(actual, expected, atol=atol, rtol=rtol)
128
129
130def check_model_with_multiple_inputs(
131    self: TestCase,
132    model,
133    list_example_inputs,
134    options=None,
135    dynamic_shapes=None,
136):
137    with torch.no_grad(), config.patch(
138        {
139            "abi_compatible": self.abi_compatible,
140            "allow_stack_allocation": self.allow_stack_allocation,
141        }
142    ):
143        torch.manual_seed(0)
144        model = model.to(self.device)
145        ref_model = copy.deepcopy(model)
146        ref_inputs = copy.deepcopy(list_example_inputs)
147        list_expected = [ref_model(*inputs) for inputs in ref_inputs]
148
149        torch.manual_seed(0)
150        list_actual = AOTIRunnerUtil.run_multiple(
151            self.device, model, list_example_inputs, options, dynamic_shapes
152        )
153
154    self.assertTrue(same(list_actual, list_expected))
155
156
157def code_check_count(
158    self: TestCase,
159    model,
160    example_inputs,
161    target_str: str,
162    target_count: int,
163):
164    so_path = torch._export.aot_compile(model, example_inputs)
165    with open(os.path.splitext(so_path)[0] + ".cpp") as cpp:
166        src_code = cpp.read()
167        FileCheck().check_count(
168            target_str,
169            target_count,
170            exactly=True,
171        ).run(src_code)
172
173
174class AOTInductorTestsTemplate:
175    def test_simple(self):
176        class Model(torch.nn.Module):
177            def __init__(self) -> None:
178                super().__init__()
179                self.linear = torch.nn.Linear(10, 10)
180
181            def forward(self, x, y):
182                return x + self.linear(y)
183
184        example_inputs = (
185            torch.randn(10, 10, device=self.device),
186            torch.randn(10, 10, device=self.device),
187        )
188        self.check_model(Model(), example_inputs)
189
190    def test_small_constant(self):
191        class Model(torch.nn.Module):
192            def __init__(self) -> None:
193                super().__init__()
194                self.linear = torch.nn.Linear(4, 4)
195
196            def forward(self, x):
197                return self.linear(x)
198
199        example_inputs = (torch.randn(4, 4, device=self.device),)
200        with config.patch({"always_keep_tensor_constants": True}):
201            self.check_model(Model().to(self.device), example_inputs)
202
203    def test_output_path_1(self):
204        class Model(torch.nn.Module):
205            def __init__(self) -> None:
206                super().__init__()
207                self.linear = torch.nn.Linear(10, 10)
208
209            def forward(self, x, y):
210                return x + self.linear(y)
211
212        example_inputs = (
213            torch.randn(10, 10, device=self.device),
214            torch.randn(10, 10, device=self.device),
215        )
216        with config.patch("aot_inductor.output_path", "tmp_output_"):
217            self.check_model(Model(), example_inputs)
218
219    def test_output_path_2(self):
220        class Model(torch.nn.Module):
221            def __init__(self) -> None:
222                super().__init__()
223                self.linear = torch.nn.Linear(10, 10)
224
225            def forward(self, x, y):
226                return x + self.linear(y)
227
228        model = Model().to(device=self.device)
229        example_inputs = (
230            torch.randn(10, 10, device=self.device),
231            torch.randn(10, 10, device=self.device),
232        )
233        expected_path = os.path.join(tempfile.mkdtemp(dir=cache_dir()), "model.so")
234        actual_path = AOTIRunnerUtil.compile(
235            model, example_inputs, options={"aot_inductor.output_path": expected_path}
236        )
237        self.assertTrue(actual_path == expected_path)
238
239    def test_constant_folding(self):
240        class Model(torch.nn.Module):
241            def __init__(self, device):
242                super().__init__()
243                self.w_pre = torch.randn(4, 4, device=device)
244                self.b = torch.randn(4, device=device)
245
246            def forward(self, x):
247                w_transpose = torch.transpose(self.w_pre, 0, 1)
248                w_relu = torch.nn.functional.relu(w_transpose)
249                w = w_relu + self.b
250                return torch.matmul(x, w)
251
252        example_inputs = (torch.randn(4, 4, device=self.device),)
253        with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
254            self.check_model(Model(self.device), example_inputs)
255
256    @requires_cuda
257    def test_duplicate_constant_folding(self):
258        class Model(torch.nn.Module):
259            def __init__(self, device):
260                super().__init__()
261                self.w1 = torch.randn(4, 4, device=device)
262                self.w2 = torch.randn(4, 4, device=device)
263                self.w3 = torch.randn(4, 4, device=device)
264                self.w4 = torch.randn(4, 4, device=device)
265
266            def forward(self, x):
267                w_concat = torch.cat((self.w1, self.w2, self.w3, self.w4))
268                return torch.cat((x, w_concat))
269
270        example_inputs = (torch.randn(4, 4, device=self.device),)
271        with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
272            self.check_model(Model(self.device), example_inputs)
273
274    @requires_cuda
275    def test_multi_device(self):
276        class Model(torch.nn.Module):
277            def forward(self, x):
278                x = x + 1
279                x = x.cpu()
280                x = x + 2
281                x = x.cuda()
282                return x
283
284        example_inputs = (torch.randn(32, 64, device=self.device),)
285        self.check_model(Model(), example_inputs)
286
287    def test_large_weight(self):
288        class Model(torch.nn.Module):
289            def __init__(self) -> None:
290                super().__init__()
291                self.linear = torch.nn.Linear(2048, 262144)
292
293            def forward(self, x, y):
294                return x + self.linear(y)
295
296        example_inputs = (
297            torch.randn(1, 262144, device=self.device),
298            torch.randn(1, 2048, device=self.device),
299        )
300
301        # We only test compilation since we often get OOM running in CI.
302        model = Model()
303        model = model.to(self.device)
304        AOTIRunnerUtil.compile(model, example_inputs)
305
306    def test_large_mmaped_weights(self):
307        class Model(torch.nn.Module):
308            def __init__(self) -> None:
309                super().__init__()
310                self.linear = torch.nn.Linear(512, 250112)
311
312            def forward(self, x, y):
313                return x + self.linear(y)
314
315        example_inputs = (
316            torch.randn(1, 250112, device=self.device),
317            torch.randn(1, 512, device=self.device),
318        )
319        with config.patch({"aot_inductor.force_mmap_weights": True}):
320            self.check_model(Model(), example_inputs)
321
322    def test_with_offset(self):
323        class Model(torch.nn.Module):
324            def __init__(self, device):
325                super().__init__()
326                self.orig_tensor = torch.randn(2, 15, 10, device=device)[0]
327                self.tensor = self.orig_tensor[5:, :]
328
329            def forward(self, x, y):
330                return (
331                    x
332                    + torch.nn.functional.linear(y, self.orig_tensor[:10, :])
333                    + self.tensor
334                )
335
336        example_inputs = (
337            torch.randn(10, 10, device=self.device),
338            torch.randn(10, 10, device=self.device),
339        )
340        self.check_model(Model(self.device), example_inputs)
341
342    @unittest.skipIf(
343        IS_FBCODE,
344        "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
345    )
346    def test_freezing(self):
347        class Model(torch.nn.Module):
348            def __init__(self, device):
349                super().__init__()
350                self.weight = torch.randn(9, 10, device=device)
351                self.padding = torch.randn(1, 10, device=device)
352
353            def forward(self, x, y):
354                padded_weight = torch.cat((self.weight, self.padding), dim=0)
355                return x + torch.nn.functional.linear(y, padded_weight)
356
357        example_inputs = (
358            torch.randn(10, 10, device=self.device),
359            torch.randn(10, 10, device=self.device),
360        )
361
362        with config.patch({"freezing": True}):
363            self.check_model(Model(self.device), example_inputs)
364
365    @unittest.skipIf(
366        IS_FBCODE,
367        "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
368    )
369    def test_conv_freezing(self):
370        for dtype, groups in itertools.product([torch.bfloat16, torch.float], [1, 2]):
371            iC = 2
372            oC = 3
373
374            class Model(torch.nn.Module):
375                def __init__(self, device):
376                    super().__init__()
377                    self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to(
378                        dtype
379                    )
380
381                def forward(self, y):
382                    return torch.nn.functional.conv2d(y, self.weight, groups=groups)
383
384            example_inputs = (
385                torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype),
386            )
387
388            with config.patch({"freezing": True}):
389                self.check_model(Model(self.device), example_inputs)
390
391    @unittest.skipIf(
392        IS_FBCODE,
393        "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
394    )
395    def test_deconv_freezing(self):
396        dtypes = [torch.float]
397        if torch._C._has_mkldnn and torch.ops.mkldnn._is_mkldnn_bf16_supported():
398            dtypes.append(torch.bfloat16)
399        for dtype, groups in itertools.product(dtypes, [2, 1]):
400            iC = 4
401            oC = 2
402
403            class Model(torch.nn.Module):
404                def __init__(self, device):
405                    super().__init__()
406                    self.weight = torch.randn(iC, oC * groups, 2, 2, device=device).to(
407                        dtype
408                    )
409
410                def forward(self, y):
411                    return torch.nn.functional.conv_transpose2d(
412                        y, self.weight, groups=groups
413                    )
414
415            example_inputs = (torch.randn(1, iC, 3, 3, device=self.device).to(dtype),)
416            with config.patch({"freezing": True}):
417                self.check_model(Model(self.device), example_inputs)
418
419    @unittest.skipIf(
420        IS_FBCODE,
421        "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
422    )
423    def test_linear_freezing(self):
424        for dtype in [torch.float32, torch.bfloat16]:
425
426            class LinearModel(torch.nn.Module):
427                def __init__(self, device):
428                    super().__init__()
429                    self.weight = torch.randn(10, 10, device=device).to(dtype)
430                    self.bias = torch.randn(10, device=device).to(dtype)
431
432                def forward(self, y):
433                    return torch.nn.functional.linear(y, self.weight, self.bias)
434
435            example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),)
436
437            with config.patch({"freezing": True}):
438                self.check_model(LinearModel(self.device), example_inputs)
439
440    @torch._inductor.config.patch(
441        pre_grad_fusion_options={
442            "normalization_pass": {},
443            "remove_split_with_size_one_pass": {},
444            "merge_getitem_cat_pass": {},
445            "merge_stack_tahn_unbind_pass": {},
446            "merge_splits_pass": {},
447            "mutate_cat_pass": {},
448            "split_cat_pass": {},
449            "unbind_stack_pass": {},
450        },
451        post_grad_fusion_options={},
452    )
453    def test_simple_split(self):
454        class Model(torch.nn.Module):
455            def __init__(self) -> None:
456                super().__init__()
457
458            def forward(self, x):
459                return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2)
460
461        example_inputs = (torch.randn(2, 8, device=self.device),)
462        counters.clear()
463        self.check_model(Model(), example_inputs)
464        self.assertEqual(counters["inductor"]["scmerge_split_removed"], 1)
465        self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 1)
466        self.assertEqual(counters["inductor"]["scmerge_split_sections_removed"], 1)
467
468    def test_amp_fallback_random(self):
469        def fn(x, w):
470            return torch.functional.F.linear(x, w)
471
472        example_inputs = (
473            torch.randn(10, 10, device=self.device),
474            torch.randn(10, 10, device=self.device),
475        )
476        if self.device == "cuda":
477            ctx = torch.cuda.amp.autocast
478        elif self.device == "cpu":
479            ctx = torch.cpu.amp.autocast
480        else:
481            raise AssertionError("Unsupported device")
482
483        with config.patch({"fallback_random": True}):
484            with ctx():
485                self.check_model(fn, example_inputs)
486
487    def test_missing_output(self):
488        class Model(torch.nn.Module):
489            def __init__(self) -> None:
490                super().__init__()
491
492            def forward(self, x, y):
493                a = torch.sin(x)
494                b = torch.mm(a, y)
495                c = torch.cos(b)
496                return c
497
498        example_inputs = (
499            torch.randn(10, 10, device=self.device),
500            torch.randn(10, 10, device=self.device),
501        )
502        self.check_model(Model(), example_inputs)
503
504    def test_output_misaligned(self):
505        class Model(torch.nn.Module):
506            def __init__(self) -> None:
507                super().__init__()
508
509            def forward(self, x, y):
510                x_unsqueeze = torch.unsqueeze(x, dim=0)
511                y_unsqueeze = torch.unsqueeze(y, dim=0)
512                cat = torch.cat([x_unsqueeze, y_unsqueeze], dim=0)
513                x_getitem = cat[0]
514                y_getitem = cat[1]
515                x_sigmoid = torch.sigmoid(x_getitem)
516                return x_sigmoid, y_getitem
517
518        example_inputs = (
519            torch.randn(10, 10, device=self.device),
520            torch.randn(10, 10, device=self.device),
521        )
522        self.check_model(Model(), example_inputs)
523
524    @skip("Test was marked as expected failure, but does not fail always anymore.")
525    def test_dynamic_smem_above_default_limit(self):
526        class Model(torch.nn.Module):
527            def forward(self, x, y):
528                return x @ y
529
530        model = Model().to(self.device)
531        # on A100, the generated Triton kernel for this MM
532        # requires 55296 bytes of dynamic SMEM which is above
533        # the A100's default dynamic SMEM limit of 49152 bytes.
534        example_inputs = (
535            torch.randn(10285, 96, device=self.device),
536            torch.randn(96, 1, device=self.device),
537        )
538        self.check_model(
539            model,
540            example_inputs,
541            options={
542                "max_autotune": True,
543                "max_autotune_gemm_backends": "TRITON",
544            },
545        )
546
547    @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
548    def test_seq(self):
549        layernorm = torch.nn.LayerNorm(10)
550        net = torch.nn.Sequential(
551            layernorm,
552            torch.nn.ReLU(),
553            layernorm,
554            torch.nn.ReLU(),
555        )
556
557        example_inputs = (torch.randn(10, device=self.device),)
558        self.check_model(net.eval(), example_inputs)
559
560    def test_addmm(self):
561        class Model(torch.nn.Module):
562            def __init__(self, n, k, device):
563                super().__init__()
564                self.weight = torch.randn(n, k, device=device)
565                self.bias = torch.randn(n, device=device)
566
567            def forward(self, a):
568                return torch.nn.functional.linear(a, self.weight, self.bias)
569
570        M = 8
571        N = 6
572        K = 16
573        model = Model(N, K, self.device)
574        batch = 2
575        a = torch.randn(batch, M, K, device=self.device)
576        example_inputs = (a,)
577        self.check_model(model, example_inputs)
578
579    def test_aliased_buffer_reuse(self):
580        class Model(torch.nn.Module):
581            def __init__(self) -> None:
582                super().__init__()
583
584            def forward(self, x, y):
585                x = 2 * x
586                y = 2 * y
587                c = torch.cat([x, y], dim=-1)
588                d = 1 + c
589                m = torch.mm(d, d)
590                return m[:, :2] + x
591
592        example_inputs = (
593            torch.randn(4, 2, device=self.device),
594            torch.randn(4, 2, device=self.device),
595        )
596        self.check_model(Model(), example_inputs)
597
598    def test_buffer_reuse(self):
599        class Model(torch.nn.Module):
600            def __init__(self) -> None:
601                super().__init__()
602
603            def forward(self, x, y):
604                a = torch.sin(x)
605                b = torch.cos(y)
606                c = torch.mm(a, b)
607                d = torch.relu(c)
608                e = torch.sigmoid(d)
609                f = torch.mm(x, y)
610                g = e + f
611                return g
612
613        example_inputs = (
614            torch.randn(4, 4, device=self.device),
615            torch.randn(4, 4, device=self.device),
616        )
617        self.check_model(Model(), example_inputs)
618
619    def test_duplicated_params(self):
620        class Model(torch.nn.Module):
621            def __init__(self) -> None:
622                super().__init__()
623                self.p = torch.nn.Parameter(torch.rand(6))
624                self.q = self.p
625
626            def forward(self, x):
627                return self.p * x + self.q
628
629        example_inputs = (torch.rand(6, device=self.device),)
630        self.check_model(Model(), example_inputs)
631
632    @unittest.skip("Skip this test, only for local test. SIGABRT is produced.")
633    def test_inf(self):
634        class Model(torch.nn.Module):
635            def __init__(self) -> None:
636                super().__init__()
637                self.linear = torch.nn.Linear(10, 10)
638
639            def forward(self, x, y):
640                return x + self.linear(y)
641
642        x = torch.randn(10, 10, device=self.device)
643        x[0][0] = float("Inf")
644        example_inputs = (
645            x,
646            torch.randn(10, 10, device=self.device),
647        )
648        self.check_model(
649            Model().to(self.device),
650            example_inputs,
651            options={"debug_check_inf_and_nan": True},
652        )
653
654    @unittest.skip("Skip this test, only for local test. SIGABRT is produced.")
655    def test_nan(self):
656        class Model(torch.nn.Module):
657            def __init__(self) -> None:
658                super().__init__()
659                self.linear = torch.nn.Linear(10, 10)
660
661            def forward(self, x, y):
662                return x + self.linear(y)
663
664        x = torch.randn(10, 10, device=self.device)
665        x[0][0] = float("nan")
666        example_inputs = (
667            x,
668            torch.randn(10, 10, device=self.device),
669        )
670        self.check_model(
671            Model().to(self.device),
672            example_inputs,
673            options={"debug_check_inf_and_nan": True},
674        )
675
676    def test_assert_async(self):
677        if self.device != "cuda":
678            raise unittest.SkipTest("requires CUDA")
679
680        class Model(torch.nn.Module):
681            def __init__(self) -> None:
682                super().__init__()
683
684            def forward(self, x):
685                u0 = x.item()
686                torch._check(u0 > 3)
687                return torch.ones(u0)[0]
688
689        x = torch.tensor(23, device=self.device)
690        example_inputs = (x,)
691        self.check_model(Model(), example_inputs)
692
693    def test_simple_dynamic(self):
694        class Model(torch.nn.Module):
695            def __init__(self) -> None:
696                super().__init__()
697
698            def forward(self, x, y):
699                add_0 = x + y
700                return torch.nn.functional.relu(input=add_0, inplace=False)
701
702        x = torch.randn(128, 2048, device=self.device)
703        y = torch.randn(128, 2048, device=self.device)
704        dim0_x = Dim("dim0_x", min=1, max=2048)
705        dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
706        example_inputs = (x, y)
707        self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
708
709    @unittest.skipIf(
710        not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
711        "FP8 is only supported on H100+",
712    )
713    @skipIfRocm  # _scaled_mm_out_cuda  is not compiled for ROCm platform
714    def test_fp8(self):
715        class Model(torch.nn.Module):
716            def __init__(self, dtype):
717                super().__init__()
718                self.out_dtype = dtype
719
720            def forward(self, x, weight, bias, scale_a, scale_b):
721                weight = weight.to(torch.float8_e4m3fn)
722                output = torch._scaled_mm(
723                    x,
724                    weight,
725                    bias=input_bias,
726                    out_dtype=self.out_dtype,
727                    scale_a=scale_a,
728                    scale_b=scale_b,
729                )
730                return output
731
732        dtype = torch.float16
733
734        a_scale = torch.Tensor([1.0]).to(device="cuda")
735        b_scale = torch.Tensor([1.0]).to(device="cuda")
736        input_bias = torch.rand(32, device="cuda", dtype=dtype)
737        weight_shape = (32, 16)
738        weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T
739        a_inverse_scale = 1 / a_scale
740        b_inverse_scale = 1 / b_scale
741
742        x_shape = (16, 16)
743        x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(torch.float8_e4m3fn)
744        dim0_x = Dim("dim0_x", min=1, max=2048)
745        dynamic_shapes = ({0: dim0_x}, None, None, None, None)
746        self.check_model(
747            Model(dtype),
748            (x, weight, input_bias, a_inverse_scale, b_inverse_scale),
749            dynamic_shapes=dynamic_shapes,
750        )
751
752    @unittest.skipIf(
753        not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
754        "FP8 is only supported on H100+",
755    )
756    @skipIfRocm  # _scaled_mm_out_cuda  is not compiled for ROCm platform
757    def test_fp8_view_of_param(self):
758        # cuda only
759        if self.device != "cuda":
760            return
761
762        class Model(torch.nn.Module):
763            def __init__(self, dtype, weight):
764                super().__init__()
765                self.out_dtype = dtype
766                self.weight = weight
767
768            def forward(self, x, bias, scale_a, scale_b):
769                # test: do the view inside of the graph,
770                # AOTI needs to materialize this view before passing
771                # it into the scaled_mm extern kernel
772                weight = self.weight.T
773                output = torch._scaled_mm(
774                    x,
775                    weight,
776                    bias=input_bias,
777                    out_dtype=self.out_dtype,
778                    scale_a=scale_a,
779                    scale_b=scale_b,
780                )
781                return output
782
783        dtype = torch.float16
784
785        a_scale = torch.Tensor([1.0]).to(device=self.device)
786        b_scale = torch.Tensor([1.0]).to(device=self.device)
787        input_bias = torch.rand(32, device=self.device, dtype=dtype)
788        weight_shape = (32, 16)
789        weight = torch.rand(*weight_shape, device=self.device, dtype=dtype).to(
790            torch.float8_e4m3fn
791        )
792        a_inverse_scale = 1 / a_scale
793        b_inverse_scale = 1 / b_scale
794
795        x_shape = (16, 16)
796        x = torch.rand(*x_shape, device=self.device, dtype=dtype).to(
797            torch.float8_e4m3fn
798        )
799        dim0_x = Dim("dim0_x", min=1, max=2048)
800        dynamic_shapes = ({0: dim0_x}, None, None, None)
801        self.check_model(
802            Model(dtype, weight),
803            (x, input_bias, a_inverse_scale, b_inverse_scale),
804            dynamic_shapes=dynamic_shapes,
805        )
806
807    def test_poi_multiple_dynamic(self):
808        class Model(torch.nn.Module):
809            def __init__(self) -> None:
810                super().__init__()
811
812            def forward(self, x, y):
813                add_0 = x + y
814                return torch.nn.functional.relu(input=add_0, inplace=False)
815
816        x = torch.randn(128, 2048, device=self.device)
817        y = torch.randn(128, 2048, device=self.device)
818        dim0_x = Dim("dim0_x", min=1, max=2048)
819        dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
820        list_example_inputs = [(x, y)]
821        list_example_inputs.append(
822            (
823                torch.randn(64, 2048, device=self.device),
824                torch.randn(64, 2048, device=self.device),
825            ),
826        )
827        list_example_inputs.append(
828            (
829                torch.randn(211, 2048, device=self.device),
830                torch.randn(211, 2048, device=self.device),
831            ),
832        )
833        self.check_model_with_multiple_inputs(
834            Model(), list_example_inputs, dynamic_shapes=dynamic_shapes
835        )
836
837    def test_addmm_multiple_dynamic(self):
838        class Model(torch.nn.Module):
839            def __init__(self, n, k, device):
840                super().__init__()
841                self.weight = torch.randn(n, k, device=device)
842                self.bias = torch.randn(n, device=device)
843
844            def forward(self, a):
845                return torch.nn.functional.linear(a, self.weight, self.bias)
846
847        M = 8
848        N = 6
849        K = 16
850        model = Model(N, K, self.device)
851        batch = 2
852        a = torch.randn(batch, M, K, device=self.device)
853        dim0_a = Dim("dim0_a", min=1, max=2048)
854        dynamic_shapes = {"a": {0: dim0_a}}
855        list_example_inputs = [(a,)]
856        batch = 2048
857        list_example_inputs.append(
858            (torch.randn(batch, M, K, device=self.device),),
859        )
860        batch = 128
861        list_example_inputs.append(
862            (torch.randn(batch, M, K, device=self.device),),
863        )
864        self.check_model_with_multiple_inputs(
865            model,
866            list_example_inputs,
867            dynamic_shapes=dynamic_shapes,
868            options={
869                "max_autotune": True,
870                "max_autotune_gemm_backends": "TRITON",
871            },
872        )
873
874    def test_bmm_multiple_dynamic(self):
875        class Model(torch.nn.Module):
876            def __init__(self) -> None:
877                super().__init__()
878
879            def forward(self, a, b):
880                return torch.bmm(a, b)
881
882        M = 8
883        N = 6
884        K = 16
885        model = Model()
886        batch = 1024
887        a = torch.randn(batch, M, K, device=self.device)
888        b = torch.randn(batch, K, N, device=self.device)
889        dim0_a = Dim("dim0_a", min=1, max=2048)
890        dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_a}}
891        list_example_inputs = [(a, b)]
892        batch = 2048
893        list_example_inputs.append(
894            (
895                torch.randn(batch, M, K, device=self.device),
896                torch.randn(batch, K, N, device=self.device),
897            ),
898        )
899        batch = 128
900        list_example_inputs.append(
901            (
902                torch.randn(batch, M, K, device=self.device),
903                torch.randn(batch, K, N, device=self.device),
904            ),
905        )
906        self.check_model_with_multiple_inputs(
907            model,
908            list_example_inputs,
909            options={
910                "max_autotune": True,
911                "max_autotune_gemm_backends": "TRITON",
912            },
913            dynamic_shapes=dynamic_shapes,
914        )
915
916    def test_foreach_multiple_dynamic(self):
917        class Model(torch.nn.Module):
918            def __init__(self) -> None:
919                super().__init__()
920
921            def forward(self, x, y):
922                x_unsqueeze = torch.unsqueeze(x, dim=0)
923                y_unsqueeze = torch.unsqueeze(y, dim=0)
924                cat = torch.cat([x_unsqueeze, y_unsqueeze], dim=0)
925                return cat
926
927        model = Model()
928        x = torch.randn(128, 2048, device=self.device)
929        y = torch.randn(128, 2048, device=self.device)
930        dim0_x = Dim("dim0_x", min=1, max=2048)
931        dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
932        list_example_inputs = [(x, y)]
933        list_example_inputs.append(
934            (
935                torch.randn(64, 2048, device=self.device),
936                torch.randn(64, 2048, device=self.device),
937            ),
938        )
939        list_example_inputs.append(
940            (
941                torch.randn(211, 2048, device=self.device),
942                torch.randn(211, 2048, device=self.device),
943            ),
944        )
945        self.check_model_with_multiple_inputs(
946            model,
947            list_example_inputs,
948            dynamic_shapes=dynamic_shapes,
949        )
950
951    # scaled_dot_product_flash_attention
952    @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
953    @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
954    def test_sdpa(self):
955        class Model(torch.nn.Module):
956            def __init__(self) -> None:
957                super().__init__()
958
959            def forward(self, q, k, v):
960                return torch.nn.functional.scaled_dot_product_attention(q, k, v)[0]
961
962        example_inputs = (
963            torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
964            torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
965            torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
966        )
967        self.check_model(Model(), example_inputs)
968
969    @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
970    @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
971    def test_sdpa_2(self):
972        class Model(torch.nn.Module):
973            def __init__(self) -> None:
974                super().__init__()
975
976            def forward(self, q, k, v, x):
977                t = torch.nn.functional.scaled_dot_product_attention(
978                    q, k, v, is_causal=True
979                )[0]
980                return x + t
981
982        example_inputs = (
983            torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
984            torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
985            torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
986            torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
987        )
988        self.check_model(Model(), example_inputs)
989
990    @skipIfNoFBGEMM
991    def test_quantized_linear(self):
992        class Model(torch.nn.Module):
993            def __init__(self, device):
994                super().__init__()
995                self.weight = torch.randn(10, 10, device=device)
996                self.bias = torch.randn(10, device=device)
997
998            def forward(self, x):
999                return torch.ops.quantized.linear_dynamic_fp16_unpacked_weight(
1000                    x, self.weight, self.bias
1001                )
1002
1003        example_inputs = (torch.randn(10, 10, device=self.device),)
1004        with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
1005            self.check_model(Model(self.device), example_inputs)
1006
1007    @skipIfNoFBGEMM
1008    def test_quanatized_int8_linear(self):
1009        class Model(torch.nn.Module):
1010            def __init__(self, device):
1011                super().__init__()
1012                self.weight = torch.randn(10, 10, device=device)
1013                self.bias = torch.randn(10, device=device)
1014                self.input_scale = torch.tensor(0.1)
1015                self.input_zero_point = torch.tensor(0)
1016                self.weight_scale = torch.tensor(0.1)
1017                self.weight_zero_point = torch.tensor(0)
1018                self.output_scale = torch.tensor(0.1)
1019                self.output_zero_point = torch.tensor(0)
1020                self.out_channel = 10
1021
1022            def forward(self, x):
1023                return torch.ops._quantized.wrapped_quantized_linear(
1024                    x,
1025                    self.input_scale,
1026                    self.input_zero_point,
1027                    self.weight,
1028                    self.weight_scale,
1029                    self.weight_zero_point,
1030                    self.bias,
1031                    self.output_scale,
1032                    self.output_zero_point,
1033                    self.out_channel,
1034                )
1035
1036        example_inputs = (torch.randn(10, 10, device=self.device),)
1037        with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
1038            self.check_model(Model(self.device), example_inputs)
1039
1040    def test_zero_grid_with_unbacked_symbols(self):
1041        class Repro(torch.nn.Module):
1042            def __init__(self) -> None:
1043                super().__init__()
1044
1045            def forward(self, x, y):
1046                nz = torch.nonzero(x)
1047                b = torch.ones_like(nz, dtype=torch.float16)
1048                c = torch.zeros_like(nz, dtype=torch.float16)
1049                d = (b + c) @ y
1050                return d.sum()
1051
1052        example_inputs = (
1053            torch.tensor([1, 1, 1], device=self.device),
1054            torch.randn((1, 32), dtype=torch.float16, device=self.device),
1055        )
1056        self.check_model(Repro(), example_inputs)
1057
1058    def test_large_grid(self):
1059        if self.device != "cuda":
1060            raise unittest.SkipTest("requires CUDA")
1061
1062        class Model(torch.nn.Module):
1063            def __init__(self) -> None:
1064                super().__init__()
1065
1066            def forward(self, primals_5):
1067                view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4])
1068                primals_5 = None
1069                permute = torch.ops.aten.permute.default(view, [0, 2, 1])
1070                clone = torch.ops.aten.clone.default(
1071                    permute, memory_format=torch.contiguous_format
1072                )
1073                return clone
1074
1075        # let y_grid = 65537
1076        s0 = 16777472
1077        s1 = 8
1078        example_inputs = (torch.rand(s0, s1, device=self.device),)
1079        self.check_model(Model(), example_inputs)
1080
1081    def test_cond_simple(self):
1082        inputs = (
1083            torch.randn((10, 20), device=self.device),
1084            torch.randn((10, 20), device=self.device),
1085        )
1086        dim0_ab = Dim("s0", min=2, max=1024)
1087        dynamic_shapes = {
1088            "p": {},
1089            "a": {0: dim0_ab, 1: None},
1090            "b": {0: dim0_ab, 1: None},
1091        }
1092        self.check_model_with_multiple_inputs(
1093            CondModels.Simple(),
1094            prepend_predicates(inputs),
1095            dynamic_shapes=dynamic_shapes,
1096        )
1097
1098    def test_cond_nested(self):
1099        inputs = (
1100            torch.randn((10, 20), device=self.device),
1101            torch.randn((10, 20), device=self.device),
1102            torch.randn((10, 20), device=self.device),
1103        )
1104        dim0_abc = Dim("s0", min=2, max=1024)
1105        dynamic_shapes = {
1106            "p0": {},
1107            "p1": {},
1108            "p2": {},
1109            "a": {0: dim0_abc, 1: None},
1110            "b": {0: dim0_abc, 1: None},
1111            "c": {0: dim0_abc, 1: None},
1112        }
1113        self.check_model_with_multiple_inputs(
1114            CondModels.Nested(),
1115            prepend_predicates(inputs, num_predicates=3),
1116            dynamic_shapes=dynamic_shapes,
1117        )
1118
1119    def test_cond_with_parameters(self):
1120        inputs = (torch.randn((10, 20), device=self.device),)
1121        dim0_abc = Dim("s0", min=2, max=1024)
1122        dynamic_shapes = {
1123            "p": {},
1124            "a": {0: dim0_abc, 1: None},
1125        }
1126        self.check_model_with_multiple_inputs(
1127            CondModels.Parameters(self.device),
1128            prepend_predicates(inputs),
1129            dynamic_shapes=dynamic_shapes,
1130        )
1131
1132    def test_cond_with_reinterpret_view_inputs_outputs(self):
1133        inputs = (
1134            torch.randn((10, 20), device=self.device),
1135            torch.randn((10, 20), device=self.device),
1136        )
1137        dim0_ab = Dim("s0", min=3, max=1024)
1138        dynamic_shapes = {
1139            "p": {},
1140            "a": {0: dim0_ab, 1: None},
1141            "b": {0: dim0_ab, 1: None},
1142        }
1143        self.check_model_with_multiple_inputs(
1144            CondModels.ReinterpretView(),
1145            prepend_predicates(inputs),
1146            dynamic_shapes=dynamic_shapes,
1147        )
1148
1149    def test_cond_with_multiple_outputs(self):
1150        inputs = (
1151            torch.randn((10, 20), device=self.device),
1152            torch.randn((10, 20), device=self.device),
1153            torch.randn((30, 40), device=self.device),
1154        )
1155        dim0_ab = Dim("s0", min=2, max=1024)
1156        dim0_c = Dim("s1", min=2, max=1024)
1157        dynamic_shapes = {
1158            "p": {},
1159            "a": {0: dim0_ab, 1: None},
1160            "b": {0: dim0_ab, 1: None},
1161            "c": {0: dim0_c, 1: None},
1162        }
1163        self.check_model_with_multiple_inputs(
1164            CondModels.MultipleOutputs(),
1165            prepend_predicates(inputs),
1166            dynamic_shapes=dynamic_shapes,
1167        )
1168
1169    def test_cond_with_outer_code_before_after(self):
1170        inputs = (
1171            torch.randn((10, 20), device=self.device),
1172            torch.randn((10, 20), device=self.device),
1173        )
1174        dim0_ab = Dim("s0", min=2, max=1024)
1175        dynamic_shapes = {
1176            "p": {},
1177            "a": {0: dim0_ab, 1: None},
1178            "b": {0: dim0_ab, 1: None},
1179        }
1180        self.check_model_with_multiple_inputs(
1181            CondModels.OuterCode(),
1182            prepend_predicates(inputs),
1183            dynamic_shapes=dynamic_shapes,
1184        )
1185
1186    def test_cond_use_buffers_from_outer_scope(self):
1187        inputs = (
1188            torch.randn((10, 20), device=self.device),
1189            torch.randn((10, 20), device=self.device),
1190            torch.randn((10, 20), device=self.device),
1191        )
1192        dim0_abc = Dim("s0", min=2, max=1024)
1193        dynamic_shapes = {
1194            "p": {},
1195            "a": {0: dim0_abc, 1: None},
1196            "b": {0: dim0_abc, 1: None},
1197            "c": {0: dim0_abc, 1: None},
1198        }
1199        self.check_model_with_multiple_inputs(
1200            CondModels.OuterBuffers(),
1201            prepend_predicates(inputs),
1202            dynamic_shapes=dynamic_shapes,
1203        )
1204
1205    @common_utils.parametrize("dynamic", [False, True])
1206    def test_cond_non_tensor_predicates(self, dynamic):
1207        inputs1 = (
1208            torch.randn((10, 20), device=self.device),
1209            torch.randn((15, 20), device=self.device),
1210        )
1211        inputs2 = (
1212            torch.randn((10, 20), device=self.device),
1213            torch.randn((5, 20), device=self.device),
1214        )
1215        inputs = (inputs1,)
1216        dynamic_shapes = None
1217        if dynamic:
1218            inputs = (inputs1, inputs2)
1219            dim0_a = Dim("s0", min=2, max=1024)
1220            dim0_b = Dim("s1", min=2, max=1024)
1221            dynamic_shapes = {
1222                "a": {0: dim0_a, 1: None},
1223                "b": {0: dim0_b, 1: None},
1224            }
1225        self.check_model_with_multiple_inputs(
1226            CondModels.WithNonTensorPredicate(),
1227            inputs,
1228            dynamic_shapes=dynamic_shapes,
1229        )
1230
1231    def test_while_loop_simple(self):
1232        inputs = (
1233            torch.randn((10, 20), device=self.device),
1234            torch.randn((10, 20), device=self.device),
1235        )
1236        dim0_ab = Dim("s0", min=2, max=1024)
1237        dynamic_shapes = {
1238            "ci": {},
1239            "a": {0: dim0_ab, 1: None},
1240            "b": {0: dim0_ab, 1: None},
1241        }
1242        self.check_model_with_multiple_inputs(
1243            WhileLoopModels.Simple(),
1244            prepend_counters(inputs),
1245            dynamic_shapes=dynamic_shapes,
1246        )
1247
1248    def test_while_loop_nested(self):
1249        inputs = (
1250            torch.randn((10, 20), device=self.device),
1251            torch.randn((10, 20), device=self.device),
1252        )
1253        dim0_ab = Dim("s0", min=2, max=1024)
1254        dynamic_shapes = {
1255            "ci": {},
1256            "cj": {},
1257            "a": {0: dim0_ab, 1: None},
1258            "b": {0: dim0_ab, 1: None},
1259        }
1260        self.check_model_with_multiple_inputs(
1261            WhileLoopModels.Nested(),
1262            prepend_counters(inputs, num_counters=2),
1263            dynamic_shapes=dynamic_shapes,
1264        )
1265
1266    def test_while_loop_with_outer_code(self):
1267        inputs = (
1268            torch.randn((10, 20), device=self.device),
1269            torch.randn((10, 20), device=self.device),
1270        )
1271        dim0_ab = Dim("s0", min=2, max=1024)
1272        dynamic_shapes = {
1273            "c": {},
1274            "a": {0: dim0_ab, 1: None},
1275            "b": {0: dim0_ab, 1: None},
1276        }
1277        self.check_model_with_multiple_inputs(
1278            WhileLoopModels.OuterCode(),
1279            prepend_counters(inputs),
1280            dynamic_shapes=dynamic_shapes,
1281        )
1282
1283    def test_while_loop_with_parameters(self):
1284        inputs = (torch.randn((10, 20), device=self.device),)
1285        dim0_a = Dim("s0", min=2, max=1024)
1286        dynamic_shapes = {
1287            "c": {},
1288            "a": {0: dim0_a, 1: None},
1289        }
1290        self.check_model_with_multiple_inputs(
1291            WhileLoopModels.Parameters(self.device),
1292            prepend_counters(inputs),
1293            dynamic_shapes=dynamic_shapes,
1294        )
1295
1296    def test_while_loop_with_outer_buffers(self):
1297        inputs = (
1298            torch.randn((10, 20), device=self.device),
1299            torch.randn((10, 20), device=self.device),
1300        )
1301        # dynamic shapes don't work now due to
1302        # https://github.com/pytorch/pytorch/issues/123596
1303        # dim0_ab = Dim("s0", min=2, max=1024)
1304        # dynamic_shapes = {
1305        #     "c": {},
1306        #     "a": {0: dim0_ab, 1: None},
1307        #     "b": {0: dim0_ab, 1: None},
1308        # }
1309        dynamic_shapes = None
1310        self.check_model_with_multiple_inputs(
1311            WhileLoopModels.OuterBuffers(),
1312            prepend_counters(inputs),
1313            dynamic_shapes=dynamic_shapes,
1314        )
1315
1316    @config.patch({"is_predispatch": True})
1317    def test_constant(self):
1318        class M(torch.nn.Module):
1319            def __init__(self, device):
1320                super().__init__()
1321                self.device = device
1322
1323            def forward(self, x):
1324                t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float)
1325                t = torch.sqrt(t * 3)
1326                return x * t
1327
1328        self.check_model(M(self.device), (torch.randn(5, 5, device=self.device),))
1329
1330    def test_zero_grid_with_backed_symbols(self):
1331        class Repro(torch.nn.Module):
1332            def __init__(self) -> None:
1333                super().__init__()
1334
1335            def forward(self, x, b):
1336                return x + b
1337
1338        example_inputs = (
1339            x := torch.randn((3, 2), device=self.device),
1340            torch.randn((1, 2), device=self.device),
1341        )
1342        torch._dynamo.mark_dynamic(x, index=0)  # Create dynamic symbol
1343
1344        # Compile & run model where dynamic dim size > 0.
1345        so_path: str = AOTIRunnerUtil.compile(
1346            Repro(),
1347            example_inputs,
1348        )
1349        aot_inductor_module = AOTIRunnerUtil.load("cuda", so_path)
1350        aot_inductor_module(*example_inputs)
1351
1352        # Re-run where dynamic dim size is 0.
1353        example_inputs = (
1354            torch.randn((0, 2), device=self.device),
1355            torch.randn((1, 2), device=self.device),
1356        )
1357        actual = aot_inductor_module(*example_inputs)
1358        expected = Repro()(*example_inputs)
1359        torch.testing.assert_close(actual, expected)
1360
1361    def test_repeat_interleave(self):
1362        class Repro(torch.nn.Module):
1363            def __init__(self) -> None:
1364                super().__init__()
1365
1366            def forward(self, x):
1367                return torch.ops.aten.repeat_interleave.Tensor(x, output_size=12)
1368
1369        example_inputs = (torch.ones((1,), dtype=torch.int32, device=self.device) * 12,)
1370        self.check_model(Repro(), example_inputs)
1371
1372    def test_dynamic_cat(self):
1373        class Model(torch.nn.Module):
1374            def __init__(self) -> None:
1375                super().__init__()
1376
1377            def forward(self, a, b):
1378                return torch.cat([a, b], dim=0)
1379
1380        a = torch.randn(2, 4, device=self.device)
1381        b = torch.randn(3, 4, device=self.device)
1382        dim0_a = Dim("dim0_a", min=1, max=10)
1383        dim0_b = Dim("dim0_b", min=1, max=20)
1384        dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}}
1385        example_inputs = (a, b)
1386        self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
1387
1388    def test_buffer_mutation_1(self):
1389        class Model(torch.nn.Module):
1390            def __init__(self, device):
1391                super().__init__()
1392                self.foo = torch.nn.Buffer(torch.randn(4, 4, device=device))
1393
1394            def forward(self, x):
1395                self.foo.add_(1)
1396                return self.foo + x
1397
1398        example_inputs = (torch.rand(4, 4, device=self.device),)
1399        self.check_model(Model(self.device), example_inputs)
1400
1401    def test_non_tensor_input(self):
1402        class Model(torch.nn.Module):
1403            def forward(self, a, b, alpha=1.0):
1404                return torch.add(a, b, alpha=alpha)
1405
1406        a = torch.randn(10, device=self.device)
1407        b = torch.randn(10, device=self.device)
1408
1409        for simdlen in [0, None]:
1410            with torch._inductor.config.patch({"cpp.simdlen": simdlen}):
1411                so_path = torch._export.aot_compile(
1412                    torch.ops.aten.add,
1413                    args=(a, b),
1414                    kwargs={"alpha": 2.0},
1415                )
1416                kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path)
1417                res = kernel_runner.run([a, b])
1418                self.assertTrue(isinstance(res, list))
1419                self.assertTrue(len(res) == 1)
1420                self.assertEqual(Model()(a, b, alpha=2.0), res[0])
1421
1422    def test_buffer_mutation_2(self):
1423        class Model(torch.nn.Module):
1424            def __init__(self, device):
1425                super().__init__()
1426                self.foo = torch.nn.Buffer(torch.arange(10, device=device))
1427                self.bar = torch.nn.Buffer(torch.arange(10, device=device))
1428
1429            def forward(self, x):
1430                self.bar.mul_(2)
1431                self.foo[5] = self.bar[0]
1432                return x + self.bar, x * self.foo
1433
1434        example_inputs = (torch.randn(10, device=self.device),)
1435        self.check_model(Model(self.device), example_inputs)
1436
1437    def test_buffer_mutation_3(self):
1438        class KVCache(torch.nn.Module):
1439            def __init__(
1440                self,
1441                max_batch_size,
1442                max_seq_length,
1443                n_heads,
1444                head_dim,
1445                dtype=torch.float,
1446            ):
1447                super().__init__()
1448                cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
1449                self.k_cache = torch.nn.Buffer(torch.zeros(cache_shape, dtype=dtype))
1450                self.v_cache = torch.nn.Buffer(torch.zeros(cache_shape, dtype=dtype))
1451
1452            def update(self, input_pos, k_val, v_val):
1453                # input_pos: [S], k_val: [B, H, S, D]
1454                k_out = self.k_cache
1455                v_out = self.v_cache
1456                k_out[:, :, input_pos] = k_val
1457                v_out[:, :, input_pos] = v_val
1458
1459                return k_out, v_out
1460
1461        class Model(torch.nn.Module):
1462            def __init__(self, device):
1463                super().__init__()
1464                self.kv_cache = KVCache(1, 256, 6, 48)
1465
1466            def forward(self, inp_pos, k, v):
1467                self.kv_cache.update(inp_pos, k, v)
1468                return self.kv_cache.k_cache + 1, self.kv_cache.v_cache / 2
1469
1470        example_inputs = (
1471            torch.tensor([0], device=self.device),
1472            torch.randn(1, 6, 1, 48, device=self.device),
1473            torch.randn(1, 6, 1, 48, device=self.device),
1474        )
1475        model = Model(self.device)
1476        self.check_model(model, example_inputs)
1477        self.code_check_count(model, example_inputs, "empty_strided", 2)
1478
1479    def test_buffer_mutation_4(self):
1480        if self.device != "cuda":
1481            raise unittest.SkipTest("requires CUDA")
1482
1483        class Model(torch.nn.Module):
1484            def __init__(self) -> None:
1485                super().__init__()
1486                self.register_buffer(
1487                    "_tensor_constant0",
1488                    torch.randint(1, size=[38], dtype=torch.int64, device="cpu"),
1489                )
1490
1491            def forward(self, x):
1492                return x + self._tensor_constant0.to(torch.device(type="cuda", index=0))
1493
1494        example_inputs = (
1495            torch.randint(1, size=[38], dtype=torch.int64, device="cuda"),
1496        )
1497        torch._export.aot_compile(Model(), example_inputs)
1498
1499    @requires_multigpu()
1500    def test_replicate_on_devices(self):
1501        if self.device != "cuda":
1502            raise unittest.SkipTest("requires CUDA")
1503
1504        class Model(torch.nn.Module):
1505            def __init__(self, w1, w2):
1506                super().__init__()
1507                self.w1 = w1
1508                self.w2 = w2
1509
1510            def forward(self, x, y):
1511                a = x * self.w1
1512                b = y * self.w2
1513                return a + b
1514
1515        w1 = torch.randn(10, 10)
1516        w2 = torch.randn(10, 10)
1517        inputs = (torch.randn(10, 10), torch.randn(10, 10))
1518        result_cpu = Model(w1, w2)(*inputs)
1519
1520        # Compile model with AOTInductor
1521        with torch.cuda.device(0), config.patch("abi_compatible", self.abi_compatible):
1522            so_path = AOTIRunnerUtil.compile(
1523                model=Model(w1.cuda(0), w2.cuda(0)),
1524                example_inputs=tuple(t.cuda(0) for t in inputs),
1525            )
1526
1527        # Run model on cuda:N
1528        for i in range(torch.cuda.device_count()):
1529            with torch.cuda.device(i):
1530                example_inputs = tuple(t.cuda(i) for t in inputs)
1531                optimized = AOTIRunnerUtil.load("cuda", so_path)
1532                result_cuda = optimized(*example_inputs)
1533            self.assertTrue(same(result_cpu, result_cuda.cpu()))
1534
1535    def test_pytree_inputs(self):
1536        class M(torch.nn.Module):
1537            def __init__(self) -> None:
1538                super().__init__()
1539
1540            def forward(self, x: Dict[str, torch.Tensor]):
1541                device = next(iter(x.values())).device
1542                add_ = torch.zeros(5, device=device)
1543                mul_ = torch.ones(5, device=device)
1544                for v in x.values():
1545                    add_ += v
1546                    mul_ *= v
1547
1548                return [add_, mul_]
1549
1550        self.check_model(
1551            M(),
1552            (
1553                {
1554                    "x": torch.ones(5, device=self.device),
1555                    "y": torch.ones(5, device=self.device),
1556                },
1557            ),
1558        )
1559
1560    @requires_multigpu()
1561    def test_non_default_cuda_device(self):
1562        if self.device != "cuda":
1563            raise unittest.SkipTest("requires CUDA")
1564
1565        class Model(torch.nn.Module):
1566            def __init__(self, weight):
1567                super().__init__()
1568                self.weight = weight
1569
1570            def forward(self, x, y):
1571                return x + torch.nn.functional.linear(y, self.weight)
1572
1573        weight = torch.randn(10, 10)
1574        inputs = (torch.randn(10, 10), torch.randn(10, 10))
1575        result_cpu = Model(weight)(*inputs)
1576
1577        with torch.cuda.device(0), torch.no_grad(), config.patch(
1578            "abi_compatible", self.abi_compatible
1579        ):
1580            result_cuda_0 = AOTIRunnerUtil.run(
1581                "cuda", Model(weight.cuda(0)), tuple(t.cuda(0) for t in inputs)
1582            )
1583
1584        with torch.cuda.device(1), torch.no_grad(), config.patch(
1585            "abi_compatible", self.abi_compatible
1586        ):
1587            result_cuda_1 = AOTIRunnerUtil.run(
1588                "cuda", Model(weight.cuda(1)), tuple(t.cuda(1) for t in inputs)
1589            )
1590
1591        self.assertTrue(same(result_cpu, result_cuda_0.cpu()))
1592        self.assertTrue(same(result_cpu, result_cuda_1.cpu()))
1593
1594    def test_reuse_kernel(self):
1595        class Model(torch.nn.Module):
1596            def __init__(self) -> None:
1597                super().__init__()
1598
1599            def forward(self, x, y):
1600                a = torch.sin(x)
1601                b = torch.mm(a, y)
1602                c = torch.sin(b)
1603                d = torch.mm(b, c)
1604                return d
1605
1606        example_inputs = (
1607            torch.randn(87, 87, device=self.device),
1608            torch.randn(87, 87, device=self.device),
1609        )
1610        model = Model()
1611        self.check_model(
1612            model, example_inputs, atol=1e-4, rtol=1e-4
1613        )  # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
1614
1615        if self.device == "cuda":
1616            self.code_check_count(
1617                model, example_inputs, "triton_poi_fused_sin_0 = loadKernel(", 1
1618            )
1619
1620    def test_reuse_kernel_dynamic(self):
1621        class Model(torch.nn.Module):
1622            def __init__(self, device):
1623                super().__init__()
1624                self.cst = torch.randn(48, device=device, dtype=torch.float)
1625                self.weights = torch.randn(6, 48, 48, device=device, dtype=torch.float)
1626                self.cst_1 = torch.randn(48, device=device, dtype=torch.float)
1627                self.weights_1 = torch.randn(
1628                    6, 48, 48, device=device, dtype=torch.float
1629                )
1630
1631            def forward(self, x, y, z):
1632                dim0 = x.size(1)
1633                add_0 = z + z
1634                expand_2 = add_0.expand(-1, -1, 48)
1635                # [s0, 6, 48]
1636                mul_3 = add_0 * expand_2
1637                # [6, s0, 48]
1638                permute_4 = torch.permute(mul_3, (1, 0, 2))
1639                # [6, s0, 48]
1640                bmm_5 = torch.bmm(permute_4, self.weights)
1641                add_6 = bmm_5 + self.cst
1642                reshape_7 = torch.reshape(add_6, [6, dim0 * 6, 8])
1643                # [6*s0, 6, 8]
1644                permute_8 = torch.permute(reshape_7, (1, 0, 2))
1645                mul_9 = permute_8 * 0.123
1646                reshape_10 = torch.reshape(y, [8, dim0 * 6, 4])
1647                # [6*s0, 8, 4]
1648                permute_11 = torch.permute(reshape_10, (1, 0, 2))
1649                bmm_12 = torch.bmm(mul_9, permute_11)
1650
1651                add_0_1 = z + z
1652                expand_2_1 = add_0_1.expand(-1, -1, 48)
1653                # [s0, 6, 48]
1654                mul_3_1 = add_0_1 * expand_2_1
1655                # [6, s0, 48]
1656                permute_4_1 = torch.permute(mul_3_1, (1, 0, 2))
1657                # [6, s0, 48]
1658                bmm_5_1 = torch.bmm(permute_4_1, self.weights_1)
1659                add_6_1 = bmm_5_1 + self.cst_1
1660                reshape_7_1 = torch.reshape(add_6_1, [6, dim0 * 6, 8])
1661                # [6*s0, 6, 8]
1662                permute_8_1 = torch.permute(reshape_7_1, (1, 0, 2))
1663                mul_9_1 = permute_8_1 * 0.123
1664                reshape_10_1 = torch.reshape(y, [8, dim0 * 6, 4])
1665                # [6*s0, 8, 4]
1666                permute_11_1 = torch.permute(reshape_10_1, (1, 0, 2))
1667                bmm_12_1 = torch.bmm(mul_9_1, permute_11_1)
1668                return bmm_12 + bmm_12_1
1669
1670        x = torch.randn(6, 2, 48, device=self.device, dtype=torch.float)
1671        y = torch.randn(48, 2, 4, device=self.device, dtype=torch.float)
1672        z = torch.randn(2, 6, 1, device=self.device, dtype=torch.float)
1673        dim0 = Dim("dim0", min=1, max=2048)
1674        dynamic_shapes = {
1675            "x": {1: dim0},
1676            "y": {1: dim0},
1677            "z": {0: dim0},
1678        }
1679
1680        example_inputs = (x, y, z)
1681        m = Model(self.device).to(dtype=torch.float)
1682        self.check_model(m, example_inputs, dynamic_shapes=dynamic_shapes)
1683
1684    def test_fake_tensor_device_validation(self):
1685        if self.device != "cuda":
1686            raise unittest.SkipTest("requires CUDA")
1687
1688        class Model(torch.nn.Module):
1689            def __init__(self) -> None:
1690                super().__init__()
1691
1692            def forward(self, x, y):
1693                return x + y
1694
1695        example_inputs = (torch.randn(10, 10), torch.randn(10, 10))
1696
1697        # Export on CPU
1698        exported_program = export(Model(), example_inputs)
1699
1700        # Compile exported model on CUDA
1701        gm = exported_program.graph_module.to(self.device)
1702        with self.assertRaisesRegex(ValueError, "Device mismatch between fake input"):
1703            torch._inductor.aot_compile(
1704                gm, tuple(i.to(self.device) for i in example_inputs)
1705            )
1706
1707    def test_fx_gm_return_tuple_validation(self):
1708        from torch.fx.experimental.proxy_tensor import make_fx
1709
1710        class Model(torch.nn.Module):
1711            def __init__(self) -> None:
1712                super().__init__()
1713
1714            def forward(self, x, y):
1715                return x + y
1716
1717        example_inputs = (torch.randn(10, 10), torch.randn(10, 10))
1718
1719        gm = make_fx(Model(), tracing_mode="symbolic")(*example_inputs)
1720        with self.assertRaisesRegex(
1721            AssertionError,
1722            r"Graph output must be a tuple\(\). This is so that we can avoid "
1723            "pytree processing of the outputs.",
1724        ):
1725            torch._inductor.aot_compile(gm, example_inputs)
1726
1727    @unittest.mock.patch("torch._inductor.graph.supported_dtype_of_cpp_wrapper")
1728    def test_unsupported_input_dtype(self, supported_dtype_of_cpp_wrapper_mock):
1729        supported_dtype_of_cpp_wrapper_mock.return_value = False
1730
1731        class Model(torch.nn.Module):
1732            def __init__(self) -> None:
1733                super().__init__()
1734
1735            def forward(self, x, y):
1736                return x + y
1737
1738        example_inputs = (
1739            torch.randn(10, 10).to(self.device),
1740            torch.randn(10, 10).to(self.device),
1741        )
1742        with self.assertRaisesRegex(
1743            CppWrapperCodeGenError, "Unsupported input dtype torch.float32"
1744        ):
1745            torch._export.aot_compile(Model(), example_inputs)
1746
1747        supported_dtype_of_cpp_wrapper_mock.assert_called_once_with(
1748            torch.float32, self.device == "cuda"
1749        )
1750
1751    def test_consecutive_compiles(self):
1752        """Test that compilation behaves correctly with cache hits"""
1753
1754        class TestModule(torch.nn.Module):
1755            def __init__(self) -> None:
1756                super().__init__()
1757
1758            def forward(self, x):
1759                return x + 1
1760
1761        mod = TestModule()
1762        inp = torch.rand(1)
1763        mod(inp)
1764        mod2 = torch.fx.symbolic_trace(mod, concrete_args=[inp])
1765        so = torch._export.aot_compile(mod2, (inp,))
1766        assert so is not None
1767        # compile the 2nd time with cache hit
1768        so = torch._export.aot_compile(mod2, (inp,))
1769        assert so is not None
1770
1771    def test_normal_functional(self):
1772        class Model(torch.nn.Module):
1773            def __init__(self) -> None:
1774                super().__init__()
1775
1776            def forward(self, x):
1777                return torch.ops.aten.normal_functional.default(x)
1778
1779        self.check_model(Model(), (torch.empty(4, 1, 4, 4),))
1780
1781    def test_empty_graph(self):
1782        class Model(torch.nn.Module):
1783            def __init__(self) -> None:
1784                super().__init__()
1785
1786            def forward(self, x):
1787                return x
1788
1789        example_inputs = (torch.randn(8, 4, 4, device=self.device),)
1790        self.check_model(Model(), example_inputs)
1791
1792    @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
1793    def test_dup_unbacked_sym_decl(self):
1794        class Model(torch.nn.Module):
1795            def __init__(self) -> None:
1796                super().__init__()
1797
1798            def forward(self, x):
1799                abs_1 = torch.ops.aten.abs.default(x)
1800                lt = torch.ops.aten.lt.Scalar(abs_1, 0.001)
1801                eq = torch.ops.aten.eq.Scalar(lt, 0)
1802                index_1 = torch.ops.aten.index.Tensor(x, [eq])
1803                sin = torch.ops.aten.sin.default(index_1)
1804                index_2 = torch.ops.aten.index.Tensor(x, [eq])
1805                div_3 = torch.ops.aten.div.Tensor(sin, index_2)
1806                return div_3
1807
1808        example_inputs = (torch.randn(4, 4, 4, 4).to(self.device),)
1809        self.check_model(Model(), example_inputs)
1810
1811    # This exercises _eliminate_unbacked path in ShapeEnv
1812    @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
1813    def test_dup_unbacked_sym_decl_with_refinement(self):
1814        class Model(torch.nn.Module):
1815            def __init__(self) -> None:
1816                super().__init__()
1817
1818            def forward(self, x):
1819                abs_1 = torch.ops.aten.abs.default(x)
1820                lt = torch.ops.aten.lt.Scalar(abs_1, 0.001)
1821                eq = torch.ops.aten.eq.Scalar(lt, 0)
1822                index_1 = torch.ops.aten.index.Tensor(x, [eq])
1823                torch._check(index_1.size(0) == 4**4)
1824                sin = torch.ops.aten.sin.default(index_1)
1825                index_2 = torch.ops.aten.index.Tensor(x, [eq])
1826                div_3 = torch.ops.aten.div.Tensor(sin, index_2)
1827                return div_3
1828
1829        example_inputs = (torch.ones(4, 4, 4, 4).to(self.device),)
1830        self.check_model(Model(), example_inputs)
1831
1832    def test_run_with_grad_enabled(self):
1833        class Model(torch.nn.Module):
1834            def forward(self, x, weight, bias):
1835                return torch.ops.aten.addmm(bias, weight, x)
1836
1837        m = Model().to(device=self.device)
1838        x = torch.rand(8, 8, device=self.device, requires_grad=True)
1839        weight = torch.rand(8, 8, device=self.device, requires_grad=True)
1840        bias = torch.rand(8, device=self.device, requires_grad=True)
1841        example_inputs = (x, weight, bias)
1842
1843        expected = m(*example_inputs)
1844        expected = pytree.tree_leaves(expected)
1845
1846        # compiler under no_grad
1847        with torch.no_grad():
1848            so_path = AOTIRunnerUtil.compile(m, example_inputs)
1849
1850        # run under grad enabled
1851        self.assertTrue(torch.is_grad_enabled())
1852
1853        optimized = AOTIRunnerUtil.load(self.device, so_path)
1854        actual = optimized(*example_inputs)
1855        actual = pytree.tree_leaves(actual)
1856
1857        self.assertTrue(same(actual, expected))
1858
1859    def test_return_constant(self):
1860        class Model(torch.nn.Module):
1861            def __init__(self, device):
1862                super().__init__()
1863                self.cst = torch.randn(5, 5, device=device)
1864
1865            def forward(self, x):
1866                a = self.cst.clone()
1867                return (x, a)
1868
1869        x = torch.randn(5, device=self.device)
1870        self.check_model(Model(self.device), (x,))
1871
1872    def test_return_view_constant(self):
1873        class Model(torch.nn.Module):
1874            def __init__(self, device):
1875                super().__init__()
1876                self.cst = torch.randn(5, 5, device=device)
1877
1878            def forward(self, x):
1879                a = torch.transpose(self.cst, 0, 1)
1880                return (x, a)
1881
1882        x = torch.randn(5, device=self.device)
1883        self.check_model(Model(self.device), (x,))
1884
1885    def test_with_profiler(self):
1886        class Model(torch.nn.Module):
1887            def __init__(self) -> None:
1888                super().__init__()
1889                self.linear = torch.nn.Linear(10, 10)
1890
1891            def forward(self, x, y):
1892                return x + self.linear(y)
1893
1894        example_inputs = (
1895            torch.randn(10, 10, device=self.device),
1896            torch.randn(10, 10, device=self.device),
1897        )
1898        with config.patch({"profile_bandwidth": "1", "profile_bandwidth_regex": ""}):
1899            self.check_model(Model(), example_inputs)
1900
1901    def test_with_no_triton_profiler(self):
1902        class Model(torch.nn.Module):
1903            def __init__(self) -> None:
1904                super().__init__()
1905
1906            def forward(self, x):
1907                return torch.permute(x, (1, 0))
1908
1909        example_inputs = (torch.randn(10, 10, device=self.device),)
1910        with config.patch({"profile_bandwidth": "1", "profile_bandwidth_regex": ""}):
1911            self.check_model(Model(), example_inputs)
1912
1913    def test_repeat_output(self):
1914        class Model(torch.nn.Module):
1915            def __init__(self) -> None:
1916                super().__init__()
1917
1918            def forward(self, x):
1919                y = torch.sin(x)
1920                return y, y
1921
1922        example_inputs = (torch.randn(3, 10, device=self.device),)
1923        self.check_model(Model(), example_inputs)
1924
1925    def test_view_outputs(self):
1926        class Model(torch.nn.Module):
1927            def forward(self, x):
1928                y = torch.sin(x)
1929                y_same_size = y.view(*y.shape)
1930                y_diff_size = y.view(1, *y.shape)
1931                return y, y_same_size, y_diff_size
1932
1933        example_inputs = (torch.randn(3, 10, device=self.device),)
1934        self.check_model(Model(), example_inputs)
1935
1936    @skip_if_no_torchvision
1937    def test_missing_cubin(self):
1938        from torchvision.models.resnet import Bottleneck, ResNet
1939
1940        class Model(ResNet):
1941            def __init__(self) -> None:
1942                super().__init__(
1943                    block=Bottleneck,
1944                    layers=[3, 4, 6, 3],
1945                    replace_stride_with_dilation=[False, False, True],
1946                    norm_layer=None,
1947                )
1948
1949            def forward(self, x):
1950                x = self.conv1(x)
1951                x = self.bn1(x)
1952                x = self.relu(x)
1953                f1 = x
1954                x = self.maxpool(x)
1955                x = self.layer1(x)
1956                f2 = x
1957                x = self.layer2(x)
1958                f3 = x
1959                x = self.layer3(x)
1960                x = self.layer4(x)
1961                f4 = x
1962                return [f1, f2, f3, f4]
1963
1964        # Call eval() here so that batch_norm won't update the running stats
1965        # Use float64 to avoid numeric difference failure
1966        model = Model().to(device=self.device, dtype=torch.float64).eval()
1967        example_inputs = (
1968            torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64),
1969        )
1970        self.check_model(model, example_inputs)
1971
1972    @common_utils.parametrize("grid_type", [1, 2, 3])
1973    @common_utils.parametrize("num_dims", [1, 2])
1974    @common_utils.parametrize("dynamic", [False, True])
1975    @common_utils.parametrize("autotune", [False, True])
1976    def test_triton_kernel(self, grid_type, num_dims, dynamic, autotune):
1977        if self.device != "cuda":
1978            raise unittest.SkipTest("requires CUDA")
1979
1980        class Model(torch.nn.Module):
1981            def __init__(self) -> None:
1982                super().__init__()
1983
1984            def forward(self, x, y):
1985                output = torch.zeros_like(x)
1986                if autotune and num_dims == 2:
1987                    x_elements = output.size()[0]
1988                    y_elements = output.size()[1]
1989                else:
1990                    n_elements = output.numel()
1991
1992                # Select grid
1993                if autotune and num_dims == 2:
1994                    if grid_type == 1:
1995                        grid = (x_elements, y_elements)
1996                    elif grid_type == 2:
1997                        grid = lambda meta: (  # noqa: E731
1998                            triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
1999                            triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
2000                        )
2001                    else:
2002
2003                        def grid_fn(meta):
2004                            return (
2005                                triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
2006                                triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
2007                            )
2008
2009                        grid = grid_fn
2010                else:
2011                    if grid_type == 1:
2012                        grid = (n_elements,)
2013                    elif grid_type == 2:
2014                        grid = lambda meta: (  # noqa: E731
2015                            triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
2016                        )
2017                    else:
2018
2019                        def grid_fn(meta):
2020                            return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
2021
2022                        grid = grid_fn
2023
2024                # Select kernel
2025                if autotune:
2026                    if num_dims == 1:
2027                        add_kernel_autotuned[grid](x, y, output, n_elements)
2028                    else:
2029                        add_kernel_2d_autotuned[grid](
2030                            x, y, output, x_elements, y_elements
2031                        )
2032                else:
2033                    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
2034                return output
2035
2036        dims = [10] * num_dims
2037        x = torch.randn(*dims, device=self.device)
2038        y = torch.randn(*dims, device=self.device)
2039        dynamic_shapes = []
2040        if dynamic:
2041            dim0_x = Dim("dim0_x", min=1, max=10)
2042            dim0_y = Dim("dim0_y", min=1, max=10)
2043            dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
2044        self.check_model(Model(), (x, y), dynamic_shapes=dynamic_shapes)
2045
2046    def test_triton_kernel_dynamic_shape_with_div(self):
2047        if self.device != "cuda":
2048            raise unittest.SkipTest("requires CUDA")
2049
2050        @triton.jit
2051        def pass_kernel(x, num):
2052            pass
2053
2054        class Model(torch.nn.Module):
2055            def __init__(self) -> None:
2056                super().__init__()
2057
2058            def forward(self, x):
2059                num = x.numel() // 4
2060
2061                grid = lambda meta: (triton.cdiv(num, 16),)  # noqa: E731
2062                pass_kernel[grid](x, num)
2063                return x
2064
2065        x = torch.randn(10, device=self.device)
2066        dim0_x = Dim("dim0_x", min=1, max=10)
2067        dynamic_shapes = {"x": {0: dim0_x}}
2068        self.check_model(Model(), (x,), dynamic_shapes=dynamic_shapes)
2069
2070    def test_triton_kernel_reinterpret_view(self):
2071        if self.device != "cuda":
2072            raise unittest.SkipTest("requires CUDA")
2073
2074        @triton.jit
2075        def pass_kernel(x, y):
2076            pass
2077
2078        class Model(torch.nn.Module):
2079            def __init__(self) -> None:
2080                super().__init__()
2081
2082            def forward(self, x):
2083                out = torch.zeros_like(x[:, 4:])
2084                # the slicing below creates two ReinterpretView
2085                # instances: with offset=3 and offset=4
2086                add_kernel[(10,)](
2087                    in_ptr0=x[:, 3:-1],
2088                    in_ptr1=x[:, 4:],
2089                    out_ptr=out,
2090                    n_elements=160,
2091                    BLOCK_SIZE=16,
2092                )
2093                return out
2094
2095        example_inputs = (torch.randn(10, 20, device=self.device),)
2096        self.check_model(Model(), example_inputs)
2097
2098    def test_triton_kernel_sympy_expr_arg(self):
2099        if self.device != "cuda":
2100            raise unittest.SkipTest("requires CUDA")
2101
2102        class Model(torch.nn.Module):
2103            def forward(self, x, e):
2104                sympy_expr = max(1, e.item())
2105                out = torch.zeros_like(x)
2106                add_kernel[(1,)](
2107                    in_ptr0=x,
2108                    in_ptr1=x,
2109                    out_ptr=out,
2110                    n_elements=sympy_expr,
2111                    BLOCK_SIZE=1,
2112                )
2113                return out
2114
2115        NUMEL = 64
2116        inputs = (
2117            torch.randn(NUMEL, device=self.device),
2118            torch.tensor(NUMEL, device=self.device),
2119        )
2120        self.check_model(Model(), inputs)
2121
2122    def test_triton_kernel_sympy_fn_like_arg(self):
2123        # This test should hit sympy.expand("sqrt") which crashes with
2124        # AttributeError: 'function' object has no attribute 'expand'.
2125        if self.device != "cuda":
2126            raise unittest.SkipTest("requires CUDA")
2127
2128        class Model(torch.nn.Module):
2129            def forward(self, x):
2130                out = torch.zeros_like(x)
2131                add_kernel_with_optional_param[1,](
2132                    in_ptr0=x,
2133                    in_ptr1=x,
2134                    out_ptr=out,
2135                    n_elements=x.numel(),
2136                    BLOCK_SIZE=1,
2137                    ARGS_PASSED="sqrt",  # sqrt is a valid sympy fn
2138                )
2139                return out
2140
2141        inputs = (torch.randn(4, device=self.device),)
2142        self.check_model(Model(), inputs)
2143
2144    def test_triton_kernel_with_none_input(self):
2145        if self.device != "cuda":
2146            raise unittest.SkipTest("requires CUDA")
2147
2148        class Model(torch.nn.Module):
2149            def __init__(self) -> None:
2150                super().__init__()
2151
2152            def forward(self, x, y):
2153                n_elements = x.size()[0]
2154                BLOCK_SIZE = 1024
2155
2156                output_wo_y = torch.empty_like(x)
2157                output_with_y = torch.empty_like(x)
2158
2159                wo_kernel = add_kernel_with_optional_param[(1,)](
2160                    x,
2161                    None,
2162                    output_wo_y,
2163                    n_elements,
2164                    ARGS_PASSED="one",
2165                    BLOCK_SIZE=BLOCK_SIZE,
2166                )
2167                with_kernel = add_kernel_with_optional_param[(1,)](
2168                    x,
2169                    y,
2170                    output_with_y,
2171                    n_elements,
2172                    ARGS_PASSED="two",
2173                    BLOCK_SIZE=BLOCK_SIZE,
2174                )
2175
2176                return 2.71 * output_wo_y + 3.14 * output_with_y
2177
2178        example_inputs = (
2179            torch.randn(1023, device=self.device),
2180            torch.randn(1023, device=self.device),
2181        )
2182
2183        self.check_model(Model(), example_inputs)
2184
2185    def test_triton_kernel_equal_to_1_arg(self):
2186        if self.device != "cuda":
2187            raise unittest.SkipTest("requires CUDA")
2188
2189        class Model(torch.nn.Module):
2190            def forward(self, x, y):
2191                out = torch.empty_like(x)
2192                n_elements = x.numel()
2193                add_kernel[(n_elements,)](x, y, out, n_elements, BLOCK_SIZE=16)
2194                return out
2195
2196        example_inputs = (
2197            torch.randn(1, device=self.device),
2198            torch.randn(1, device=self.device),
2199        )
2200
2201        self.check_model(Model(), example_inputs)
2202
2203    @common_utils.parametrize("dynamic", [False, True])
2204    def test_triton_kernel_equal_to_1_float_arg(self, dynamic):
2205        if self.device != "cuda":
2206            raise unittest.SkipTest("requires CUDA")
2207
2208        class Model(torch.nn.Module):
2209            def forward(self, x, y):
2210                out = torch.empty_like(x)
2211                n_elements = x.numel()
2212                scaling_factor = (n_elements**0) / 1.0
2213                add_kernel_with_scaling[(n_elements,)](
2214                    x,
2215                    y,
2216                    out,
2217                    n_elements,
2218                    scaling_factor,
2219                    BLOCK_SIZE=16,
2220                )
2221                return out
2222
2223        dynamic_shapes = None
2224        if dynamic:
2225            dim0_xy = Dim("s0", min=2, max=1024)
2226            dynamic_shapes = {
2227                "x": {0: dim0_xy, 1: None},
2228                "y": {0: dim0_xy, 1: None},
2229            }
2230        example_inputs = (
2231            torch.randn(2, device=self.device),
2232            torch.randn(2, device=self.device),
2233        )
2234        self.check_model(
2235            Model(),
2236            example_inputs,
2237            dynamic_shapes=dynamic_shapes,
2238        )
2239
2240    def test_triton_kernel_weird_param_order(self):
2241        if self.device != "cuda":
2242            raise unittest.SkipTest("requires CUDA")
2243
2244        class Model(torch.nn.Module):
2245            def __init__(self) -> None:
2246                super().__init__()
2247
2248            def forward(self, x):
2249                out = torch.empty_like(x)
2250                add_kernel_autotuned_weird_param_order[16,](
2251                    in_ptr0=x,
2252                    in_ptr1=x,
2253                    n_elements=x.numel(),
2254                    out_ptr=out,
2255                )
2256                return out
2257
2258        x = torch.randn(16, 16, device=self.device)
2259        self.check_model(Model(), (x,))
2260
2261    def test_shifted_constraint_ranges(self):
2262        class Model(torch.nn.Module):
2263            def __init__(self) -> None:
2264                super().__init__()
2265
2266            def forward(
2267                self,
2268                x: torch.Tensor,
2269                y: torch.Tensor,
2270            ):
2271                torch._check(y.size(0) == x.size(0) + 1)
2272                return x.sum(0) + y.sum(0)
2273
2274        a = torch.randn((4, 5), device=self.device)
2275        b = torch.randn((5, 5), device=self.device)
2276        dim0_x = Dim("dim0_x", min=2, max=1024)
2277        dim0_y = dim0_x + 1
2278        dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
2279        self.check_model(
2280            Model(),
2281            (a, b),
2282            dynamic_shapes=dynamic_shapes,
2283        )
2284
2285    def test_scatter_fallback(self):
2286        class Model(torch.nn.Module):
2287            def __init__(self) -> None:
2288                super().__init__()
2289
2290            def forward(
2291                self,
2292                inp: torch.Tensor,
2293                index: torch.Tensor,
2294                src: torch.Tensor,
2295            ):
2296                return torch.scatter(inp, 1, index, src)
2297
2298        inputs = (
2299            torch.ones((3, 5), device=self.device, dtype=torch.int64),
2300            torch.tensor([[0, 1, 2, 0]], device=self.device, dtype=torch.int64),
2301            torch.zeros((2, 5), device=self.device, dtype=torch.int64),
2302        )
2303
2304        self.check_model(Model(), inputs)
2305
2306    def test_scatter_reduce_fallback(self):
2307        class Model(torch.nn.Module):
2308            def __init__(self) -> None:
2309                super().__init__()
2310
2311            def forward(
2312                self,
2313                inp: torch.Tensor,
2314                index: torch.Tensor,
2315                src: torch.Tensor,
2316            ):
2317                return torch.scatter_reduce(inp, 0, index, src, reduce="sum")
2318
2319        inputs = (
2320            torch.tensor([1, 10, 100, 1000], device=self.device, dtype=torch.int64),
2321            torch.tensor([0, 1, 0, 1, 2, 1], device=self.device, dtype=torch.int64),
2322            torch.tensor([1, 2, 3, 4, 5, 6], device=self.device, dtype=torch.int64),
2323        )
2324
2325        self.check_model(Model(), inputs)
2326
2327    def test_index_put_fallback(self):
2328        # index_put falls back in the deterministic mode
2329        with DeterministicGuard(True):
2330
2331            class Model(torch.nn.Module):
2332                def __init__(self) -> None:
2333                    super().__init__()
2334
2335                def forward(
2336                    self,
2337                    self_tensor: torch.Tensor,
2338                    indices: Tuple[torch.Tensor],
2339                    values: torch.Tensor,
2340                ):
2341                    return torch.index_put(
2342                        self_tensor, indices, values, accumulate=True
2343                    )
2344
2345            inputs = (
2346                torch.ones(4, device=self.device, dtype=torch.int64),
2347                (torch.tensor([1, 1, 2, 2], device=self.device, dtype=torch.bool),),
2348                torch.ones(4, device=self.device, dtype=torch.int64),
2349            )
2350
2351            self.check_model(Model(), inputs)
2352
2353    def test_repeated_user_defined_triton_kernel(self):
2354        if self.device != "cuda":
2355            raise unittest.SkipTest("requires CUDA")
2356
2357        class Model(torch.nn.Module):
2358            def __init__(self) -> None:
2359                super().__init__()
2360
2361            def forward(self, x):
2362                for _ in range(3):
2363                    mul2_inplace_kernel[4,](x, n_elements=4, BLOCK_SIZE=16)
2364                return x
2365
2366        inputs = (torch.randn(4, 4, device=self.device),)
2367        self.check_model(Model(), inputs)
2368
2369    def test_convolution(self):
2370        class Model(torch.nn.Module):
2371            def __init__(self) -> None:
2372                super().__init__()
2373
2374            def forward(self, x, w, b):
2375                return torch.ops.aten.convolution(x, w, b, [4], [0], [1], True, [0], 1)
2376
2377        example_inputs = (
2378            torch.randn([2, 32, 90], device=self.device),
2379            torch.randn([32, 16, 8], device=self.device),
2380            torch.randn([16], device=self.device),
2381        )
2382        with config.patch(
2383            {
2384                "max_autotune": True,
2385                "max_autotune_gemm_backends": "Triton",
2386            }
2387        ):
2388            self.check_model(Model(), example_inputs)
2389
2390    def test_zero_size_weight(self):
2391        class Model(torch.nn.Module):
2392            def __init__(self, channel, r=8):
2393                super().__init__()
2394                self.pool = torch.nn.AdaptiveAvgPool2d(1)
2395                self.net = torch.nn.Sequential(
2396                    torch.nn.Linear(channel, channel // r, bias=False),
2397                    torch.nn.ReLU(inplace=True),
2398                    torch.nn.Linear(channel // r, channel, bias=False),
2399                    torch.nn.Sigmoid(),
2400                )
2401
2402            def forward(self, inp):
2403                b, c, _, _ = inp.shape
2404                x = self.pool(inp).view(b, c)
2405                x = self.net(x).view(b, c, 1, 1)
2406                x = inp * x
2407                return x
2408
2409        inputs = (torch.rand(4, 4, 4, 4, device=self.device),)
2410        self.check_model(Model(4), inputs)
2411
2412    def test_no_args(self):
2413        class Model(torch.nn.Module):
2414            def __init__(self, m, n):
2415                super().__init__()
2416                self.weight = torch.nn.Parameter(
2417                    torch.randn(m, n),
2418                )
2419                self.alpha = torch.nn.Parameter(torch.randn(m, n))
2420
2421            def forward(self):
2422                return self.weight * self.alpha
2423
2424        self.check_model(Model(6, 4), ())
2425
2426    def test_dynamic_scalar(self):
2427        class Model(torch.nn.Module):
2428            def __init__(self) -> None:
2429                super().__init__()
2430                self.criterion_ce = torch.nn.CrossEntropyLoss(reduction="none")
2431
2432            def forward(self, inputs, targets, split_index=None):
2433                statistics = {}
2434                total_loss = self.criterion_ce(inputs, targets).sum()
2435                statistics["dl"] = total_loss.item()
2436                return total_loss, statistics
2437
2438        inputs = (
2439            torch.rand(4, 4, 4, 4, device=self.device),
2440            torch.rand(4, 4, 4, 4, device=self.device),
2441        )
2442        self.check_model(Model(), inputs)
2443
2444    def test_constant_original_fqn_and_dtype(self):
2445        class FooBarModule(torch.nn.Module):
2446            def __init__(self) -> None:
2447                super().__init__()
2448                self.register_parameter("0", torch.nn.Parameter(torch.randn(3, 4)))
2449                self.test_buf = torch.nn.Buffer(torch.randn(3, 4))
2450                self.register_parameter(
2451                    "test_param", torch.nn.Parameter(torch.randn(3, 4))
2452                )
2453
2454            def forward(self, x):
2455                return ((x + self.test_buf) * getattr(self, "0")) / self.test_param
2456
2457        class TestModule(torch.nn.Module):
2458            def __init__(self) -> None:
2459                super().__init__()
2460                self.foo_bar = FooBarModule()
2461                self.register_parameter(
2462                    "test_param", torch.nn.Parameter(torch.randn(3, 4))
2463                )
2464                self.test_buf = torch.nn.Buffer(torch.randn(3, 4))
2465
2466            def forward(self, x):
2467                return (self.foo_bar(x) + self.test_param) * self.test_buf
2468
2469        with torch.no_grad():
2470            so_path = AOTIRunnerUtil.compile(
2471                model=TestModule().to(device=self.device),
2472                example_inputs=(torch.rand(3, 4, device=self.device),),
2473            )
2474
2475        runner = AOTIRunnerUtil.load_runner(self.device, so_path)
2476
2477        expected_original_fqns = {
2478            "L__self___test_param": "test_param",
2479            "L__self___test_buf": "test_buf",
2480            "getattr_L__self___foo_bar___0__": "foo_bar.0",
2481            "L__self___foo_bar_test_param": "foo_bar.test_param",
2482            "L__self___foo_bar_test_buf": "foo_bar.test_buf",
2483        }
2484        self.assertEqual(
2485            expected_original_fqns, runner.get_constant_names_to_original_fqns()
2486        )
2487
2488        expected_dtypes = {
2489            "L__self___test_param": 6,
2490            "L__self___test_buf": 6,
2491            "getattr_L__self___foo_bar___0__": 6,
2492            "L__self___foo_bar_test_param": 6,
2493            "L__self___foo_bar_test_buf": 6,
2494        }
2495        self.assertEqual(expected_dtypes, runner.get_constant_names_to_dtypes())
2496
2497    def test_fqn(self):
2498        class NestedChild(torch.nn.Module):
2499            def __init__(self) -> None:
2500                super().__init__()
2501                self.nestedchild3buffer = torch.nn.Buffer(torch.ones(2, 3) * 3)
2502
2503            def forward(self, x):
2504                return x / self.nestedchild3buffer
2505
2506        class Child1(torch.nn.Module):
2507            def __init__(self) -> None:
2508                super().__init__()
2509                self.nested = NestedChild()
2510                self.register_parameter(
2511                    "child1param", torch.nn.Parameter(torch.ones(2, 3))
2512                )
2513
2514            def forward(self, x):
2515                x = self.nested(x)
2516                return x + self.child1param
2517
2518        class Child2(torch.nn.Module):
2519            def __init__(self) -> None:
2520                super().__init__()
2521                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3) * 2)
2522
2523            def forward(self, x):
2524                return x - self.child2buffer
2525
2526        class MyModule(torch.nn.Module):
2527            def __init__(self) -> None:
2528                super().__init__()
2529                self.foo = Child1()
2530                self.bar = Child2()
2531                self.register_parameter(
2532                    "rootparam", torch.nn.Parameter(torch.ones(2, 3) * 4)
2533                )
2534
2535            def forward(self, x):
2536                x = x * self.rootparam
2537                x = self.foo(x)
2538                x = self.bar(x)
2539                return x
2540
2541        orig_eager = MyModule()
2542
2543        self.check_model(MyModule(), (torch.randn(2, 3, device=self.device),))
2544
2545    def test_model_modified_weights(self):
2546        class Model(torch.nn.Module):
2547            def __init__(self, n, k, device):
2548                super().__init__()
2549                self.weight = torch.randn(n, k, device=device)
2550                self.bias = torch.randn(n, device=device)
2551
2552            def forward(self, a):
2553                return torch.nn.functional.linear(a, self.weight, self.bias)
2554
2555        M = 16
2556        N = 10
2557        K = 128
2558        batch = 8
2559        example_inputs = (torch.randn(2, M, K, device=self.device),)
2560        model = Model(N, K, self.device)
2561        self.check_model(model, example_inputs)
2562        # Update model weights, after this AOTInductor should re-generate model.so
2563        # if weights are stored in the model.so
2564        model.weight += 1
2565        self.check_model(model, example_inputs)
2566
2567    def test_custom_op_add(self) -> None:
2568        class M(torch.nn.Module):
2569            def forward(self, x, y):
2570                return torch.ops.aoti_custom_ops.custom_add(x, y)
2571
2572        m = M().to(device=self.device)
2573        args = (
2574            torch.randn(3, 3, device=self.device),
2575            torch.randn(3, 3, device=self.device),
2576        )
2577        self.check_model(m, args)
2578
2579    def test_custom_op_all_inputs(self) -> None:
2580        class MyModel(torch.nn.Module):
2581            # pyre-fixme[3]: Return type must be annotated.
2582            def __init__(self):
2583                super().__init__()
2584
2585            # pyre-fixme[3]: Return type must be annotated.
2586            # pyre-fixme[2]: Parameter must be annotated.
2587            def forward(self, x, y):
2588                with torch.no_grad():
2589                    x_dim0 = x.shape[0]
2590                    x_dim1 = x.shape[1]
2591                    y_dim0 = y.shape[0]
2592                    y_dim1 = y.shape[1]
2593                    symint_0 = x_dim0 + x_dim1
2594                    symint_1 = y_dim0 * y_dim1
2595
2596                    z = torch.concat((x, x))
2597
2598                    _2547 = torch.ops.aoti_custom_ops.fn_with_all_inputs(
2599                        tensor=x,
2600                        tensors=[x, y],
2601                        optional_tensors=[None, z],
2602                        b8=False,
2603                        b8s=[True, False],
2604                        i64=42,
2605                        i64s=[16, 17],
2606                        symint=symint_0,
2607                        symints=[symint_0, symint_1],
2608                        f64=3.14,
2609                        f64s=[2.2, 3.3],
2610                        scalar=1.23,
2611                        scalars=[45, 67],
2612                        string="hello",
2613                        strings=["ab", "cde"],
2614                        # dtype=torch.float16,
2615                        # memory_format=torch.contiguous_format,
2616                        # layout=torch.strided,
2617                        device=torch.device("cpu"),
2618                        # optional
2619                        o_tensor=None,
2620                        o_tensors=[x, y],
2621                        o_b8=False,
2622                        o_b8s=[True, False],
2623                        o_i64=None,
2624                        o_i64s=[16, 17],
2625                        o_symint=symint_1,
2626                        o_symints=[symint_1, symint_0],
2627                        o_f64=3.14,
2628                        o_f64s=None,
2629                        o_scalar=None,
2630                        o_scalars=[89, 910],
2631                        o_string="hello",
2632                        o_strings=["ab", "cde"],
2633                        # o_dtype=None,
2634                        # o_memory_format=torch.contiguous_format,
2635                        # o_layout=torch.strided,
2636                        o_device=None,
2637                    )
2638
2639                return _2547
2640
2641        m = MyModel().to(device=self.device)
2642        x = torch.zeros(4, 8, device=self.device)
2643        y = torch.ones(3, 9, device=self.device)
2644        args = (x, y)
2645        m(*args)
2646
2647        self.check_model(m, args)
2648
2649    def test_custom_op_with_multiple_outputs(self) -> None:
2650        class Model(torch.nn.Module):
2651            def forward(self, x, y):
2652                out = x + y
2653                # tuple of Tensor output
2654                out3, out4 = torch.ops.aoti_custom_ops.fn_with_tuple_output(out, 1)
2655                # TensorList output
2656                out5, out6 = torch.ops.aoti_custom_ops.fn_with_list_output(
2657                    [out3, out4], 1
2658                )
2659                # tuple of Tensor and TensorList
2660                out7, [out8, out9] = torch.ops.aoti_custom_ops.fn_with_mix_outputs(
2661                    out5, [out6, out4]
2662                )
2663                return out3, out4, out5, out6, out7, out8, out9
2664
2665        m = Model().to(device=self.device)
2666        args = (
2667            torch.randn(4, 4, device=self.device),
2668            torch.randn(4, 4, device=self.device),
2669        )
2670        m(*args)
2671
2672        self.check_model(m, args)
2673
2674    def test_custom_op_with_reinterpret_view_inputs(self) -> None:
2675        class Model(torch.nn.Module):
2676            def forward(self, x):
2677                out = x.permute([1, 0])
2678                return torch.ops.aoti_custom_ops.fn_with_default_input(out, 1)
2679
2680        m = Model().to(device=self.device)
2681        args = (torch.randn(2, 3, device=self.device),)
2682
2683        self.check_model(m, args)
2684
2685    def test_custom_op_with_concat_inputs(self) -> None:
2686        class Model(torch.nn.Module):
2687            def forward(self, x, y):
2688                out = torch.concat([x, y], dim=0)
2689                return torch.ops.aoti_custom_ops.fn_with_default_input(out, 1)
2690
2691        m = Model().to(device=self.device)
2692        args = (
2693            torch.randn(2, 3, device=self.device),
2694            torch.randn(2, 3, device=self.device),
2695        )
2696
2697        self.check_model(m, args)
2698
2699    def test_custom_op_missing_arg_with_default_value(self) -> None:
2700        class Model(torch.nn.Module):
2701            def forward(self, x):
2702                # missing second arg
2703                return torch.ops.aoti_custom_ops.fn_with_default_input(x)
2704
2705        m = Model().to(device=self.device)
2706        args = (torch.randn(2, 3, device=self.device),)
2707
2708        self.check_model(m, args)
2709
2710    def test_triton_kernel_extern_kernel_arg(self):
2711        if self.device != "cuda":
2712            raise unittest.SkipTest("requires CUDA")
2713
2714        class Model(torch.nn.Module):
2715            def forward(self, x, y):
2716                out = torch.zeros_like(x)
2717                # torch.mm is ExternKernelOut
2718                add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16)
2719                return out
2720
2721        example_inputs = (
2722            torch.randn(4, 4, device="cuda"),
2723            torch.randn(4, 4, device="cuda"),
2724        )
2725
2726        self.check_model(Model(), example_inputs)
2727
2728    def test_triton_kernel_multi_output_arg(self):
2729        if self.device != "cuda":
2730            raise unittest.SkipTest("requires CUDA")
2731
2732        class Model(torch.nn.Module):
2733            def forward(self, x, y):
2734                out = torch.zeros_like(x)
2735                # torch.sort creates fallback kernel and hence MultiOutput
2736                add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16)
2737                return out
2738
2739        example_inputs = (
2740            torch.randn(4, 4, device="cuda"),
2741            torch.randn(4, 4, device="cuda"),
2742        )
2743
2744        self.check_model(Model(), example_inputs)
2745
2746    @config.patch({"abi_compatible": True})
2747    def test_triton_kernel_reinterpret_view_mem_leak(self):
2748        # Check for memory leak when using user-defined Triton Kernel + AOTI.
2749        if self.device != "cuda":
2750            raise unittest.SkipTest("requires CUDA")
2751
2752        class Model(torch.nn.Module):
2753            def __init__(self) -> None:
2754                super().__init__()
2755
2756            def forward(self, x, y):
2757                out = torch.zeros_like(x)
2758                yy = y * y
2759                # reshape creates a ReinterpretView
2760                add_kernel[(4,)](x, yy.reshape_as(x), out, 4, 16)
2761                return out
2762
2763        example_inputs = (
2764            torch.randn(4, 4, device="cuda"),
2765            torch.randn(1, 16, device="cuda"),
2766        )
2767
2768        so_path: str = AOTIRunnerUtil.compile(
2769            Model(),
2770            example_inputs,
2771        )
2772        aot_inductor_module = AOTIRunnerUtil.load("cuda", so_path)
2773
2774        # Don't assign outputs to a variable b/c it will allocate GPU memory.
2775        device: int = torch.cuda.current_device()
2776        mem_before = torch.cuda.memory_allocated(device)
2777        aot_inductor_module(*example_inputs)
2778        aot_inductor_module(*example_inputs)
2779        mem_after = torch.cuda.memory_allocated(device)
2780        self.assertEqual(mem_before, mem_after)
2781
2782        actual = aot_inductor_module(*example_inputs)
2783        expected = Model()(*example_inputs)
2784        torch.testing.assert_close(actual, expected)
2785
2786    @torch._dynamo.config.patch(capture_scalar_outputs=True)
2787    @common_utils.parametrize("dynamic", [False, True])
2788    @common_utils.parametrize("autotuning", [False, True])
2789    def test_triton_kernel_unbacked_symint_in_grid(self, dynamic, autotuning):
2790        if self.device != "cuda":
2791            raise unittest.SkipTest("requires CUDA")
2792
2793        class Model(torch.nn.Module):
2794            def forward(self, x, y, n_elements_tensor):
2795                output = torch.zeros_like(x)
2796                n_elements_symint = n_elements_tensor.item()
2797                n_elements = x.numel()
2798
2799                def grid(meta):
2800                    return (triton.cdiv(n_elements_symint, meta["BLOCK_SIZE"]),)
2801
2802                if autotuning:
2803                    add_kernel_autotuned[grid](
2804                        x,
2805                        y,
2806                        output,
2807                        n_elements,
2808                    )
2809                else:
2810                    add_kernel[grid](
2811                        x,
2812                        y,
2813                        output,
2814                        n_elements,
2815                        BLOCK_SIZE=16,
2816                    )
2817
2818                return output
2819
2820        example_inputs = (
2821            torch.randn(123, device="cuda"),
2822            torch.randn(123, device="cuda"),
2823            torch.tensor(123),
2824        )
2825
2826        dynamic_shapes = None
2827        if dynamic:
2828            dim0 = Dim("s0", min=2, max=1024)
2829            dynamic_shapes = {
2830                "x": {0: dim0},
2831                "y": {0: dim0},
2832                "n_elements_tensor": {},
2833            }
2834
2835        self.check_model(
2836            Model(),
2837            example_inputs,
2838            dynamic_shapes=dynamic_shapes,
2839        )
2840
2841    @skipIfRocm  # USE_MEM_EFF_ATTENTION was not enabled for build.
2842    def test_scaled_dot_product_efficient_attention(self):
2843        if self.device != "cuda":
2844            raise unittest.SkipTest("requires CUDA")
2845
2846        class Model(torch.nn.Module):
2847            def forward(self, q, k, v, attn_bias):
2848                return torch.ops.aten._scaled_dot_product_efficient_attention(
2849                    q, k, v, attn_bias, False
2850                )[0]
2851
2852        example_inputs = (
2853            torch.randn(4, 4, 36, 36, device="cuda"),
2854            torch.randn(4, 4, 36, 36, device="cuda"),
2855            torch.randn(4, 4, 36, 36, device="cuda"),
2856            torch.randn(4, 4, 36, 36, device="cuda"),
2857        )
2858        self.check_model(Model(), example_inputs)
2859
2860    def test_index_put_with_none_index(self):
2861        # index_put falls back in the deterministic mode
2862        with DeterministicGuard(True):
2863
2864            class Model(torch.nn.Module):
2865                def forward(self, x, i1, i2, y):
2866                    return torch.ops.aten.index_put(
2867                        x,
2868                        (None, None, i1, i2.transpose(0, 1)),
2869                        y,
2870                        accumulate=True,
2871                    )
2872
2873            example_inputs = (
2874                torch.rand(8, 192, 30, 30, device=self.device),
2875                torch.zeros(3, 14, 1, 1, dtype=torch.int64, device=self.device),
2876                torch.ones(14, 3, dtype=torch.int64, device=self.device),
2877                torch.randn(8, 192, 3, 14, 3, 14, device=self.device),
2878            )
2879            self.check_model(Model(), example_inputs)
2880
2881    def test_runtime_checks(self):
2882        class Model(torch.nn.Module):
2883            def __init__(self) -> None:
2884                super().__init__()
2885
2886            def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9):
2887                return (x0, x1, x2, x3, x4, x5, x6, x7, x8, x9)
2888
2889        inputs = []
2890        for dtype in (
2891            torch.float16,
2892            torch.float32,
2893            torch.float64,
2894            torch.bfloat16,
2895            torch.bool,
2896            torch.int8,
2897            torch.int16,
2898            torch.int32,
2899            torch.int64,
2900            torch.uint8,
2901        ):
2902            inputs.append(torch.ones(4, 8, 10, dtype=dtype, device=self.device))
2903        dim0 = Dim("s0", min=2, max=1024)
2904        dim1 = Dim("s1", min=2, max=512)
2905        dim2 = Dim("s2", min=2, max=128)
2906        dynamic_shapes = {
2907            "x0": {0: dim0},
2908            "x1": {0: dim0},
2909            "x2": {0: dim0},
2910            "x3": {1: dim1},
2911            "x4": {1: dim1},
2912            "x5": {1: dim1},
2913            "x6": {},
2914            "x7": {2: dim2},
2915            "x8": {2: dim2},
2916            "x9": {2: dim2},
2917        }
2918        m = Model()
2919        inputs = tuple(inputs)
2920        with torch.no_grad(), config.patch(
2921            {
2922                "abi_compatible": self.abi_compatible,
2923                "aot_inductor.debug_compile": True,
2924            }
2925        ):
2926            so_path = AOTIRunnerUtil.compile(m, inputs, dynamic_shapes=dynamic_shapes)
2927        with open(os.path.splitext(so_path)[0] + ".cpp") as cpp:
2928            src_code = cpp.read()
2929            FileCheck().check_count(
2930                "unmatched dtype",
2931                10,
2932                exactly=True,
2933            ).run(src_code)
2934            FileCheck().check_count(
2935                "unmatched dim value at",
2936                21,  # we have 9 dynamic dims for which we generate different checks
2937                exactly=True,
2938            ).run(src_code)
2939            FileCheck().check_count(
2940                "dim value is too",
2941                18,  # we have 9 dynamic dims for which we generate two checks
2942                exactly=True,
2943            ).run(src_code)
2944            FileCheck().check_count(
2945                "unmatched stride value at",
2946                21,  # we have 9 symbolic strides for which we don't generate checks
2947                exactly=True,
2948            ).run(src_code)
2949        optimized = AOTIRunnerUtil.load(self.device, so_path)
2950        actual = optimized(*inputs)
2951        expected = m(*inputs)
2952        torch.testing.assert_close(actual, expected)
2953
2954    @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
2955    @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
2956    def test_runtime_checks_fp8(self):
2957        class Model(torch.nn.Module):
2958            def __init__(self) -> None:
2959                super().__init__()
2960
2961            def forward(self, x0, x1):
2962                t = x0.to(torch.float) + x1.to(torch.float)
2963                return t
2964
2965        inputs = []
2966        for dtype in (
2967            torch.float8_e4m3fn,
2968            torch.float8_e5m2,
2969            # FP8 funz are for AMD
2970            # see https://github.com/pytorch/pytorch/issues/126734
2971            # torch.float8_e4m3fnuz,
2972            # torch.float8_e5m2fnuz,
2973        ):
2974            inputs.append(torch.ones(8, 8, 8, dtype=dtype, device=self.device))
2975        dim0 = Dim("s0", min=2, max=1024)
2976        dynamic_shapes = {
2977            "x0": {0: dim0},
2978            "x1": {0: dim0},
2979        }
2980        with torch.no_grad(), config.patch(
2981            {
2982                "abi_compatible": self.abi_compatible,
2983                "aot_inductor.debug_compile": True,
2984            }
2985        ):
2986            self.check_model(
2987                Model(),
2988                tuple(inputs),
2989                dynamic_shapes=dynamic_shapes,
2990            )
2991
2992    def test_runtime_checks_complex(self):
2993        class Model(torch.nn.Module):
2994            def __init__(self) -> None:
2995                super().__init__()
2996
2997            def forward(self, x0, x1, x2):
2998                return (x0, x1, x2)
2999
3000        inputs = []
3001        x0 = torch.tensor([1, -1], dtype=torch.complex32, device=self.device)
3002        x1 = torch.tensor(
3003            [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1],
3004            dtype=torch.complex64,
3005            device=self.device,
3006        )
3007        x2 = torch.tensor(128, dtype=torch.complex128, device=self.device)
3008        inputs.append(x0)
3009        inputs.append(x1)
3010        inputs.append(x2)
3011        dim0 = Dim("s0", min=2, max=1024)
3012        dynamic_shapes = {
3013            "x0": {0: dim0},
3014            "x1": {},
3015            "x2": {},
3016        }
3017        with torch.no_grad(), config.patch(
3018            {
3019                "abi_compatible": self.abi_compatible,
3020                "aot_inductor.debug_compile": True,
3021            }
3022        ):
3023            self.check_model(
3024                Model(),
3025                tuple(inputs),
3026                dynamic_shapes=dynamic_shapes,
3027            )
3028
3029    @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
3030    def test_runtime_checks_dtype_failed(self):
3031        class Model(torch.nn.Module):
3032            def __init__(self) -> None:
3033                super().__init__()
3034
3035            def forward(self, x):
3036                y = x.type(torch.float)
3037                return y
3038
3039        x = torch.randn(1, 4, dtype=torch.float16, device=self.device)
3040        model = Model()
3041        with torch.no_grad(), config.patch(
3042            {
3043                "abi_compatible": self.abi_compatible,
3044                "aot_inductor.debug_compile": True,
3045            }
3046        ):
3047            so_path: str = AOTIRunnerUtil.compile(
3048                model,
3049                (x,),
3050            )
3051        aot_inductor_module = AOTIRunnerUtil.load(self.device, so_path)
3052        x_casted = x.float()
3053        with self.assertRaisesRegex(Exception, ""):
3054            aot_inductor_module(x_casted)
3055
3056    def test_non_contiguous_output_alias(self):
3057        # Test return x, x.contiguous() where x is non-contiguous.
3058        class Model(torch.nn.Module):
3059            def forward(self, x):
3060                squared = x * x
3061                transposed = squared.t()  # non-contiguous
3062                contig = transposed.contiguous()
3063                return transposed, contig
3064
3065        x = torch.randn(3, 4, dtype=torch.float16, device=self.device)
3066        model = Model()
3067        with torch.no_grad(), config.patch(
3068            {
3069                "abi_compatible": self.abi_compatible,
3070            }
3071        ):
3072            result = AOTIRunnerUtil.run(
3073                self.device,
3074                model,
3075                (x,),
3076            )
3077        actual = model(x)
3078        self.assertTrue(same(result, actual))
3079
3080        # contiguous() should create a new tensor
3081        self.assertTrue(result[0].data_ptr() != result[1].data_ptr())
3082
3083    def test_multiple_output_alias(self):
3084        # Test when mutliple outputs alias the same tensor
3085        class Model(torch.nn.Module):
3086            def forward(self, x):
3087                squared = x * x
3088                contig = squared.contiguous()  # alias
3089                reshaped = squared.reshape(squared.shape)  # alias
3090                cubed = squared * x
3091                return squared, contig, reshaped, cubed
3092
3093        x = torch.randn(3, 4, dtype=torch.float32, device=self.device)
3094        model = Model()
3095
3096        with torch.no_grad(), config.patch(
3097            {
3098                "abi_compatible": self.abi_compatible,
3099            }
3100        ):
3101            result = AOTIRunnerUtil.run(
3102                self.device,
3103                model,
3104                (x,),
3105            )
3106        actual = model(x)
3107        self.assertTrue(same(result, actual))
3108
3109        # squared, contig and reshaped alias the same tensor.
3110        self.assertTrue(result[0].data_ptr() == result[1].data_ptr())
3111        self.assertTrue(result[0].data_ptr() == result[2].data_ptr())
3112        # cubed shouldn't be an alias.
3113        self.assertTrue(result[0].data_ptr() != result[3].data_ptr())
3114
3115    def test_runtime_checks_shape_failed(self):
3116        class Model(torch.nn.Module):
3117            def __init__(self) -> None:
3118                super().__init__()
3119
3120            def forward(self, x):
3121                return x
3122
3123        x = torch.randn(4, 4, 4, dtype=torch.float16, device=self.device)
3124        y0 = torch.randn(8, 4, 4, dtype=torch.float16, device=self.device)
3125        y1 = torch.randn(4, 8, 4, dtype=torch.float16, device=self.device)
3126        y2 = rand_strided(
3127            (4, 4, 4), (16, 1, 4), dtype=torch.float16, device=self.device
3128        )
3129        # batch size is outside of the range
3130        y3 = torch.randn(2048, 3, 4, dtype=torch.float16, device=self.device)
3131        y4 = torch.randn(2048, 4, 4, dtype=torch.float16, device=self.device)
3132        dim0 = Dim("s0", min=4, max=1024)
3133        dynamic_shapes = {
3134            "x": {0: dim0},
3135        }
3136        model = Model()
3137        with torch.no_grad(), config.patch(
3138            {
3139                "abi_compatible": self.abi_compatible,
3140                "aot_inductor.debug_compile": True,
3141            }
3142        ):
3143            so_path: str = AOTIRunnerUtil.compile(
3144                model, (x,), dynamic_shapes=dynamic_shapes
3145            )
3146        aot_inductor_module = AOTIRunnerUtil.load(self.device, so_path)
3147        # dynamic dim works fine
3148        _ = aot_inductor_module(y0)
3149        with self.assertRaisesRegex(Exception, ""):
3150            aot_inductor_module(y1)
3151        with self.assertRaisesRegex(Exception, ""):
3152            aot_inductor_module(y2)
3153        with self.assertRaisesRegex(Exception, ""):
3154            aot_inductor_module(y3)
3155        with self.assertRaisesRegex(Exception, ""):
3156            aot_inductor_module(y4)
3157
3158    def test_add_complex(self):
3159        class Model(torch.nn.Module):
3160            def forward(self, a, b):
3161                return torch.add(a, b)
3162
3163        x = torch.tensor(
3164            [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], device=self.device
3165        )
3166        y = torch.tensor(
3167            [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], device=self.device
3168        )
3169        self.check_model(Model(), (x, y))
3170
3171    def test_embedding_bag(self):
3172        class Model(torch.nn.Module):
3173            def forward(self, w, i, o):
3174                return torch.ops.aten._embedding_bag(w, i, o, False, 0, False, None)
3175
3176        example_inputs = (
3177            torch.randn([10, 4], device=self.device),
3178            torch.randint(10, [8], device=self.device),
3179            torch.tensor([0, 2, 6], device=self.device),
3180        )
3181        self.check_model(Model(), example_inputs)
3182
3183    def test_fft_c2c(self):
3184        class Model(torch.nn.Module):
3185            def forward(self, x):
3186                return torch.fft.fftn(x), torch.fft.fftn(x).real
3187
3188        example_inputs = (torch.randn(16, 16, 16, device=self.device),)
3189        self.check_model(Model(), example_inputs)
3190
3191    def test_bool_input(self):
3192        # Specialize on whichever branch the example input for b is
3193        class Model(torch.nn.Module):
3194            def forward(self, x, b):
3195                if b:
3196                    return x * x
3197                else:
3198                    return x + x
3199
3200        example_inputs = (torch.randn(3, 3, device=self.device), True)
3201        self.check_model(Model(), example_inputs)
3202
3203    def test_int_list_input(self):
3204        class Model(torch.nn.Module):
3205            def forward(self, x, i):
3206                return x * i[0] * i[1]
3207
3208        example_inputs = (torch.randn(3, 3, device=self.device), [3, 4])
3209        self.check_model(Model(), example_inputs)
3210
3211    def test_nested_tensor_from_jagged(self):
3212        class Model(nn.Module):
3213            def __init__(self) -> None:
3214                super().__init__()
3215                self.mlp = nn.Sequential(
3216                    nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 32), nn.Sigmoid()
3217                )
3218
3219            def forward(self, values, offsets):
3220                nt = torch.nested.nested_tensor_from_jagged(values, offsets)
3221                res = self.mlp(nt)
3222                return res.values()
3223
3224        model = Model().to(device=self.device)
3225
3226        example_inputs_1 = (
3227            torch.randn((15, 128), device=self.device),
3228            torch.tensor([0, 3, 4, 10, 15], device=self.device),
3229        )
3230
3231        # same "NT batch size", different actual amount of data
3232        example_inputs_2 = (
3233            torch.randn((31, 128), device=self.device),
3234            torch.tensor([0, 1, 20, 25, 31], device=self.device),
3235        )
3236
3237        # same actual amount of data, different "NT batch size"
3238        example_inputs_3 = (
3239            torch.randn((15, 128), device=self.device),
3240            torch.tensor([0, 3, 10, 15], device=self.device),
3241        )
3242
3243        # different "NT batch size"
3244        example_inputs_4 = (
3245            torch.randn((37, 128), device=self.device),
3246            torch.tensor([0, 5, 16, 25, 29, 37], device=self.device),
3247        )
3248
3249        dim0_values = Dim("dim0_values", min=1, max=128)
3250        dim0_offsets = Dim("dim0_offsets", min=1, max=9)
3251        dynamic_shapes = {"values": {0: dim0_values}, "offsets": {0: dim0_offsets}}
3252        example_inputs_list = [
3253            example_inputs_1,
3254            example_inputs_2,
3255            example_inputs_3,
3256            example_inputs_4,
3257        ]
3258
3259        self.check_model_with_multiple_inputs(
3260            model, example_inputs_list, dynamic_shapes=dynamic_shapes
3261        )
3262
3263    @common_utils.parametrize("max_autotune", [False, True])
3264    def test_misc_1(self, max_autotune):
3265        if self.device == "cpu" and IS_MACOS and max_autotune:
3266            raise unittest.SkipTest("max_autotune not supported on macos")
3267
3268        class Model(nn.Module):
3269            def __init__(self) -> None:
3270                super().__init__()
3271                self.mlp = nn.Sequential(
3272                    nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 32), nn.Sigmoid()
3273                )
3274                self.emb = nn.EmbeddingBag(num_embeddings=128, embedding_dim=32)
3275                self.over_arch = nn.Sequential(
3276                    nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 32), nn.Sigmoid()
3277                )
3278
3279            def forward(self, x, y):
3280                mlp_output = self.mlp(x)
3281                emb_output = self.emb(y)
3282                return self.over_arch(torch.concat([mlp_output, emb_output], dim=1))
3283
3284        example_inputs = (
3285            torch.randn(16, 128, device=self.device),
3286            torch.randint(0, 128, (16, 10), device=self.device),
3287        )
3288        self.check_model(
3289            Model(), example_inputs, options=dict(max_autotune=max_autotune)
3290        )
3291
3292    def test_aoti_debug_printer_codegen(self):
3293        # basic addmm model to test codegen for aoti intermediate debug printer
3294        class Model(torch.nn.Module):
3295            def __init__(self, n, k, device):
3296                super().__init__()
3297                self.weight = torch.randn(n, k, device=device)
3298                self.bias = torch.randn(n, device=device)
3299
3300            def forward(self, a):
3301                return torch.nn.functional.linear(a, self.weight, self.bias)
3302
3303        M = 8
3304        N = 6
3305        K = 16
3306        model = Model(N, K, self.device)
3307        batch = 2
3308        a = torch.randn(batch, M, K, device=self.device)
3309        example_inputs = (a,)
3310
3311        kernel_calls = (
3312            [
3313                ("triton_poi_fused_0", 1),
3314                ("aoti_torch_cuda_addmm_out", 2),
3315            ]
3316            if self.device == "cuda"
3317            else [
3318                ("aoti_torch_cpu_addmm_out", 2),
3319            ]
3320        )
3321
3322        # test default debug printing all tensor values codegen
3323        with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
3324            result, code = run_and_get_cpp_code(
3325                AOTIRunnerUtil.compile, model, example_inputs
3326            )
3327
3328            # check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
3329            self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
3330
3331            # check the codegen for debug printing around the actual kernel call is expected
3332
3333            for kernel_call, count in kernel_calls:
3334                FileCheck().check_count(
3335                    f"before_launch - {kernel_call}",
3336                    count,
3337                ).run(code)
3338                FileCheck().check_count(
3339                    f"after_launch - {kernel_call}",
3340                    count,
3341                ).run(code)
3342
3343        # test printing selected kernel's tensor values codegen
3344        filtered_kernel_name = f"aoti_torch_{self.device}_addmm_out"
3345        with config.patch(
3346            {
3347                "aot_inductor.debug_intermediate_value_printer": "2",
3348                "aot_inductor.filtered_kernel_names": filtered_kernel_name,
3349            }
3350        ):
3351            result, code = run_and_get_cpp_code(
3352                AOTIRunnerUtil.compile, model, example_inputs
3353            )
3354            filtered_kernel_calls = [
3355                (filtered_kernel_name, 2),
3356            ]
3357            for kernel_call, count in filtered_kernel_calls:
3358                FileCheck().check_count(
3359                    f"before_launch - {kernel_call}",
3360                    count,
3361                ).run(code)
3362                FileCheck().check_count(
3363                    f"after_launch - {kernel_call}",
3364                    count,
3365                ).run(code)
3366
3367            kernel_calls_not_to_print = [
3368                kernel_call
3369                for kernel_call in kernel_calls
3370                if kernel_call[0] != filtered_kernel_name
3371            ]
3372            for kernel_name, _ in kernel_calls_not_to_print:
3373                FileCheck().check_not(f"before_launch - {kernel_name}").run(code)
3374                FileCheck().check_not(f"after_launch - {kernel_name}").run(code)
3375
3376    def test_aoti_debug_printer_user_defined_triton_kernel(self):
3377        if self.device != "cuda":
3378            raise unittest.SkipTest("requires CUDA")
3379
3380        class Model(torch.nn.Module):
3381            def __init__(self) -> None:
3382                super().__init__()
3383
3384            def forward(self, x, y):
3385                out = torch.zeros_like(x)
3386                add_kernel[(4,)](x, y, out, n_elements=4, BLOCK_SIZE=16)
3387                return out
3388
3389        example_inputs = (
3390            torch.randn(4, 4, device=self.device),
3391            torch.randn(4, 4, device=self.device),
3392        )
3393
3394        kernel_calls = [
3395            ("add_kernel_0", 3),
3396        ]
3397
3398        with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
3399            result, code = run_and_get_cpp_code(
3400                AOTIRunnerUtil.compile, Model(), example_inputs
3401            )
3402            # check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
3403            self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
3404            # check the codegen for debug printing around the actual kernel call is expected
3405            for kernel_call, count in kernel_calls:
3406                FileCheck().check_count(
3407                    f"before_launch - {kernel_call}",
3408                    count,
3409                ).run(code)
3410                FileCheck().check_count(
3411                    f"after_launch - {kernel_call}",
3412                    count,
3413                ).run(code)
3414
3415    def test_size_from_multi_output(self):
3416        class Model(torch.nn.Module):
3417            def __init__(self):
3418                super().__init__()
3419                self.relu = torch.nn.ReLU()
3420
3421            def forward(self, x):
3422                _x, _i = torch.unique(x, sorted=True, return_inverse=True)
3423                _x = _x.clone().detach()
3424                return self.relu(_x), _i
3425
3426        example_inputs = (torch.randn(8, device=self.device),)
3427        self.check_model(Model(), example_inputs)
3428
3429
3430common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
3431
3432
3433class AOTITestCase(TestCase):
3434    def setUp(self):
3435        if IS_SANDCASTLE or IS_FBCODE:
3436            torch.ops.load_library("//caffe2/test/inductor:custom_ops")
3437        elif IS_MACOS:
3438            raise unittest.SkipTest("non-portable load_library call used in test")
3439        else:
3440            lib_file_path = find_library_location("libaoti_custom_ops.so")
3441            if IS_WINDOWS:
3442                lib_file_path = find_library_location("aoti_custom_ops.dll")
3443            torch.ops.load_library(str(lib_file_path))
3444        super().setUp()
3445
3446
3447class AOTInductorTestABICompatibleCpu(AOTITestCase):
3448    device = "cpu"
3449    abi_compatible = True
3450    check_model = check_model
3451    check_model_with_multiple_inputs = check_model_with_multiple_inputs
3452    code_check_count = code_check_count
3453    allow_stack_allocation = False
3454    use_minimal_arrayref_interface = False
3455
3456
3457def fail_with_and_without_stack_allocation(is_skip=False):
3458    return TestFailure(
3459        (
3460            "abi_compatible_cpu",
3461            "abi_compatible_cpu_with_stack_allocation",
3462            "abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface",
3463        ),
3464        is_skip=is_skip,
3465    )
3466
3467
3468def fail_stack_allocation(is_skip=False):
3469    return TestFailure(
3470        (
3471            "abi_compatible_cpu_with_stack_allocation",
3472            "abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface",
3473        ),
3474        is_skip=is_skip,
3475    )
3476
3477
3478def fail_minimal_arrayref_interface(is_skip=False):
3479    return TestFailure(
3480        ("abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface",),
3481        is_skip=is_skip,
3482    )
3483
3484
3485def fail_cuda(is_skip=False):
3486    return TestFailure(
3487        ("abi_compatible_cuda", "non_abi_compatible_cuda"),
3488        is_skip=is_skip,
3489    )
3490
3491
3492def fail_abi_compatible_cuda(is_skip=False):
3493    return TestFailure(
3494        ("abi_compatible_cuda",),
3495        is_skip=is_skip,
3496    )
3497
3498
3499def fail_non_abi_compatible_cuda(is_skip=False):
3500    return TestFailure(
3501        ("non_abi_compatible_cuda",),
3502        is_skip=is_skip,
3503    )
3504
3505
3506# test_failures, xfail by default, set is_skip=True to skip
3507CPU_TEST_FAILURES = {
3508    # TODO: error: ‘complex64’ was not declared in this scope
3509    "test_add_complex": fail_minimal_arrayref_interface(is_skip=True),
3510    # TODO: test_conv_freezing_abi_compatible_cpu fails,
3511    #   AssertionError: None, i.e. optional output is not supported
3512    "test_conv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
3513    # TODO: test_deconv_freezing_abi_compatible_cpu fails,
3514    #   AssertionError: None, i.e. optional output is not supported
3515    "test_deconv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
3516    # FIXME: failed with Segfault while exiting the Python runtime
3517    "test_duplicate_constant_folding": fail_with_and_without_stack_allocation(
3518        is_skip=True
3519    ),
3520    # TODO: use of deleted function RAIIAtenTensorHandle
3521    "test_dup_unbacked_sym_decl": fail_minimal_arrayref_interface(is_skip=True),
3522    # TODO: use of deleted function RAIIAtenTensorHandle
3523    "test_dup_unbacked_sym_decl_with_refinement": fail_minimal_arrayref_interface(
3524        is_skip=True
3525    ),
3526    # TODO:  error: cannot convert ArrayRefTensor<float> to AtenTensorHandle
3527    "test_dynamic_cat": fail_minimal_arrayref_interface(),
3528    # https://github.com/pytorch/pytorch/issues/129550
3529    # https://github.com/pytorch/pytorch/issues/123691
3530    "test_dynamic_scalar": fail_minimal_arrayref_interface(is_skip=True),
3531    # https://github.com/pytorch/pytorch/issues/122980
3532    "test_fft_c2c": fail_stack_allocation(is_skip=True),
3533    # TODO: test_freezing_abi_compatible_cpu fails,
3534    #   AssertionError: None, i.e. optional output is not supported
3535    "test_freezing": fail_with_and_without_stack_allocation(is_skip=True),
3536    # TODO: test_linear_freezing_abi_compatible_cpu fails,
3537    #   AssertionError: None, i.e. optional output is not supported
3538    "test_linear_freezing": fail_with_and_without_stack_allocation(is_skip=True),
3539    # FIXME: failed with Segfault while exiting the Python runtime
3540    "test_missing_cubin": fail_with_and_without_stack_allocation(is_skip=True),
3541    # minimal arrayref interface only works with CPU; test crashes.
3542    # https://github.com/pytorch/pytorch/issues/122983
3543    "test_multi_device": fail_minimal_arrayref_interface(is_skip=True),
3544    # TODO: AssertionError: unsupported Optional type in convert_arg_type: Generator
3545    "test_normal_functional": fail_with_and_without_stack_allocation(is_skip=True),
3546    # TODO: The same issue as https://github.com/pytorch/pytorch/issues/122978
3547    # error: cannot convert ArrayRefTensor<float> to AtenTensorHandle
3548    "test_reuse_kernel_dynamic": fail_minimal_arrayref_interface(is_skip=True),
3549    # the test segfaults
3550    "test_repeat_output": fail_stack_allocation(is_skip=True),
3551    # TODO: failed internally
3552    "test_multiple_output_alias": fail_with_and_without_stack_allocation(is_skip=True),
3553    # segfault
3554    "test_buffer_mutation_1": fail_stack_allocation(is_skip=True),
3555    # segfault
3556    "test_buffer_mutation_2": fail_stack_allocation(is_skip=True),
3557    # segfault
3558    "test_bool_input": fail_stack_allocation(is_skip=True),
3559    # segfault
3560    "test_int_list_input": fail_stack_allocation(is_skip=True),
3561    # segfault
3562    # 'AOTInductorTestABICompatibleCpuWithStackAllocation' object has no attribute 'code_check_count'
3563    "test_buffer_mutation_3": fail_stack_allocation(is_skip=True),
3564    # FIXME: failed with Segfault while exiting the Python runtime
3565    "test_scatter_fallback": fail_stack_allocation(is_skip=True),
3566    # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978
3567    "test_scatter_reduce_fallback": fail_minimal_arrayref_interface(is_skip=True),
3568    # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978
3569    "test_index_put_fallback": fail_minimal_arrayref_interface(is_skip=True),
3570    # https://github.com/pytorch/pytorch/issues/122984
3571    "test_index_put_with_none_index": fail_minimal_arrayref_interface(is_skip=True),
3572    # FIXME: failed with Segfault while exiting the Python runtime
3573    "test_constant": fail_stack_allocation(is_skip=True),
3574    # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978
3575    "test_shifted_constraint_ranges": fail_with_and_without_stack_allocation(
3576        is_skip=True
3577    ),
3578    # https://github.com/pytorch/pytorch/issues/123691
3579    "test_amp_fallback_random": fail_minimal_arrayref_interface(is_skip=True),
3580    "test_simple_dynamic": fail_minimal_arrayref_interface(),
3581    # https://github.com/pytorch/pytorch/issues/123691
3582    "test_zero_grid_with_unbacked_symbols": fail_minimal_arrayref_interface(
3583        is_skip=True
3584    ),
3585    # failed on MacOS
3586    "test_zero_grid_with_backed_symbols": fail_with_and_without_stack_allocation(
3587        is_skip=True
3588    ),
3589    # https://github.com/pytorch/pytorch/issues/122990
3590    "test_cond_non_tensor_predicates_dynamic_False": fail_stack_allocation(
3591        is_skip=True
3592    ),
3593    # same issue as https://github.com/pytorch/pytorch/issues/122990
3594    "test_cond_non_tensor_predicates_dynamic_True": fail_stack_allocation(is_skip=True),
3595    # https://github.com/pytorch/pytorch/issues/122991
3596    "test_runtime_checks_complex": fail_with_and_without_stack_allocation(is_skip=True),
3597    "test_runtime_checks_fp8": fail_with_and_without_stack_allocation(is_skip=True),
3598    "test_while_loop_simple": fail_stack_allocation(is_skip=True),
3599    "test_while_loop_nested": fail_stack_allocation(is_skip=True),
3600    "test_while_loop_with_outer_code": fail_stack_allocation(is_skip=True),
3601    # TODO: error: cannot convert ArrayRefTensor<float> to AtenTensorHandle
3602    "test_while_loop_with_outer_buffers": fail_stack_allocation(is_skip=True),
3603    # TODO: use of undeclared identifier 'float8_e4m3fn' and 'half'
3604    "test_fp8": fail_minimal_arrayref_interface(is_skip=True),
3605    "test_custom_op_add": fail_minimal_arrayref_interface(is_skip=True),
3606    "test_custom_op_all_inputs": fail_minimal_arrayref_interface(is_skip=True),
3607    "test_custom_op_with_multiple_outputs": fail_minimal_arrayref_interface(
3608        is_skip=True
3609    ),
3610    "test_custom_op_with_reinterpret_view_inputs": fail_minimal_arrayref_interface(
3611        is_skip=True
3612    ),
3613    "test_custom_op_with_concat_inputs": fail_minimal_arrayref_interface(is_skip=True),
3614    "test_custom_op_missing_arg_with_default_value": fail_minimal_arrayref_interface(
3615        is_skip=True
3616    ),
3617    "test_size_from_multi_output": fail_stack_allocation(is_skip=True),
3618}
3619
3620# test_failures, xfail by default, set is_skip=True to skip
3621CUDA_TEST_FAILURES = {
3622    # TODO: AssertionError: unsupported Optional type in convert_arg_type: Generator
3623    "test_normal_functional": fail_abi_compatible_cuda(is_skip=True),
3624    # no runtime checks for non_abi_compatible mode
3625    "test_runtime_checks": fail_non_abi_compatible_cuda(is_skip=True),
3626    "test_runtime_checks_complex": fail_non_abi_compatible_cuda(is_skip=True),
3627    "test_runtime_checks_fp8": fail_non_abi_compatible_cuda(is_skip=True),
3628    "test_runtime_checks_dtype_failed": fail_non_abi_compatible_cuda(is_skip=True),
3629    "test_runtime_checks_shape_failed": fail_non_abi_compatible_cuda(is_skip=True),
3630    # quantized unsupported for GPU
3631    "test_quantized_linear": fail_cuda(is_skip=True),
3632    "test_quanatized_int8_linear": fail_cuda(is_skip=True),
3633    "test_custom_op_add": fail_non_abi_compatible_cuda(is_skip=True),
3634    # fp8 to be re-enabled for AOTI
3635    "test_fp8": fail_cuda(is_skip=True),
3636    "test_custom_op_all_inputs": fail_non_abi_compatible_cuda(is_skip=True),
3637    "test_custom_op_missing_arg_with_default_value": fail_non_abi_compatible_cuda(
3638        is_skip=True
3639    ),
3640    "test_custom_op_with_concat_inputs": fail_non_abi_compatible_cuda(is_skip=True),
3641    "test_custom_op_with_reinterpret_view_inputs": fail_non_abi_compatible_cuda(
3642        is_skip=True
3643    ),
3644    "test_custom_op_with_multiple_outputs": fail_non_abi_compatible_cuda(is_skip=True),
3645    # non-abi compatible mode aoti debug printer is not supported yet
3646    "test_aoti_debug_printer_codegen": fail_non_abi_compatible_cuda(is_skip=True),
3647    "test_aoti_debug_printer_user_defined_triton_kernel": fail_non_abi_compatible_cuda(
3648        is_skip=True
3649    ),
3650}
3651
3652
3653if not IS_FBCODE:
3654    # The following tests look like they pass in both pytest and unittest (xml
3655    # and terminal output say pass), but the process will segfault.  This only
3656    # happens in OSS CI and is fine internally.
3657    CPU_TEST_FAILURES.update(
3658        {
3659            "test_duplicated_params": fail_stack_allocation(is_skip=True),
3660            "test_embedding_bag": fail_stack_allocation(is_skip=True),
3661            "test_fqn": fail_stack_allocation(is_skip=True),
3662            "test_no_args": fail_stack_allocation(is_skip=True),
3663            "test_output_misaligned": fail_stack_allocation(is_skip=True),
3664            "test_pytree_inputs": fail_stack_allocation(is_skip=True),
3665            "test_seq": fail_stack_allocation(is_skip=True),
3666            "test_simple_split": fail_stack_allocation(is_skip=True),
3667            "test_addmm": fail_minimal_arrayref_interface(is_skip=True),
3668            "test_aliased_buffer_reuse": fail_minimal_arrayref_interface(is_skip=True),
3669            "test_buffer_reuse": fail_minimal_arrayref_interface(is_skip=True),
3670            "test_constant_folding": fail_minimal_arrayref_interface(is_skip=True),
3671            "test_convolution": fail_minimal_arrayref_interface(is_skip=True),
3672            "test_empty_graph": fail_minimal_arrayref_interface(is_skip=True),
3673            "test_large_weight": fail_minimal_arrayref_interface(is_skip=True),
3674            "test_large_mmaped_weights": fail_minimal_arrayref_interface(is_skip=True),
3675            "test_normal_functional": fail_minimal_arrayref_interface(is_skip=True),
3676            "test_misc_1": fail_minimal_arrayref_interface(is_skip=True),
3677            "test_missing_output": fail_minimal_arrayref_interface(is_skip=True),
3678            "test_model_modified_weights": fail_minimal_arrayref_interface(
3679                is_skip=True
3680            ),
3681            "test_output_path_1": fail_minimal_arrayref_interface(is_skip=True),
3682            "test_quantized_linear": fail_minimal_arrayref_interface(is_skip=True),
3683            "test_quanatized_int8_linear": fail_minimal_arrayref_interface(
3684                is_skip=True
3685            ),
3686            "test_repeat_interleave": fail_minimal_arrayref_interface(is_skip=True),
3687            "test_return_constant": fail_minimal_arrayref_interface(is_skip=True),
3688            "test_reuse_kernel": fail_minimal_arrayref_interface(is_skip=True),
3689            "test_simple": fail_minimal_arrayref_interface(is_skip=True),
3690            "test_small_constant": fail_minimal_arrayref_interface(is_skip=True),
3691            "test_with_no_triton_profiler": fail_minimal_arrayref_interface(
3692                is_skip=True
3693            ),
3694            "test_with_offset": fail_minimal_arrayref_interface(is_skip=True),
3695            "test_with_profiler": fail_minimal_arrayref_interface(is_skip=True),
3696            "test_zero_size_weight": fail_minimal_arrayref_interface(is_skip=True),
3697            "test_aoti_debug_printer_codegen": fail_with_and_without_stack_allocation(
3698                is_skip=True
3699            ),
3700        }
3701    ),
3702    # The following test passes internally but fails in OSS CI. To be investigated.
3703    CUDA_TEST_FAILURES.update(
3704        {
3705            "test_aoti_debug_printer_codegen": fail_cuda(is_skip=True),
3706            "test_aoti_debug_printer_user_defined_triton_kernel": fail_cuda(
3707                is_skip=True
3708            ),
3709        }
3710    )
3711
3712copy_tests(
3713    AOTInductorTestsTemplate,
3714    AOTInductorTestABICompatibleCpu,
3715    "abi_compatible_cpu",
3716    CPU_TEST_FAILURES,
3717)
3718
3719
3720class AOTInductorTestABICompatibleCpuWithStackAllocation(AOTITestCase):
3721    device = "cpu"
3722    abi_compatible = True
3723    check_model = check_model
3724    check_model_with_multiple_inputs = check_model_with_multiple_inputs
3725    code_check_count = code_check_count
3726    allow_stack_allocation = True
3727    use_minimal_arrayref_interface = False
3728
3729
3730copy_tests(
3731    AOTInductorTestsTemplate,
3732    AOTInductorTestABICompatibleCpuWithStackAllocation,
3733    "abi_compatible_cpu_with_stack_allocation",
3734    CPU_TEST_FAILURES,
3735)
3736
3737
3738class AOTInductorTestABICompatibleCpuWithStackAllocationAndMinimalArrayRefInterface(
3739    TestCase
3740):
3741    device = "cpu"
3742    abi_compatible = True
3743    check_model = check_model
3744    check_model_with_multiple_inputs = check_model_with_multiple_inputs
3745    allow_stack_allocation = True
3746    use_minimal_arrayref_interface = True
3747
3748
3749copy_tests(
3750    AOTInductorTestsTemplate,
3751    AOTInductorTestABICompatibleCpuWithStackAllocationAndMinimalArrayRefInterface,
3752    "abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface",
3753    CPU_TEST_FAILURES,
3754)
3755
3756
3757@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
3758class AOTInductorTestABICompatibleCuda(AOTITestCase):
3759    device = "cuda"
3760    abi_compatible = True
3761    check_model = check_model
3762    check_model_with_multiple_inputs = check_model_with_multiple_inputs
3763    code_check_count = code_check_count
3764    allow_stack_allocation = False
3765    use_minimal_arrayref_interface = False
3766
3767
3768copy_tests(
3769    AOTInductorTestsTemplate,
3770    AOTInductorTestABICompatibleCuda,
3771    "abi_compatible_cuda",
3772    CUDA_TEST_FAILURES,
3773)
3774
3775
3776@unittest.skipIf(
3777    IS_FBCODE or sys.platform == "darwin",
3778    "NonABI mode should not be used in fbcode nor on MacOS",
3779)
3780class AOTInductorTestNonABICompatibleCpu(AOTITestCase):
3781    device = "cpu"
3782    abi_compatible = False
3783    check_model = check_model
3784    check_model_with_multiple_inputs = check_model_with_multiple_inputs
3785    code_check_count = code_check_count
3786    allow_stack_allocation = False
3787    use_minimal_arrayref_interface = False
3788
3789
3790copy_tests(
3791    AOTInductorTestsTemplate,
3792    AOTInductorTestNonABICompatibleCpu,
3793    "non_abi_compatible_cpu",
3794    # test_failures, xfail by default, set is_skip=True to skip
3795    {
3796        "test_duplicate_constant_folding": TestFailure(
3797            ("non_abi_compatible_cpu",), is_skip=True
3798        ),
3799        # no runtime checks for non_abi_compatible mode
3800        "test_runtime_checks": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
3801        "test_runtime_checks_dtype_failed": TestFailure(
3802            ("non_abi_compatible_cpu",), is_skip=True
3803        ),
3804        "test_runtime_checks_shape_failed": TestFailure(
3805            ("non_abi_compatible_cpu",), is_skip=True
3806        ),
3807        "test_custom_op_add": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
3808        "test_aoti_debug_printer_codegen": TestFailure(
3809            ("non_abi_compatible_cpu",), is_skip=True
3810        ),
3811        "test_custom_op_all_inputs": TestFailure(
3812            ("non_abi_compatible_cpu",), is_skip=True
3813        ),
3814        "test_custom_op_missing_arg_with_default_value": TestFailure(
3815            ("non_abi_compatible_cpu",), is_skip=True
3816        ),
3817        "test_custom_op_with_concat_inputs": TestFailure(
3818            ("non_abi_compatible_cpu",), is_skip=True
3819        ),
3820        "test_custom_op_with_multiple_outputs": TestFailure(
3821            ("non_abi_compatible_cpu",), is_skip=True
3822        ),
3823        "test_custom_op_with_reinterpret_view_inputs": TestFailure(
3824            ("non_abi_compatible_cpu",), is_skip=True
3825        ),
3826    },
3827)
3828
3829
3830@unittest.skipIf(
3831    IS_FBCODE or sys.platform == "darwin",
3832    "NonABI mode should not be used in fbcode nor on MacOS",
3833)
3834class AOTInductorTestNonABICompatibleCuda(AOTITestCase):
3835    device = "cuda"
3836    abi_compatible = False
3837    check_model = check_model
3838    check_model_with_multiple_inputs = check_model_with_multiple_inputs
3839    code_check_count = code_check_count
3840    allow_stack_allocation = False
3841    use_minimal_arrayref_interface = False
3842
3843
3844copy_tests(
3845    AOTInductorTestsTemplate,
3846    AOTInductorTestNonABICompatibleCuda,
3847    "non_abi_compatible_cuda",
3848    CUDA_TEST_FAILURES,
3849)
3850
3851
3852if __name__ == "__main__":
3853    from torch._inductor.test_case import run_tests
3854
3855    # cpp_extension N/A in fbcode
3856    if HAS_CUDA or sys.platform == "darwin":
3857        run_tests(needs="filelock")
3858