xref: /aosp_15_r20/external/pytorch/test/inductor/test_multi_kernel.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3import os
4import re
5import unittest
6
7import torch
8from torch import nn
9from torch._dynamo.testing import reset_rng_state
10from torch._inductor import config, test_operators
11from torch._inductor.codegen.multi_kernel import MultiKernelCall
12from torch._inductor.test_case import TestCase
13from torch._inductor.utils import run_and_get_code
14from torch.nn import functional as F
15from torch.testing import make_tensor
16from torch.testing._internal.common_utils import (
17    instantiate_parametrized_tests,
18    parametrize,
19    skipIfXpu,
20)
21from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
22
23
24class TransformerSnippet(nn.Module):
25    def __init__(self) -> None:
26        super().__init__()
27        self.ln1 = nn.LayerNorm(64)
28        self.ln2 = nn.LayerNorm(64)
29
30    def forward(self, x1, x2):
31        x1 = F.dropout(x1, 0.1)
32        x2 = F.dropout(self.ln1(x2), 0.1)
33
34        return self.ln2(x1 + x2)
35
36    def example_inputs(self):
37        return (torch.randn(2, 64).to(GPU_TYPE), torch.randn(2, 64).to(GPU_TYPE))
38
39
40def _contains_multi_kernel_code(wrapper_code: str):
41    return (
42        re.search(r"multi_kernel_[^ ]* = async_compile.multi_kernel[(]", wrapper_code)
43        is not None
44    )
45
46
47def make_cpp_wrapper_test(orig_test, **extra_args):
48    """
49    Wrap an existing test into a new test with cpp-wrapper enabled.
50
51    Make this as a free function rather than staticmethod in MultiKernelTest.
52    Otherwise we get 'TypeError: 'staticmethod' object is not callable'
53    error in py3.8. (py3.10 works)
54    """
55
56    @config.patch("cpp_wrapper", True)
57    @skipIfXpu(msg="cpp wrapper doesn't currently work on the XPU stack")
58    def fn(self):
59        # The same kernel may have been compiled by previous tests with
60        # cpp_wrapper disabled. Clear the cache so we go ahead to re-compile
61        # the kernel with cpp_wrapper enabled.
62        from torch._inductor import codecache
63
64        codecache.PyCodeCache.cache_clear()
65        return orig_test(self, **extra_args)
66
67    return fn
68
69
70@config.patch(
71    {
72        "triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")),
73        "benchmark_kernel": True,
74    }
75)
76@instantiate_parametrized_tests
77class MultiKernelTest(TestCase):
78    def test_softmax(self, expect_multi_kernel=True):
79        x = torch.rand(2, 1024).to(GPU_TYPE)
80        ref = torch.softmax(x, -1)
81        compiled_fn = torch.compile(torch.softmax)
82        act, wrapper_code = run_and_get_code(compiled_fn, x, -1)
83
84        # wrapper_code will contains 2 entries if cpp_wrapper=True.
85        # One for the first pass and one for the second pass.
86        # We mainly care about the wrapper for the final pass here.
87        wrapper_code = wrapper_code[-1]
88        self.assertEqual(ref, act)
89        if expect_multi_kernel:
90            self.assertTrue(_contains_multi_kernel_code(wrapper_code))
91        else:
92            # Skip verifying the wrapper_code in fbcode since we may fail
93            # compiling the cpp wrapper cuda code due to lacking proper setup of
94            # cuda compiler in fbcode environment. In that case, the last
95            # collected wrapper_code will corresponds to the first pass
96            # cpp-wrapper codegen which contains the multi-kernel.
97            if not config.is_fbcode():
98                self.assertFalse(_contains_multi_kernel_code(wrapper_code))
99
100    @parametrize("force_kernel", (0, 1))
101    @unittest.mock.patch.dict(
102        os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"}
103    )
104    def test_softmax_force_non_persistent_reduction(self, force_kernel):
105        """
106        Force a specific sub-kernel being picked by mocking the benchmark result.
107        """
108        x = torch.rand(2, 1024).to(GPU_TYPE)
109        mock_latency = [0.2, 0.2]
110        mock_latency[force_kernel] = 0.1  # this make sure force_kernel will be picked
111
112        def f(x):
113            return torch.softmax(x, -1) + force_kernel
114
115        orig_run = MultiKernelCall.run
116        picked_kernel = None
117
118        def mock_run(self, *args, **kwargs):
119            out = orig_run(self, *args, **kwargs)
120            nonlocal picked_kernel
121            picked_kernel = self.picked_kernel
122            return out
123
124        with unittest.mock.patch.object(
125            MultiKernelCall, "run", mock_run
126        ), unittest.mock.patch.object(
127            MultiKernelCall,
128            "benchmark_sub_kernels",
129            lambda *args, **kwargs: mock_latency,
130        ):
131            torch.compile(f)(x)
132        self.assertEqual(picked_kernel, force_kernel)
133
134    @config.patch("warn_mix_layout", True)
135    def test_softmax_warn_mixed_layout(self):
136        self.test_softmax()
137
138    test_softmax_cpp_wrapper = make_cpp_wrapper_test(
139        test_softmax, expect_multi_kernel=False
140    )
141
142    def test_layernorm(self):
143        ln = nn.LayerNorm(1024).to(GPU_TYPE)
144        x = torch.rand(2, 1024).to(GPU_TYPE)
145        ref = ln(x)
146        act = torch.compile(ln)(x)
147        self.assertEqual(ref, act, atol=1e-4, rtol=1e-4)
148
149    def test_inplace_update(self):
150        """
151        Inductor generate inplace kernel for mul.
152        """
153
154        def f(x, y):
155            return x.sum(dim=-1, keepdims=True) * (y @ y)
156
157        x = torch.rand(1024, 1024).to(GPU_TYPE)
158        y = torch.rand(1024, 1024).to(GPU_TYPE)
159        ref = f(x, y)
160        act = torch.compile(f)(x, y)
161        self.assertEqual(ref, act)
162
163    def test_transformer_snippet(self):
164        model = TransformerSnippet().to(GPU_TYPE)
165        x = model.example_inputs()
166
167        def f(*x):
168            y = model(*x)
169            return y
170
171        reset_rng_state()
172        ref = f(*x)
173
174        opt_f = torch.compile(f)
175        reset_rng_state()
176        act = opt_f(*x)
177
178        # don't compare tensor if using inductor random number generator.
179        # inductor random number implementation is different to eager.
180        # We should fallback to eager if we want to test accuracy.
181        if config.fallback_random:
182            self.assertEqual(ref, act, atol=1e-4, rtol=1e-4)
183
184    def test_transformer_snippet_with_fallback_random(self):
185        """
186        Same as test_transformer_snippet but fallback the random number
187        generator to eager so we can check accuracy.
188        """
189        with config.patch("fallback_random", True):
190            self.test_transformer_snippet()
191
192    def test_batchnorm_training(self):
193        """
194        For training, batchnorm will tracking running mean/variance during forward pass.
195        The kernel generated by inductor currently will pass in those tensors twice as arguments:
196        once for input and once for output. They are ruled out as in-out argument because
197        they are considered as graph inputs.
198
199        Multi-kernel previously assumes that we never pass the same argument mutli times
200        for a kernel. No mater if we change inductor behavior to assure that, it's better
201        to make multi-kernel being able to handle those cases.
202        """
203        bn = nn.BatchNorm2d(3).to(GPU_TYPE)
204
205        @torch.compile
206        def f(x):
207            bn(x).sum().backward()
208
209        _, (wrapper_code, _) = run_and_get_code(
210            f, torch.randn(2, 3, 8, 8, device=GPU_TYPE)
211        )
212        self.assertTrue(_contains_multi_kernel_code(wrapper_code))
213
214    def test_pass_same_arg_multi_times(self):
215        """
216        A super simple example that simulate how BatchNorm update the running
217        stats.
218
219        Inductor currently pass the same tensor multiple times for the generated
220        kernel: once for input and once for output.
221
222        Here is a paster for the generated kernel (without multi-kernel enabled):
223        https://gist.github.com/shunting314/f0b446b4b9a28f4940e31dcd3e809cf9
224        """
225
226        def f(x, y):
227            x = x.sum(dim=1, keepdim=False)
228            y.copy_(y * 0.9 + x * 0.1)
229
230        x = torch.randn(8, 16, device=GPU_TYPE)
231        y = torch.randn(8, device=GPU_TYPE)
232        y_ref = y.clone()
233
234        ref = f(x, y_ref)
235        act = torch.compile(f)(x, y)
236        self.assertEqual(y_ref, y)
237
238    def test_reduction_scratch_buffer(self, force_multi_kernel=1):
239        """
240        The explicited realized buffer in the test function will be passed in
241        as a scratch buffer for the non-persistent reduction kernel but
242        can be skipped for the persistent reduction kernel.
243
244        This causes different argument lists for non-persistent reduction kernel and
245        persistent reduction kernel.
246
247        Check documentation around torch._inductor.config.triton.multi_kernel about
248        how to interpret the force_multi_kernel argument.
249        """
250
251        def f(x):
252            x = x.sum(dim=-1, keepdim=True) + x
253            x = test_operators.realize(x)
254            x = x.sum(dim=-1, keepdim=True) + x
255            return x
256
257        x = torch.rand(16, 16, device=GPU_TYPE)
258        ref = f(x)
259        with config.patch("triton.multi_kernel", force_multi_kernel):
260            act = torch.compile(f)(x)
261        self.assertEqual(ref, act)
262
263    def test_split_scan(self, force_multi_kernel=1):
264        def f(x):
265            x = x.view(-1)
266            return torch.cumsum(x, 0)
267
268        x = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=GPU_TYPE)
269        expect = f(x)
270        with config.patch("triton.multi_kernel", force_multi_kernel):
271            actual = torch.compile(f)(x)
272        self.assertEqual(expect, actual)
273
274    def test_sort_disables_multi_kernel(self, force_multi_kernel=1):
275        """
276        Sort currently requires a persistent kernel, so multi-kernel is not
277        possible. Make sure this falls back gracefully.
278        """
279
280        def f(x):
281            return x.sort(-1).values
282
283        x = torch.rand(32, 32, device=GPU_TYPE)
284        expect = f(x)
285        with config.patch("triton.multi_kernel", force_multi_kernel):
286            actual = torch.compile(f)(x)
287        self.assertEqual(expect, actual)
288
289    # Use benchmarking to pick the faster kernel
290    test_reduction_scratch_buffer_cpp_wrapper = make_cpp_wrapper_test(
291        test_reduction_scratch_buffer, force_multi_kernel=1
292    )
293    # force pick persistent reduction. This can be a good test since this persistent
294    # reduction uses less call arguments than the corresponding non-persistent
295    # reduction.
296    test_reduction_scratch_buffer_cpp_wrapper_persistent_reduction = (
297        make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=2)
298    )
299    # force pick non-persistent reduction
300    test_reduction_scratch_buffer_cpp_wrapper_non_persistent_reduction = (
301        make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=3)
302    )
303
304
305if __name__ == "__main__":
306    from torch._inductor.test_case import run_tests
307
308    if HAS_GPU:
309        run_tests()
310