xref: /aosp_15_r20/external/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import copy
4import dataclasses
5import inspect
6import logging
7import threading
8from collections import defaultdict
9from typing import Any, Dict, List, Optional, Union
10
11import torch
12import torch.fx as fx
13import torch.utils._pytree as pytree
14from torch import Tensor
15from torch._C import DispatchKey
16from torch._ops import HigherOrderOperator
17from torch._prims_common import clone_preserve_strides
18from torch._subclasses.fake_tensor import FakeTensorMode
19from torch.fx.experimental.proxy_tensor import (
20    disable_proxy_modes_tracing,
21    ProxyTorchDispatchMode,
22    track_tensor_tree,
23)
24
25
26log = logging.getLogger("torch._dynamo")
27
28
29###############################################################################
30# Kernel Side Table
31
32
33# We cannot put Triton Kernels into the FX graph as the graph nodes
34# do not support arbitrary functions.
35# Use a side table.
36# We use two dicts so that fetching both the kernel and id are O(1)
37class KernelSideTable:
38    id_to_kernel: Dict[int, Any] = {}
39    kernel_to_id: Dict[Any, int] = {}
40    constant_args: Dict[int, Any] = {}
41    lock = threading.Lock()
42
43    # Returns index on the table
44    def add_kernel(self, kernel) -> int:
45        with self.lock:
46            if kernel in self.kernel_to_id:
47                return self.kernel_to_id[kernel]
48
49            idx = len(self.id_to_kernel)
50            self.id_to_kernel[idx] = kernel
51            self.kernel_to_id[kernel] = idx
52            return idx
53
54    # Returns the triton kernel at the given index
55    def get_kernel(self, idx: int):
56        # No need to lock here as fetching from dict is atomic
57        assert idx in self.id_to_kernel
58        return self.id_to_kernel[idx]
59
60    # Not every constant arg can be added to the graph. Use this side table
61    # for constant args.
62    def add_constant_args(self, args) -> int:
63        with self.lock:
64            idx = len(self.constant_args)
65            self.constant_args[idx] = args
66            return idx
67
68    # Returns the constant args
69    def get_constant_args(self, idx: int):
70        # No need to lock here as fetching from dict is atomic
71        assert idx in self.constant_args
72        return self.constant_args[idx]
73
74    # Resets the table (only meant to be used in unit tests)
75    # This is only safe assuming single threaded execution
76    def reset_table(self) -> None:
77        self.id_to_kernel = {}
78        self.kernel_to_id = {}
79        self.constant_args = {}
80
81
82kernel_side_table = KernelSideTable()
83
84
85###############################################################################
86# Mutation Tracker
87
88
89@dataclasses.dataclass(frozen=True)
90class Param:
91    idx: int
92
93
94@dataclasses.dataclass(frozen=True)
95class Intermediate:
96    idx: int
97
98    def fake(self):
99        return self.idx < 0
100
101
102@dataclasses.dataclass(frozen=True)
103class Op:
104    name: str
105    fn_call_name: Optional[str]
106    args: List[Union[Param, Intermediate]]
107    ret: Intermediate = dataclasses.field(repr=False)
108
109    def __post_init__(self):
110        if self.name == "tt.call":
111            assert self.fn_call_name is not None
112        else:
113            assert self.fn_call_name is None
114
115
116def generate_ttir(kernel, kwargs):
117    """
118    Uses Triton's internal code generation to create TTIR
119    """
120    import sympy
121    import triton
122    from triton.compiler.compiler import ASTSource
123    from triton.runtime.autotuner import Autotuner
124    from triton.runtime.jit import JITFunction
125
126    import torch
127    import torch._inductor.ir
128    from torch._subclasses.fake_tensor import FakeTensor
129
130    if isinstance(kernel, Autotuner):
131        if len(kernel.configs) > 0:
132            # If we are autotuning, then it doesn't matter which version gets
133            # picked for tracing purposes, so lets pick the first one
134            kwargs = {**kwargs, **kernel.configs[0].kwargs}
135        kernel = kernel.fn
136
137    assert isinstance(kernel, JITFunction)
138
139    if len(kwargs) != len(kernel.arg_names):
140        raise ValueError("Incorrect number of arguments passed to kernel")
141
142    # Replace all SymExprs with a regular value for TTIR generation
143    # Replace all FakeTensor/TensorBox with real tensors
144    # These replacements are needed for triton's type, key and config functions
145    ordered_args: Dict[str, Any] = {}
146    for name in kernel.arg_names:
147        a = kwargs[name]
148        if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool, sympy.Expr)):
149            ordered_args[name] = 2
150        elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)):
151            with torch._C._DisableTorchDispatch():
152                ordered_args[name] = torch.empty(2, dtype=a.dtype)
153        else:
154            ordered_args[name] = a
155
156    ordered_tensor_names = [
157        name for name, arg in ordered_args.items() if isinstance(arg, Tensor)
158    ]
159    specialization = kernel._get_config(*ordered_args.values())
160    constants = {
161        name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor)
162    }
163
164    # Build kernel signature -- doesn't include constexpr arguments.
165    signature = {
166        name: kernel._type_of(kernel._key_of(arg))
167        for i, (name, arg) in enumerate(ordered_args.items())
168        if i not in kernel.constexprs
169    }
170
171    context = triton._C.libtriton.ir.context()
172    target = triton.runtime.driver.active.get_current_target()
173    backend = triton.compiler.compiler.make_backend(target)
174    options = backend.parse_options({})
175    triton._C.libtriton.ir.load_dialects(context)
176    backend.load_dialects(context)
177
178    src = ASTSource(kernel, signature, constants, specialization)
179
180    # Triton changes ASTSource.make_ir to take 3/4 arguments. Handle
181    # backward compatibility here.
182    make_ir_sig_params = len(inspect.signature(src.make_ir).parameters)
183    if make_ir_sig_params == 2:
184        ttir_module = src.make_ir(options, context)
185    elif make_ir_sig_params == 3:
186        codegen_fns = backend.get_codegen_implementation()
187        ttir_module = src.make_ir(options, codegen_fns, context)
188    else:
189        codegen_fns = backend.get_codegen_implementation()
190        module_map = backend.get_module_map()
191        ttir_module = src.make_ir(options, codegen_fns, module_map, context)
192    if not ttir_module.verify():
193        raise RuntimeError("Verification for TTIR module has failed")
194
195    return ttir_module, ordered_tensor_names
196
197
198def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]:
199    """
200    Walk the `ttir_module` bottom up to mine the `functions` from
201    the structured MLIR entities representing the Triton kernel
202    (mlir::Operation, mlir::Block, mlir::Region).
203    """
204    functions: Dict[str, Dict[Intermediate, List[Op]]] = {}
205
206    # block id --> op result (Intermediate) --> one or more ops
207    op_stack: Dict[int, Dict[Intermediate, List[Op]]] = defaultdict(
208        lambda: defaultdict(list)
209    )
210    region_id_to_block_ids: Dict[int, List[int]] = defaultdict(list)
211    block_id_to_block_arg_ids: Dict[int, List[int]] = {}
212    replacements: Dict[int, Union[Intermediate, Param]] = {}
213    reindex_map: Dict[int, int] = {}
214    next_fake_intermediate = 0
215
216    def reindex(idx):
217        if idx not in reindex_map:
218            reindex_map[idx] = len(reindex_map)
219        return reindex_map[idx]
220
221    def mlir_to_functions(op) -> None:
222        name: str = op.get_name()
223        if name == "builtin.module":
224            # this wraps all tt.func ops
225            return
226
227        operand_ids: List[int] = [
228            reindex(op.get_operand(i).id()) for i in range(op.get_num_operands())
229        ]
230        result_ids: List[int] = [
231            reindex(op.get_result(i).id()) for i in range(op.get_num_results())
232        ]
233
234        child_block_ids: List[int] = []
235        for i in [op.get_region(i).id() for i in range(op.get_num_regions())]:
236            # as the walk is bottom-up, the region_id_to_block_ids[i]
237            # must be populated by the time we process the enclosing op
238            child_block_ids.extend(region_id_to_block_ids[i])
239
240        parent_block_id = -1
241        parent_block = op.get_block()
242        if parent_block is not None:
243            parent_block_id = parent_block.id()
244            if parent_block_id not in block_id_to_block_arg_ids:
245                block_id_to_block_arg_ids[parent_block_id] = []
246                for i in range(parent_block.get_num_arguments()):
247                    block_id_to_block_arg_ids[parent_block_id].append(
248                        reindex(parent_block.get_argument(i).id()),
249                    )
250                # the region info is collected via ops' parent blocks to be
251                # used later when the region's encloding op is traversed
252                parent_region = parent_block.get_parent()
253                if parent_region is not None:
254                    region_id_to_block_ids[parent_region.id()].append(parent_block_id)
255
256        nonlocal next_fake_intermediate
257
258        if name == "tt.func":
259            # for function ops: gather and inline
260            # the ops from all child blocks
261            fn_ops = defaultdict(list)
262            for child_block_id in child_block_ids:
263                for result, block_fn_ops in op_stack.pop(child_block_id).items():
264                    for block_fn_op in block_fn_ops:
265                        fn_ops[result].append(block_fn_op)
266
267            # replace the corresponding Intermediates in the
268            # child op args with the function args (Params)
269            for i, idx in enumerate(block_id_to_block_arg_ids[child_block_ids[0]]):
270                replacements[idx] = Param(i)
271
272            for fn_op_list in fn_ops.values():
273                for fn_op in fn_op_list:
274                    for i in range(len(fn_op.args)):
275                        arg = fn_op.args[i]
276                        seen = set()  # to break cycles
277                        # there can be transitive replacements, but likely
278                        # no cycles (we keep the `seen` set just in case)
279                        while (
280                            isinstance(arg, Intermediate)
281                            and arg.idx in replacements
282                            and arg.idx not in seen
283                        ):
284                            seen.add(arg.idx)
285                            arg = fn_op.args[i] = replacements[arg.idx]
286
287            # next function capture starts
288            # with empty replacements
289            replacements.clear()
290
291            fn_name = op.get_str_attr("sym_name")
292            functions[fn_name] = fn_ops
293        elif child_block_ids:
294            if name in {"scf.if", "scf.for", "scf.while", "tt.reduce", "tt.scan"}:
295                # for blocked ops: inline the enclosed ops into
296                # the parent block + rewire the last op in each
297                # child block to return the block result
298                return_ops = []
299                for block_id in child_block_ids:
300                    if name == "scf.for":
301                        # example:
302                        # %result = scf.for %iv = %lb to %ub step %step iter_args(%arg = %init) -> (i32) ...
303                        # block args: 2 (%iv, %arg)
304                        # op operands: 4 (%lb, %ub, %step, %init)
305                        # `%arg` is mapping to `%init`
306                        for i, idx in enumerate(block_id_to_block_arg_ids[block_id]):
307                            if i == 0:
308                                next_fake_intermediate -= 1
309                                replacements[idx] = Intermediate(next_fake_intermediate)
310                            else:
311                                replacements[idx] = Intermediate(operand_ids[i + 2])
312                    elif name == "scf.while":
313                        # example:
314                        # %3:3 = scf.while (%arg2 = %1, %arg3 = %2, %arg4 = %c0_i32_8) ...
315                        # block args: 3 (%arg2, %arg3, %arg4)
316                        # op operands: 3 (%1, %2, %c0_i32_8)
317                        # `%arg2` is mapping to `%1`, `%arg3` is mapping to `%2`, ...
318                        for i, idx in enumerate(block_id_to_block_arg_ids[block_id]):
319                            replacements[idx] = Intermediate(operand_ids[i])
320                    elif name == "scf.if":
321                        # the scf block args are ignored by the pass. but, as they
322                        # may be used as operands of the ops inside the block
323                        # (and nested blocks inlined in the current block by now),
324                        # they are replaced by new fake Intermediates to avoid "this
325                        # operand is not returned by any other op in the fn" error
326                        # in the downstream analysis
327                        for idx in block_id_to_block_arg_ids[block_id]:
328                            next_fake_intermediate -= 1
329                            replacements[idx] = Intermediate(next_fake_intermediate)
330                    else:
331                        assert name in ("tt.reduce", "tt.scan")
332                        # wire the block arguments to the op arguments
333                        num_operands = len(operand_ids)
334                        block_arg_ids = block_id_to_block_arg_ids[block_id]
335                        assert len(block_arg_ids) == 2 * num_operands, (
336                            f"{name} is expected to have twice as "
337                            "many block arguments as op arguments: "
338                            f"{operand_ids=}, {block_arg_ids=}."
339                        )
340                        for i, idx in enumerate(block_arg_ids):
341                            # for a tt.reduce/tt.scan op with N arguments, the block
342                            # arguments comprise N reduced values followed by
343                            # N current values corresponding to the N op args
344                            replacements[idx] = Intermediate(
345                                operand_ids[i % num_operands]
346                            )
347
348                    if block_id in op_stack:
349                        block_ops = op_stack.pop(block_id)
350                        if not block_ops:
351                            continue
352                        last_ret, last_ops = block_ops.popitem()
353                        if all(
354                            op.name
355                            in ("scf.yield", "tt.reduce.return", "tt.scan.return")
356                            for op in last_ops
357                        ):
358                            # if last_ops are all return ops, treat them separately
359                            return_ops.extend(last_ops)
360                        else:
361                            # otherwise, return last_ops to the block
362                            block_ops[last_ret] = last_ops
363                        for op_result, child_ops in block_ops.items():
364                            op_stack[parent_block_id][op_result].extend(child_ops)
365
366                scf_results = [Intermediate(idx) for idx in result_ids]
367                for scf_result in scf_results:
368                    for return_op in return_ops:
369                        op_stack[parent_block_id][scf_result].append(return_op)
370            else:
371                raise RuntimeError(
372                    f"Unknown blocked function: {name}. Can't capture the TTIR."
373                )
374        else:
375            callee = None
376            if name == "tt.call":
377                callee = op.get_flat_symbol_ref_attr("callee")
378            args: List[Union[Param, Intermediate]] = [
379                Intermediate(operand) for operand in operand_ids
380            ]
381            block_ops = op_stack[parent_block_id]
382            if result_ids:
383                for result_id in result_ids:
384                    res = Intermediate(result_id)
385                    block_ops[res].append(Op(name, callee, args, res))
386            else:
387                next_fake_intermediate -= 1
388                fake_res = Intermediate(next_fake_intermediate)
389                block_ops[fake_res].append(Op(name, callee, args, fake_res))
390
391    ttir_module.walk(mlir_to_functions)
392
393    return functions
394
395
396class MemoizeWithCycleCheck:
397    def __init__(self, fn):
398        self.fn = fn
399        self.reset()
400
401    def __call__(self, functions, fn_name, num_args):
402        key = (fn_name, num_args)
403        if key not in self.cache:
404            self.cache[key] = None
405            self.cache[key] = self.fn(functions, fn_name, num_args)
406        if self.cache[key] is None:
407            raise RuntimeError("Recursion is not supported")
408        return self.cache[key]
409
410    def reset(self):
411        self.cache = {}
412
413
414@MemoizeWithCycleCheck
415def analyze_kernel_mutations(functions, fn_name, num_args):
416    """
417    Analyzes the graph to detect all sinks from a predefined list of sinks
418    by using triton's MemWrite trait list. NOTE: What if triton exposed this?
419    From each sink, it traverses the CFG backwards to identify all the input
420    pointers that are mutated.
421    """
422    # Name of mutation op to mutated parameter indices
423    # List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td
424    # All the OPs that have MemWrite trait.
425    # What if Triton exposed this?
426    MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]}
427    # Ops that we want to bail out on
428    UNKNOWN_OPS = {"tt.elementwise_inline_asm"}
429
430    stack: List[Union[Param, Intermediate]] = []
431    visited = set()
432    ops = functions[fn_name]
433    for op_list in ops.values():
434        for op in op_list:
435            if op.name in UNKNOWN_OPS:
436                raise RuntimeError(
437                    f"ttir analysis hit an op we do not know how to analyze: {op.name}"
438                )
439
440            if op.name == "tt.call":
441                assert op.fn_call_name in functions
442                mutations = analyze_kernel_mutations(
443                    functions, op.fn_call_name, len(op.args)
444                )
445                stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated)
446            else:
447                for idx in MUTATION_OPS.get(op.name, []):
448                    stack.append(op.args[idx])
449
450    # The following is an iterative DFS algorithm
451    mutated = [False] * num_args
452    while stack:
453        arg = stack.pop()
454        if arg in visited:
455            continue
456
457        visited.add(arg)
458
459        if isinstance(arg, Param):
460            if arg.idx >= num_args:
461                # This is an argument defined in the kernel, not passed in
462                continue
463            mutated[arg.idx] = True
464        elif isinstance(arg, Intermediate) and not arg.fake():
465            for op in ops[arg]:
466                # Skip arguments to load
467                if op.name != "tt.load":
468                    stack.extend(op.args)
469    return mutated
470
471
472def identify_mutated_tensors(kernel, kwargs):
473    """
474    Given a triton kernel and the arguments for this kernel, this function
475    1) Retrieves the TTIR converted version of the kernel from Triton's API.
476    2) Parses the TTIR and creates a control flow graph
477    3) Analyzes the graph to detect all input tensor mutations
478    """
479
480    ttir_module = None
481    functions = None
482    try:
483        ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
484
485        # extract functions from TTIR using MLIR bindings exposed by Triton code
486        functions = ttir_to_functions(ttir_module)
487
488        assert functions is not None
489        kernel_name = next(iter(functions.keys()))
490        # Triton codegen modifies the name
491        assert kernel.fn.__name__ in kernel_name
492        # Reset the cache between top level invocations
493        # The cache for analyze kernel mutations is mainly used for cycle
494        # detection, so each top level invocation needs a clean cache
495        analyze_kernel_mutations.reset()
496        mutations = analyze_kernel_mutations(
497            functions, kernel_name, len(ordered_tensor_names)
498        )
499
500        return [
501            ordered_tensor_names[i] for i, mutated in enumerate(mutations) if mutated
502        ]
503    except Exception as e:
504        log.warning(
505            "Encountered an exception in identify_mutated_tensors, assuming every input is mutated",
506            exc_info=True,
507        )
508        if ttir_module is not None:
509            log.debug("TTIR:\n%s", str(ttir_module))
510        if functions is not None:
511            log.debug("functions:")
512            for name, fn in functions.items():
513                log.debug("===\t%s\t===", name)
514                for ret, ops in fn.items():
515                    log.debug("%s\t=>\t%s", ret, ops)
516        return [key for key, value in kwargs.items() if isinstance(value, Tensor)]
517
518
519###############################################################################
520# Triton Kernel Wrappers
521
522
523# Used for wrapping a Triton Kernel
524class TritonKernelWrapperMutation(HigherOrderOperator):
525    def __init__(self) -> None:
526        super().__init__("triton_kernel_wrapper_mutation")
527
528    def __call__(self, kernel_idx, constant_args_idx, grid, kwargs):
529        return super().__call__(
530            kernel_idx=kernel_idx,
531            constant_args_idx=constant_args_idx,
532            grid=grid,
533            kwargs=kwargs,
534        )
535
536
537triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()
538
539
540# Used for wrapping a Triton Kernel in a functional manner
541class TritonKernelWrapperFunctional(HigherOrderOperator):
542    def __init__(self) -> None:
543        super().__init__("triton_kernel_wrapper_functional")
544
545    def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone):
546        return super().__call__(
547            kernel_idx=kernel_idx,
548            constant_args_idx=constant_args_idx,
549            grid=grid,
550            kwargs=kwargs,
551            tensors_to_clone=tensors_to_clone,
552        )
553
554
555triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()
556
557
558@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
559def triton_kernel_wrapper_mutation_dense(
560    *, kernel_idx, constant_args_idx, grid, kwargs
561):
562    from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code
563
564    kernel = kernel_side_table.get_kernel(kernel_idx)
565    constant_args = kernel_side_table.get_constant_args(constant_args_idx)
566
567    if len(grid) == 1:
568        grid_fn = grid[0]
569    else:
570        fn_name, code = user_defined_kernel_grid_fn_code(
571            kernel.fn.__name__, kernel.configs, grid
572        )
573        namespace: Dict[str, Any] = {}
574        exec(code, namespace)
575        grid_fn = namespace[fn_name]
576
577    kernel[grid_fn](**kwargs, **constant_args)
578
579
580@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode)
581def triton_kernel_wrapper_mutation_fake_tensor_mode(
582    mode, *, kernel_idx, constant_args_idx, grid, kwargs
583):
584    with mode:
585        return None
586
587
588@triton_kernel_wrapper_mutation.py_impl(DispatchKey.Meta)
589def _(*, kernel_idx, constant_args_idx, grid, kwargs):
590    return None
591
592
593def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args):
594    with disable_proxy_modes_tracing():
595        out = func_overload(**node_args)
596
597    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
598    out_proxy = proxy_mode.tracer.create_proxy(
599        "call_function",
600        func_overload,
601        (),
602        proxy_args,
603        name=func_overload.__name__ + "_proxy",
604    )
605    ret = track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
606    return ret
607
608
609@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode)
610def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
611    mode, *, kernel_idx, constant_args_idx, grid, kwargs
612):
613    trace_triton_kernel_wrapper(
614        mode,
615        triton_kernel_wrapper_mutation,
616        {
617            "kernel_idx": kernel_idx,
618            "constant_args_idx": constant_args_idx,
619            "grid": grid,
620            "kwargs": kwargs,
621        },
622    )
623
624    return None
625
626
627@triton_kernel_wrapper_mutation.py_functionalize_impl
628def triton_kernel_wrapper_mutation_functionalize(
629    ctx, kernel_idx, constant_args_idx, grid, kwargs
630):
631    unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
632    kernel = kernel_side_table.get_kernel(kernel_idx)
633    constant_args = kernel_side_table.get_constant_args(constant_args_idx)
634    # TODO(oulgen): Preexisting bug, if two kernel inputs are views of each
635    # other, and one gets mutated in kernel, and later another gets mutated,
636    # they are no longer equal. Fix this by graph breaking on this condition
637    # earlier in dynamo.
638    tensors_to_clone = identify_mutated_tensors(
639        kernel, {**unwrapped_kwargs, **constant_args}
640    )
641    with ctx.redispatch_to_next():
642        unwrapped_outputs = triton_kernel_wrapper_functional(
643            kernel_idx=kernel_idx,
644            constant_args_idx=constant_args_idx,
645            grid=grid,
646            kwargs=unwrapped_kwargs,
647            tensors_to_clone=tensors_to_clone,
648        )
649
650    assert set(unwrapped_outputs.keys()).issubset(set(kwargs.keys()))
651    for key, output_arg in unwrapped_outputs.items():
652        if not isinstance(output_arg, Tensor):
653            continue
654        input_arg = kwargs[key]
655        assert isinstance(input_arg, Tensor)
656
657        ctx.replace(input_arg, output_arg)
658        # indicate that above replace is hidden from autograd
659        ctx.mark_mutation_hidden_from_autograd(input_arg)
660        ctx.commit_update(input_arg)
661        ctx.sync(input_arg)
662    return None
663
664
665@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd)
666def triton_kernel_wrapper_functional_dense(
667    *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
668):
669    # TODO(oulgen): For performance reasons, we want to ensure that these
670    # `clone_preserve_strides` calls are never executed at runtime
671    # (inductor should always optimize them away).
672    # Requires https://github.com/pytorch/pytorch/issues/109240
673    kwargs = {
674        key: (clone_preserve_strides(val) if key in tensors_to_clone else val)
675        for key, val in kwargs.items()
676    }
677    triton_kernel_wrapper_mutation(
678        kernel_idx=kernel_idx,
679        constant_args_idx=constant_args_idx,
680        grid=grid,
681        kwargs=kwargs,
682    )
683    return {key: val for key, val in kwargs.items() if key in tensors_to_clone}
684
685
686@triton_kernel_wrapper_functional.py_impl(FakeTensorMode)
687def triton_kernel_wrapper_functional_fake_tensor_mode(
688    mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
689):
690    # TODO(oulgen): For performance reasons, we want to ensure that these
691    # `clone_preserve_strides` calls are never executed at runtime
692    # (inductor should always optimize them away).
693    # Requires https://github.com/pytorch/pytorch/issues/109240
694    with mode:
695        return {
696            key: clone_preserve_strides(val)
697            for key, val in kwargs.items()
698            if key in tensors_to_clone
699        }
700
701
702@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode)
703def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
704    mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
705):
706    return trace_triton_kernel_wrapper(
707        mode,
708        triton_kernel_wrapper_functional,
709        {
710            "kernel_idx": kernel_idx,
711            "constant_args_idx": constant_args_idx,
712            "grid": grid,
713            "kwargs": kwargs,
714            "tensors_to_clone": tensors_to_clone,
715        },
716    )
717
718
719@triton_kernel_wrapper_functional.py_functionalize_impl
720def triton_kernel_wrapper_functional_functionalize(
721    ctx, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
722):
723    unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
724    with ctx.redispatch_to_next():
725        outputs = triton_kernel_wrapper_functional(
726            kernel_idx=kernel_idx,
727            constant_args_idx=constant_args_idx,
728            grid=grid,
729            kwargs=unwrapped_kwargs,
730            tensors_to_clone=tensors_to_clone,
731        )
732        return ctx.wrap_tensors(outputs)
733
734
735triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher)  # type: ignore[attr-defined]
736triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot)  # type: ignore[attr-defined]
737triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView)
738triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect)
739triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU)  # type: ignore[attr-defined]
740triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA)  # type: ignore[attr-defined]
741triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA)
742triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCPU)
743
744triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher)  # type: ignore[attr-defined]
745triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot)  # type: ignore[attr-defined]
746triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView)
747triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect)
748triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU)  # type: ignore[attr-defined]
749triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA)  # type: ignore[attr-defined]
750triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
751triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
752triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU)
753
754
755###############################################################################
756# The "TritonHOPifier": a class that transforms a call to a triton kernel into
757# a call to the triton_kernel_wrapper_mutation HOP.
758
759
760class TritonHOPifier:
761    """Orchestrator for converting a user-defined triton kernel into a call
762    to the triton_kernel_wrapper_mutation HOP.
763
764    It has two main use cases.
765
766    1. When Dynamo sees a triton kernel, it wraps it into a TritonKernelVariable
767    and uses the TritonHOPifier to convert calls to the TritonKernelVariable
768    into a call to the HOP.
769
770    2. In order to capture a user-defined triton kernel while performing
771    tracing (via make_fx or non-strict export), a user must annotate their
772    triton kernel with the `capture_triton` decorator. The decorator uses
773    TritonHOPifier to convert calls to the triton kernel into a call
774    to the HOP (which can then be traced).
775
776    Because Dynamo has its own calling conventions for e.g. invoking a user-defined function
777    TritonHOPifier is an abstract class that can be overriden by its subclasses.
778    """
779
780    def raise_unsupported(self, msg):
781        raise NotImplementedError("abstract method")
782
783    def is_callable(self, maybe_callable):
784        raise NotImplementedError("abstract method")
785
786    def get_value(self, val):
787        raise NotImplementedError("abstract method")
788
789    def call_grid(self, grid, meta, tx):
790        raise NotImplementedError("abstract method")
791
792    def call_HOP(self, variable, grids, combined_args, tx):
793        raise NotImplementedError("abstract method")
794
795    def check_grid(self, grid):
796        raise NotImplementedError("abstract method")
797
798    def init_variable(self, variable, kernel, kernel_idx, grid):
799        from triton.runtime.autotuner import Autotuner
800
801        assert kernel is not None
802
803        variable.kernel = kernel
804        variable.kernel_idx = kernel_side_table.add_kernel(kernel)
805
806        assert kernel_idx is None or variable.kernel_idx == kernel_idx
807
808        variable.grid = grid
809
810        if isinstance(kernel, Autotuner):
811            import torch
812            import torch._dynamo
813
814            # We only support configs and keys arguments of triton.autotune
815            # Make sure other arguments are defaulted
816            defaults = inspect.signature(Autotuner.__init__).parameters
817
818            # Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
819            # The call to get_first_attr is to maintain backward-compatibility.
820            if (
821                not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args
822                and (
823                    (
824                        "warmup" in defaults
825                        and defaults["warmup"].default
826                        != torch._dynamo.utils.get_first_attr(
827                            kernel, "num_warmups", "warmup"
828                        )
829                    )
830                    or (
831                        "rep" in defaults
832                        and defaults["rep"].default
833                        != torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep")
834                    )
835                    or (
836                        "prune_configs_by" in defaults
837                        and defaults["prune_configs_by"].default
838                        != kernel.early_config_prune
839                    )
840                    # Set via reset_to_zero argument
841                    or len(kernel.reset_idx) != 0
842                    or len(kernel.restore_idx) != 0
843                    or (
844                        "use_cuda_graph" in defaults
845                        and defaults["use_cuda_graph"].default != kernel.use_cuda_graph
846                    )
847                )
848            ):
849                self.raise_unsupported(
850                    "Only configs and keys are supported for triton.autotune"
851                )
852
853    def call_getitem(self, variable, args):
854        # __getitem__ should only be called if we don't already have a grid
855        # Only grid needs to be passed
856        if variable.grid is not None or len(args) != 1:
857            self.raise_unsupported(
858                "Triton kernels should be called with only a single grid"
859            )
860
861        return type(variable)(
862            kernel=variable.kernel,
863            kernel_idx=variable.kernel_idx,
864            grid=args[0],
865        )
866
867    def call_run(self, variable, args, kwargs, tx):
868        if "grid" not in kwargs:
869            self.raise_unsupported("Triton kernel requires to be called with a grid")
870        grid = kwargs.pop("grid")
871        kwargs.pop("warmup", None)
872        # rewrite kernel.run(*args, grid=grid) to kernel[grid](*args)
873        return self.call_triton_kernel(
874            type(variable)(
875                kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=grid
876            ),
877            args,
878            kwargs,
879            tx,
880        )
881
882    def call_triton_kernel(self, variable, args, kwargs, tx):
883        from triton.runtime.autotuner import autotune, Autotuner, Config
884
885        if "num_ctas" in kwargs:
886            self.raise_unsupported(
887                "Passing num_ctas directly to the Triton kernel is not supported. "
888                "Please use a Config in @triton.autotune instead."
889            )
890
891        special_kwargs = {}
892        for name in ("num_warps", "num_stages"):
893            if name in kwargs:
894                # remove special kwargs from `kwargs`
895                val = kwargs.pop(name)
896                special_kwargs[name] = self.get_value(val)
897
898        if special_kwargs:
899            if isinstance(variable.kernel, Autotuner):
900                # if there is Autotuner already, set
901                # special kwargs to each of its configs
902                new_configs = copy.deepcopy(variable.kernel.configs)
903                for config in new_configs:
904                    config.__dict__.update(special_kwargs)
905                new_kernel = autotune(configs=new_configs, key=[])(variable.kernel.fn)
906            else:
907                # if there is no Autotuner, wrap the kernel into a
908                # new one with a single config with special kwargs
909                new_config = Config(kwargs={}, **special_kwargs)
910                new_kernel = autotune(configs=[new_config], key=[])(variable.kernel)
911
912            # create a new variable to contain the new (wrapped) kernel;
913            # skip kernel_idx to get a new record in the kernel side table
914            new_var = type(variable)(new_kernel, None, variable.grid)
915            return self.call_triton_kernel(new_var, args, kwargs, tx)
916
917        if variable.grid is None:
918            self.raise_unsupported("Triton kernels should always be called with a grid")
919
920        # Both for grid's meta as well as for the kernel, we need combined
921        # args and kwargs combined and normalized
922        combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs}
923
924        configs = (
925            [config.kwargs for config in variable.kernel.configs]
926            if isinstance(variable.kernel, Autotuner)
927            else [{}]
928        )
929        grids = []
930        for config_args in configs:
931            # If the grid is a function, then lets execute it and convert it to
932            # a list
933            grid = variable.grid
934            if self.is_callable(grid):
935                # Populate the special "meta" argument to call the grid function
936                meta = {**combined_args_raw, **config_args}
937                grid = self.call_grid(grid, meta, tx)
938            grids.append(self.check_grid(grid))
939
940        for i in range(len(grids)):
941            if not isinstance(grids[i], tuple):
942                self.raise_unsupported("Only tuple grids are supported")
943            # inductor expects all grids to be 3-tuple so lets make it
944            if len(grids[i]) == 1:
945                grids[i] = (grids[i][0], 1, 1)
946            elif len(grids[i]) == 2:
947                grids[i] = (grids[i][0], grids[i][1], 1)
948            elif len(grids[i]) > 3:
949                self.raise_unsupported("Grid can have at most rank 3")
950
951        assert len(grids) != 0
952
953        def intify(x):
954            if isinstance(x, torch.SymInt):
955                return int(x)
956            else:
957                return x
958
959        if len(set(pytree.tree_map(intify, grids))) == 1:
960            # If there's only one unique grid, lets simplify
961            grids = [grids[0]]
962
963        return self.call_HOP(variable, grids, combined_args_raw, tx)
964
965
966###############################################################################
967# Helpers for capture_triton API that makes a user-defined triton kernel traceable into
968# a graph via make_fx or non-strict export (coming soon)
969
970
971class TracingTritonHOPifier(TritonHOPifier):
972    def raise_unsupported(self, msg):
973        raise RuntimeError(msg)
974
975    def is_callable(self, maybe_callable):
976        return callable(maybe_callable)
977
978    def get_value(self, val):
979        return val
980
981    def call_grid(self, grid, meta, tx):
982        assert tx is None
983        return grid(meta)
984
985    def check_grid(self, grid):
986        if not isinstance(grid, collections.abc.Sequence):
987            raise RuntimeError(
988                "capture_triton can only handle grids that resolve to Sequence[int]."
989            )
990        # normalize to tuple
991        return tuple(grid)
992
993    def call_HOP(self, variable, grids, combined_args, tx):
994        assert tx is None
995
996        def is_graphable(val):
997            return isinstance(val, fx.node.base_types)
998
999        non_graphable_args = {
1000            k: v for k, v in combined_args.items() if not is_graphable(v)
1001        }
1002        graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
1003
1004        constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
1005        return triton_kernel_wrapper_mutation(
1006            kernel_idx=variable.kernel_idx,
1007            constant_args_idx=constant_args_idx,
1008            grid=grids,
1009            kwargs=graphable_args,
1010        )
1011
1012
1013tracing_triton_hopifier_singleton = TracingTritonHOPifier()
1014
1015
1016class TraceableTritonKernelWrapper:
1017    def __init__(self, kernel, kernel_idx, grid):
1018        self.kernel = None
1019        self.grid = None
1020        tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
1021        assert self.kernel is not None
1022
1023    def __getitem__(self, *args):
1024        return tracing_triton_hopifier_singleton.call_getitem(self, args)
1025
1026    def run(self, *args, **kwargs):
1027        from torch._library.triton import is_capture_triton_enabled
1028
1029        if is_capture_triton_enabled():
1030            return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
1031        else:
1032            assert self.kernel is not None
1033            return self.kernel.run(*args, **kwargs)
1034
1035    def __call__(self, *args, **kwargs):
1036        from torch._library.triton import is_capture_triton_enabled
1037
1038        if is_capture_triton_enabled():
1039            return tracing_triton_hopifier_singleton.call_triton_kernel(
1040                self, args, kwargs, None
1041            )
1042        else:
1043            assert self.kernel is not None
1044            return self.kernel[self.grid](*args, **kwargs)
1045