xref: /aosp_15_r20/external/pytorch/torch/_inductor/select_algorithm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import builtins
3import contextlib
4import functools
5import inspect
6import itertools
7import json
8import logging
9import math
10import operator
11import os
12import sys
13import textwrap
14import time
15from collections import namedtuple
16from concurrent.futures import as_completed, ThreadPoolExecutor
17from io import StringIO
18from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19from unittest.mock import patch
20
21import sympy
22from filelock import FileLock
23
24import torch
25import torch._inductor.async_compile  # noqa: F401 required to warm up AsyncCompile pools
26from torch._dynamo.testing import rand_strided
27from torch._dynamo.utils import counters, identity, preserve_rng_state
28
29from . import config, ir
30from .autotune_process import TensorMeta, TritonBenchmarkRequest
31from .codecache import code_hash, PersistentCache, PyCodeCache
32from .codegen.common import IndentedBuffer, KernelTemplate
33from .codegen.triton import (
34    gen_common_triton_imports,
35    texpr,
36    TritonKernel,
37    TritonPrinter,
38    TritonScheduling,
39)
40from .codegen.triton_utils import config_of, signature_to_meta
41from .exc import CUDACompileError
42from .ir import ChoiceCaller, PrimitiveInfoType
43from .runtime.benchmarking import benchmarker
44from .runtime.hints import DeviceProperties
45from .utils import (
46    FakeIndentedBuffer,
47    get_dtype_size,
48    Placeholder,
49    restore_stdout_stderr,
50    sympy_dot,
51    sympy_index_symbol,
52    sympy_product,
53    unique,
54)
55from .virtualized import V
56
57
58log = logging.getLogger(__name__)
59
60# correctness checks struggle with fp16/tf32
61VERIFY: Dict[str, Any] = {}
62PRINT_AUTOTUNE = True
63DEBUG = False
64
65
66class KernelNamespace:
67    pass
68
69
70# these objects are imported from the generated wrapper code
71extern_kernels = KernelNamespace()
72
73
74class PartialRender:
75    """
76    Some parts of a template need to be generated at the end, but
77    inserted into the template at the start.  This allows doing a bunch
78    of replacements after the initial render.
79    """
80
81    def __init__(self, code, replacement_hooks) -> None:
82        super().__init__()
83        self.code = code
84        self.replacement_hooks = replacement_hooks
85
86    def finalize_hook(self, hook_key: str, strict=True) -> None:
87        if hook_key not in self.replacement_hooks:
88            if strict:
89                raise RuntimeError(
90                    f"{hook_key} not registered in self.replacement_hooks"
91                )
92            else:
93                return
94        assert (
95            self.replacement_hooks[hook_key] is not None
96        ), "hook_key can only be called once"
97        self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
98        self.replacement_hooks[hook_key] = None
99
100    def finalize_all(self) -> str:
101        for key, fn in self.replacement_hooks.items():
102            self.code = self.code.replace(key, fn())
103        return self.code
104
105
106# This is used to store info needed for lowering each subgraph in triton
107# templates
108SubgraphInfo = namedtuple(
109    "SubgraphInfo",
110    [
111        "body",
112        "template_mask",
113        "template_out",
114    ],
115)
116
117
118class TritonTemplateKernel(TritonKernel):
119    def __init__(
120        self,
121        kernel_name,
122        input_nodes,
123        output_node,
124        defines,
125        num_stages,
126        num_warps,
127        grid_fn,
128        meta,
129        call_sizes,
130        use_jit=False,
131        prefix_args=0,
132        suffix_args=0,
133        epilogue_fn=identity,
134        subgraphs: Optional[List[ir.ComputedBuffer]] = None,
135        *,
136        index_dtype,
137    ) -> None:
138        super().__init__(
139            sympy_product(output_node.get_size()),
140            sympy.Integer(1),
141            index_dtype=index_dtype,
142        )
143        self.input_nodes = input_nodes
144        self.output_node = output_node
145        self.named_input_nodes = {}  # type: ignore[var-annotated]
146        self.defines = defines
147        self.kernel_name = kernel_name
148        self.use_jit = use_jit
149        self.num_stages = num_stages
150        self.num_warps = num_warps
151        self.grid_fn = grid_fn
152        self.meta = meta
153        self.call_sizes = call_sizes
154        # for templates with fixed epilogues
155        self.prefix_args = prefix_args
156        self.suffix_args = suffix_args
157        self.epilogue_fn = epilogue_fn
158        self.render_hooks = {}  # type: ignore[var-annotated]
159        self.triton_meta: Optional[Dict[str, object]] = None
160        # For Templated Attention this can be a list of ir.Subgraph
161        self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs
162
163        # The following attributes (body, template_mask, output_val) are all
164        # used for triton kernel codegen.
165        # They are swapped onto the TritonTemplateKernel object by
166        # `set_subgraph_body`
167        self.subgraph_bodies: Dict[str, SubgraphInfo] = {}
168
169        self.body: IndentedBuffer = FakeIndentedBuffer()
170        self.template_mask: Optional[str] = None
171        self.template_out: Optional[str] = None
172
173    @contextlib.contextmanager
174    def set_subgraph_body(self, body_name: str):
175        old_body, old_mask, old_out = self.body, self.template_mask, self.template_out
176        assert body_name in self.subgraph_bodies, body_name
177        self.body, self.template_mask, self.template_out = self.subgraph_bodies[
178            body_name
179        ]
180        yield
181        self.subgraph_bodies[body_name] = SubgraphInfo(
182            self.body, self.template_mask, self.template_out
183        )
184        self.body, self.template_mask, self.template_out = old_body, old_mask, old_out
185
186    @contextlib.contextmanager
187    def create_subgraph_body(self, body_name: str):
188        assert body_name not in self.subgraph_bodies
189        self.subgraph_bodies[body_name] = SubgraphInfo(IndentedBuffer(), None, None)
190        with self.set_subgraph_body(body_name):
191            yield
192
193    def need_numel_args(self):
194        return False
195
196    def estimate_kernel_num_bytes(self):
197        """
198        Estimate the total number of bytes this kernel takes.
199        For in/out nodes, sizes are counted twice: once for reading and
200        once for writing.
201        """
202        ninplace_args = len(unique(self.args.inplace_buffers.values()))
203        num_bytes = []
204        for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))):
205            size = V.graph.sizevars.size_hints(inp.get_size())
206            numel = functools.reduce(operator.mul, size, 1)
207            dtype_size = get_dtype_size(inp.get_dtype())
208            num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
209        return sum(num_bytes)
210
211    def jit_lines(self):
212        if self.use_jit:
213            return "@triton.jit"
214
215        argdefs, _, signature, _ = self.args.python_argdefs()
216        triton_meta = {
217            "signature": signature_to_meta(signature, size_dtype=self.index_dtype),
218            "device": DeviceProperties.create(self.output_node.get_device()),
219            "constants": {},
220        }
221        triton_meta["configs"] = [config_of(signature)]
222        for arg_num in triton_meta["configs"][0].equal_to_1:  # type: ignore[index]
223            triton_meta["constants"][arg_num] = 1  # type: ignore[index]
224        matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0)
225        if matrix_instr_nonkdim != 0:
226            triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim
227
228        self.triton_meta = triton_meta
229
230        inductor_meta = {
231            "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
232            **TritonKernel.inductor_meta_common(),
233        }
234        if config.profile_bandwidth or config.benchmark_kernel:
235            num_gb = self.estimate_kernel_num_bytes() / 1e9
236            inductor_meta["kernel_num_gb"] = num_gb
237        return f"""
238            @triton_heuristics.template(
239                num_stages={self.num_stages},
240                num_warps={self.num_warps},
241                triton_meta={triton_meta!r},
242                inductor_meta={inductor_meta!r},
243            )
244            @triton.jit
245        """
246
247    def gen_argdefs(self):
248        def hook():
249            # python_argdefs() cannot be run until after the rest of the template lazily adds more args
250            arg_defs, *_ = self.args.python_argdefs()
251            return f"{', '.join(arg_defs)}"
252
253        self.render_hooks["<ARGDEFS>"] = hook
254        return "<ARGDEFS>"
255
256    def gen_defines(self):
257        return self.defines
258
259    def def_kernel(self, *argnames):
260        """
261        Hook called from template code to generate function def and
262        needed args.
263        """
264        assert all(isinstance(x, str) for x in argnames)
265        renames = IndentedBuffer(initial_indent=1)
266
267        named_args = self.input_nodes[
268            self.prefix_args : len(self.input_nodes) - self.suffix_args
269        ]
270
271        assert len(argnames) == len(named_args), (
272            len(argnames),
273            len(named_args),
274            self.prefix_args,
275            len(self.input_nodes),
276        )
277
278        for input_node in self.input_nodes[: self.prefix_args]:
279            # get args in correct order
280            self.args.input(input_node.get_name())
281
282        for name, input_node in zip(argnames, named_args):
283            arg_name = f"arg_{name}"
284            self.named_input_nodes[name] = input_node
285            self.args.input_buffers[input_node.get_name()] = arg_name
286
287        # The args may be duplicated, so renaming must be after args are de-duplicated.
288        for name in argnames:
289            input_node = self.named_input_nodes[name]
290            arg_name = self.args.input_buffers[input_node.get_name()]
291            if input_node.get_layout().offset == 0:
292                renames.writeline(f"{name} = {arg_name}")
293            else:
294                offset = texpr(self.rename_indexing(input_node.get_layout().offset))
295                renames.writeline(f"{name} = {arg_name} + {offset}")
296
297        for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]:
298            # get args in correct order
299            self.args.input(input_node.get_name())
300
301        def hook():
302            # python_argdefs() cannot be run until after the rest of the template lazily adds more args
303            arg_defs, *_ = self.args.python_argdefs()
304            code = IndentedBuffer()
305            code.splice(gen_common_triton_imports())
306            code.splice(self.jit_lines())
307            code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):")
308            with code.indent():
309                code.splice(self.defines)
310                code.splice(renames.getvalue())
311            return code.getvalue()
312
313        assert "<DEF_KERNEL>" not in self.render_hooks
314        self.render_hooks["<DEF_KERNEL>"] = hook
315        return "<DEF_KERNEL>"
316
317    def size(self, name: str, index: int):
318        """
319        Hook called from template code to get the size of an arg.
320        Will add needed args to pass it in if it is dynamic.
321        """
322        assert isinstance(index, int)
323        if name is None:
324            val = self.output_node.get_size()[index]
325        else:
326            assert isinstance(name, str)
327            val = self.named_input_nodes[name].get_size()[index]
328        return texpr(self.rename_indexing(val))
329
330    def stride(self, name, index=None):
331        """
332        Hook called from template code to get the stride of an arg.
333        Will add needed args to pass it in if it is dynamic.
334        """
335        if name is None:
336            val = self.output_node.get_stride()
337        else:
338            assert isinstance(name, str)
339            val = self.named_input_nodes[name].get_stride()
340
341        if isinstance(index, int):
342            return texpr(self.rename_indexing(val[index]))
343        else:
344            return ", ".join([texpr(self.rename_indexing(i)) for i in val])
345
346    def modification(
347        self, subgraph_number: int, output_name: str, **fixed_inputs
348    ) -> str:
349        """This creates a modification function for a subgraph.
350        To use this inside a template, the first argument should specify which subgraph to codegen for
351
352        Args:
353            subgraph_number (int): The index of the subgraph in self.subgraphs
354        """
355        num = 0
356        while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies:
357            num += 1
358        with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"):
359            assert isinstance(subgraph_number, int)
360            assert isinstance(self.subgraphs, list)
361            assert (
362                self.body.getvalue() == ""
363            ), "Body should be clear before adding a modification"
364            assert subgraph_number < len(
365                self.subgraphs
366            ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
367
368            subgraph = self.subgraphs[subgraph_number]
369
370            def add_input(name):
371                return self.args.input(name)
372
373            name = f"PlaceholderSubstitution_{subgraph_number}"
374
375            class PlaceholderSubstitution(V.WrapperHandler):  # type: ignore[name-defined]
376                self.name = name
377
378                def load(self, name: str, index: sympy.Expr):
379                    if name not in fixed_inputs:
380                        # If it's not a fixed input, it's a load from a captured
381                        # tensor
382                        var = add_input(name)
383                        return f"tl.load({var} + {index})"
384
385                    return f"({fixed_inputs[name]})"
386
387                def indirect_indexing(self, index_var, size, check, wrap_neg=True):
388                    return sympy_index_symbol(str(index_var))
389
390            with V.set_ops_handler(PlaceholderSubstitution(V.ops)):
391                assert isinstance(
392                    subgraph, ir.ComputedBuffer
393                ), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}"
394                if isinstance(subgraph.data, ir.InputBuffer):
395                    out = subgraph.data.make_loader()(())
396                else:
397                    out = subgraph.data.inner_fn(())
398
399            self.codegen_body()
400            self.body.writeline(f"{output_name} = {out.value}")
401
402            body_val = self.body.getvalue()
403            self.cse.invalidate(set())  # type: ignore[arg-type]
404            return body_val
405
406    def store_output(
407        self,
408        indices: Union[List[Any], Tuple[Any]],
409        val: str,
410        mask: Optional[str] = None,
411        indent_width: int = 4,
412    ):
413        """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away.
414
415        Args:
416            indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of
417                these indices and output strides must match `val`.
418            val (str): The value to store.
419            mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask
420                will be applied to the store.
421            indent_width (int): The number of spaces to use for indentation. This is used when the call to
422                store_output is indented in the kernel definition.
423        """
424        with self.create_subgraph_body("<STORE_OUTPUT>"):
425            assert isinstance(indices, (list, tuple))
426            assert isinstance(val, str)
427            assert isinstance(mask, (str, type(None)))
428            assert self.template_mask is None
429            indices = list(map(TritonPrinter.paren, indices))
430            index_symbols = [sympy.Symbol(x, integer=True) for x in indices]
431            lengths = [
432                V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
433            ]
434            assert len(indices) == len(lengths)
435
436            # glue to make generated code use same indexing from template
437            for name, range_tree_entry in zip(
438                indices, self.range_trees[0].construct_entries(lengths)
439            ):
440                range_tree_entry.set_name(name)
441            contiguous_index = sympy_dot(
442                ir.FlexibleLayout.contiguous_strides(lengths), index_symbols
443            )
444            contiguous_index = self.rename_indexing(contiguous_index)
445            self.body.writeline("xindex = " + texpr(contiguous_index))
446            self.range_trees[0].lookup(
447                sympy.Integer(1), sympy_product(lengths)
448            ).set_name("xindex")
449            self.template_mask = mask
450            self.template_out = val
451            self.template_indices = indices
452            output_index = self.output_node.get_layout().make_indexer()(index_symbols)
453            output_index = self.rename_indexing(output_index)
454            if output_index == contiguous_index:
455                output_index = sympy.Symbol("xindex", integer=True)
456
457            epilogue_args = [val]
458            for input_node in itertools.chain(
459                self.input_nodes[: self.prefix_args],
460                self.input_nodes[len(self.input_nodes) - self.suffix_args :],
461            ):
462                input_node.freeze_layout()
463                epilogue_args.append(input_node.make_loader()(index_symbols))
464
465            V.ops.store(
466                self.output_node.get_name(),
467                output_index,
468                self.epilogue_fn(*epilogue_args),
469            )
470            self.codegen_body()
471
472        def hook():
473            # more stuff might have been added since the codegen_body above
474            self.codegen_body()
475
476            return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()
477
478        assert "<STORE_OUTPUT>" not in self.render_hooks
479        self.render_hooks["<STORE_OUTPUT>"] = hook
480        return "<STORE_OUTPUT>"
481
482    def render(self, template, kwargs):
483        return PartialRender(
484            template.render(**self.template_env(), **kwargs),
485            self.render_hooks,
486        )
487
488    def make_load(self, name, indices, mask):
489        """
490        Optional helper called from template code to generate the code
491        needed to load from an tensor.
492        """
493        assert isinstance(indices, (list, tuple))
494        assert isinstance(name, str)
495        assert isinstance(mask, str)
496        stride = self.named_input_nodes[name].get_stride()
497        indices = list(map(TritonPrinter.paren, indices))
498        assert len(indices) == len(stride)
499        index = " + ".join(
500            f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
501        )
502        return f"tl.load({name} + ({index}), {mask}, other=0.0)"
503
504    def template_env(self):
505        """
506        Generate the namespace visible in the template.
507        """
508        return {
509            fn.__name__: fn
510            for fn in [
511                self.def_kernel,
512                self.size,
513                self.stride,
514                self.store_output,
515                self.make_load,
516                self.modification,
517                self.gen_argdefs,
518                self.gen_defines,
519            ]
520        }
521
522    def indexing(
523        self,
524        index: sympy.Expr,
525        *,
526        dense_indexing=False,
527        copy_shape=None,
528        override_mask=None,
529        block_ptr=False,
530    ):
531        """
532        Override the default indexing to use our custom mask and force
533        dense indexing.
534        """
535        return super().indexing(
536            index,
537            dense_indexing=False,
538            # We pass template_out as the shape to broadcast the indexing to as
539            # the mask might be broadcast to the output shape
540            copy_shape=self.template_out,
541            override_mask=self.template_mask,
542            block_ptr=block_ptr,
543        )
544
545    def codegen_range_tree(self):
546        pass  # ignore default codegen
547
548    def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
549        wrapper = V.graph.wrapper_code
550        _, call_args, _, arg_types = self.args.python_argdefs()
551        if V.graph.cpp_wrapper:
552            # In the cpp_wrapper case, we have to compute CUDA launch grid at runtime
553            # if any dynamic dimension is involved. We rely on the Python version
554            # of the grid function to generate those grid configs, which may contain
555            # symbolic values. The wrapper will use cexpr to print out C++ code
556            # appropriately for the grid configs.
557            grid = self.call_sizes + [self.meta]
558            wrapper.generate_kernel_call(
559                name,
560                call_args,
561                grid=self.grid_fn(*grid),
562                arg_types=arg_types,
563                triton_meta=self.triton_meta,
564            )
565        else:
566            wrapper.add_import_once(f"import {self.grid_fn.__module__}")
567            meta = wrapper.add_meta_once(self.meta)
568            grid = self.call_sizes + [meta]
569            wrapper.generate_kernel_call(
570                name,
571                call_args,
572                grid=grid,
573                grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}",
574                arg_types=arg_types,
575                triton_meta=self.triton_meta,
576            )
577
578
579@functools.lru_cache(None)
580def _jinja2_env():
581    try:
582        import jinja2
583
584        return jinja2.Environment(
585            undefined=jinja2.StrictUndefined,
586        )
587    except ImportError:
588        return None
589
590
591class TritonTemplate(KernelTemplate):
592    index_counter = itertools.count()
593    all_templates: Dict[str, "TritonTemplate"] = {}
594
595    def __init__(self, name: str, grid: Any, source: str, debug=False) -> None:
596        super().__init__(name)
597        self.grid = grid
598        self.template = self._template_from_string(source)
599        assert name not in self.all_templates, "duplicate template name"
600        self.all_templates[name] = self
601        self.debug = debug
602
603    def generate(  # type: ignore[override]
604        self,
605        input_nodes,
606        layout,
607        num_stages,
608        num_warps,
609        prefix_args=0,
610        suffix_args=0,
611        epilogue_fn=identity,
612        subgraphs=None,
613        mutated_inputs=None,
614        call_sizes=None,
615        **kwargs,
616    ):
617        """This function generates a TritonTemplateCaller
618
619        Args:
620            input_nodes: List of input nodes
621            layout: Output layout
622            num_stages: Number of stages for triton launch
623            num_warps: Number of warps for triton launch
624            prefix_args: Number of input nodes to be passed as arguments
625            suffix_args: Number of input nodes to be passed as arguments
626            epilogue_fn: Optional epilogue function to be called on the output
627            subgraphs: Optional subgraphs to be passed as arguments, these will be inlined
628                into the triton template string
629            mutated_inputs: Optional list of input nodes that are mutated by the kernel, this is helpful
630                if you need to return multiple outputs. You can pass them as inputs and mark them as
631                being mutated by the kernel.
632        """
633        assert self.template, "requires jinja2"
634        defines = StringIO()
635        for name, val in kwargs.items():
636            defines.write(f"{name} : tl.constexpr = {val}\n")
637        defines = defines.getvalue()
638
639        fake_out = ir.Buffer("buf_out", layout)
640        kernel_name = f"triton_{self.name}"
641
642        numel = sympy_product(layout.size)
643        buffers = itertools.chain(input_nodes, (fake_out,))
644        if not TritonScheduling.can_use_32bit_indexing(numel, buffers):
645            raise NotImplementedError(
646                "64-bit indexing is not yet implemented for triton templates"
647            )
648
649        if call_sizes is None:
650            call_sizes = layout.size
651
652        kernel_options = dict(
653            input_nodes=input_nodes,
654            defines=defines,
655            num_stages=num_stages,
656            num_warps=num_warps,
657            grid_fn=self.grid,
658            meta=kwargs,
659            call_sizes=call_sizes,
660            prefix_args=prefix_args,
661            suffix_args=suffix_args,
662            epilogue_fn=epilogue_fn,
663            index_dtype="tl.int32",
664            subgraphs=subgraphs,
665        )
666
667        with patch.object(
668            V.graph, "get_dtype", self._fake_get_dtype(fake_out)
669        ), TritonTemplateKernel(
670            kernel_name=kernel_name,
671            output_node=fake_out,
672            use_jit=False,
673            **kernel_options,
674        ) as kernel:
675            try:
676                template = kernel.render(self.template, kwargs)
677                with kernel.set_subgraph_body("<STORE_OUTPUT>"):
678                    code = template.finalize_all()
679            except ZeroDivisionError:
680                # TODO(nmacchioni): fix sympy division by zero
681                return None
682            if self.debug:
683                print("Generated Code:\n", code)
684            extra = (
685                "-".join(
686                    [
687                        *[
688                            f"{kwarg}={repr(kwargs[kwarg])}"
689                            for kwarg in sorted(kwargs.keys())
690                        ],
691                        f"num_stages={num_stages}",
692                        f"num_warps={num_warps}",
693                    ]
694                )
695                + "-"
696            )
697            mod = PyCodeCache.load(code, extra)
698
699        input_call_args = tuple(kernel.args.input_buffers.keys())
700        output_call_args = tuple(kernel.args.output_buffers.keys())
701
702        # We expect the input_buffer order to be [*input_nodes, *captured_buffers]
703        expected_input_args = tuple(unique(x.get_name() for x in input_nodes))
704        expected_output_args = (fake_out.get_name(),)
705        assert input_call_args[: len(expected_input_args)] == expected_input_args, (
706            input_call_args,
707            expected_input_args,
708        )
709        assert output_call_args == expected_output_args, (
710            output_call_args,
711            expected_output_args,
712        )
713
714        full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args])
715        extra_args = V.graph.sizevars.size_hints(
716            map(sympy.expand, tuple(kernel.args.sizevars.keys())),
717            fallback=config.unbacked_symint_fallback,
718        )
719
720        kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
721
722        def make_kernel_render(out_node):
723            kernel = TritonTemplateKernel(
724                kernel_name=str(Placeholder.KERNEL_NAME),
725                output_node=out_node,
726                use_jit=False,
727                **kernel_options,
728            )
729            render = functools.partial(
730                kernel.render,
731                self.template,
732                kwargs,
733            )
734            return kernel, render
735
736        # create the BenchmarkRequest
737        assert mod.__file__ is not None
738        grid = self.grid(
739            *V.graph.sizevars.size_hints(
740                call_sizes,
741                fallback=config.unbacked_symint_fallback,
742            ),
743            kwargs,
744        )
745        bmreq = TritonBenchmarkRequest(
746            module_path=mod.__file__,
747            module_cache_key=mod.key,
748            kernel_name=kernel_name,
749            grid=grid,
750            extra_args=extra_args,
751            num_stages=num_stages,
752            num_warps=num_warps,
753            matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
754            input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes),  # type: ignore[arg-type]
755            output_tensor_meta=TensorMeta.from_irnodes(layout),
756        )
757
758        return TritonTemplateCaller(
759            kernel_hash_name,
760            full_input_nodes,
761            layout,
762            make_kernel_render,
763            extra.strip("-").replace("-", ", "),
764            bmreq,
765            log_info={
766                "tile_shape": str(
767                    (
768                        kwargs.get("BLOCK_M", -1),
769                        kwargs.get("BLOCK_K", -1),
770                        kwargs.get("BLOCK_N", -1),
771                    )
772                ),
773                "num_stages": num_stages,
774                "num_warps": num_warps,
775                "allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
776                "acc_type": str(kwargs.get("ACC_TYPE", None)),
777            },
778            mutated_inputs=mutated_inputs,
779        )
780
781
782class ExternKernelChoice:
783    def __init__(
784        self,
785        kernel,
786        cpp_kernel=None,
787        *,
788        name=None,
789        has_out_variant=True,
790        op_overload=None,
791        use_fallback_kernel=False,
792        kernel_creator=None,
793    ) -> None:
794        super().__init__()
795        name = name or kernel.__name__
796        assert callable(kernel)
797        assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}"
798        self.name = name
799        self.cpp_kernel_name = cpp_kernel
800        self.has_out_variant = has_out_variant
801        setattr(extern_kernels, name, kernel)
802        self.op_overload = op_overload
803        self.use_fallback_kernel = use_fallback_kernel
804        self.kernel_creator = kernel_creator
805
806    def to_callable(self):
807        return getattr(extern_kernels, self.name)
808
809    def call_name(self):
810        return f"extern_kernels.{self.name}"
811
812    @functools.lru_cache(None)  # noqa: B019
813    def hash_key(self):
814        fn = self.to_callable()
815        parts = [
816            self.name,
817            getattr(fn, "__name__", ""),
818            getattr(fn, "__module__", ""),
819        ]
820        try:
821            parts.append(inspect.getsource(fn))
822        except Exception:
823            pass
824        return code_hash("-".join(parts))
825
826    def bind(
827        self,
828        input_nodes,
829        layout,
830        ordered_kwargs_for_cpp_kernel=(),
831        **kwargs,
832    ):
833        self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
834        return ExternKernelCaller(
835            self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant
836        )
837
838
839class TritonTemplateCaller(ir.TritonTemplateCallerBase):
840    def __init__(
841        self,
842        name,
843        input_nodes,
844        layout,
845        make_kernel_render,
846        debug_extra,
847        bmreq,
848        log_info: Optional[
849            Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]
850        ] = None,
851        mutated_inputs=None,
852    ) -> None:
853        super().__init__(name, input_nodes, layout)
854        self.make_kernel_render = make_kernel_render
855        self.debug_extra = debug_extra
856        self.bmreq: TritonBenchmarkRequest = bmreq
857        if log_info is None:
858            log_info = {}
859        self.log_info: Dict[str, Any] = log_info
860        self.log_info.update(
861            {
862                "backend": "Triton",
863                "grid": str(self.bmreq.grid),
864                "num_stages": self.bmreq.num_stages,
865                "num_warps": self.bmreq.num_warps,
866            }
867        )
868        self.mutated_inputs = mutated_inputs
869
870    def benchmark(self, *args, out):
871        assert self.bmreq is not None
872        return self.bmreq.benchmark(*args, output_tensor=out)
873
874    def precompile(self):
875        assert self.bmreq is not None
876        self.bmreq.precompile()
877
878    def __str__(self) -> str:
879        return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})"
880
881    def call_name(self):
882        return f"template_kernels.{self.name}"
883
884    def hash_key(self):
885        return "-".join(
886            [
887                self.name.rsplit("_", 1)[0],
888                self.bmreq.module_cache_key,
889            ]
890        )
891
892    def output_node(self):
893        return ir.TensorBox.create(
894            ir.TritonTemplateBuffer(
895                layout=self.layout,
896                inputs=self.input_nodes,
897                make_kernel_render=self.make_kernel_render,
898                debug_extra=self.debug_extra,
899                mutated_inputs=self.mutated_inputs,
900            )
901        )
902
903    def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
904        """Information returned here is logged to the autotune log file when that is enabled."""
905        return self.log_info
906
907    def get_make_kernel_render(self):
908        return self.make_kernel_render
909
910    def autoheuristic_id(self):
911        type_name = "triton"
912        info = self.info_dict()
913        # TODO(AlnisM): Does tile_shape always exist?
914        tile = info["tile_shape"]
915        tile_vals = eval(tile)  # type: ignore[arg-type]
916        BLOCK_M = tile_vals[0]
917        BLOCK_K = tile_vals[1]
918        BLOCK_N = tile_vals[2]
919        num_stages = info["num_stages"]
920        num_warps = info["num_warps"]
921        return f"type={type_name}_BLOCK-M={BLOCK_M}_BLOCK-K={BLOCK_K}_BLOCK-N={BLOCK_N}_numstages={num_stages}_numwarps={num_warps}"
922
923
924class ExternKernelCaller(ChoiceCaller):
925    def __init__(
926        self,
927        choice: ExternKernelChoice,
928        input_nodes,
929        layout,
930        kwargs=None,
931        *,
932        has_out_variant=True,
933    ) -> None:
934        super().__init__(choice.name, input_nodes, layout)
935        self.choice = choice
936        self.kwargs = kwargs or {}
937        self.has_out_variant = has_out_variant
938
939    def __str__(self) -> str:
940        return f"ExternKernelCaller({self.choice.call_name()})"
941
942    def benchmark(self, *args, out):
943        if out.numel() == 0:
944            # no need to run the kerrnel of do benchmarking
945            return 0.0
946        if self.has_out_variant:
947            return super().benchmark(*args, out=out)
948        else:
949            algo = self.to_callable()
950            out_new = algo(*args)
951            torch._C._dynamo.guards.assert_size_stride(
952                out_new, tuple(out.size()), tuple(out.stride())
953            )
954            out.copy_(out_new)  # for correctness checking
955            return benchmarker.benchmark(algo, args, {})
956
957    def to_callable(self):
958        fn = self.choice.to_callable()
959        if self.kwargs:
960            return functools.partial(fn, **self.kwargs)
961        else:
962            return fn
963
964    def hash_key(self):
965        return "-".join(
966            [
967                self.choice.name,
968                *[
969                    f"{kwarg}={repr(self.kwargs[kwarg])}"
970                    for kwarg in sorted(self.kwargs.keys())
971                ],
972                self.choice.hash_key(),
973            ]
974        )
975
976    def output_node(self):
977        if config.abi_compatible and self.choice.use_fallback_kernel:
978            assert (
979                self.choice.op_overload is not None
980            ), "Please provide an op_overload to use ir.FallbackKernel"
981            inner = ir.FallbackKernel.create(
982                self.choice.op_overload, *self.input_nodes, **self.kwargs
983            )
984        elif self.choice.kernel_creator is not None:
985            inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs)
986        else:
987            cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc
988            inner = cls(
989                layout=self.layout,
990                inputs=self.input_nodes,
991                python_kernel_name=self.choice.call_name(),
992                cpp_kernel_name=self.choice.cpp_kernel_name,
993                ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel,
994                op_overload=self.choice.op_overload,
995                kwargs=self.kwargs,
996            )
997
998        return ir.TensorBox.create(inner)
999
1000    def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
1001        """Information returned here is logged to the autotune log file when that is enabled."""
1002        return {
1003            "backend": "extern",
1004            "kernel_call_name": self.choice.call_name(),
1005        }
1006
1007    def autoheuristic_id(self):
1008        return f"extern_{self.choice.name}"
1009
1010
1011@functools.lru_cache(None)
1012def get_mm_log_filename() -> Optional[str]:
1013    mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None)
1014    if not mm_file_name:
1015        return None
1016
1017    if "json" not in mm_file_name:
1018        mm_file_name = f"{mm_file_name}.json"
1019
1020    return mm_file_name
1021
1022
1023def append_to_log(filename, data):
1024    lock_file = filename.replace(".json", ".lock")
1025    lock = FileLock(lock_file)
1026    with lock:
1027        try:
1028            with open(filename) as f:
1029                log_data = json.load(f)
1030        except (FileNotFoundError, json.JSONDecodeError):
1031            log_data = []
1032
1033        log_data.append(data)
1034
1035        with open(filename, "w") as f:
1036            json.dump(log_data, f, indent=4)
1037
1038
1039class DataProcessorChoiceCallerWrapper:
1040    def __init__(self, wrapped, preprocessor, postprocessor) -> None:
1041        self._wrapped = wrapped
1042        if preprocessor is not None:
1043            self._preprocessor = preprocessor
1044        else:
1045            self._preprocessor = lambda x, y: (x, y)
1046        if postprocessor is not None:
1047            self._postprocessor = postprocessor
1048        else:
1049            self._postprocessor = lambda x: x
1050
1051    def __getattr__(self, name):
1052        return getattr(self._wrapped, name)
1053
1054    def benchmark(self, *args, out) -> float:
1055        new_args, new_out = self._preprocessor(args, out)
1056        result = self._wrapped.benchmark(*new_args, out=new_out)
1057        new_out = self._postprocessor(new_out)
1058        if out is not new_out:
1059            out.copy_(new_out)
1060        return result
1061
1062    def output_node(self) -> ir.TensorBox:
1063        result = self._wrapped.output_node()
1064        return self._postprocessor(result)
1065
1066    def __repr__(self) -> str:
1067        return f"DataProcessorChoiceCallerWrapper({self._wrapped})"
1068
1069
1070class DataProcessorTemplateWrapper:
1071    """
1072    A wrapper class for a kernel template.
1073
1074    This class together with `DataProcessorChoiceCallerWrapper` provides a convenient way to
1075    preprocess and postprocess data before and after using the wrapped template. A typical
1076    usage is to reorder or filter the input nodes in order to match the expected input of other
1077    kernel choices like a ATen kernel. A more complicated usage is to prepack the weights.
1078    See the example from :mod:`cpp_gemm_template` for more details.
1079    """
1080
1081    def __init__(
1082        self,
1083        wrapped_template_cls,
1084        preprocessor,
1085        postprocessor,
1086        **kwargs,
1087    ) -> None:
1088        if preprocessor is not None:
1089            self._preprocessor = preprocessor
1090        else:
1091            self._preprocessor = lambda x, y: (x, y)
1092        if postprocessor is not None:
1093            self._postprocessor = postprocessor
1094        else:
1095            self._postprocessor = lambda x: x
1096        assert "input_nodes" in kwargs
1097        assert "layout" in kwargs
1098        kwargs["input_nodes"], kwargs["layout"] = preprocessor(
1099            kwargs["input_nodes"], kwargs["layout"]
1100        )
1101        self._wrapped = wrapped_template_cls(**kwargs)
1102
1103    def __getattr__(self, name):
1104        return getattr(self._wrapped, name)
1105
1106    def maybe_append_choice(self, choices, **kwargs):
1107        return type(self._wrapped).maybe_append_choice(self, choices, **kwargs)
1108
1109    def generate(self, **kwargs):
1110        choice_caller = self._wrapped.generate(**kwargs)
1111        return DataProcessorChoiceCallerWrapper(
1112            choice_caller, self._preprocessor, self._postprocessor
1113        )
1114
1115    def __repr__(self) -> str:
1116        return f"DataProcessorTemplateWrapper({self._wrapped})"
1117
1118
1119class ErrorFromChoice(RuntimeError):
1120    def __init__(self, msg, choice: ChoiceCaller, inputs_str) -> None:
1121        msg += f"\nFrom choice {choice}\n{inputs_str}"
1122        super().__init__(msg)
1123        self.choice = choice
1124
1125
1126class NoValidChoicesError(RuntimeError):
1127    pass
1128
1129
1130@functools.lru_cache(None)
1131def get_env_num_workers() -> Optional[int]:
1132    if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
1133        return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
1134    return None
1135
1136
1137def create_inputs_key(input_nodes) -> str:
1138    return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes])
1139
1140
1141def create_precompile_key(
1142    name: str, inputs_key: str, choices: List[ChoiceCaller]
1143) -> str:
1144    return ":".join(
1145        [
1146            name,
1147            inputs_key,
1148            torch.get_float32_matmul_precision(),
1149        ]
1150        + [choice.hash_key() for choice in choices]
1151    )
1152
1153
1154class AlgorithmSelectorCache(PersistentCache):
1155    def __init__(self, *args, **kwargs) -> None:
1156        super().__init__(*args, **kwargs)
1157
1158        # the autotuning will get occur in the scheduler, so there is
1159        # no guarantee that the first lowering for a given key will also be the
1160        # first to benchmark it. share a single precompilation function for all lowerings
1161        # of a particular key
1162        self.precompile_cache: Dict[str, Callable[[], None]] = {}
1163        # list of callbacks that are called after benchmarking
1164        self.feedback_saver_fns: List[
1165            Callable[
1166                [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None
1167            ]
1168        ] = []
1169
1170    def __call__(
1171        self,
1172        name,
1173        choices: List[ChoiceCaller],
1174        input_nodes,
1175        layout,
1176        # optional dict mapping arg indices to the functions
1177        # generating a torch.Tensor for that input from the
1178        # corresponding ir.Buffer. if passed for a given
1179        # arg, the function will be called instead of
1180        # generating a random torch.Tensor for benchmarking.
1181        input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
1182        precompilation_timeout_seconds: int = 60 * 60,
1183        return_multi_template=False,
1184    ):
1185        from .codegen.cuda.cuda_kernel import CUDATemplateCaller
1186
1187        # Templates selected with input_gen_fns require specific input data to avoid IMA
1188        # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection
1189        # TODO(jgong5): support multi-template on CPU
1190        if input_gen_fns is not None or layout.device.type == "cpu":
1191            return_multi_template = False
1192
1193        # TODO - assert that we have not mutating kernels here
1194
1195        # TODO(nmacchioni): remove once CI tests are fixed
1196        choices = [choice for choice in choices if choice is not None]
1197
1198        if mm_file_name := get_mm_log_filename():
1199            M, K = input_nodes[-2].get_size()[:2]
1200            N = input_nodes[-1].get_size()[-1]
1201            append_to_log(mm_file_name, {"invoke": str((M, K, N))})
1202
1203        if len(choices) == 0:
1204            backend_config = (
1205                "max_autotune_gemm_backends"
1206                if name != "convolution"
1207                else "max_autotune_conv_backends"
1208            )
1209            raise NoValidChoicesError(
1210                f"No choices to select, please consider adding ATEN into {backend_config} "
1211                "config (defined in torch/_inductor/config.py) to allow at least one choice. "
1212            )
1213        log.debug("Max autotune selects from %s choices.", str(len(choices)))
1214
1215        if len(choices) == 1:
1216            if not isinstance(choices[0], CUDATemplateCaller):
1217                # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
1218                return choices[0].output_node()
1219
1220        @functools.lru_cache(None)
1221        def make_benchmark_fn():
1222            return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns)
1223
1224        inputs_key = create_inputs_key(input_nodes)
1225
1226        def precompile(choices) -> Callable[[], None]:
1227            def no_op(*args, **kwargs):
1228                return
1229
1230            if (
1231                precompilation_timeout_seconds is None
1232                or precompilation_timeout_seconds <= 0
1233            ):
1234                return no_op
1235
1236            env_workers = get_env_num_workers()
1237            num_workers = env_workers if env_workers is not None else (len(choices))
1238
1239            if num_workers <= 0:
1240                return no_op
1241
1242            # https://github.com/python/cpython/issues/106905
1243            if (
1244                sys.version_info.major == 3
1245                and sys.version_info.minor == 11
1246                and sys.version_info.micro <= 8
1247            ):
1248                return no_op
1249
1250            # check local and global cache before precompiling
1251            timings = self.lookup(
1252                choices,
1253                name,
1254                inputs_key,
1255                benchmark=None,
1256            )
1257
1258            if timings:
1259                return no_op
1260
1261            precompile_key = create_precompile_key(name, inputs_key, choices)
1262            if precompile_func := self.precompile_cache.get(precompile_key):
1263                return precompile_func
1264
1265            log.info(
1266                "Multithreaded precompilation for %d choices using %d worker threads",
1267                len(choices),
1268                num_workers,
1269            )
1270
1271            # In rare circumstances, because python threads inherit global state,
1272            # thread pool executor can race and leave stdout/stderr in a state
1273            # different than the original values. we explicitly restore the state
1274            # here to avoid this issue.
1275
1276            initial_stdout = sys.stdout
1277            initial_stderr = sys.stderr
1278
1279            def precompile_with_captured_stdout(choice):
1280                with restore_stdout_stderr(initial_stdout, initial_stderr):
1281                    return choice.precompile()
1282
1283            executor = ThreadPoolExecutor(max_workers=num_workers)
1284
1285            futures = {}
1286            for c in choices:
1287                if hasattr(c, "precompile"):
1288                    future = executor.submit(precompile_with_captured_stdout, c)
1289                    futures[future] = c
1290
1291            @functools.lru_cache(None)
1292            @restore_stdout_stderr(initial_stdout, initial_stderr)
1293            def wait_on_futures():
1294                counters["inductor"]["select_algorithm_precompile"] += 1
1295                for future in as_completed(
1296                    futures,
1297                    timeout=precompilation_timeout_seconds,
1298                ):
1299                    if e := future.exception():
1300                        log.error(
1301                            "Exception %s for benchmark choice %s", e, futures[future]
1302                        )
1303
1304                executor.shutdown(wait=True)
1305
1306            self.precompile_cache[precompile_key] = wait_on_futures
1307
1308            return wait_on_futures
1309
1310        def autotune(choices):
1311            return make_benchmark_fn()(choices)
1312
1313        if config.autotune_in_subproc:
1314            from .autotune_process import tuning_pool
1315
1316            # do the optional warmup
1317            tuning_pool.initialize()
1318
1319        def do_autotuning(precompile_fn):
1320            precompile_start_ts = time.time()
1321            precompile_fn()
1322            precompile_elapse = time.time() - precompile_start_ts
1323
1324            autotune_start_ts = time.time()
1325            timings = self.lookup(
1326                choices,
1327                name,
1328                inputs_key,
1329                autotune,
1330            )
1331            autotune_elapse = time.time() - autotune_start_ts
1332
1333            if timings and all(
1334                not math.isfinite(timing) for timing in timings.values()
1335            ):
1336                raise NoValidChoicesError
1337
1338            if make_benchmark_fn.cache_info().currsize:
1339                counters["inductor"]["select_algorithm_autotune"] += 1
1340
1341            if (
1342                make_benchmark_fn.cache_info().currsize
1343                or log.getEffectiveLevel() == logging.DEBUG
1344                or config.trace.log_autotuning_results
1345            ):
1346                self.log_results(
1347                    name, input_nodes, timings, autotune_elapse, precompile_elapse
1348                )
1349
1350            for feedback_fn in self.feedback_saver_fns:
1351                feedback_fn(timings, name, input_nodes, choices)
1352
1353            return timings
1354
1355        precompile_fn = precompile(choices)
1356
1357        if return_multi_template and (config.max_autotune or config.max_autotune_gemm):
1358
1359            def get_timings():
1360                timings = do_autotuning(precompile_fn)
1361                min_extern_choice = float("inf")
1362                for choice, timing in timings.items():
1363                    if isinstance(choice, ExternKernelCaller):
1364                        min_extern_choice = min(min_extern_choice, timing)
1365
1366                timings = {
1367                    choice: time
1368                    for choice, time in timings.items()
1369                    if (
1370                        time <= min_extern_choice
1371                        or not isinstance(choice, ExternKernelCaller)
1372                    )
1373                }
1374
1375                return timings
1376
1377            return torch._inductor.ir.TensorBox.create(
1378                torch._inductor.ir.MultiTemplateBuffer(
1379                    layout,
1380                    input_nodes,
1381                    get_timings,
1382                )
1383            )
1384
1385        # TODO - dont want to precompile if we have a cache hit
1386        timings = do_autotuning(precompile_fn)
1387        if timings == {} or choices[0] not in timings:
1388            return choices[0].output_node()
1389
1390        selected_key = builtins.min(timings, key=timings.__getitem__)
1391        selected_time = timings[selected_key]
1392        selected_choice = selected_key.output_node()
1393        log.debug("selected choice: %s", str(selected_choice))
1394        return selected_choice
1395
1396    @classmethod
1397    def make_benchmark_fn(
1398        cls,
1399        choices,
1400        input_nodes,
1401        layout,
1402        input_gen_fns=None,
1403    ):
1404        if input_gen_fns is None:
1405            input_gen_fns = {}
1406
1407        def get_inputs():
1408            # de-duplicate args
1409            unique_example_inputs = {
1410                x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
1411                for i, x in enumerate(input_nodes)
1412            }
1413            example_inputs = list(unique_example_inputs.values())
1414            example_inputs_extern = [
1415                unique_example_inputs[input_node.get_name()]
1416                if unique_example_inputs[input_node.get_name()].is_mkldnn
1417                else torch.as_strided(
1418                    unique_example_inputs[input_node.get_name()],
1419                    V.graph.sizevars.size_hints(
1420                        input_node.get_size(),
1421                        fallback=config.unbacked_symint_fallback,
1422                    ),
1423                    V.graph.sizevars.size_hints(
1424                        input_node.get_stride(),
1425                        fallback=config.unbacked_symint_fallback,
1426                    ),
1427                    V.graph.sizevars.size_hint(
1428                        input_node.get_layout().offset,
1429                        fallback=config.unbacked_symint_fallback,
1430                    ),
1431                )
1432                for input_node in input_nodes
1433            ]
1434
1435            out = cls.benchmark_example_value(layout)
1436            out_extern = torch.as_strided(
1437                out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
1438            )
1439            expected = None
1440            if VERIFY:
1441                choices[0].benchmark(*example_inputs_extern, out=out_extern)
1442                expected = out_extern.clone()
1443
1444            return example_inputs, example_inputs_extern, out, out_extern, expected
1445
1446        if DEBUG:
1447            print(f"{len(choices)} tuning requests:")
1448
1449        def debug_str(example_inputs, out):
1450            def tensor_repr(x):
1451                return (
1452                    f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, "
1453                    f"dtype={x.dtype!r}, device={x.device.type!r})"
1454                )
1455
1456            lines = [
1457                "inputs = [",
1458            ]
1459            for x in example_inputs:
1460                lines.append(f"    {tensor_repr(x)},")
1461            lines += ["]", f"out = {tensor_repr(out)}", ""]
1462            return "\n".join(lines)
1463
1464        def benchmark_choice_in_current_process(
1465            choice, example_inputs, example_inputs_extern, out, out_extern, expected
1466        ):
1467            out.zero_()
1468            if isinstance(choice, ExternKernelCaller):
1469                # aten kernels want the offset baked in for sliced tensors
1470                result = choice.benchmark(*example_inputs_extern, out=out_extern)
1471            else:
1472                # triton templates want the base pointer for sliced tensors
1473                result = choice.benchmark(*example_inputs, out=out)
1474            if VERIFY and expected is not None:
1475                torch.testing.assert_close(out_extern, expected, **VERIFY)
1476            if torch.cuda.is_available():
1477                torch.cuda.synchronize()  # shake out any CUDA errors
1478            return result
1479
1480        def benchmark_in_current_process(choices):
1481            inputs = get_inputs()
1482            example_inputs, _, out, _, _ = inputs
1483            timings = {}
1484            for choice in choices:
1485                try:
1486                    timing = benchmark_choice_in_current_process(choice, *inputs)
1487                except CUDACompileError as e:
1488                    log.error(
1489                        "CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.",
1490                        str(e),
1491                    )
1492                    timing = float("inf")
1493                except NotImplementedError as e:
1494                    log.warning("Not yet implemented: %s", e)
1495                    timing = float("inf")
1496                except RuntimeError as e:
1497                    msg = str(e)
1498                    if "invalid argument" in msg:
1499                        msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n"
1500                    else:
1501                        if "illegal memory access" in msg:
1502                            msg += "\n\nEither error in template or triton bug.\n"
1503                    log.error(
1504                        "Runtime error during autotuning: \n%s. \nIgnoring this choice.",
1505                        msg,
1506                    )
1507                    timing = float("inf")
1508                except AssertionError as e:
1509                    raise AssertionError(  # noqa: B904
1510                        f"Incorrect result from choice {choice}\n\n{e}"
1511                    )
1512                except Exception as e:
1513                    try:
1514                        from triton.runtime.autotuner import OutOfResources
1515
1516                        if isinstance(e, OutOfResources):
1517                            log.warning(e)
1518                            timing = float("inf")
1519                        else:
1520                            raise e
1521                    except ImportError:
1522                        raise e from None
1523
1524                timings[choice] = timing
1525
1526            return timings
1527
1528        def benchmark_in_sub_process(choices):
1529            from . import autotune_process
1530
1531            # only benchmark triton kernel in sub process for now.
1532            # ATen/Extern kernel are still benchmarked in the current process.
1533            extern = [c for c in choices if isinstance(c, ExternKernelCaller)]
1534            triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
1535
1536            timings = benchmark_in_current_process(extern)
1537            timings.update(autotune_process.benchmark_in_sub_process(triton))
1538            return timings
1539
1540        benchmark = (
1541            benchmark_in_sub_process
1542            if config.autotune_in_subproc
1543            else benchmark_in_current_process
1544        )
1545
1546        return benchmark
1547
1548    @staticmethod
1549    def log_results(
1550        name: str,
1551        input_nodes: List[ir.IRNode],
1552        timings: Dict[ChoiceCaller, float],
1553        elapse: float,
1554        precompile_elapse: float,
1555    ):
1556        V.debug.log_autotuning_results(
1557            name, input_nodes, timings, elapse, precompile_elapse
1558        )
1559        if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE:
1560            return
1561        sizes = ", ".join(
1562            [
1563                "x".join(
1564                    map(
1565                        str,
1566                        V.graph.sizevars.size_hints(
1567                            n.get_size(), fallback=config.unbacked_symint_fallback
1568                        ),
1569                    )
1570                )
1571                for n in input_nodes
1572            ]
1573        )
1574
1575        n = None if log.getEffectiveLevel() == logging.DEBUG else 10
1576        top_k = sorted(timings, key=timings.__getitem__)[:n]
1577        best = top_k[0]
1578
1579        def get_choice_info(choice):
1580            if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller):
1581                return {"type": "cublas", "time": timings[choice]}
1582
1583            assert isinstance(
1584                choice, torch._inductor.select_algorithm.TritonTemplateCaller
1585            )
1586
1587            info = choice.info_dict()
1588            tile = info["tile_shape"]
1589
1590            tile_vals = eval(tile)  # type: ignore[arg-type]
1591            BLOCK_M = tile_vals[0]
1592            BLOCK_K = tile_vals[1]
1593            BLOCK_N = tile_vals[2]
1594
1595            return {
1596                "type": "triton",
1597                "time": timings[choice],
1598                "BLOCK_M": BLOCK_M,
1599                "BLOCK_K": BLOCK_K,
1600                "BLOCK_N": BLOCK_N,
1601                "num_stages": info["num_stages"],
1602                "num_warps": info["num_warps"],
1603            }
1604
1605        mm_filename = get_mm_log_filename()
1606        if mm_filename and "mm" in name:
1607            M, K = input_nodes[-2].get_size()[:2]
1608            N = input_nodes[-1].get_size()[-1]
1609
1610            out_dict = {
1611                str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()]
1612            }
1613
1614            append_to_log(mm_filename, out_dict)
1615
1616        best_time = timings[best]
1617        sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
1618        for choice in top_k:
1619            result = timings[choice]
1620            if result:
1621                kernel_info = (
1622                    choice.debug_extra if hasattr(choice, "debug_extra") else ""
1623                )
1624                sys.stderr.write(
1625                    f"  {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_info}\n"
1626                )
1627            else:
1628                sys.stderr.write(
1629                    f"  {choice.name} {result:.4f} ms <DIVIDED BY ZERO ERROR>\n"
1630                )
1631
1632        autotune_type_str = (
1633            "SubProcess" if config.autotune_in_subproc else "SingleProcess"
1634        )
1635        sys.stderr.write(
1636            f"{autotune_type_str} AUTOTUNE benchmarking takes {elapse:.4f} seconds and {precompile_elapse:.4f}"
1637            " seconds precompiling\n"
1638        )
1639
1640    @staticmethod
1641    def benchmark_example_value(node):
1642        """
1643        Convert an ir.Buffer into a concrete torch.Tensor we can use for
1644        benchmarking.
1645        """
1646        if isinstance(node, ir.Layout):
1647            node = ir.Buffer("fake", node)
1648        # triton templates want the base tensor.
1649        if isinstance(node, ir.BaseView):
1650            node = node.unwrap_view()
1651        return AlgorithmSelectorCache.generate_example_value(
1652            V.graph.sizevars.size_hints(
1653                node.get_size(),
1654                fallback=config.unbacked_symint_fallback,
1655            ),
1656            V.graph.sizevars.size_hints(
1657                node.get_stride(),
1658                fallback=config.unbacked_symint_fallback,
1659            ),
1660            node.get_device(),
1661            node.get_dtype(),
1662            node.layout.offset,
1663        )
1664
1665    @staticmethod
1666    def generate_example_value(size, stride, device, dtype, extra_size):
1667        # preserve rng states to avoid the rand_strided call below changes
1668        # the rng states for the real model code.
1669        with preserve_rng_state():
1670            return rand_strided(
1671                size,
1672                stride,
1673                device=device,
1674                dtype=dtype,
1675                extra_size=extra_size,
1676            )
1677
1678    @staticmethod
1679    def key_of(node):
1680        """
1681        Extract the pieces of an ir.Buffer that we should invalidate cached
1682        autotuning results on.
1683        """
1684        sizevars = V.graph.sizevars
1685        return (
1686            node.get_device().type,
1687            str(node.get_dtype()),
1688            *sizevars.size_hints(
1689                node.get_size(),
1690                fallback=config.unbacked_symint_fallback,
1691            ),
1692            *sizevars.size_hints(
1693                node.get_stride(),
1694                fallback=config.unbacked_symint_fallback,
1695            ),
1696            sizevars.size_hint(
1697                node.get_layout().offset,
1698                fallback=config.unbacked_symint_fallback,
1699            ),
1700        )
1701
1702    def add_feedback_saver(
1703        self,
1704        fn: Callable[
1705            [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None
1706        ],
1707    ):
1708        self.feedback_saver_fns.append(fn)
1709
1710
1711_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None
1712
1713
1714def autotune_select_algorithm(*args, **kwargs):
1715    global _ALGORITHM_SELECTOR_CACHE
1716    if _ALGORITHM_SELECTOR_CACHE is None:
1717        _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
1718
1719    if "return_multi_template" not in kwargs:
1720        kwargs[
1721            "return_multi_template"
1722        ] = torch._inductor.config.benchmark_epilogue_fusion
1723
1724    return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
1725
1726
1727def add_feedback_saver(
1728    fn: Callable[[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None]
1729):
1730    global _ALGORITHM_SELECTOR_CACHE
1731    if _ALGORITHM_SELECTOR_CACHE is None:
1732        _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
1733    _ALGORITHM_SELECTOR_CACHE.add_feedback_saver(fn)
1734
1735
1736def realize_inputs(*args):
1737    if len(args) == 1:
1738        return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0]))
1739    return [realize_inputs(x) for x in args]
1740
1741
1742# ensure lowering is imported so that `extern_kernels.*` is populated
1743from . import lowering  # noqa: F401
1744