xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/multi_kernel.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import os
4import pathlib
5from typing import Any, List
6
7from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
8from torch.utils._ordered_set import OrderedSet
9
10from .. import config
11from ..codecache import get_path, TritonFuture
12from ..runtime.benchmarking import benchmarker
13from ..utils import cache_on_self, IndentedBuffer
14from ..virtualized import V
15from .common import TensorArg
16
17
18log = logging.getLogger(__name__)
19
20
21def get_kernel_argdefs(kernel):
22    arg_defs, _, _, _ = kernel.args.python_argdefs()
23    return arg_defs
24
25
26def _get_all_args(args_list, arg_types_list=None):
27    all_args = max(args_list, key=len)[:]
28    arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None
29    for args in args_list:
30        assert set(args).issubset(set(all_args)), f"{args} v.s. {all_args}"
31
32    return all_args, arg_types
33
34
35def get_all_kernel_argdefs(kernels):
36    """
37    The logic here must match with `get_all_call_args`, except no need to get arg_types here
38    """
39    argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels]
40
41    return _get_all_args(argdefs_list)[0]
42
43
44def get_all_call_args(call_args_list, arg_types_list):
45    """
46    Passed in the call_args for each subkernel and return the call_args for the
47    combined multi-kernel.
48
49    Note an algorithm as follows does not always work:
50    ```
51        all_call_args: Dict[
52            Any, None
53        ] = {}  # use a dict rather than set to maintain insertion order
54        for call_args in call_args_list:
55            all_call_args.update({arg: None for arg in call_args})
56
57        all_call_args = list(all_call_args.keys())
58    ```
59    It will fail if any kernel has the same argument passed in multiple times.
60    Check test_pass_same_arg_multi_times in test_multi_kernel.py
61
62    Instead, we pick the longest call args and assert that other call args are
63    a subset of it.
64    """
65    return _get_all_args(call_args_list, arg_types_list)
66
67
68def get_numel_argdefs(kernel):
69    numel_argdefs = []
70    for tree in kernel.range_trees:
71        if tree.prefix != "r" or kernel.inside_reduction:
72            numel_argdefs.append(f"{tree.prefix}numel")
73
74    return numel_argdefs
75
76
77class MultiKernelState:
78    """
79    Maintain state of multi-kernel compilation so we don't define duplicated
80    multi-kernel for the same set of sub-kernels.
81
82    V.graph.wrapper_code has a reference to MultiKernelState instance.
83    """
84
85    def __init__(self):
86        self.subkernel_to_kernel_name = {}
87
88    def define_kernel(self, kernels):
89        """
90        Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}".
91        This has some minor issue.
92
93        E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca ,
94        there are 2 flavors of non-persistent reduction:
95          https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4
96        and
97          https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd
98
99        The only different is cache eviction policy.
100
101        We should name the multi-kernel differently in these 2 cases.
102        """
103        kernel_names = tuple(k.kernel_name for k in kernels)
104        if kernel_names in self.subkernel_to_kernel_name:
105            return self.subkernel_to_kernel_name[kernel_names]
106
107        # name the multi kernel based on the first kernel
108        multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}"
109        self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name
110
111        if V.graph.cpp_wrapper:
112            # we should not generate any python code for multi-kernel during
113            # the second pass of cpp-wrapper.
114            return multi_kernel_name
115
116        buf = IndentedBuffer()
117        buf.writeline(
118            f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
119        )
120        with buf.indent():
121            for name in kernel_names:
122                buf.writeline(f"{name},")
123        buf.writeline("])")
124
125        wrapper = V.graph.wrapper_code
126        wrapper.header.splice(buf)
127        if config.triton.autotune_at_compile_time:
128            wrapper.kernel_autotune_defs.splice(buf)
129
130        return multi_kernel_name
131
132
133class MultiKernel:
134    """
135    This class maintains the compile time state for multi kernels.
136
137    Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2.
138    The generated definition for the multi-kernel will looks like:
139    ```
140    multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code)
141    ```
142
143    Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39
144    """
145
146    def __init__(self, kernels):
147        assert len(kernels) >= 2
148
149        self.kernels = kernels
150        self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel(
151            kernels
152        )
153
154        # need this since some code in inductor check if the kernel object has an args
155        # attribute to decide if it's a non-null kernel.
156        self.args = object()
157
158    def call_kernel(self, kernel_name):
159        """
160        Collect the union of arguments from all subkernels as the arguments
161        for the multi-kernel.
162        """
163        assert kernel_name == self.kernel_name
164        V.graph.wrapper_code.write_triton_header_once()
165        _, call_args, _, arg_types = self.kernels[0].args.python_argdefs()
166        for kernel in self.kernels[1:]:
167            _, other_call_args, _, other_arg_types = kernel.args.python_argdefs()
168            assert call_args == other_call_args
169            assert arg_types == other_arg_types
170
171        grid: List[Any] = []
172
173        if V.graph.cpp_wrapper:
174            # for the second pass of cpp-wrapper codegen, we should call
175            # the fast kernel directly
176            picked_kernel = MultiKernelCall.lookup_choice(kernel_name)
177            kernel_name = self.kernels[picked_kernel].kernel_name
178
179        # numels for all subkernels should be the same. Use kernels[0] here
180        self.kernels[0].add_numel_to_call_args_and_grid(
181            kernel_name, call_args, arg_types, grid
182        )
183
184        grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid)
185        V.graph.wrapper_code.generate_kernel_call(
186            kernel_name,
187            call_args,
188            grid,
189            arg_types=arg_types,
190        )
191
192    def codegen_nan_check(self):
193        wrapper = V.graph.wrapper_code
194        seen = set()
195        for k in self.kernels:
196            _, call_args, precompile_args, _ = k.args.python_argdefs()
197            for arg, precompile_arg in zip(call_args, precompile_args):
198                if arg in seen:
199                    continue
200                seen.add(arg)
201                if isinstance(precompile_arg, TensorArg):
202                    line = f"assert not {arg}.isnan().any().item()"
203                    wrapper.writeline(line)
204                    line = f"assert not {arg}.isinf().any().item()"
205                    wrapper.writeline(line)
206
207    @property
208    def removed_buffers(self):
209        return OrderedSet.intersection(*[k.removed_buffers for k in self.kernels])
210
211    @property
212    def inplaced_to_remove(self):
213        return OrderedSet.intersection(*[k.inplaced_to_remove for k in self.kernels])
214
215    @property
216    @cache_on_self
217    def inplace_update_buffers(self):
218        """
219        Make sure all kernels have the same inplace update mappings.
220        """
221        for k in self.kernels[1:]:
222            assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers
223        return self.kernels[0].inplace_update_buffers
224
225    def warn_mix_layout(self, kernel_name: str):
226        pass
227
228
229class MultiKernelCall:
230    """
231    This class is called at run time to actually run the kernel
232    """
233
234    def __init__(self, multi_kernel_name, kernels):
235        assert len(kernels) >= 2
236        self._kernels = kernels
237        self.multi_kernel_name = multi_kernel_name
238
239        self.disable_cache = os.environ.get(
240            "TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE"
241        ) == "1" or is_metric_table_enabled("persistent_red_perf")
242
243        self.picked_kernel = None
244        if config.triton.multi_kernel > 1:
245            # manually force a subkernel to ease perf testing
246            picked_by_config = config.triton.multi_kernel - 2
247            assert picked_by_config < len(self._kernels)
248            self.picked_kernel = picked_by_config
249        elif not self.disable_cache:
250            self.load_cache()
251
252        self._recorded = False
253
254    def cache_file_path(self):
255        _, _, path = get_path(self.kernels[0].fn.cache_key, "picked_kernel")
256        return pathlib.Path(path)
257
258    def load_cache(self):
259        assert self.picked_kernel is None
260        path = self.cache_file_path()
261        if path.exists():
262            with path.open() as fd:
263                self.picked_kernel = int(fd.read())
264                assert self.picked_kernel >= 0 and self.picked_kernel < len(
265                    self._kernels
266                )
267                log.debug(
268                    "Load picked kernel %d from cache file %s", self.picked_kernel, path
269                )
270
271    def store_cache(self):
272        assert self.picked_kernel is not None
273        path = self.cache_file_path()
274        path.parent.mkdir(parents=True, exist_ok=True)
275
276        with path.open("w") as fd:
277            fd.write(str(self.picked_kernel))
278        log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path)
279
280    @property
281    def kernels(self):
282        """
283        Read results from future.
284
285        This should be called after parallel compilation is done.
286        In case you call this before compilation is done,
287        it may slow down the parallel compilation.
288        """
289        for i, kernel in enumerate(self._kernels):
290            if isinstance(kernel, TritonFuture):
291                self._kernels[i] = kernel.result()
292
293        return self._kernels
294
295    def benchmark_sub_kernels(self, *args, **kwargs):
296        """
297        Benchmark all the sub kernels and return the execution time
298        (in milliseconds) for each of time.
299
300        Unit test may mock this method to force a specific kernel to
301        be picked.
302        """
303
304        def wrap_fn(kernel):
305            def inner():
306                args_clone, kwargs_clone = kernel.clone_args(*args, **kwargs)
307                return kernel.run(*args_clone, **kwargs_clone)
308
309            return inner
310
311        return [
312            benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40, fast_flush=True)
313            for kernel in self.kernels
314        ]
315
316    # record_choice and lookup_choice are helper functions for cpp-wrapper
317    # codegen. The first pass use record_choice to keep the choice and
318    # the second pass do lookup by calling lookup_choice.
319    #
320    # An alternative that reused the multi-kernel cache does not work well
321    # since during codegen of the second pass, it's very hard to know the
322    # path for the cache file. Also reading the cache file need do some IO
323    # which can be slower.
324    @staticmethod
325    def record_choice(multi_kernel_name, choice):
326        """
327        Record the multi-kernel choice for cpp-wrapper first pass codegen
328        for the second pass.
329
330        We should do nothing if this function is not called during codegen.
331        """
332        from torch._inductor.graph import GraphLowering
333
334        if not isinstance(V.graph, GraphLowering):
335            return
336
337        if not V.graph.record_multi_kernel_choice:
338            return
339
340        V.graph.multi_kernel_to_choice[multi_kernel_name] = choice
341
342    @staticmethod
343    def lookup_choice(multi_kernel_name):
344        # this should always been done during cpp-wrapper codegen
345        assert V.graph.record_multi_kernel_choice
346        # there should be no miss
347        return V.graph.multi_kernel_to_choice[multi_kernel_name]
348
349    def run(self, *args, **kwargs):
350        if self.picked_kernel is None:
351            timings = self.benchmark_sub_kernels(*args, **kwargs)
352            self.picked_kernel = timings.index(min(timings))
353            k0 = self.kernels[0]
354            log.debug(
355                "pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s",
356                self.picked_kernel,
357                [k.inductor_meta.get("kernel_name") for k in self.kernels],
358                k0.size_hints,
359                k0.inductor_meta.get("reduction_hint"),
360                timings,
361            )
362
363            def get_kernel_path(k):
364                return k.fn.fn.__code__.co_filename
365
366            get_metric_table("persistent_red_perf").add_row(
367                lambda: {
368                    "kernel1_name": get_kernel_path(self.kernels[0]),
369                    "kernel2_name": get_kernel_path(self.kernels[1]),
370                    "kernel1_latency": timings[0],
371                    "kernel2_latency": timings[1],
372                    "size_hints": k0.size_hints,
373                    "reduction_hint": k0.inductor_meta.get("reduction_hint"),
374                    "speedup": timings[1] / timings[0],
375                }
376            )
377
378            if not self.disable_cache:
379                self.store_cache()
380
381        if not self._recorded:
382            self._recorded = True
383            self.record_choice(self.multi_kernel_name, self.picked_kernel)
384        self.run = self.kernels[self.picked_kernel].run  # type: ignore[method-assign]
385        self.run(*args, **kwargs)
386