xref: /aosp_15_r20/external/pytorch/test/inductor/test_max_autotune.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import os
3import unittest
4from typing import Callable, List, Optional
5
6import torch
7from torch import multiprocessing as mp, nn
8from torch._dynamo import reset
9from torch._dynamo.exc import BackendCompilerFailed
10from torch._dynamo.testing import rand_strided, reset_rng_state
11from torch._inductor import config
12from torch._inductor.autotune_process import (
13    BenchmarkRequest,
14    CUDA_VISIBLE_DEVICES,
15    TuningProcessPool,
16)
17from torch._inductor.graph import GraphLowering
18from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout
19from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
20from torch._inductor.select_algorithm import (
21    AlgorithmSelectorCache,
22    TritonTemplateCaller,
23)
24from torch._inductor.test_case import run_tests, TestCase
25from torch._inductor.utils import fresh_inductor_cache, run_and_get_code
26from torch._inductor.virtualized import V
27from torch.fx.experimental.proxy_tensor import make_fx
28from torch.testing import FileCheck
29from torch.testing._internal.common_utils import (
30    instantiate_parametrized_tests,
31    parametrize,
32    skipIfRocm,
33)
34from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
35
36
37try:
38    from .mock_cache import global_stats, PatchCaches
39except ImportError:
40    from mock_cache import global_stats, PatchCaches  # @manual
41
42
43torch.set_float32_matmul_precision("high")
44if HAS_CUDA:
45    torch.cuda.memory._set_allocator_settings("expandable_segments:False")
46
47_CUTLASS_DIR = os.path.join(os.path.dirname(__file__), "../../third_party/cutlass/")
48
49
50def _get_path_without_sccache() -> str:
51    """
52    Get the PATH environment variable without sccache.
53    """
54    path_envs = os.environ.get("PATH", "").split(":")
55    path_envs = [env for env in path_envs if "/opt/cache/bin" not in env]
56    return ":".join(path_envs)
57
58
59def benchmark_choice(choice, args, out, expected_out, timings):
60    result = choice.benchmark(*args, out=out)
61    if expected_out is not None:
62        torch.testing.assert_close(out, expected_out)
63
64    timings.copy_(torch.tensor(result))
65
66
67class FailChoiceCaller(ChoiceCaller):
68    def benchmark(self, *args, out):
69        raise RuntimeError("This choice caller will always throw")
70
71
72@instantiate_parametrized_tests
73class TestMaxAutotune(TestCase):
74    def _create_buffer(self, name, shape):
75        return Buffer(name, FixedLayout(torch.device("cuda:0"), torch.float32, shape))
76
77    def test_benchmark_choice_in_subproc(self):
78        gm = make_fx(
79            lambda: torch.zeros(2, 3)
80        )()  # a dummy graph to construct the GraphLowering
81        graph = GraphLowering(gm)
82
83        # the graph handler is neede to create benchmark example value below
84        with V.set_graph_handler(graph):
85            buf1 = self._create_buffer("mat1", (2, 3))
86            buf2 = self._create_buffer("mat2", (3, 2))
87            buf3 = self._create_buffer("mat3", (2, 3))
88            buf4 = self._create_buffer("mat4", (3, 2))
89
90            layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2))
91
92            mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
93            mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
94            mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
95            mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
96
97            out = AlgorithmSelectorCache.benchmark_example_value(layout)
98            # expected_out = (mat1 @ mat2) + (mat3 @ mat4)
99            expected_out = None
100
101            choice = aten_mm_plus_mm.bind((buf1, buf2, buf3, buf4), layout)
102            # use a tensor since the mutation to a python list in a sub process
103            # is not synced back to the parent process
104            timings = torch.zeros(3, dtype=torch.float32)
105            ctx = mp.get_context("spawn")
106            child = ctx.Process(
107                target=benchmark_choice,
108                args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
109            )
110            child.start()
111            child.join()
112            self.assertEqual(0, child.exitcode)
113            print(f"timings is {timings}, out {out}, expected_out {expected_out}")
114
115    def test_benchmark_choice_fail_in_subproc(self):
116        gm = make_fx(
117            lambda: torch.zeros(2, 3)
118        )()  # a dummy graph to construct the GraphLowering
119        graph = GraphLowering(gm)
120
121        # the graph handler is neede to create benchmark example value below
122        with V.set_graph_handler(graph):
123            buf1 = self._create_buffer("mat1", (2, 3))
124            buf2 = self._create_buffer("mat2", (3, 2))
125            buf3 = self._create_buffer("mat3", (2, 3))
126            buf4 = self._create_buffer("mat4", (3, 2))
127
128            layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2))
129
130            mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
131            mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
132            mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
133            mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
134
135            out = AlgorithmSelectorCache.benchmark_example_value(layout)
136            expected_out = (mat1 @ mat2) + (mat3 @ mat4)
137
138            choice = FailChoiceCaller("fail_choice_caller", [], None)
139
140            # use a tensor since python list is not synced back
141            timings = torch.zeros(3, dtype=torch.float32)
142            ctx = mp.get_context("spawn")
143            child = ctx.Process(
144                target=benchmark_choice,
145                args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
146            )
147            child.start()
148            child.join()
149            self.assertNotEqual(0, child.exitcode)
150
151    @parametrize("autotune_in_subproc", (True, False))
152    @parametrize("autotune_multi_device", (True, False))
153    def test_max_autotune_mm_plus_mm(self, autotune_in_subproc, autotune_multi_device):
154        """
155        This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 .
156        With autotuning in subprocess, we don't crash anymore.
157        """
158        m, n, k = 2048, 1536, 64
159
160        def mm_plus_mm(a, b, c, d):
161            return a @ b + c @ d
162
163        a = torch.randn(m, k).cuda()
164        b = torch.randn(k, n).cuda()
165        c = torch.randn(m, k).cuda()
166        d = torch.randn(k, n).cuda()
167
168        with config.patch(
169            {
170                "max_autotune": True,
171                "autotune_in_subproc": autotune_in_subproc,
172                "autotune_multi_device": autotune_multi_device,
173            }
174        ):
175            torch.compile(mm_plus_mm)(a, b, c, d)
176
177    @parametrize("dynamic", (False, True))
178    def test_max_autotune_mm_plus_mm_zero_size_input(self, dynamic):
179        """
180        Make sure autotuning mm_plus_mm with zero-size input works without crashes.
181        """
182        m, n, k = 0, 1536, 64
183
184        def mm_plus_mm(a, b, c, d):
185            return a @ b + c @ d
186
187        a = torch.randn(m, k).cuda()
188        b = torch.randn(k, n).cuda()
189        c = torch.randn(m, k).cuda()
190        d = torch.randn(k, n).cuda()
191
192        with config.patch({"max_autotune": True}):
193            torch.compile(mm_plus_mm, dynamic=dynamic)(a, b, c, d)
194
195    @parametrize("dynamic", (False, True))
196    def test_max_autotune_regular_mm(self, dynamic: bool):
197        """
198        Make sure autotuning mm in sub processes work without crashes.
199        """
200
201        def mm(a, b):
202            a = torch.sin(a)
203            return a @ b
204
205        a = torch.randn(100, 10).cuda()
206        b = torch.randn(10, 100).cuda()
207
208        with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
209            torch.compile(mm, dynamic=dynamic)(a, b)
210
211    @parametrize("dynamic", (False, True))
212    def test_max_autotune_regular_mm_zero_size_input(self, dynamic: bool):
213        """
214        Make sure autotuning mm with zero-size input works without crashes.
215        """
216
217        def mm(a, b):
218            a = torch.sin(a)
219            return a @ b
220
221        a = torch.randn(0, 10).cuda()
222        b = torch.randn(10, 100).cuda()
223
224        with config.patch({"max_autotune": True}):
225            torch.compile(mm, dynamic=dynamic)(a, b)
226
227    @skipIfRocm
228    def test_precompilation_threads(self):
229        import threading
230        from typing import Any, Dict
231        from unittest.mock import Mock, patch
232
233        class FakeChoiceCaller(ChoiceCaller):
234            def __init__(self) -> None:
235                super().__init__("none", [], Mock())
236                self.thread_id = None
237
238            def precompile(self):
239                self.thread_id = threading.get_ident()
240
241            def call_name(self) -> str:
242                return None
243
244            def to_callable(self):
245                return None
246
247            def hash_key(self) -> str:
248                return str(hash(self))
249
250            def output_node(self) -> "TensorBox":  # noqa: F821
251                return None
252
253        fake_choices = [FakeChoiceCaller() for i in range(10)]
254        fake_lookup_result = dict.fromkeys(fake_choices, 0.123)
255
256        def no_lookup(
257            choices: List[ChoiceCaller],
258            op: str,
259            inputs: str,
260            benchmark: Callable[[Any], Dict[ChoiceCaller, float]],
261        ) -> Optional[Dict[ChoiceCaller, float]]:
262            if benchmark is not None:
263                return benchmark(choices)
264
265        asc = AlgorithmSelectorCache()
266
267        def fake_benchmark_fn(*args, **kwargs):
268            return fake_lookup_result
269
270        main_thread_id = threading.get_ident()
271        mock_debug_handler = Mock()
272        old_debug_handler = V.debug
273        try:
274            V.set_debug_handler(mock_debug_handler)
275            with patch.object(asc, "lookup", new=no_lookup):
276                with patch.object(
277                    asc, "make_benchmark_fn", return_value=fake_benchmark_fn
278                ):
279                    with config.patch(
280                        {
281                            "autotune_in_subproc": False,
282                            "compile_threads": len(fake_choices),
283                        }
284                    ):
285                        asc("test_call", fake_choices, [], Mock())
286            for fake_choice in fake_choices:
287                assert (
288                    fake_choice.thread_id is not None
289                ), "Expected all ChoiceCaller's precompile method to have been called"
290                assert (
291                    fake_choice.thread_id != main_thread_id
292                ), "Expected all ChoiceCaller's precompile method to have been called on separate thread"
293        finally:
294            V.set_debug_handler(old_debug_handler)
295
296    @parametrize("dynamic", (False, True))
297    def test_max_autotune_addmm(self, dynamic=False):
298        """
299        Make sure autotuning addmm in sub processes work without crashes.
300        """
301
302        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
303
304        def addmm(x, a, b):
305            return torch.addmm(x, a, b)
306
307        x = torch.randn(100).cuda()
308        a = torch.randn(100, 10).cuda()
309        b = torch.randn(10, 100).cuda()
310        with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
311            Y_compiled = torch.compile(addmm, dynamic=dynamic)(x, a, b)
312            Y = addmm(x, a, b)
313            torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)
314
315    @parametrize("dynamic", (False, True))
316    def test_max_autotune_addmm_zero_size_input(self, dynamic):
317        """
318        Make sure autotuning addmm with zero-size input works without crashes.
319        """
320
321        def addmm(x, a, b):
322            return torch.addmm(x, a, b)
323
324        x = torch.randn(100).cuda()
325        a = torch.randn(0, 10).cuda()
326        b = torch.randn(10, 100).cuda()
327        with config.patch({"max_autotune": True}):
328            torch.compile(addmm, dynamic=dynamic)(x, a, b)
329
330    @skipIfRocm
331    def test_autotune_conv1x1(self):
332        # Assuming input has 3 channels and we want to produce 16 channels as output
333        conv1x1 = (
334            torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1)
335            .to(memory_format=torch.channels_last)
336            .cuda()
337        )
338
339        # Example input tensor: batch size = 4, channels = 3, height = 32, width = 32
340        # The memory format is set to `channels_last`
341        input_tensor = (
342            torch.randn(4, 3, 32, 32)
343            .contiguous(memory_format=torch.channels_last)
344            .cuda()
345        )
346
347        with config.patch(
348            {"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
349        ):
350
351            @torch.compile()
352            def foo(mod, x):
353                return mod(x)
354
355            with torch.no_grad():
356                out, code = run_and_get_code(foo, conv1x1, input_tensor)
357
358            FileCheck().check_not("extern_kernels.convolution").run(code[0])
359            self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0)
360
361    @skipIfRocm
362    def test_filled_cache_precompile(self):
363        def fn(a, b, c):
364            a = (a @ b) @ c
365            a, b, c = (t.to(torch.float16) for t in [a, b, c])
366            return (a @ b) @ c
367
368        fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
369        inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)]
370        from torch._dynamo.utils import counters
371
372        self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2)
373
374        torch._dynamo.reset()
375        counters.clear()
376
377        fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
378        self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
379
380    @skipIfRocm
381    @fresh_inductor_cache()
382    @config.patch(max_autotune=True, max_fusion_size=2)
383    def test_jit_fusion_matches_aot_fusion(self):
384        # In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due
385        # to proximity, we want to make sure AOT-compile pass does the same.
386        # AOT could do fuse(buf2, buf4) instead if buf3 was pushed to the end
387        # of the V.graph.buffers list because fuse(buf2, buf4) would have a
388        # better proximity score than fuse(buf1, buf2). This scenario is possible
389        # since finalizing MultiTemplateBuffers needs to replace buffers.
390        def fn(x, number):
391            buf0 = x + x
392            buf1 = number.item()
393            buf2 = x * x
394            buf3 = x @ x  # MultiTemplateBuffer
395            buf4 = x**2
396            return buf0, buf1, buf2, buf3, buf4
397
398        inputs = (torch.rand([256, 256], device="cuda"), torch.tensor(3, device="cuda"))
399        torch._export.aot_compile(fn, args=inputs)
400
401    @config.patch(autotune_local_cache=False, autotune_remote_cache=False)
402    @skipIfRocm
403    def test_precompilations(self):
404        def fn(a, b, c):
405            a = (a @ b) @ c
406            a, b, c = (t.to(torch.float16) for t in [a, b, c])
407            return (a @ b) @ c
408
409        fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
410        inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)]
411
412        torch.testing.assert_close(fn_c(*inputs), fn(*inputs), atol=1e-2, rtol=1e-2)
413
414        from torch._dynamo.utils import counters
415
416        self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 2)
417
418    def test_cat_addmm(self):
419        def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
420            return torch.cat(
421                [
422                    torch.addmm(a, b, c),
423                    torch.addmm(b, c, a),
424                ],
425                1,
426            )
427
428        args = [
429            torch.randn(4, 4, device="cuda"),
430            torch.randn(4, 4, device="cuda"),
431            torch.randn(4, 4, device="cuda"),
432        ]
433        with config.patch(
434            {
435                "max_autotune": True,
436                "max_autotune_gemm_backends": "Triton",
437            }
438        ):
439            expected = fn(*args)
440            actual = torch.compile(fn)(*args)
441            torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
442
443    def test_triton_template_with_epilogues_and_dynamic_shape(self):
444        def fn(
445            x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor, mul: torch.Tensor
446        ) -> torch.Tensor:
447            return (
448                torch.nn.functional.relu(
449                    torch.matmul(torch.transpose(x, 0, 1), torch.transpose(w, 0, 1))
450                    + bias
451                )
452                * mul
453            )
454
455        M0 = 5
456        M1 = 8
457        K = 4
458        N = 3
459        w = torch.rand(N, K).cuda().half()
460        b = torch.rand(N).cuda().half()
461
462        with config.patch(
463            {
464                "max_autotune": True,
465                "autotune_in_subproc": True,
466                "max_autotune_gemm_backends": "Triton",
467            }
468        ):
469            compiled_fn = torch.compile(
470                fn, fullgraph=True, dynamic=True, mode="max-autotune-no-cudagraphs"
471            )
472
473            x0 = torch.rand(K, M0).cuda().half()
474            mul0 = torch.rand(M0, N).cuda().half()
475            y0 = compiled_fn(x0, w, b, mul0)
476            y0_expected = fn(x0, w, b, mul0)
477            torch.testing.assert_close(y0, y0_expected)
478
479            x1 = torch.rand(K, M1).cuda().half()
480            mul1 = torch.rand(M1, N).cuda().half()
481            y1 = compiled_fn(x1, w, b, mul1)
482            y1_expected = fn(x1, w, b, mul1)
483            torch.testing.assert_close(y1, y1_expected)
484
485    @config.patch(
486        benchmark_kernel=True,
487        fallback_random=True,
488        max_autotune_gemm=True,
489    )
490    @parametrize("device", ("cpu", "cuda"))
491    def test_matmul_dropout(self, device):
492        def fwd(a, b):
493            x = a @ b
494            x = torch.nn.functional.dropout(x, 0.1)
495            return x
496
497        def fn(a, b):
498            x = fwd(a, b).sum()
499            x.backward()
500            return a.grad
501
502        N = 128
503        a = torch.randn(N, N, device=device, requires_grad=True)
504        b = torch.randn(N, N, device=device)
505
506        opt_fn = torch.compile(fn)
507        reset_rng_state()
508        ref = fn(a, b)
509        reset_rng_state()
510        act = opt_fn(a, b)
511
512        if N <= 8:
513            print(f"ref\n{ref}\nact\n{act}")
514        torch.testing.assert_close(ref, act, atol=1e-1, rtol=1e-1)
515
516    @config.patch(
517        max_autotune_gemm=True,
518    )
519    @unittest.skipIf(
520        torch.cuda.device_count() < 2, "Need at least 2 devices for this test"
521    )
522    def test_autotune_device_guard(self):
523        x = torch.randn(1024, 1024, device="cuda:1")
524        y = torch.randn(1024, 1024, device="cuda:1")
525
526        def f(x, y):
527            return x @ y
528
529        with fresh_inductor_cache():
530            act = torch.compile(f)(x, y)
531        ref = f(x, y)
532        self.assertTrue(torch.allclose(act, ref, atol=4 * 1e-3, rtol=4 * 1e-3))
533
534    @config.patch(max_autotune=True)
535    def test_empty_conv_input(self, kernel_size=3):
536        x = torch.randn(0, 256, 14, 14, device="cuda")
537        weight = torch.randn(256, 256, kernel_size, kernel_size, device="cuda")
538
539        def f(x, weight):
540            return torch.convolution(
541                x,
542                weight,
543                bias=None,
544                stride=[1, 1],
545                padding=[0, 0],
546                dilation=[1, 1],
547                transposed=False,
548                output_padding=[0, 0],
549                groups=1,
550            )
551
552        opt_f = torch.compile(f)
553        ref = f(x, weight)
554        act = opt_f(x, weight)
555        self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3))
556
557    @config.patch(max_autotune=True)
558    def test_empty_conv_input_with_1x1_kernel(self):
559        self.test_empty_conv_input(kernel_size=1)
560
561    @config.patch(max_autotune=True)
562    def test_conv1x1_with_free_symbols(self):
563        """
564        Make sure there is no exception due to free symbols.
565        """
566        conv = nn.Conv2d(
567            3, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False
568        ).to(device="cuda")
569
570        @torch.compile
571        def f(x, y, z):
572            h = y.nonzero().size(0)
573            w = z.nonzero().size(0)
574            x = x[:, :, :h, :w]
575            x = conv(x)
576            return x
577
578        x = torch.randn(4, 3, 224, 224).to(
579            memory_format=torch.channels_last, device="cuda"
580        )
581        for _ in range(2):
582            y = torch.randint(0, 10, (224,)).to(device="cuda")
583            z = torch.randint(0, 10, (224,)).to(device="cuda")
584            f(x, y, z)
585
586    def test_conv3d(self):
587        fn = torch.nn.functional.conv3d
588        image = torch.randn([1, 3, 8, 16, 32])
589        filt = torch.randn([3, 3, 7, 7, 7])
590
591        with config.patch({"max_autotune": True}):
592            expected = fn(image, filt)
593            actual = torch.compile(fn)(image, filt)
594            torch.testing.assert_close(actual, expected, atol=6e-5, rtol=0.001)
595
596    @config.patch(
597        max_autotune=True, max_autotune_conv_backends="", layout_optimization=False
598    )
599    def test_conv_backend(self):
600        m = torch.nn.Sequential(
601            torch.nn.Conv2d(3, 3, 1, 1),
602        ).cuda()
603        inp = torch.randn([2, 3, 16, 16]).cuda()
604
605        with self.assertRaises(BackendCompilerFailed) as context:
606            torch.compile(m)(inp)
607
608        self.assertIn("NoValidChoicesError", str(context.exception))
609
610    def test_non_contiguous_input_mm(self):
611        """
612        Make sure the triton template can work with non-contiguous inputs without crash.
613        Check https://github.com/pytorch/pytorch/issues/125437 for more details.
614        """
615        x = rand_strided(
616            (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda"
617        )
618        y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda")
619
620        @torch.compile(mode="max-autotune")
621        def f(x, y):
622            return x @ y
623
624        ref = x @ y
625        act = f(x, y)
626        torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)
627
628    def test_non_contiguous_input_addmm(self):
629        b = torch.randn((768), dtype=torch.bfloat16, device="cuda")
630        x = rand_strided(
631            (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda"
632        )
633        y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda")
634
635        @torch.compile(mode="max-autotune")
636        def f(x, y):
637            return torch.addmm(b, x, y)
638
639        ref = torch.addmm(b, x, y)
640        act = f(x, y)
641        torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)
642
643    def test_non_contiguous_input_bmm(self):
644        x = rand_strided(
645            (1, 50257, 32768), (0, 1, 50304), dtype=torch.bfloat16, device="cuda"
646        )
647        y = rand_strided(
648            (1, 32768, 768), (0, 768, 1), dtype=torch.bfloat16, device="cuda"
649        )
650
651        @torch.compile(mode="max-autotune")
652        def f(x, y):
653            return torch.bmm(x, y)
654
655        ref = torch.bmm(x, y)
656        act = f(x, y)
657        torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)
658
659    def test_non_contiguous_input_mm_plus_mm(self):
660        x1 = rand_strided((50257, 32768), (1, 50304), device="cuda")
661        y1 = rand_strided((32768, 768), (768, 1), device="cuda")
662
663        x2 = rand_strided((50257, 32768), (1, 50304), device="cuda")
664        y2 = rand_strided((32768, 768), (768, 1), device="cuda")
665
666        @torch.compile(mode="max-autotune")
667        def f(x1, y1, x2, y2):
668            return x1 @ y1 + x2 @ y2
669
670        ref = x1 @ y1 + x2 @ y2
671        act = f(x1, y1, x2, y2)
672        torch.testing.assert_close(act, ref, atol=1e-2, rtol=1e-2)
673
674    @config.patch(
675        max_autotune=True,
676        max_autotune_gemm_backends="",
677        autotune_fallback_to_aten=False,
678    )
679    def test_no_valid_choices(self):
680        a = torch.zeros([2, 2], device="cuda")
681        b = torch.zeros([2, 2], device="cuda")
682        with self.assertRaises(BackendCompilerFailed) as context:
683            torch.compile(lambda a, b: a.matmul(b))(a, b)
684        self.assertIn("NoValidChoicesError", str(context.exception))
685
686    @parametrize("multi_template", (True, False))
687    @config.patch(
688        max_autotune=True,
689        max_autotune_gemm_backends="TRITON",
690        autotune_fallback_to_aten=False,
691    )
692    def test_inf_timing(self, multi_template):
693        from unittest.mock import patch
694
695        lookup = AlgorithmSelectorCache.lookup
696
697        def mock_lookup(self, *args, **kwargs):
698            timings = lookup(self, *args, **kwargs)
699            return {choice: float("inf") for choice in timings.keys()}
700
701        a = torch.zeros([16, 16], device="cuda")
702        b = torch.zeros([16, 16], device="cuda")
703        with patch.object(AlgorithmSelectorCache, "lookup", mock_lookup), config.patch(
704            benchmark_epilogue_fusion=multi_template
705        ):
706            with self.assertRaises(BackendCompilerFailed) as context:
707                torch.compile(lambda a, b: a.matmul(b))(a, b)
708            self.assertIn("NoValidChoicesError", str(context.exception))
709
710
711@instantiate_parametrized_tests
712class TestMaxAutotuneRemoteCache(TestCase):
713    def setUp(self):
714        super().setUp()
715        PatchCaches.setUp()
716
717    def tearDown(self):
718        super().tearDown()
719        PatchCaches.tearDown()
720
721    @skipIfRocm
722    @parametrize("dynamic", (False, True))
723    def test_max_autotune_remote_caching(self, dynamic: bool):
724        from unittest.mock import patch
725
726        def mm(a, b):
727            a = torch.sin(a)
728            return a @ b
729
730        a = torch.randn(100, 10).cuda()
731        b = torch.randn(10, 100).cuda()
732
733        class Model(torch.nn.Module):
734            def forward(self, x, y):
735                return x + y
736
737        def f(x, y):
738            return Model()(x, y)
739
740        x = torch.randn(100, 100).cuda()
741        y = torch.randn(100, 100).cuda()
742
743        with config.patch(
744            {
745                "autotune_local_cache": False,
746                "autotune_remote_cache": True,
747            }
748        ), patch.dict(os.environ), PatchCaches():
749            os.environ.pop("TRITON_CACHE_MANAGER", None)
750            with config.patch({"max_autotune": True}):
751                for _ in range(4):
752                    with fresh_inductor_cache():
753                        torch.compile(mm, dynamic=dynamic)(a, b)
754                    reset()
755
756                global_stats.report()
757                self.assertEqual(global_stats.autotune.num_get_hit, 3)
758                self.assertEqual(global_stats.autotune.num_get_miss, 1)
759                self.assertEqual(global_stats.autotune.num_put, 1)
760
761            global_stats.reset()
762            for _ in range(4):
763                with fresh_inductor_cache():
764                    torch.compile(f, dynamic=dynamic)(x, y)
765                reset()
766            global_stats.report()
767            self.assertEqual(global_stats.autotune.num_get_hit, 3)
768            self.assertEqual(global_stats.autotune.num_get_miss, 1)
769            self.assertEqual(global_stats.autotune.num_put, 1)
770
771
772class TestBenchmarkRequest(BenchmarkRequest):
773    def __init__(
774        self, value: float, multi_device: bool, parent_visible_devices: Optional[str]
775    ) -> None:
776        self.value = value
777        self.multi_device = multi_device
778        self.parent_visible_devices = parent_visible_devices
779
780    def benchmark(
781        self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
782    ) -> float:
783        # Verify that the visible devices env var is set correctly. If multi-device
784        # auto-tuning is disabled, the visible devices should be unmanipulated from
785        # the parent process. If multi-device auto-tuning is enabled, the visible
786        # devices should be a _single_ valid device number. Note that we can't perform
787        # this validation directly from the test body because benchmarks execute in a
788        # separate process. If the check fails, however, the test will detect the
789        # failure by virtue of not receiving the expected result back.
790        visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES)
791        if not self.multi_device:
792            assert visible_devices == self.parent_visible_devices
793        else:
794            assert self.parent_visible_devices is not None
795            valid_devices = self.parent_visible_devices.split(",")
796            assert visible_devices in valid_devices
797
798        return self.value
799
800
801class TestTritonTemplateCaller(TritonTemplateCaller):
802    def __init__(self, bmreq: TestBenchmarkRequest):
803        self.bmreq = bmreq
804
805    def __str__(self) -> str:
806        return "test"
807
808
809class TestTuningProcess(TestCase):
810    def test_tuning_pool_crash(self):
811        # Use only one device/subprocess so we test the process restarts
812        # and is usable after a "crash".
813        with config.patch({"autotune_multi_device": False}):
814            tuning_pool = TuningProcessPool()
815            tuning_pool.initialize()
816
817            # First force the tuning process to "crash" by setting a bogus
818            # string for the expected visible devices.
819            bmreq = TestBenchmarkRequest(3.14, False, "invalid")
820            choice = TestTritonTemplateCaller(bmreq)
821
822            timings = tuning_pool.benchmark([choice])
823            self.assertTrue(choice in timings)
824            self.assertEqual(timings[choice], float("inf"))
825
826            # Then send another request and make sure the sub-process
827            # has restarted and is operational. 'valid_devices' expected
828            # to be None because autotune_multi_device is off.
829            choice.bmreq.parent_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES)
830
831            timings = tuning_pool.benchmark([choice])
832            self.assertTrue(choice in timings)
833            self.assertEqual(timings[choice], bmreq.value)
834
835            tuning_pool.terminate()
836
837    def test_tuning_pool_multiple_devices(self):
838        with config.patch({"autotune_multi_device": True}):
839            # Adapt the test to the available devices (and whether CUDA_VISIBLE_DEVICES
840            # is already set in the environment); use a subset of the available devices
841            # to ensure only the subset are visible to the sub-processes.
842            if CUDA_VISIBLE_DEVICES in os.environ:
843                visible_devices = os.environ[CUDA_VISIBLE_DEVICES].split(",")
844            else:
845                visible_devices = [str(d) for d in range(torch.cuda.device_count())]
846
847            parent_visible_devices = ",".join(visible_devices[-2:])
848            os.environ[CUDA_VISIBLE_DEVICES] = parent_visible_devices
849
850            tuning_pool = TuningProcessPool()
851            tuning_pool.initialize()
852
853            choice1 = TestTritonTemplateCaller(
854                TestBenchmarkRequest(3.14, True, parent_visible_devices),
855            )
856            choice2 = TestTritonTemplateCaller(
857                TestBenchmarkRequest(2.718, True, parent_visible_devices),
858            )
859
860            timings = tuning_pool.benchmark([choice1, choice2])
861            self.assertEqual(timings[choice1], choice1.bmreq.value)
862            self.assertEqual(timings[choice2], choice2.bmreq.value)
863
864            tuning_pool.terminate()
865
866
867if __name__ == "__main__":
868    from torch._inductor.utils import is_big_gpu
869
870    # Set env to make it work in CI.
871    if HAS_CUDA and HAS_CPU and is_big_gpu(0):
872        run_tests()
873