xref: /aosp_15_r20/external/pytorch/test/inductor/test_cutlass_backend.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import logging
3import math
4import os
5import unittest
6from typing import Callable, List, Optional
7from unittest import mock
8
9import torch
10from torch._dynamo.utils import counters
11from torch._inductor import config
12from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
13from torch._inductor.codegen.cuda.cutlass_utils import get_max_alignment
14from torch._inductor.ir import ChoiceCaller, FixedLayout
15from torch._inductor.select_algorithm import NoValidChoicesError
16from torch._inductor.test_case import run_tests, TestCase
17from torch._inductor.utils import fresh_inductor_cache
18from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
19from torch.testing._internal.common_cuda import SM75OrLater, SM80OrLater, SM90OrLater
20from torch.testing._internal.common_utils import (
21    instantiate_parametrized_tests,
22    parametrize,
23)
24from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
25
26
27torch.set_float32_matmul_precision("high")
28if HAS_CUDA:
29    torch.cuda.memory._set_allocator_settings("expandable_segments:False")
30
31_CUTLASS_DIR = os.path.join(os.path.dirname(__file__), "../../third_party/cutlass/")
32
33log = logging.getLogger(__name__)
34
35HAS_CUDA = HAS_CUDA and not torch.version.hip
36SM75OrLater = SM75OrLater and not torch.version.hip
37SM80OrLater = SM80OrLater and not torch.version.hip
38SM90OrLater = SM90OrLater and not torch.version.hip
39SM80 = SM80OrLater and torch.cuda.get_device_capability() == (8, 0)
40
41
42def _get_path_without_sccache() -> str:
43    """
44    Get the PATH environment variable without sccache.
45    """
46    path_envs = os.environ.get("PATH", "").split(":")
47    path_envs = [env for env in path_envs if "/opt/cache/bin" not in env]
48    return ":".join(path_envs)
49
50
51@instantiate_parametrized_tests
52class TestCutlassBackend(TestCase):
53    def setUp(self):
54        # The new inductor cache refresh mechanism
55        # introduced with https://github.com/pytorch/pytorch/pull/122661
56        # interacts badly with persistent subprocesses during
57        # autotuning. So we need to disable automatic cache refresh
58        # before calling setUp() on the parent class.
59        old_disable_fresh_cache_envvar = os.environ.get(
60            "INDUCTOR_TEST_DISABLE_FRESH_CACHE", ""
61        )
62        try:
63            os.environ["INDUCTOR_TEST_DISABLE_FRESH_CACHE"] = "1"
64            super().setUp()
65        finally:
66            os.environ[
67                "INDUCTOR_TEST_DISABLE_FRESH_CACHE"
68            ] = old_disable_fresh_cache_envvar
69        torch.random.manual_seed(1234)
70
71    @unittest.skipIf(not SM75OrLater, "need sm_75")
72    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
73    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
74    def test_max_autotune_cutlass_threshold(self):
75        """
76        Make sure Cutlass GEMM threshold works as intended.
77        """
78
79        if torch.version.hip:
80            return
81
82        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
83
84        def mm(a, b):
85            return a @ b
86
87        a = torch.randn(100, 10).cuda().half()
88        b = torch.randn(10, 100).cuda().half()
89
90        with config.patch(
91            {
92                "max_autotune": True,
93                "autotune_in_subproc": True,
94                "max_autotune_gemm_backends": "CUTLASS,ATen",
95                "compile_threads": 4,
96                "cuda.cutlass_backend_min_gemm_size": 100000,
97                "cuda.cutlass_dir": _CUTLASS_DIR,
98                "cuda.cutlass_max_profiling_configs": 2,
99            }
100        ):
101            from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
102
103            with mock.patch(
104                "torch._inductor.select_algorithm.autotune_select_algorithm"
105            ) as mocked_select_algorithm:
106                Y_compiled = torch.compile(mm, dynamic=False)(a, b)
107                Y = mm(a, b)
108                passed_choice_callers: List[ChoiceCaller] = mocked_select_algorithm[0][
109                    1
110                ]
111                assert all(
112                    isinstance(cc, ChoiceCaller) for cc in passed_choice_callers
113                ), "Argument 1 to autotune_select_algorithm should be a list of ChoiceCaller instances"
114                # We expect that no Cutlass Kernels are considered, due to the threshold
115                assert all(
116                    not isinstance(cc, CUDATemplateCaller)
117                    for cc in passed_choice_callers
118                ), "Cutlass Kernels should have been filtered, GEMM size is too small"
119            torch.testing.assert_close(Y_compiled, Y)
120
121    @unittest.skipIf(not SM75OrLater, "need sm_75")
122    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
123    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
124    def test_max_autotune_precompile(self):
125        """
126        Make sure autotuning mm in sub processes work without crashes.
127        """
128
129        if torch.version.hip:
130            return
131
132        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
133
134        def mm(a, b):
135            return a @ b
136
137        a = torch.randn(100, 10).cuda().half()
138        b = torch.randn(10, 100).cuda().half()
139
140        with config.patch(
141            {
142                "max_autotune": True,
143                "autotune_in_subproc": True,
144                "max_autotune_gemm_backends": "CUTLASS,Triton,ATen",
145                "compile_threads": 4,
146                "cuda.cutlass_dir": _CUTLASS_DIR,
147                "cuda.cutlass_max_profiling_configs": 2,
148            }
149        ):
150            Y_compiled = torch.compile(mm, dynamic=False)(a, b)
151            Y = mm(a, b)
152            torch.testing.assert_close(Y_compiled, Y)
153
154    # TODO: Enable dynamic test cases when dynamic support is added.
155    @unittest.skipIf(not SM75OrLater, "need sm_75")
156    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
157    @parametrize("dynamic", (False, True))
158    @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
159    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
160    def test_max_autotune_cutlass_backend_regular_mm(
161        self, dynamic: bool, max_autotune_gemm_backends: str
162    ):
163        """
164        Make sure autotuning mm in sub processes work without crashes.
165        """
166
167        if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
168            return
169
170        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
171
172        def mm(a, b):
173            return a @ b
174
175        a = torch.randn(128, 16).cuda().half()
176        b = torch.randn(16, 128).cuda().half()
177
178        with config.patch(
179            {
180                "max_autotune": True,
181                "autotune_in_subproc": False,
182                "max_autotune_gemm_backends": max_autotune_gemm_backends,
183                "cuda.cutlass_dir": _CUTLASS_DIR,
184                "cuda.cutlass_max_profiling_configs": 2,
185            }
186        ):
187            Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
188            Y = mm(a, b)
189            torch.testing.assert_close(Y_compiled, Y)
190
191    @unittest.skipIf(not SM90OrLater, "need sm_90")
192    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
193    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
194    def test_max_autotune_cutlass_backend_regular_mm_streamk(
195        self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS"
196    ):
197        """
198        Make sure autotuning mm in sub processes work without crashes.
199        """
200
201        if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
202            return
203
204        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
205
206        def mm(a, b):
207            return a @ b
208
209        a = torch.randn(128, 16).cuda().half()
210        b = torch.randn(16, 128).cuda().half()
211
212        with config.patch(
213            {
214                "max_autotune": True,
215                "autotune_in_subproc": True,
216                "max_autotune_gemm_backends": max_autotune_gemm_backends,
217                "cuda.cutlass_dir": _CUTLASS_DIR,
218                "cuda.cutlass_max_profiling_configs": 2,
219                "cuda.cutlass_op_allowlist_regex": "stream_k",  # only stream-k GEMM Kernels
220            }
221        ):
222            for M, K, N in (
223                (128, 16, 128),
224                (1024, 256, 1024),
225                (
226                    16384,
227                    1024,
228                    16384,
229                ),
230                (
231                    16384,
232                    1408,
233                    16384,
234                ),
235            ):
236                a = torch.randn(M, K).cuda().half()
237                b = torch.randn(K, N).cuda().half()
238                Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
239                Y = mm(a, b)
240                # we need relaxed numerical limits due to the sheer size of the
241                # matmuls involved. Many small addition differences add up.
242                torch.testing.assert_close(Y_compiled, Y, atol=0.01, rtol=0.01)
243
244    def _test_max_autotune_cutlass_backend_epilogue_fusion(
245        self,
246        dynamic: bool = False,
247        max_autotune_gemm_backends: str = "CUTLASS",
248        mixed_precision=False,
249        fp16=True,
250        expected_fuse_count=0,
251        mm: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
252        batch_size: Optional[int] = None,
253    ):
254        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
255            mixed_precision
256        )
257
258        # Note: The ops that are available
259        # also depend on the alignment of the shapes
260        # so if these shapes don't all align to at least 8 elements
261        # it can happen that no Cutlass 3.x op is available
262        # that allows fusions
263        if batch_size is None:
264            a = torch.randn(256, 32).cuda()
265            b = torch.randn(32, 256).cuda()
266        else:
267            a = torch.randn(batch_size, 256, 32).cuda()
268            b = torch.randn(batch_size, 32, 256).cuda()
269        if fp16:
270            a = a.half()
271            b = b.half()
272
273        with config.patch(
274            {
275                "max_autotune": True,
276                "autotune_in_subproc": True,
277                "max_autotune_gemm_backends": max_autotune_gemm_backends,
278                "cuda.cutlass_dir": _CUTLASS_DIR,
279                "cuda.cutlass_max_profiling_configs": 4,
280                "cuda.version": "12.2",  # required to enable the Kernels we need
281            }
282        ):
283            counters["inductor"]["cuda_epilogue_fusion_counter"] = 0
284            Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
285            Y = mm(a, b)
286            actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"]
287            assert (
288                actual_count == expected_fuse_count
289            ), f"Expected fuse count of {expected_fuse_count} but got {actual_count}"
290            torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)
291
292    @unittest.skipIf(not SM90OrLater, "need sm_90")
293    @unittest.skipIf(torch.version.hip, "HIP not supported")
294    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
295    def test_max_autotune_cutlass_backend_simple_fusion_fp16(self):
296        def mm(a, b):
297            return (a @ b) * 3.0
298
299        #  The pointwise ops seem to be pre-fused into a single Pointwise
300        self._test_max_autotune_cutlass_backend_epilogue_fusion(
301            mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
302        )
303
304    @unittest.skipIf(not SM90OrLater, "need sm_90")
305    @unittest.skipIf(torch.version.hip, "HIP not supported")
306    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
307    def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self):
308        def mm(a, b):
309            return (a @ b) * 3.0
310
311        self._test_max_autotune_cutlass_backend_epilogue_fusion(
312            mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
313        )
314
315    @unittest.skipIf(not SM90OrLater, "need sm_90")
316    @unittest.skipIf(torch.version.hip, "HIP not supported")
317    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
318    def test_max_autotune_cutlass_backend_chained_fusion_fp16(self):
319        def mm(a, b):
320            return (a @ b) * 3.3 - 1.234
321
322        #  The pointwise ops seem to be pre-fused into a single Pointwise
323        self._test_max_autotune_cutlass_backend_epilogue_fusion(
324            mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
325        )
326
327    @unittest.skipIf(not SM90OrLater, "need sm_90")
328    @unittest.skipIf(torch.version.hip, "HIP not supported")
329    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
330    def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self):
331        def mm(a, b):
332            return (a @ b) * 3.3 - 1.234
333
334        self._test_max_autotune_cutlass_backend_epilogue_fusion(
335            mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
336        )
337
338    @unittest.skipIf(not SM90OrLater, "need sm_90")
339    @unittest.skipIf(torch.version.hip, "HIP not supported")
340    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
341    def test_max_autotune_cutlass_backend_relu_fusion_fp16(self):
342        def mm(a, b):
343            return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)
344
345        self._test_max_autotune_cutlass_backend_epilogue_fusion(
346            mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
347        )
348
349    @unittest.skipIf(not SM90OrLater, "need sm_90")
350    @unittest.skipIf(torch.version.hip, "HIP not supported")
351    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
352    def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self):
353        def mm(a, b):
354            return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)
355
356        #  The pointwise ops seem to be pre-fused into a single Pointwise
357        self._test_max_autotune_cutlass_backend_epilogue_fusion(
358            mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
359        )
360
361    @unittest.skipIf(not SM90OrLater, "need sm_90")
362    @unittest.skipIf(torch.version.hip, "HIP not supported")
363    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
364    def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self):
365        def mm(a, b):
366            return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0)
367
368        #  The pointwise ops seem to be pre-fused into a single Pointwise
369        self._test_max_autotune_cutlass_backend_epilogue_fusion(
370            mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
371        )
372
373    @unittest.skipIf(not SM90OrLater, "need sm_90")
374    @unittest.skipIf(torch.version.hip, "HIP not supported")
375    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
376    def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self):
377        def mm(a, b):
378            # this should not be fused, since the output dtype is different from the matmul dtype
379            return (a @ b).to(torch.float32) * 0.00001
380
381        self._test_max_autotune_cutlass_backend_epilogue_fusion(
382            mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
383        )
384
385    def test_max_autotune_cutlass_backend_simple_bmm(self):
386        def bmm(a, b):
387            return torch.bmm(a, b)
388
389        self._test_max_autotune_cutlass_backend_epilogue_fusion(  # test bmm
390            mixed_precision=False,
391            fp16=True,
392            expected_fuse_count=0,
393            mm=bmm,
394            batch_size=10,
395        )
396
397    @unittest.skipIf(not SM90OrLater, "need sm_90")
398    @unittest.skipIf(torch.version.hip, "HIP not supported")
399    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
400    def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self):
401        def mm(a, b):
402            return (a @ b) / b.size(1)
403
404        self._test_max_autotune_cutlass_backend_epilogue_fusion(
405            mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
406        )
407
408    # TODO: Enable dynamic test cases when dynamic support is added.
409    @unittest.skipIf(not SM75OrLater, "need sm_75")
410    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
411    @parametrize("dynamic", (False,))
412    @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
413    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
414    def test_max_autotune_cutlass_backend_mm_bias(
415        self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS"
416    ):
417        """
418        Make sure autotuning mm in sub processes work without crashes.
419        """
420
421        if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
422            return
423
424        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
425
426        def mm(a, b, bias):
427            return torch.nn.functional.linear(a, b, bias)
428
429        a = torch.randn(2048, 4096).cuda().half()
430        bias = torch.randn(2048).cuda().half()
431
432        with config.patch(
433            {
434                "max_autotune": True,
435                "autotune_in_subproc": True,
436                "max_autotune_gemm_backends": max_autotune_gemm_backends,
437                "cuda.cutlass_dir": _CUTLASS_DIR,
438                "cuda.cutlass_max_profiling_configs": 2,
439            }
440        ):
441            Y = mm(a, a, bias)
442            Y_compiled = torch.compile(mm, dynamic=dynamic)(a, a, bias)
443            torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1)
444
445    @unittest.skipIf(not SM75OrLater, "need sm_75")
446    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
447    @parametrize("dynamic", (False,))
448    @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
449    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
450    def test_max_autotune_cutlass_backend_addmm(
451        self, dynamic, max_autotune_gemm_backends
452    ):
453        """
454        Make sure autotuning addmm in sub processes work without crashes.
455        """
456
457        if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
458            return
459
460        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
461
462        def addmm(x, a, b, alpha, beta):
463            return torch.addmm(x, a, b, alpha=alpha, beta=beta)
464
465        def compare_results(
466            m: int, k: int, n: int, alpha: float, beta: float, x_shape: List[int]
467        ) -> None:
468            x = torch.randn(x_shape).cuda().half()
469            a = torch.randn(m, k).cuda().half()
470            b = torch.randn(k, n).cuda().half()
471            y_expected = addmm(x, a, b, alpha, beta)
472
473            compiled_fn = torch.compile(addmm, dynamic=dynamic)
474            y = compiled_fn(x, a, b, alpha, beta)
475            torch.testing.assert_close(y, y_expected)
476
477        with config.patch(
478            {
479                "max_autotune": True,
480                # Some Cutlass Kernels fail with IMA on this example, which leads to unrecoverable CUDA errors
481                # unless we tune in a subproc here.
482                "autotune_in_subproc": True,
483                "max_autotune_gemm_backends": max_autotune_gemm_backends,
484                "cuda.cutlass_dir": _CUTLASS_DIR,
485                "cuda.cutlass_max_profiling_configs": 4,
486                "cuda.cutlass_op_allowlist_regex": "",
487                "cuda.cutlass_op_denylist_regex": "pingpong",  # Pingpong Kernels can lead to numerical issues
488            }
489        ):
490            # No broadcast
491            compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 2048])
492            # Broadcast first dim.
493            compare_results(4096, 25728, 2048, 2.0, 0.4, [2048])
494            # Broadcast last dim.
495            compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1])
496
497    # TODO: Enable dynamic test cases when dynamic support is added.
498    @unittest.skipIf(not SM80OrLater, "need sm_80")
499    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
500    @parametrize("dynamic", (False,))
501    @parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,ATen"))
502    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
503    def test_max_autotune_cutlass_backend_int_mm(
504        self, dynamic: bool, max_autotune_gemm_backends: str
505    ):
506        """
507        Make sure autotuning mm in sub processes work without crashes.
508        """
509
510        if "CUTLASS" in max_autotune_gemm_backends.upper() and torch.version.hip:
511            return
512
513        def mm(a, b):
514            return torch._int_mm(a, b)
515
516        # CUTLASS only supports row-major/column-major combination of
517        # layouts for this operation, thus the transpose of tensor b
518        # (on the other side, Triton at the moment doesn't support
519        # this combination, so it's excluded from the test).  Also,
520        # for CUTLASS alignment requirements, number of columns in
521        # both tensors has to be divisible by 16.
522        a = torch.randint(0, 5, (100, 16), dtype=torch.int8).cuda()
523        b = torch.randint(0, 5, (32, 16), dtype=torch.int8).cuda().T
524
525        with config.patch(
526            {
527                "max_autotune": True,
528                "autotune_in_subproc": True,
529                "max_autotune_gemm_backends": max_autotune_gemm_backends,
530                "cuda.cutlass_dir": _CUTLASS_DIR,
531                "cuda.cutlass_max_profiling_configs": 2,
532            }
533        ):
534            Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
535            Y = mm(a, b)
536            torch.testing.assert_close(Y_compiled, Y)
537
538    # TODO: Enable dynamic test cases when dynamic support is added.
539    @unittest.skipIf(not SM80, "need sm_80 exactly")
540    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
541    @parametrize("dynamic", (False,))
542    @parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
543    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
544    def test_max_autotune_cutlass_backend_mixed_mm(
545        self, dynamic: bool, max_autotune_gemm_backends: str
546    ):
547        """
548        Make sure autotuning mm in sub processes work without crashes.
549        """
550
551        if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
552            return
553
554        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
555
556        def mm(a, b):
557            return torch.mm(a, b.to(torch.half))
558
559        # CUTLASS only supports row-major/column-major combination of
560        # layouts for this operation, thus the transpose of tensor b.
561        # Also, for CUTLASS alignment requirements, number of columns
562        # of the first tensor has to be divisible by 16.
563        m, n, k = 100, 16, 100
564        a = torch.randn(m, k).cuda().half()
565        b = torch.randint(0, 5, (n, k), dtype=torch.int8).cuda().T
566
567        with config.patch(
568            {
569                "max_autotune": True,
570                "autotune_in_subproc": True,
571                "max_autotune_gemm_backends": max_autotune_gemm_backends,
572                "cuda.cutlass_dir": _CUTLASS_DIR,
573                "cuda.cutlass_max_profiling_configs": 2,
574                "use_mixed_mm": True,
575                "autotune_local_cache": True,
576            }
577        ):
578            Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
579            Y = mm(a, b)
580            torch.testing.assert_close(Y_compiled, Y)
581
582        cache = torch._inductor.codecache.LocalCache().lookup("mixed_mm")
583        high = cache[
584            f"[('cuda', 'torch.float16', {m}, {k}, {k}, 1, 0), "
585            f"('cuda', 'torch.int8', {k}, {n}, 1, {k}, 0)]"
586        ]["high"]
587        cutlass_kernels_count = 0
588        for kernel, time in high.items():
589            if kernel.startswith("cutlass_gemm") and not math.isinf(time):
590                cutlass_kernels_count += 1
591        assert cutlass_kernels_count > 0
592
593    # TODO: Enable dynamic test cases when dynamic support is added.
594    @unittest.skipIf(not SM80, "need sm_80 exactly")
595    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
596    @parametrize("dynamic", (False,))
597    @parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
598    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
599    def test_max_autotune_cutlass_backend_sparse_semi_structured_mm(
600        self, dynamic: bool, max_autotune_gemm_backends: str
601    ):
602        """
603        Make sure autotuning mm in sub processes work without crashes.
604        """
605
606        if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
607            return
608
609        SparseSemiStructuredTensor._FORCE_CUTLASS = True
610
611        def mm(a, b):
612            return torch.mm(a, b)
613
614        m, n, k = 32, 8, 64
615        mask = torch.tensor([0, 0, 1, 1]).tile(m, k // 4).cuda().half()
616        a = torch.rand(m, k).cuda().half() * mask
617        a_sparse = to_sparse_semi_structured(a)
618        b = torch.rand(k, n).cuda().half()
619
620        with config.patch(
621            {
622                "max_autotune": True,
623                "autotune_in_subproc": True,
624                "max_autotune_gemm_backends": max_autotune_gemm_backends,
625                "cuda.cutlass_dir": _CUTLASS_DIR,
626                "cuda.cutlass_max_profiling_configs": 2,
627                "autotune_local_cache": True,
628            }
629        ):
630            Y_compiled = torch.compile(mm, dynamic=dynamic)(a_sparse, b)
631            Y = mm(a, b)
632            torch.testing.assert_close(Y_compiled, Y)
633
634        cache = torch._inductor.codecache.LocalCache().lookup(
635            "sparse_semi_structured_mm"
636        )
637        high = cache[
638            f"[('cuda', 'torch.float16', {m}, {k // 2}, {k // 2}, 1, 0), "
639            f"('cuda', 'torch.int16', {m}, {k // 16}, {k // 16}, 1, 0), "
640            f"('cuda', 'torch.float16', {k}, {n}, {n}, 1, 0)]"
641        ]["high"]
642        cutlass_kernels_count = 0
643        for kernel, time in high.items():
644            if kernel.startswith("cutlass_gemm") and not math.isinf(time):
645                cutlass_kernels_count += 1
646        assert cutlass_kernels_count > 0
647
648    @unittest.skipIf(not SM90OrLater, "need sm_90")
649    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
650    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
651    def test_cutlass_backend_op_denylist(
652        self,
653    ):
654        def my_addmm(x, a, b, alpha, beta):
655            return torch.addmm(x, a, b, alpha=beta, beta=alpha)
656
657        x = torch.randn((128, 128)).cuda().half()
658        a = torch.randn(128, 128).cuda().half()
659        b = torch.randn(128, 128).cuda().half()
660
661        def select_no_algorithm(*args, **kwargs):
662            raise NoValidChoicesError
663
664        with fresh_inductor_cache():
665            with config.patch(
666                {
667                    "max_autotune": True,
668                    # Some Cutlass Kernels fail with IMA on this example, which leads to unrecoverable CUDA errors
669                    # unless we tune in a subproc here.
670                    "autotune_in_subproc": False,
671                    "max_autotune_gemm_backends": "CUTLASS,ATen",
672                    "cuda.cutlass_dir": _CUTLASS_DIR,
673                    "cuda.cutlass_max_profiling_configs": 2,
674                    "cuda.cutlass_op_allowlist_regex": "",
675                    "cuda.cutlass_op_denylist_regex": "pingpong",  # Pingpong Kernels can lead to numerical issues
676                }
677            ):
678                with mock.patch(
679                    "torch._inductor.kernel.mm.autotune_select_algorithm",
680                    wraps=select_no_algorithm,
681                ) as sa:
682                    torch.compile(my_addmm, dynamic=False)(x, a, b, 1.0, 2.0)
683                    args, kwargs = sa.call_args
684                    op_name, choices, _, __ = args
685                    assert op_name == "addmm"
686                    cuda_template_count = 0
687                    for choice in choices:
688                        if isinstance(choice, CUDATemplateCaller):
689                            choice_info = choice.info_dict()
690                            assert (
691                                "pingpong" not in choice_info["op_conf_name"]
692                            ), "All pingpong Kernels should have been filtered"
693                            cuda_template_count += 1
694                    assert cuda_template_count > 0, "No CUDATemplateCaller choices"
695
696    @unittest.skipIf(not SM90OrLater, "need sm_90")
697    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
698    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
699    def test_cutlass_backend_op_allowlist(
700        self,
701    ):
702        def addmm(x, a, b, alpha, beta):
703            return torch.addmm(x, a, b, alpha=alpha, beta=beta)
704
705        x = torch.randn((128, 128)).cuda().half()
706        a = torch.randn(128, 128).cuda().half()
707        b = torch.randn(128, 128).cuda().half()
708
709        def select_no_algorithm(*args, **kwargs):
710            raise NoValidChoicesError
711
712        with fresh_inductor_cache():
713            with config.patch(
714                {
715                    "max_autotune": True,
716                    # Some Cutlass Kernels fail with IMA on this example, which leads to unrecoverable CUDA errors
717                    # unless we tune in a subproc here.
718                    "autotune_in_subproc": False,
719                    "max_autotune_gemm_backends": "CUTLASS,ATen",
720                    "cuda.cutlass_dir": _CUTLASS_DIR,
721                    "cuda.cutlass_max_profiling_configs": 2,
722                    "cuda.cutlass_op_allowlist_regex": "pingpong",
723                    "cuda.cutlass_op_denylist_regex": None,  # Pingpong Kernels can lead to numerical issues
724                }
725            ):
726                with mock.patch(
727                    "torch._inductor.kernel.mm.autotune_select_algorithm",
728                    wraps=select_no_algorithm,
729                ) as sa:
730                    torch.compile(addmm, dynamic=False)(x, a, b, 1.0, 1.0)
731                    args, kwargs = sa.call_args
732                    op_name, choices, _, __ = args
733                    assert op_name == "addmm"
734                    cuda_template_count = 0
735                    for choice in choices:
736                        if isinstance(choice, CUDATemplateCaller):
737                            choice_info = choice.info_dict()
738                            assert (
739                                "pingpong" in choice_info["op_conf_name"]
740                            ), "Only pingpong Kernels should have been allowed"
741                            cuda_template_count += 1
742                    assert cuda_template_count > 0, "No CUDATemplateCaller choices"
743
744    @unittest.skipIf(not SM80OrLater, "need sm_80")
745    @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
746    @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
747    def test_get_max_alignment(self):
748        l4 = FixedLayout("cpu", torch.half, size=(1, 2, 4), stride=(0, 4, 1))
749        m4 = get_max_alignment(l4)
750        self.assertEqual(
751            m4, 4, "Wrong max alignment. Should have been 4. (simple, contiguous case)"
752        )
753
754        l4_2 = FixedLayout("cpu", torch.half, size=(1, 4, 2), stride=(0, 1, 4))
755        m4_2 = get_max_alignment(l4_2)
756        self.assertEqual(
757            m4_2,
758            4,
759            "Wrong max alignment. Should have been 4. Did not deal with strides correctly",
760        )
761
762        l1 = FixedLayout("cpu", torch.half, size=(2, 4, 2), stride=(23, 1, 4))
763        m1 = get_max_alignment(l1)
764        self.assertEqual(
765            m1,
766            1,
767            "Wrong max alignment. Should have been 1. Did not take stride into account correctly",
768        )
769
770        l2 = FixedLayout("cpu", torch.half, size=(1, 2, 4), stride=(0, 4, 1), offset=6)
771        m2 = get_max_alignment(l2)
772        self.assertEqual(
773            m2, 2, "Wrong max alignment. Should have been 2. (due to choice of offset)"
774        )
775
776        l8 = FixedLayout(
777            "cpu", torch.half, size=(2, 2, 8), stride=(32, 8, 1), offset=24
778        )
779        m8 = get_max_alignment(l8)
780        self.assertEqual(m8, 8, "Wrong max alignment. Should have been 8.")
781
782        l4 = FixedLayout(
783            "cpu", torch.float32, size=(2, 2, 8), stride=(32, 8, 1), offset=24
784        )
785        m4 = get_max_alignment(l4)
786        self.assertEqual(
787            m4, 4, "Wrong max alignment. Should have been 4 (due to float32 dtype )."
788        )
789
790
791if __name__ == "__main__":
792    from torch._inductor.utils import is_big_gpu
793
794    # Set env to make it work in CI.
795    if HAS_CUDA and HAS_CPU and is_big_gpu(0):
796        run_tests()
797