xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import dataclasses
4import functools
5import itertools
6import logging
7import math
8import operator
9import re
10from enum import auto, Enum
11from itertools import chain
12from typing import (
13    Any,
14    Callable,
15    ClassVar,
16    Dict,
17    List,
18    NamedTuple,
19    Optional,
20    Tuple,
21    Union,
22)
23
24import sympy
25from sympy.printing.printer import Printer
26
27import torch
28import torch.fx
29from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
30from torch.utils import _pytree as pytree
31from torch.utils._ordered_set import OrderedSet
32from torch.utils._sympy.numbers import int_oo
33from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
34from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
35
36from .. import config, metrics
37from ..utils import (
38    DeferredLineBase,
39    generate_assert,
40    IndentedBuffer,
41    sympy_dot,
42    sympy_subs,
43    unique,
44)
45from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
46
47
48schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
49
50
51def data_type_logger(msg):
52    if schedule_log.isEnabledFor(logging.DEBUG):
53        schedule_log.debug("Data type propagation: %s", msg)
54
55
56@dataclasses.dataclass
57class WorkspaceArg:
58    """A temporary buffer used for a single kernel, then discarded.
59
60    Not registered as a traditional buffer since there are no users,
61    so it would be dead code eliminated.
62    """
63
64    nbytes: sympy.Expr
65    zero_fill: bool
66
67
68@dataclasses.dataclass
69class TensorArg:
70    name: str
71    buffer: str
72    dtype: torch.dtype
73    offset: sympy.Expr = sympy.Integer(0)  # c++ only
74    alias_of: Optional[str] = None  # halide only
75
76
77@dataclasses.dataclass
78class SizeArg:
79    name: str
80    expr: sympy.Expr
81
82    @property
83    def alias_of(self):
84        return None
85
86
87@dataclasses.dataclass
88class DeviceCodegen:
89    scheduling: Any
90    wrapper_codegen: type
91    cpp_wrapper_codegen: type = type(None)
92
93
94KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
95
96device_codegens: Dict[str, DeviceCodegen] = {}
97
98
99class DeviceOpOverrides:
100    def import_get_raw_stream_as(self, name):
101        raise NotImplementedError
102
103    def set_device(self, device_idx):
104        raise NotImplementedError
105
106    def synchronize(self):
107        raise NotImplementedError
108
109    def device_guard(self, device_idx):
110        raise NotImplementedError
111
112
113device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
114
115
116# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
117# For any new backend looking to integrate with Inductor, customization of these two main
118# parts are necessary to generate its specific code.
119#
120# Kernel code generation is determined by different Scheduling. Consequently, a new
121# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
122# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
123#
124# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
125# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
126# and override specific member functions to create backend-specific Python wrapper code.
127#
128# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
129# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
130# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
131# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
132# register_backend_for_device, to equip a new backend at runtime.
133#
134# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
135# This backend can be used as a reference:
136# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
137def register_backend_for_device(
138    device: str,
139    device_scheduling: Any,
140    device_wrapper_codegen: type,
141    device_cpp_wrapper_codegen: type = type(None),
142):
143    device_codegens[device] = DeviceCodegen(
144        device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
145    )
146
147
148class BackendFeature(Enum):
149    FOREACH = auto()
150    BUCKETIZE = auto()
151    INPLACE_BUFFERS = auto()
152    MASKED_SCATTER_WITH_INDEX = auto()
153    SCAN = auto()
154    SORT = auto()
155    TUPLE_REDUCTION = auto()
156    PREFER_STORE_LOOP_ORDER = auto()
157    TRITON_TEMPLATES = auto()
158    REDUCE_TO_SINGLE_ELEMENT = auto()
159
160
161def get_backend_features(device: Union[torch.device, str]):
162    init_backend_registration()
163    if isinstance(device, torch.device):
164        device_type = device.type
165    else:
166        assert isinstance(device, str)
167        device_type = device
168        device = torch.device(device_type)
169    scheduling = get_scheduling_for_device(device_type)
170    return scheduling(None).get_backend_features(device)
171
172
173def has_backend_feature(device, feature):
174    """See also V.graph.has_feature"""
175    assert isinstance(feature, BackendFeature)
176    return feature in get_backend_features(device)
177
178
179def get_scheduling_for_device(device: str):
180    return device_codegens[device].scheduling if device in device_codegens else None
181
182
183def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
184    if device in device_codegens:
185        wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
186        return (
187            wrapper_codegen_obj.cpp_wrapper_codegen
188            if cpp_wrapper
189            else wrapper_codegen_obj.wrapper_codegen
190        )
191    else:
192        return None
193
194
195@functools.lru_cache(None)
196def init_backend_registration():
197    from .cpp import CppScheduling
198    from .cpp_wrapper_cpu import CppWrapperCpu
199    from .cpp_wrapper_cuda import CppWrapperCuda
200    from .cuda_combined_scheduling import CUDACombinedScheduling
201    from .halide import HalideScheduling
202    from .triton import TritonScheduling
203    from .wrapper import WrapperCodeGen
204
205    if get_scheduling_for_device("cpu") is None:
206        cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling}
207        register_backend_for_device(
208            "cpu",
209            lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs),
210            WrapperCodeGen,
211            CppWrapperCpu,
212        )
213
214    if get_scheduling_for_device("cuda") is None:
215        # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
216        cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling}
217        register_backend_for_device(
218            "cuda",
219            lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs),
220            WrapperCodeGen,
221            CppWrapperCuda,
222        )
223
224    if get_scheduling_for_device("xpu") is None:
225        register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen)
226
227    private_backend = torch._C._get_privateuse1_backend_name()
228    if (
229        private_backend != "privateuseone"
230        and get_scheduling_for_device(private_backend) is None
231    ):
232        from torch.utils.backend_registration import _get_custom_mod_func
233
234        try:
235            device_scheduling = _get_custom_mod_func("Scheduling")
236            wrapper_codegen = _get_custom_mod_func("WrapperCodeGen")
237            cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodeGen")
238            if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
239                register_backend_for_device(
240                    private_backend,
241                    device_scheduling,
242                    wrapper_codegen,
243                    cpp_wrapper_codegen,
244                )
245        except RuntimeError:
246            pass
247
248
249def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
250    from ..ir import FlexibleLayout
251
252    # added contiguous index prevents reordering
253    return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
254
255
256def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
257    device_op_overrides_dict[device] = device_op_overrides
258
259
260def get_device_op_overrides(device: str):
261    assert isinstance(device, str)
262
263    if not device_op_overrides_dict.keys():
264        from .cuda import device_op_overrides  # noqa: F401
265        from .xpu import device_op_overrides as xpu_op_overrides  # noqa: F401
266
267    if device in device_op_overrides_dict.keys():
268        return device_op_overrides_dict[device]
269
270
271@functools.lru_cache(None)
272def boolean_ops():
273    return (
274        "isinf",
275        "isnan",
276        "logical_not",
277        "signbit",
278        "le",
279        "lt",
280        "ge",
281        "gt",
282        "eq",
283        "ne",
284    )
285
286
287DTYPE_TO_COMPUTATION_DTYPE = {
288    torch.bfloat16: torch.float,
289    torch.float16: torch.float,
290    **{
291        dtype: dtype
292        for dtype in [
293            torch.bool,
294            torch.float32,
295            torch.float64,
296            torch.int8,
297            torch.int16,
298            torch.int32,
299            torch.int64,
300            torch.uint8,
301            torch.uint16,
302            torch.uint32,
303            torch.uint64,
304        ]
305    },
306}
307
308
309def deduce_output_dtype_by_name(
310    op_name: str,
311    *args,
312    **kwargs,
313) -> Optional[torch.dtype]:
314    """
315    Given op name and a list of input dtypes, deduce the output dtype
316    """
317    if op_name in boolean_ops():
318        return torch.bool
319    elif op_name in (
320        "to_dtype",
321        "index_expr",
322    ):
323        return kwargs["dtype"] if "dtype" in kwargs else args[-1]
324    elif op_name in (
325        "rand",
326        "randn",
327    ):
328        return torch.float
329    elif op_name in (
330        "get_index",
331        "randint64",
332        "load_seed",
333    ):
334        return torch.int64
335    elif op_name == "reduction":
336        return kwargs["dtype"] if "dtype" in kwargs else args[1]
337    elif op_name == "constant":
338        dtype = kwargs["dtype"] if "dtype" in kwargs else args[-1]
339        return DTYPE_TO_COMPUTATION_DTYPE[dtype]  # type: ignore[index]
340    elif op_name in (
341        "load",
342        "store",
343        "store_reduction",
344    ):
345        buf_name = args[1]
346        return V.graph.get_dtype(buf_name)  # type: ignore[arg-type]
347    elif op_name == "to_dtype_bitcast":
348        return kwargs["dtype"] if "dtype" in kwargs else args[-2]
349    return None
350
351
352class DataTypePropagation:
353    def __init__(self, body) -> None:
354        self.body = body
355        self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
356            "root": body.root_block.graph
357        }
358        for k, v in body.subblocks.items():
359            self.graphs[k] = v.graph
360
361    def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
362        inputs = node.all_input_nodes
363        input_nodes = [
364            n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
365        ]
366        if len(input_nodes) == 0:
367            return None
368
369        all_input_nodes_propagated = all(
370            OptimizationContext.key in n.meta
371            and n.meta[OptimizationContext.key].dtype is not None
372            for n in input_nodes
373        )
374        if not all_input_nodes_propagated:
375            return None
376
377        return functools.reduce(
378            torch.promote_types,
379            [n.meta[OptimizationContext.key].dtype for n in input_nodes],
380        )
381
382    def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
383        sub_graph = self.graphs[node.target]
384        dtype = self.propagate_graph(sub_graph)
385        assert dtype
386        return dtype
387
388    def deduce_node_dtype(self, node: torch.fx.Node):
389        if node.op == "placeholder":
390            return None
391
392        if node.target == "output" and len(node.args) != 1:
393            # we can infer output node if it only have 1 arg
394            return None
395
396        if node.target == operator.getitem:
397            return self.deduce_node_dtype(node.args[0])  # type: ignore[arg-type]
398
399        assert isinstance(node.target, str)
400
401        if node.target.startswith("masked_subblock"):
402            return self.deduce_node_dtype_by_subgraph(node)
403
404        if (
405            output_dtype := deduce_output_dtype_by_name(
406                node.target,
407                *node.args,
408                **node.kwargs,
409            )
410        ) is not None:
411            return output_dtype
412
413        return self.deduce_node_dtype_by_inputs(node)
414
415    def propagate_graph(self, graph: torch.fx.Graph):
416        assert graph.nodes
417        graph_dtype = None
418        # For masked_subblock, we use output's dtype to represent
419        # the dtype of this subgraph. For other cases, graph_dtype
420        # might be None
421        for node in graph.nodes:
422            if OptimizationContext.key in node.meta:
423                opt_ctx = node.meta[OptimizationContext.key]
424            else:
425                opt_ctx = OptimizationContext()
426
427            opt_ctx.dtype = self.deduce_node_dtype(node)
428            node.meta[OptimizationContext.key] = opt_ctx
429            if node.target == "output":
430                graph_dtype = opt_ctx.dtype
431        return graph_dtype
432
433    def propagate(self):
434        self.propagate_graph(self.graphs["root"])
435
436    @classmethod
437    def propagate_loopbody(cls, body):
438        return cls(body).propagate()
439
440    @classmethod
441    def propagate_scheduler_node(cls, node):
442        from ..loop_body import LoopBody
443        from ..scheduler import SchedulerNode
444
445        assert isinstance(node, SchedulerNode)
446        assert isinstance(node._body, LoopBody)
447        DataTypePropagation.propagate_loopbody(node._body)
448
449
450# This printer contains rules that are supposed to be generic for both C/C++ and
451# Python
452class ExprPrinter(Printer):
453    @staticmethod
454    def paren(string):
455        def all_in_parens(string):
456            if string[0] != "(" or len(string) < 2:
457                return False
458            count = 1
459            for i, char in enumerate(string[1:]):
460                if char == "(":
461                    count += 1
462                elif char == ")":
463                    count -= 1
464                if count == 0 and i != len(string) - 2:
465                    return False
466            assert count == 0
467            return True
468
469        if (
470            isinstance(string, CSEVariable)
471            or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
472            or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
473            or string == ""
474        ):
475            return string
476        # don't put extra parens for strings that are already wrapped in parens
477        if all_in_parens(string):
478            return string
479        return f"({string})"
480
481    def _print_Relational(self, expr):
482        return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
483
484    def _print_Mul(self, expr):
485        return "*".join(map(self.paren, map(self._print, expr.args)))
486
487    def _print_Add(self, expr):
488        return " + ".join(map(self.paren, map(self._print, expr.args)))
489
490    # NB: this is OK to put here, because Mod is only defined for positive
491    # numbers, and so across C/Python its behavior is consistent
492    def _print_Mod(self, expr):
493        return " % ".join(map(self.paren, map(self._print, expr.args)))
494
495    def _print_FloatTrueDiv(self, expr):
496        lhs, rhs = expr.args
497        return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
498
499    def _print_CleanDiv(self, expr):
500        return self._print_FloorDiv(expr)
501
502    def _print_Identity(self, expr):
503        return self._print(expr.args[0])
504
505    def _print_GreaterThan(self, expr):
506        # GreaterThan:          >=
507        # StrictlyGreaterThan:  >
508        # Go figure...
509        return " >= ".join(map(self.paren, map(self._print, expr.args)))
510
511    # NB: The C implementation is injected into codegen at
512    # torch/_inductor/codegen/wrapper.py
513    def _print_align(self, expr):
514        assert len(expr.args) == 1
515        return f"align({self._print(expr.args[0])})"
516
517    # This must be implemented because sympy will collect x * x into Pow(x, 2), without
518    # any explicit intervention.  We print it just like x * x, notably, we
519    # never generate sympy.Pow with floats.
520    #
521    # NB: this pow by natural, you should never have used builtin sympy.pow
522    # for FloatPow, and a symbolic exponent should be PowByNatural.  These
523    # means exp is guaranteed to be integer.
524    def _print_Pow(self, expr):
525        base, exp = expr.args
526        base = self._print(base)
527        assert exp == int(exp), exp
528        exp = int(exp)
529        assert exp >= 0
530        if exp > 0:
531            return "*".join([self.paren(base)] * exp)
532        else:  # exp == 0
533            return "1"
534
535    # Explicit NotImplemented functions are to prevent default sympy printing
536    # behavior, which will just barf out ToFloat(...) to your IR.  The error
537    # message is better here because it tells you which printer class it needs
538    # to go in.
539
540    def _print_ToFloat(self, expr):
541        raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
542
543    def _print_Infinity(self, expr):
544        raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
545
546    def _print_NegativeInfinity(self, expr):
547        raise NotImplementedError(
548            f"_print_NegativeInfinity not implemented for {type(self)}"
549        )
550
551    def _print_FloorDiv(self, expr):
552        raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
553
554    def _print_PythonMod(self, expr):
555        raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
556
557    def _print_IntTrueDiv(self, expr):
558        raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
559
560    def _print_PowByNatural(self, expr):
561        raise NotImplementedError(
562            f"_print_PowByNatural not implemented for {type(self)}"
563        )
564
565    def _print_FloatPow(self, expr):
566        raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
567
568    def _print_TruncToInt(self, expr):
569        raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
570
571    def _print_RoundToInt(self, expr):
572        raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
573
574    def _print_RoundDecimal(self, expr):
575        raise NotImplementedError(
576            f"_print_RoundDecimal not implemented for {type(self)}"
577        )
578
579    # NB: Some float operations are INTENTIONALLY not implemented for
580    # printers.  You can implement them as a quick unblock, but it is better
581    # to ask yourself why we haven't done this computation in the Tensor
582    # universe instead
583
584    def _print_TruncToFloat(self, expr):
585        raise NotImplementedError(
586            f"_print_TruncToFloat not implemented for {type(self)}"
587        )
588
589    def doprint(self, expr, *, simplify: bool = True):
590        # TODO: why are people passing strings to the printer here :think:
591        if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
592            expr = V.graph.sizevars.simplify(expr)
593        return super().doprint(expr)
594
595
596class PythonPrinter(ExprPrinter):
597    def _print_ToFloat(self, expr):
598        assert len(expr.args) == 1
599        return f"float({self._print(expr.args[0])})"
600
601    def _print_ModularIndexing(self, expr):
602        x, div, mod = expr.args
603        x = self.paren(self.doprint(x))
604        div = self.paren(self.doprint(div))
605        mod = self.paren(self.doprint(mod))
606        if div != "1":
607            x = f"({x} // {div})"
608        return f"{x} % {mod}"
609
610    def _print_Infinity(self, expr):
611        return "math.inf"
612
613    def _print_NegativeInfinity(self, expr):
614        return "-math.inf"
615
616    # WARNING: this is dangerous for Triton, which has C-style modulus
617    def _print_PythonMod(self, expr):
618        return " % ".join(map(self.paren, map(self._print, expr.args)))
619
620    # WARNING: this is dangerous for Triton, which has C-style modulus
621    def _print_FloorDiv(self, expr):
622        x, div = expr.args
623        x = self.paren(self.doprint(x))
624        div = self.paren(self.doprint(div))
625        return f"({x} // {div})"
626
627    # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
628    # does a special algorithm
629    def _print_IntTrueDiv(self, expr):
630        lhs, rhs = expr.args
631        return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
632
633    def _helper_sqrt(self, expr):
634        return f"math.sqrt({self._print(expr)})"
635
636    def _print_OpaqueUnaryFn_sqrt(self, expr):
637        return self._helper_sqrt(expr.args[0])
638
639    def _print_FloatPow(self, expr):
640        base, exp = expr.args
641        return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
642
643    # TODO: Not sure this works with Triton, even when base/exp are integral
644    def _print_PowByNatural(self, expr):
645        base, exp = expr.args
646        return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
647
648    def _print_floor(self, expr):
649        assert len(expr.args) == 1
650        return f"math.floor({self._print(expr.args[0])})"
651
652    def _print_FloorToInt(self, expr):
653        assert len(expr.args) == 1
654        return f"math.floor({self._print(expr.args[0])})"
655
656    def _print_TruncToInt(self, expr):
657        assert len(expr.args) == 1
658        # This also could have been int(), they'll do the same thing for float
659        return f"math.trunc({self._print(expr.args[0])})"
660
661    def _print_ceiling(self, expr):
662        assert len(expr.args) == 1
663        return f"math.ceil({self._print(expr.args[0])})"
664
665    def _print_CeilToInt(self, expr):
666        assert len(expr.args) == 1
667        return f"math.ceil({self._print(expr.args[0])})"
668
669    def _print_Abs(self, expr):
670        assert len(expr.args) == 1
671        return f"abs({self._print(expr.args[0])})"
672
673    # NB: It's expected that we've made explicit any promotion in the sympy
674    # expression, so it doesn't matter that Python max/min doesn't perform
675    # promotion
676    def _print_Max(self, expr):
677        assert len(expr.args) >= 2
678        return f"max({', '.join(map(self._print, expr.args))})"
679
680    def _print_Min(self, expr):
681        assert len(expr.args) >= 2
682        return f"min({', '.join(map(self._print, expr.args))})"
683
684    def _print_OpaqueUnaryFn_cos(self, expr):
685        assert len(expr.args) == 1
686        return f"math.cos({self._print(expr.args[0])})"
687
688    def _print_OpaqueUnaryFn_cosh(self, expr):
689        assert len(expr.args) == 1
690        return f"math.cosh({self._print(expr.args[0])})"
691
692    def _print_OpaqueUnaryFn_acos(self, expr):
693        assert len(expr.args) == 1
694        return f"math.acos({self._print(expr.args[0])})"
695
696    def _print_OpaqueUnaryFn_sin(self, expr):
697        assert len(expr.args) == 1
698        return f"math.sin({self._print(expr.args[0])})"
699
700    def _print_OpaqueUnaryFn_sinh(self, expr):
701        assert len(expr.args) == 1
702        return f"math.sinh({self._print(expr.args[0])})"
703
704    def _print_OpaqueUnaryFn_asin(self, expr):
705        assert len(expr.args) == 1
706        return f"math.asin({self._print(expr.args[0])})"
707
708    def _print_OpaqueUnaryFn_tan(self, expr):
709        assert len(expr.args) == 1
710        return f"math.tan({self._print(expr.args[0])})"
711
712    def _print_OpaqueUnaryFn_tanh(self, expr):
713        assert len(expr.args) == 1
714        return f"math.tanh({self._print(expr.args[0])})"
715
716    def _print_OpaqueUnaryFn_atan(self, expr):
717        assert len(expr.args) == 1
718        return f"math.atan({self._print(expr.args[0])})"
719
720    def _print_RoundToInt(self, expr):
721        assert len(expr.args) == 1
722        return f"round({self._print(expr.args[0])})"
723
724    def _print_RoundDecimal(self, expr):
725        assert len(expr.args) == 2
726        number, ndigits = expr.args
727        assert isinstance(ndigits, sympy.Integer)
728        return f"round({self._print(number)}, {ndigits})"
729
730
731class OpOverrides:
732    def __init__(self, parent):
733        super().__init__()
734        self._parent = parent
735
736    def __getattr__(self, item):
737        return getattr(self._parent, item)
738
739    @staticmethod
740    def identity(value):
741        # used to trigger cse
742        return value
743
744    @staticmethod
745    def constant(value, dtype):
746        return repr(value)
747
748    @staticmethod
749    def reciprocal(x):
750        return ops.truediv(ops.constant(1, torch.int32), x)
751
752    @staticmethod
753    def square(x):
754        return ops.mul(x, x)
755
756    @staticmethod
757    def erfc(x):
758        return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
759
760    @staticmethod
761    def erfcx(x):
762        return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
763
764    @staticmethod
765    def expm1(x):
766        return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
767
768    @staticmethod
769    def log10(x):
770        return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
771
772    @staticmethod
773    def log2(x):
774        return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
775
776    @staticmethod
777    def exp2(x):
778        return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
779
780    @staticmethod
781    def log1p(x):
782        return ops.log(ops.add(x, ops.constant(1, torch.int32)))
783
784    @staticmethod
785    def sigmoid(x):
786        one = ops.constant(1, torch.int32)
787        return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
788
789    @staticmethod
790    def libdevice_sigmoid(x):
791        one = ops.constant(1, torch.int32)
792        return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
793
794    @staticmethod
795    def relu(x):
796        return ops.maximum(x, ops.constant(0, torch.int32))
797
798    @staticmethod
799    def libdevice_abs(x):
800        return ops.abs(x)
801
802    @staticmethod
803    def libdevice_sqrt(x):
804        return ops.sqrt(x)
805
806    @staticmethod
807    def libdevice_cos(x):
808        return ops.cos(x)
809
810    @staticmethod
811    def libdevice_sin(x):
812        return ops.sin(x)
813
814    @staticmethod
815    def libdevice_log(x):
816        return ops.log(x)
817
818    @staticmethod
819    def libdevice_exp(x):
820        return ops.exp(x)
821
822    @staticmethod
823    def bitwise_not(x):
824        return f"~{ExprPrinter.paren(x)}"
825
826    @staticmethod
827    def logical_not(a):
828        return f"{ExprPrinter.paren(a)} == 0"
829
830    @staticmethod
831    def bitwise_and(x, y):
832        return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
833
834    @staticmethod
835    def bitwise_or(x, y):
836        return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
837
838    @staticmethod
839    def bitwise_xor(x, y):
840        return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
841
842    @staticmethod
843    def bitwise_left_shift(x, y):
844        return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
845
846    @staticmethod
847    def bitwise_right_shift(x, y):
848        return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
849
850    @staticmethod
851    def remainder(a, b):
852        r = ops.mod(a, b)
853        cond = ops.and_(
854            ops.ne(r, ops.constant(0, torch.int32)),
855            ops.ne(ops.signbit(r), ops.signbit(b)),
856        )
857        return ops.where(cond, ops.add(r, b), r)
858
859    @staticmethod
860    def trunc_to_int(a, dtype):
861        return ops.to_dtype(ops.trunc(a), dtype)
862
863    @staticmethod
864    def floor_to_int(a, dtype):
865        return ops.to_dtype(ops.floor(a), dtype)
866
867    @staticmethod
868    def ceil_to_int(a, dtype):
869        return ops.to_dtype(ops.ceil(a), dtype)
870
871    @staticmethod
872    def round_to_int(a, dtype):
873        return ops.to_dtype(ops.round(a), dtype)
874
875    @staticmethod
876    def int_truediv(a, b):
877        # TODO: this is wrong
878        # TODO: an easy bandaid is to generate runtime asserts that it's
879        # <= 2**53, which is when this equation is correct
880        return ops.truediv(a, b)
881
882    @staticmethod
883    def load_seed(name, offset):
884        return ops.load(name, sympy.Integer(offset))
885
886    @classmethod
887    def _initialize_pointwise_overrides(cls, target):
888        assert target in {"triton", "cpp", "cppvec"}, target
889
890        for funcname, data in pointwise_overrides_data.items():
891            impl = getattr(data, target)
892            if impl is None:
893                continue
894            setattr(cls, funcname, staticmethod(impl))
895
896
897@dataclasses.dataclass
898class OverridesData:
899    name: str
900    cpp: Callable[..., str]
901    # None when not impl in libdevice/triton
902    triton: Optional[Callable[..., str]] = None
903    # None when not impl in aten/.../vec
904    cppvec: Optional[Callable[..., str]] = None
905    type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
906        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
907    )
908
909
910# NB: if you add a new special function, don't forget to update
911# torch._inductor.ops_handler too
912pointwise_overrides_data: Dict[str, OverridesData] = dict(
913    airy_ai=OverridesData(
914        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
915        cpp=lambda x: f"airy_ai_forward({x})",
916        name="special_airy_ai",
917    ),
918    bessel_j0=OverridesData(
919        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
920        cpp=lambda x: f"bessel_j0_forward({x})",
921        triton=lambda x: f"libdevice.j0({x})",
922        name="special_bessel_j0",
923    ),
924    bessel_j1=OverridesData(
925        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
926        cpp=lambda x: f"bessel_j1_forward({x})",
927        triton=lambda x: f"libdevice.j1({x})",
928        name="special_bessel_j1",
929    ),
930    bessel_y0=OverridesData(
931        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
932        cpp=lambda x: f"bessel_y0_forward({x})",
933        triton=lambda x: f"libdevice.y0({x})",
934        name="special_bessel_y0",
935    ),
936    bessel_y1=OverridesData(
937        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
938        cpp=lambda x: f"bessel_y1_forward({x})",
939        triton=lambda x: f"libdevice.y1({x})",
940        name="special_bessel_y1",
941    ),
942    digamma=OverridesData(
943        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
944        cpp=lambda x: f"calc_digamma({x})",
945        cppvec=lambda x: f"{x}.digamma()",
946        name="digamma",
947    ),
948    # no cpp nor triton implementation for entr, it is defined as decomposition
949    # erf, erfc
950    erfcx=OverridesData(
951        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
952        cpp=lambda x: f"calc_erfcx({x})",
953        triton=lambda x: f"libdevice.erfcx({x})",
954        name="special_erfcx",
955    ),
956    fma=OverridesData(
957        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
958        cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
959        cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
960        triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
961        name="fma",
962    ),
963    # erfinv, exp2, expit, gammaln
964    igamma=OverridesData(
965        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
966        cpp=lambda x, y: f"calc_igamma({x}, {y})",
967        name="igamma",
968    ),
969    igammac=OverridesData(
970        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
971        cpp=lambda x, y: f"calc_igammac({x}, {y})",
972        name="igammac",
973    ),
974    gammainc=OverridesData(
975        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
976        cpp=lambda x, y: f"calc_igamma({x}, {y})",
977        name="special_gammainc",
978    ),
979    gammaincc=OverridesData(
980        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
981        cpp=lambda x, y: f"calc_igammac({x}, {y})",
982        name="special_gammaincc",
983    ),
984    i0=OverridesData(
985        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
986        cpp=lambda x: f"calc_i0({x})",
987        triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
988        cppvec=lambda x: f"{x}.i0()",
989        name="i0",
990    ),
991    i0e=OverridesData(
992        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
993        cpp=lambda x: f"calc_i0e({x})",
994        cppvec=lambda x: f"{x}.i0e()",
995        name="special_i0e",
996    ),
997    i1=OverridesData(
998        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
999        cpp=lambda x: f"calc_i1({x})",
1000        triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
1001        name="special_i1",
1002    ),
1003    i1e=OverridesData(
1004        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1005        cpp=lambda x: f"calc_i1e({x})",
1006        name="special_i1e",
1007    ),
1008    log_ndtr=OverridesData(
1009        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1010        cpp=lambda x: f"calc_log_ndtr({x})",
1011        name="special_log_ndtr",
1012    ),
1013    # logit
1014    modified_bessel_i0=OverridesData(
1015        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1016        cpp=lambda x: f"modified_bessel_i0_forward({x})",
1017        triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
1018        name="special_modified_bessel_i0",
1019    ),
1020    modified_bessel_i1=OverridesData(
1021        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1022        cpp=lambda x: f"modified_bessel_i1_forward({x})",
1023        triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
1024        name="special_modified_bessel_i1",
1025    ),
1026    modified_bessel_k0=OverridesData(
1027        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1028        cpp=lambda x: f"modified_bessel_k0_forward({x})",
1029        name="special_modified_bessel_k0",
1030    ),
1031    modified_bessel_k1=OverridesData(
1032        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1033        cpp=lambda x: f"modified_bessel_k1_forward({x})",
1034        name="special_modified_bessel_k1",
1035    ),
1036    # multigamma
1037    ndtr=OverridesData(
1038        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1039        cpp=lambda x: f"calc_ndtr({x})",
1040        name="special_ndtr",
1041    ),
1042    ndtri=OverridesData(
1043        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1044        cpp=lambda x: f"calc_ndtri({x})",
1045        name="special_ndtri",
1046    ),
1047    polygamma=OverridesData(
1048        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1049        cpp=lambda x, y: f"calc_polygamma({y}, {x})",
1050        name="polygamma",
1051    ),
1052    # psi - alias to digamma
1053    # round
1054    scaled_modified_bessel_k0=OverridesData(
1055        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1056        cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
1057        name="special_scaled_modified_bessel_k0",
1058    ),
1059    scaled_modified_bessel_k1=OverridesData(
1060        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1061        cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
1062        name="special_scaled_modified_bessel_k1",
1063    ),
1064    # sinc
1065    spherical_bessel_j0=OverridesData(
1066        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1067        cpp=lambda x: f"spherical_bessel_j0_forward({x})",
1068        name="special_spherical_bessel_j0",
1069    ),
1070    zeta=OverridesData(
1071        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1072        cpp=lambda x, y: f"zeta({x}, {y})",
1073        name="special_zeta",
1074    ),
1075    chebyshev_polynomial_t=OverridesData(
1076        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1077        cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
1078        name="special_chebyshev_polynomial_t",
1079    ),
1080    chebyshev_polynomial_u=OverridesData(
1081        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1082        cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
1083        name="special_chebyshev_polynomial_u",
1084    ),
1085    chebyshev_polynomial_v=OverridesData(
1086        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1087        cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
1088        name="special_chebyshev_polynomial_v",
1089    ),
1090    chebyshev_polynomial_w=OverridesData(
1091        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1092        cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
1093        name="special_chebyshev_polynomial_w",
1094    ),
1095    legendre_polynomial_p=OverridesData(
1096        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1097        cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
1098        name="special_legendre_polynomial_p",
1099    ),
1100    shifted_chebyshev_polynomial_t=OverridesData(
1101        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1102        cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
1103        name="special_shifted_chebyshev_polynomial_t",
1104    ),
1105    shifted_chebyshev_polynomial_u=OverridesData(
1106        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1107        cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
1108        name="special_shifted_chebyshev_polynomial_u",
1109    ),
1110    shifted_chebyshev_polynomial_v=OverridesData(
1111        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1112        cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
1113        name="special_shifted_chebyshev_polynomial_v",
1114    ),
1115    shifted_chebyshev_polynomial_w=OverridesData(
1116        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1117        cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
1118        name="special_shifted_chebyshev_polynomial_w",
1119    ),
1120    hermite_polynomial_h=OverridesData(
1121        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1122        cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
1123        name="special_hermite_polynomial_h",
1124    ),
1125    hermite_polynomial_he=OverridesData(
1126        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1127        cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
1128        name="special_hermite_polynomial_he",
1129    ),
1130    laguerre_polynomial_l=OverridesData(
1131        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1132        cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
1133        name="special_laguerre_polynomial_l",
1134    ),
1135)
1136
1137
1138# Use mypy to check protocol implemented correctly
1139def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
1140    return h
1141
1142
1143class DeferredLine(DeferredLineBase):
1144    """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
1145
1146    def __init__(self, name, line):
1147        super().__init__(line)
1148        self.name = name
1149        assert not isinstance(line, DeferredLineBase)
1150
1151    def __call__(self):
1152        if all(
1153            self.name not in x
1154            for x in (
1155                V.graph.removed_buffers,
1156                V.kernel.removed_buffers,
1157                V.graph.inplaced_to_remove,
1158                V.kernel.inplaced_to_remove,
1159            )
1160        ):
1161            return self.line
1162        return None
1163
1164    def _new_line(self, line):
1165        return DeferredLine(self.name, line)
1166
1167
1168class BracesBuffer(IndentedBuffer):
1169    def indent(self, offset=1):
1170        @contextlib.contextmanager
1171        def ctx():
1172            for _ in range(offset):
1173                self.writeline("{")
1174                self._indent += 1
1175            for _ in range(-offset):
1176                self._indent -= 1
1177                self.writeline("}")
1178            yield
1179            for _ in range(-offset):
1180                self.writeline("{")
1181                self._indent += 1
1182            for _ in range(offset):
1183                self._indent -= 1
1184                self.writeline("}")
1185
1186        return ctx()
1187
1188
1189class InplacedBuffer(NamedTuple):
1190    inner_name: str
1191    other_names: List[str]
1192
1193
1194class KernelArgs:
1195    @staticmethod
1196    def _lookup(prefix, odict, name):
1197        assert isinstance(name, (str, sympy.Symbol))
1198        if name not in odict:
1199            odict[name] = f"{prefix}{len(odict)}"
1200        return odict[name]
1201
1202    def __init__(self, sizevars=None):
1203        self.input_buffers = {}
1204        self.output_buffers = {}
1205        self.inplace_buffers = {}
1206        self.sizevars = sizevars or {}
1207        self.workspace_arg = None
1208
1209    def __repr__(self):
1210        return "KernelArgs({})".format(
1211            ", ".join(
1212                map(
1213                    repr,
1214                    [
1215                        self.input_buffers,
1216                        self.output_buffers,
1217                        self.inplace_buffers,
1218                        self.sizevars,
1219                    ],
1220                )
1221            )
1222        )
1223
1224    def _buffer_is_marked_removed(self, name):
1225        return isinstance(name, str) and name.startswith("REMOVED")
1226
1227    def input(self, name):
1228        if V.graph.scheduler:
1229            name = V.graph.scheduler.mutation_real_name.get(name, name)
1230        assert name not in V.graph.removed_buffers, name
1231        if name in self.output_buffers:
1232            return self.output_buffers[name]
1233        if name in self.inplace_buffers:
1234            return self.inplace_buffers[name].inner_name
1235        if name.startswith("seed"):
1236            return self._lookup("seed", self.input_buffers, name)
1237        return self._lookup("in_ptr", self.input_buffers, name)
1238
1239    def output(self, name):
1240        if V.graph.scheduler:
1241            name = V.graph.scheduler.mutation_real_name.get(name, name)
1242        assert name not in V.graph.removed_buffers, name
1243        if name in self.inplace_buffers:
1244            return self.inplace_buffers[name].inner_name
1245        return self._lookup("out_ptr", self.output_buffers, name)
1246
1247    def make_inplace(self, input_name, output_name):
1248        assert output_name not in self.inplace_buffers
1249        if input_name in self.inplace_buffers:
1250            buf = self.inplace_buffers[input_name]
1251            buf.other_names.append(output_name)
1252            self.inplace_buffers[output_name] = buf
1253        else:
1254            buf = InplacedBuffer(
1255                f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
1256                [input_name, output_name],
1257            )
1258            self.inplace_buffers[input_name] = buf
1259            self.inplace_buffers[output_name] = buf
1260
1261    def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
1262        if self.workspace_arg is None:
1263            self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
1264            return "ws_ptr", 0
1265
1266        offset = self.workspace_arg.nbytes
1267        zero_fill = zero_fill or self.workspace_arg.zero_fill
1268        self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
1269        return "ws_ptr", offset
1270
1271    def seed_offset(self, name, value):
1272        if value in self.sizevars:
1273            return self.sizevars[value]
1274        if name in self.sizevars.values():
1275            name = (
1276                f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
1277            )
1278        self.sizevars[value] = name
1279        return name
1280
1281    def size(self, name):
1282        if str(name) == "seed":
1283            self.sizevars["seed"] = "seed"
1284            return "seed"
1285        return self._lookup("ks", self.sizevars, name)
1286
1287    def call_names(self):
1288        return chain(
1289            self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
1290        )
1291
1292    def wrap_ptr_arg(self, buf, dtype):
1293        return buf
1294
1295    def wrap_size_arg(self, size):
1296        return str(size)
1297
1298    def cpp_argdefs(self):
1299        from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE
1300
1301        call_args = []
1302        arg_defs = []
1303        arg_types = []
1304        for inplaced in unique(self.inplace_buffers.values()):
1305            if self._buffer_is_marked_removed(inplaced):
1306                continue
1307            outer = inplaced.other_names[-1]
1308            inner = inplaced.inner_name
1309            dtype = V.graph.get_dtype(outer)
1310            cpp_dtype = DTYPE_TO_CPP[dtype]
1311            arg_defs.append(f"{cpp_dtype}* {inner}")
1312            call_args.append(self.wrap_ptr_arg(outer, dtype))
1313            arg_types.append(f"{cpp_dtype}*")
1314        for outer, inner in self.input_buffers.items():
1315            if outer in self.inplace_buffers:
1316                continue
1317            dtype = V.graph.get_dtype(outer)
1318            cpp_dtype = DTYPE_TO_CPP[dtype]
1319            arg_defs.append(f"const {cpp_dtype}* {inner}")
1320            call_args.append(self.wrap_ptr_arg(outer, dtype))
1321            arg_types.append(f"const {cpp_dtype}*")
1322        for outer, inner in self.output_buffers.items():
1323            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1324                continue
1325            dtype = V.graph.get_dtype(outer)
1326            cpp_dtype = DTYPE_TO_CPP[dtype]
1327            arg_defs.append(f"{cpp_dtype}* {inner}")
1328            call_args.append(self.wrap_ptr_arg(outer, dtype))
1329            arg_types.append(f"{cpp_dtype}*")
1330        for outer, inner in self.sizevars.items():
1331            arg_defs.append(f"const {INDEX_TYPE} {inner}")
1332            call_args.append(self.wrap_size_arg(outer))
1333            arg_types.append(f"const {INDEX_TYPE}")
1334            if V.graph.wrapper_code:
1335                V.graph.wrapper_code.ensure_size_computed(outer)
1336        assert self.workspace_arg is None, "Workspace not supported on CPU "
1337        return arg_defs, call_args, arg_types
1338
1339    def python_argdefs(self):
1340        arg_defs: List[str] = []
1341        call_args: List[str] = []
1342        arg_types: List[torch.dtype] = []
1343        precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
1344        for inplaced in unique(self.inplace_buffers.values()):
1345            if self._buffer_is_marked_removed(inplaced):
1346                continue
1347            arg_defs.append(inplaced.inner_name)
1348            call_args.append(inplaced.other_names[-1])
1349            arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
1350            precompile_args.append(
1351                TensorArg(
1352                    name=inplaced.inner_name,
1353                    buffer=inplaced.other_names[-1],
1354                    dtype=V.graph.get_dtype(inplaced.other_names[-1]),
1355                )
1356            )
1357        for outer, inner in chain(
1358            self.input_buffers.items(), self.output_buffers.items()
1359        ):
1360            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1361                continue
1362            arg_defs.append(inner)
1363            call_args.append(outer)
1364            arg_types.append(V.graph.get_dtype(outer))
1365            precompile_args.append(
1366                TensorArg(
1367                    name=inner,
1368                    buffer=outer,
1369                    dtype=V.graph.get_dtype(outer),
1370                )
1371            )
1372        for outer, inner in self.sizevars.items():
1373            arg_defs.append(inner)
1374            call_args.append(outer)
1375            arg_types.append(type(outer))  # type: ignore[arg-type]
1376            precompile_args.append(SizeArg(inner, outer))
1377            if V.graph.wrapper_code:
1378                V.graph.wrapper_code.ensure_size_computed(outer)
1379        if self.workspace_arg is not None:
1380            arg_defs.append("ws_ptr")
1381            call_args.append("workspace")
1382            precompile_args.append(self.workspace_arg)
1383        return arg_defs, call_args, precompile_args, arg_types
1384
1385    def aliases(self):
1386        for inplaced in unique(self.inplace_buffers.values()):
1387            if self._buffer_is_marked_removed(inplaced):
1388                continue
1389            for other in inplaced.other_names:
1390                if (
1391                    other in V.graph.inplaced_to_remove
1392                    or other in V.kernel.inplaced_to_remove
1393                ):
1394                    continue
1395                if other in self.input_buffers:
1396                    yield self.input_buffers[other], inplaced.inner_name
1397                if other in self.output_buffers:
1398                    yield self.output_buffers[other], inplaced.inner_name
1399
1400    def is_removed(self, name):
1401        def _is_removed(name, buffers):
1402            return name not in buffers or self._buffer_is_marked_removed(buffers[name])
1403
1404        return _is_removed(name, self.output_buffers) and _is_removed(
1405            name, self.inplace_buffers
1406        )
1407
1408    # Includes inplace buffers, excludes removed buffers.  Essentially,
1409    # after you do a call into this kernel, which buffers actually contain
1410    # updated data?  Modeled off of python_argdefs.
1411    def live_output_buffers(self):
1412        live_outs = OrderedSet()  # type: ignore[var-annotated]
1413        for inplaced in unique(self.inplace_buffers.values()):
1414            if self._buffer_is_marked_removed(inplaced):
1415                continue
1416            live_outs.add(inplaced.other_names[-1])
1417        for outer, inner in self.output_buffers.items():
1418            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1419                continue
1420            live_outs.add(outer)
1421        return live_outs
1422
1423
1424class CSEVariable:
1425    """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
1426    To do so, the backends can simply overload `Kernel.create_cse_var`
1427    The "CSEVariable.update_on_args" method gives you a hook for annotations
1428    See example of TritonCSEVariable in triton.py
1429    """
1430
1431    def __init__(self, name, bounds: ValueRanges[Any]):
1432        assert isinstance(bounds, ValueRanges)
1433        self.name = name
1434        self.bounds = bounds
1435        self.use_count = 1  # track how many tims this expression is used
1436
1437    def __str__(self):
1438        return self.name
1439
1440    def __hash__(self) -> int:
1441        return hash(self.name)
1442
1443    def __eq__(self, other) -> bool:
1444        return type(other) == type(self) and other.name == self.name
1445
1446    def update_on_args(self, name, args, kwargs):
1447        pass
1448
1449    def __repr__(self):
1450        return f"{self.__class__.__name__}({self.name!r})"
1451
1452
1453class CppWrapperKernelArgs(KernelArgs):
1454    def wrap_ptr_arg(self, buf, dtype):
1455        from .cpp_utils import DTYPE_TO_CPP
1456
1457        if config.abi_compatible:
1458            # In the abi_compatible model, we just return the buf here.
1459            # We will form correct call args later in wrapper.generate_kernel_all.
1460            return buf
1461        else:
1462            return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
1463
1464    def wrap_size_arg(self, size):
1465        return f"{size}"
1466
1467
1468class CSE:
1469    """Common subexpression elimination"""
1470
1471    def __init__(
1472        self,
1473        prefix="",
1474        suffix="",
1475        name_prefix="tmp",
1476        iter_buffers=None,
1477        store_cache=None,
1478        reduction_cache=None,
1479        varname_map=None,
1480    ):
1481        self.prefix = prefix
1482        self.suffix = suffix
1483        self.cache = {}
1484        self.name_prefix = name_prefix
1485        self.store_cache = store_cache or {}
1486        self.reduction_cache = reduction_cache or {}
1487        self.iter_buffer_ids = iter_buffers or itertools.count()
1488        self.invalidated_stores = OrderedSet()  # type: ignore[var-annotated]
1489        self.varname_map = varname_map or {}
1490
1491    def invalidate(self, keep_vars: OrderedSet[str]):
1492        for name, tmp in list(self.store_cache.items()):
1493            if tmp not in keep_vars:
1494                del self.store_cache[name]
1495                self.invalidated_stores.add(name)
1496        self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
1497
1498    def clone(self):
1499        # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
1500        return CSE(
1501            prefix=self.prefix,
1502            suffix=self.suffix,
1503            name_prefix=self.name_prefix,
1504            iter_buffers=self.iter_buffer_ids,
1505            store_cache=self.store_cache,
1506            varname_map=self.varname_map,
1507        )
1508
1509    def generate(
1510        self,
1511        buffer: IndentedBuffer,
1512        expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
1513        *,
1514        bounds: ValueRanges[Any] = ValueRanges.unknown(),
1515        write=True,
1516        assignment=True,
1517    ) -> CSEVariable:
1518        if isinstance(expr, OpsValue):
1519            expr = expr.value
1520
1521        assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
1522        assert write or assignment
1523        if isinstance(expr, CSEVariable):
1524            # If the expressions were always created with all the information, we could
1525            # assert expr.bounds == bounds, but sometimes the expression is created
1526            # with the loose ValueRanges.unknown(), so we need to tighten the bounds
1527            expr.bounds = expr.bounds.tighten(bounds)
1528            expr.use_count += 1
1529            return expr
1530        cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
1531        var = self.cache.get(cache_key, None)
1532        if not var:
1533            var = self.newvar(bounds)
1534            self.cache[cache_key] = var
1535            if write:
1536                if V.kernel.current_node:
1537                    V.kernel.current_node.codegen_originating_info(
1538                        buffer, only_once=True
1539                    )
1540                if isinstance(expr, IndentedBuffer):
1541                    if assignment:
1542                        buffer.writeline(f"{self.prefix}{var} =")
1543                    buffer.splice(expr)
1544                    buffer.writeline(self.suffix)
1545                else:
1546                    if assignment:
1547                        line = f"{self.prefix}{var} = {expr}{self.suffix}"
1548                    else:
1549                        line = f"{expr}{self.suffix}"
1550                    buffer.writeline(line)
1551        else:
1552            var.bounds = var.bounds.tighten(bounds)
1553            var.use_count += 1
1554
1555        return var
1556
1557    def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
1558        var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
1559        var = V.kernel.create_cse_var(var_name, bounds)
1560        self.varname_map[var_name] = var
1561        return var
1562
1563
1564class CodeGen:
1565    def __init__(self) -> None:
1566        super().__init__()
1567        self.exit_stack = contextlib.ExitStack()
1568
1569    def __enter__(self):
1570        self.exit_stack.__enter__()
1571        return self
1572
1573    def __exit__(self, exc_type, exc_val, exc_tb):
1574        self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
1575
1576
1577class ScopedDict:
1578    def __init__(self, original_dict):
1579        self.original_dict = original_dict
1580        self.new_items = {}
1581
1582    def __getitem__(self, key):
1583        if key in self.new_items:
1584            return self.new_items[key]
1585        return self.original_dict[key]
1586
1587    def __setitem__(self, key, value):
1588        self.new_items[key] = value
1589
1590    def __contains__(self, key):
1591        return key in self.new_items or key in self.original_dict
1592
1593    def get(self, key, default=None):
1594        if key in self.new_items:
1595            return self.new_items[key]
1596        return self.original_dict.get(key, default)
1597
1598
1599class Kernel(CodeGen):
1600    newvar_prefix = ""
1601    suffix = ""
1602    overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
1603    # TODO: these look dead, but with all the getattr it's hard to tell...
1604    load_format: None = None
1605    store_format: None = None
1606
1607    def __init__(self, args=None, increase_kernel_count=True):
1608        super().__init__()
1609        if increase_kernel_count:
1610            metrics.generated_kernel_count += 1
1611        self.args = args or KernelArgs()
1612        self.loads = IndentedBuffer()
1613        self.compute = IndentedBuffer()
1614        self.stores = IndentedBuffer()
1615
1616        self.num_load = 0
1617        self.num_reduction = 0
1618
1619        self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
1620        self.must_keep_buffers = OrderedSet()  # type: ignore[var-annotated]
1621        self.store_buffer_names = OrderedSet()  # type: ignore[var-annotated]
1622        self._load_mask = None
1623        self._load_other = None
1624        # OrderedSet in set_current_node
1625        self.current_node = None
1626        self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
1627
1628        self.removed_buffers = OrderedSet()  # type: ignore[var-annotated]
1629        self.inplaced_to_remove = OrderedSet()  # type: ignore[var-annotated]
1630
1631        # key: the buffer to write
1632        # value: the buffer to read and whose memory can be reused for
1633        #   the buffer specified by key
1634        self.inplace_update_buffers = {}
1635        # Set minimum number of elements processed per thread.
1636        self.min_elem_per_thread = 1
1637        self.kernel_name = None
1638
1639    @contextlib.contextmanager
1640    def set_current_node(self, node):
1641        prior = self.current_node
1642        self.current_node = node
1643        self.node_to_bounds = node._body.bounds().get_bounds()
1644        try:
1645            yield
1646        finally:
1647            self.current_node = prior
1648
1649    @contextlib.contextmanager
1650    def swap_buffers(self, lb, cb=None, sb=None):
1651        def scope_cse(cse):
1652            new_cse = cse.clone()
1653            new_cse.cache = ScopedDict(cse.cache)
1654            new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
1655            new_cse.store_cache = ScopedDict(cse.store_cache)
1656            return new_cse
1657
1658        if cb is None:
1659            cb = lb
1660        loads = self.loads
1661        compute = self.compute
1662        stores = self.stores
1663        cse = self.cse
1664        self.loads = lb
1665        self.compute = cb
1666        self.stores = sb
1667        self.cse = scope_cse(cse)
1668        try:
1669            yield
1670        finally:
1671            self.loads = loads
1672            self.compute = compute
1673            self.stores = stores
1674            self.cse = cse
1675
1676    def load(self, name: str, index: sympy.Expr) -> CSEVariable:
1677        raise NotImplementedError
1678
1679    def indirect_load(self, name: str, index: sympy.Expr):
1680        """A load the depends on an index we have read"""
1681        prior = self.loads
1682        try:
1683            # put the load in the compute section as it might have deps
1684            self.loads = self.compute
1685            return self.load(name, index)
1686        finally:
1687            self.loads = prior
1688
1689    def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
1690        raise NotImplementedError
1691
1692    def store(
1693        self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
1694    ) -> None:
1695        raise NotImplementedError
1696
1697    def reduction(
1698        self,
1699        dtype: torch.dtype,
1700        src_dtype: torch.dtype,
1701        reduction_type: ReductionType,
1702        value: Union[CSEVariable, Tuple[CSEVariable, ...]],
1703    ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
1704        raise NotImplementedError
1705
1706    def scan(
1707        self,
1708        dtypes: Tuple[torch.dtype, ...],
1709        combine_fn: Callable[
1710            [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
1711        ],
1712        values: Tuple[CSEVariable, ...],
1713    ) -> Tuple[CSEVariable, ...]:
1714        raise NotImplementedError
1715
1716    def sort(
1717        self,
1718        dtypes: Tuple[torch.dtype, ...],
1719        values: Tuple[CSEVariable, ...],
1720        stable: bool,
1721        descending: bool,
1722    ) -> Tuple[CSEVariable, ...]:
1723        raise NotImplementedError
1724
1725    def var_ranges(self):
1726        raise NotImplementedError
1727
1728    def bucketize(
1729        self,
1730        values: CSEVariable,
1731        offsets_name: str,
1732        offsets_size: sympy.Expr,
1733        indexing_dtype: torch.dtype,
1734        right: bool,
1735    ) -> CSEVariable:
1736        """
1737        See [Note: Inductor bucketize op]
1738        """
1739        raise NotImplementedError
1740
1741    @property
1742    def assert_function(self) -> str:
1743        raise NotImplementedError
1744
1745    def indirect_assert(
1746        self,
1747        var: Union[CSEVariable, str],
1748        lower: Optional[str],
1749        upper: Optional[str],
1750        mask: Optional[Union[CSEVariable, str]] = None,
1751    ) -> str:
1752        if isinstance(var, CSEVariable):
1753            var = str(var)
1754        assert isinstance(var, str)
1755        assert lower is None or isinstance(lower, str)
1756        assert upper is None or isinstance(upper, str)
1757        if lower and upper:
1758            # The conditions need to be in parens because of Python's operator precedence.
1759            # It'd be less error-prone to use and/or/not, which is suported by triton
1760            cond = f"({lower} <= {var}) & ({var} < {upper})"
1761            cond_print = f"{lower} <= {var} < {upper}"
1762        elif lower:
1763            cond = f"{lower} <= {var}"
1764            cond_print = cond
1765        else:
1766            assert upper
1767            cond = f"{var} < {upper}"
1768            cond_print = cond
1769
1770        if mask:
1771            cond = f"({cond}) | ~({mask})"
1772
1773        return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
1774
1775    def check_bounds(
1776        self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
1777    ):
1778        raise NotImplementedError
1779
1780    def index_to_str(self, index: sympy.Expr) -> str:
1781        raise NotImplementedError
1782
1783    def __enter__(self):
1784        # TODO: hoist this to top level
1785        class CSEProxy:
1786            self.name = "CSEProxy"
1787            vr_analysis = ValueRangeAnalysis()
1788
1789            @staticmethod
1790            def __getattr__(name: str) -> Callable[..., CSEVariable]:  # type: ignore[misc]
1791                def inner(*args, **kwargs):
1792                    bounds = CSEProxy._bound_variable(name, *args, **kwargs)
1793
1794                    value = getattr(parent_handler, name)(*args, **kwargs)  # type: ignore[has-type]
1795
1796                    def do_cse(v):
1797                        csevar = V.kernel.cse.generate(
1798                            V.kernel.compute, v, bounds=bounds
1799                        )
1800                        csevar.update_on_args(name, args, kwargs)
1801                        return csevar
1802
1803                    return pytree.tree_map(do_cse, value)
1804
1805                return inner
1806
1807            @staticmethod
1808            def _bound_variable(name, *args, **kwargs):
1809                """
1810                If the variable comes from an FX node, we forward the bound we have already computed
1811                Else, if the variable when codegen'ing another op, we try to compute its bounds
1812                """
1813                from ..select_algorithm import TritonTemplateKernel
1814
1815                if isinstance(V.kernel, TritonTemplateKernel):
1816                    return ValueRanges.unknown()
1817
1818                fx_node = V.interpreter.current_node
1819                if fx_node.target == name and self.node_to_bounds is not None:
1820                    assert isinstance(self.node_to_bounds, dict)
1821                    return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
1822                elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
1823                    # These create lots of inner strings. We would need to compute the bounds at the ops
1824                    # We will also likely not get much from computing VRs on these nodes
1825                    if any(
1826                        s in fx_node.target
1827                        for s in ("set_indirect", "reduction", "scan")
1828                    ):
1829                        return ValueRanges.unknown()
1830
1831                    # We assume that the inputs come from `ops.` and are not strings. If you want to generate
1832                    # intermediary strings, wrap them in CSE variables with properly initialised bounds.
1833
1834                    # If there is no FX bound but we know how to compute one we do so
1835                    assert not kwargs
1836
1837                    def arg_to_bound(x):
1838                        if isinstance(x, CSEVariable):
1839                            return x.bounds
1840                        elif isinstance(x, sympy.Expr):
1841                            return bound_sympy(x)
1842                        else:
1843                            return x
1844
1845                    arg_bounds = list(map(arg_to_bound, args))
1846                    return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
1847                else:
1848                    return ValueRanges.unknown()
1849
1850            @staticmethod
1851            def indirect_indexing(
1852                var: CSEVariable,
1853                size: Union[sympy.Expr, int],
1854                check: bool = True,
1855                wrap_neg=True,
1856            ):
1857                if isinstance(size, int):
1858                    size = sympy.Integer(size)
1859                assert isinstance(size, sympy.Expr), size
1860                # Skip CSE since this doesn't return an expression
1861
1862                if var.bounds.lower < 0:  # type: ignore[operator]
1863                    if wrap_neg:
1864                        stm = ops.add(var, ops.index_expr(size, torch.long))
1865                        # Mixed negative and non-negative
1866                        if var.bounds.upper >= 0:  # type: ignore[operator]
1867                            lt = ops.lt(var, 0)
1868                            stm = ops.where(lt, stm, var)
1869                    else:
1870                        stm = var
1871
1872                    # Propagate bounds as we know how to compute them properly
1873                    new_bounds = ValueRanges.unknown()
1874                    if var.bounds != ValueRanges.unknown() and isinstance(
1875                        size, sympy.Number
1876                    ):
1877                        # Take the negative part of the bound and add size to it
1878                        # Then take union of that and the positive part
1879                        # This is a tighter bound than that of a generic ops.where, as we have info on the cond
1880                        neg_bounds = var.bounds & ValueRanges(-int_oo, -1)
1881                        new_bounds = ValueRanges(
1882                            neg_bounds.lower + size, neg_bounds.upper + size
1883                        )
1884                        # We don't have a good way of representing the empty range
1885                        if var.bounds.upper >= 0:  # type: ignore[operator]
1886                            pos = var.bounds & ValueRanges(0, int_oo)
1887                            new_bounds = new_bounds | pos
1888
1889                    var = self.cse.generate(self.compute, stm, bounds=new_bounds)
1890
1891                sympy_var = parent_handler.indirect_indexing(var, size, check)
1892                if generate_assert(check):
1893                    assert_lower = not (var.bounds.lower >= 0)
1894                    # value ranges cannot x < s when x and s are symbols
1895                    assert_upper = not isinstance(size, sympy.Number) or not (
1896                        var.bounds.upper < size
1897                    )
1898                    self.check_bounds(sympy_var, size, assert_lower, assert_upper)
1899                return sympy_var
1900
1901            @staticmethod
1902            def check_bounds(
1903                expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
1904            ):
1905                return self.check_bounds(expr, size, lower, upper)
1906
1907            @staticmethod
1908            def load(name: str, index: sympy.Expr) -> CSEVariable:
1909                if name in self.cse.invalidated_stores:
1910                    # A load from an invalidated store requires us to
1911                    # keep the actual buffer around
1912                    V.kernel.must_keep_buffers.add(name)
1913                if free_symbol_is_type(index, SymT.TMP):
1914                    return self.indirect_load(name, index)
1915                store_cache = self.cse.store_cache
1916                if name in store_cache:
1917                    return store_cache[name]
1918                out = self.load(name, index)
1919                # count load that is not in the store_cache, and also not in the
1920                # cse cache.
1921                if out.use_count == 1:
1922                    self.num_load += 1
1923                return out
1924
1925            @staticmethod
1926            def _update_store_cache(name: str, value: CSEVariable):
1927                self.cse.store_cache[name] = value
1928                if self.current_node and name in V.graph.name_to_buffer:
1929                    buf = self.current_node.get_output(name)
1930                    for other_name in buf.get_mutations():
1931                        self.cse.store_cache[other_name] = value
1932
1933            @staticmethod
1934            def store(
1935                name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
1936            ) -> None:
1937                self.store_buffer_names.add(name)
1938                if mode is None:
1939                    CSEProxy._update_store_cache(name, value)
1940                if name not in V.graph.removed_buffers:
1941                    return self.store(name, index, value, mode=mode)
1942                else:
1943                    return None  # type: ignore[return-value]
1944
1945            @staticmethod
1946            def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
1947                self.store_buffer_names.add(name)
1948                CSEProxy._update_store_cache(name, value)
1949
1950                if name not in V.graph.removed_buffers:
1951                    return self.store_reduction(name, index, value)
1952
1953            @staticmethod
1954            def reduction(
1955                dtype: torch.dtype,
1956                src_dtype: torch.dtype,
1957                reduction_type: ReductionType,
1958                value: Union[CSEVariable, Tuple[CSEVariable, ...]],
1959            ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
1960                self.num_reduction += 1
1961                return self.reduction(dtype, src_dtype, reduction_type, value)
1962
1963            @staticmethod
1964            def scan(
1965                dtypes: Tuple[torch.dtype, ...],
1966                combine_fn: Callable[
1967                    [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
1968                    Tuple[CSEVariable, ...],
1969                ],
1970                values: Tuple[CSEVariable, ...],
1971            ) -> Tuple[CSEVariable, ...]:
1972                return self.scan(dtypes, combine_fn, values)
1973
1974            @staticmethod
1975            def sort(
1976                dtypes: Tuple[torch.dtype, ...],
1977                values: Tuple[CSEVariable, ...],
1978                stable: bool,
1979                descending: bool,
1980            ) -> Tuple[CSEVariable, ...]:
1981                return self.sort(dtypes, values, stable, descending)
1982
1983            @staticmethod
1984            def bucketize(
1985                values: CSEVariable,
1986                offsets_name: str,
1987                offsets_size: sympy.Expr,
1988                indexing_dtype: torch.dtype,
1989                right: bool,
1990            ) -> CSEVariable:
1991                """
1992                [Note: Inductor bucketize op]
1993
1994                Given values (tensor) and offsets_name (reference to the name of a 1D
1995                tensor), calculate the bucket that each value belongs to.
1996
1997                e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
1998                return =        [ 0, 1, 1, 1, 1, 3, 3, 4].
1999
2000                When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
2001                When right == True,  bucket i refers to range [offsets[i], offsets[i+1]).
2002
2003                Offsets must be non-decreasing or the result is undefined.
2004                """
2005                return self.bucketize(
2006                    values, offsets_name, offsets_size, indexing_dtype, right
2007                )
2008
2009        # Use mypy to check protocol implemented correctly
2010        def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
2011            return h
2012
2013        super().__enter__()
2014        assert self.overrides
2015        parent_handler = self.overrides(V.get_ops_handler())
2016        self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
2017        self.exit_stack.enter_context(V.set_kernel_handler(self))
2018        return self
2019
2020    def __exit__(self, exc_type, exc_val, exc_tb):
2021        """
2022        Note that V.graph.scheduler can be None when codegening triton template
2023        kernels.
2024        """
2025        if V.graph.scheduler:
2026            V.graph.scheduler.remove_kernel_local_buffers()
2027        super().__exit__(exc_type, exc_val, exc_tb)
2028
2029    def rename_indexing(self, index) -> sympy.Expr:
2030        # adds the necessary kernel args for index expressions
2031        # and renames variables in index expressions to kernel arg names
2032        if isinstance(index, (list, tuple)):
2033            return [self.rename_indexing(x) for x in index]  # type: ignore[return-value]
2034        index = V.graph.sizevars.simplify(index)
2035        sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
2036        replacements = {
2037            x: self.args.size(x)
2038            for x in sorted_symbols
2039            if symbol_is_type(
2040                x,
2041                (
2042                    SymT.UNBACKED_INT,
2043                    SymT.SIZE,
2044                    SymT.PRECOMPUTED_SIZE,
2045                ),
2046            )
2047        }
2048        return sympy_subs(index, replacements)
2049
2050    def create_cse_var(self, *args, **kwargs):
2051        return CSEVariable(*args, **kwargs)
2052
2053
2054@dataclasses.dataclass
2055class OptimizationContext:
2056    key: ClassVar[str] = "opt_ctx"
2057
2058    dtype: Optional[torch.dtype] = None
2059    ops_name: str = ""
2060
2061
2062@functools.lru_cache(None)
2063def jinja2_env():
2064    try:
2065        import jinja2
2066
2067        return jinja2.Environment(
2068            undefined=jinja2.StrictUndefined,
2069        )
2070    except ImportError:
2071        return None
2072
2073
2074class KernelTemplate:
2075    """
2076    Base class for defining kernel templates.
2077
2078    Children classes: TritonTemplate, CUDATemplate
2079    """
2080
2081    @staticmethod
2082    def indent_except_first(source: str, num_indents: int, indents_spacing=4):
2083        lines = source.splitlines(True)
2084        if len(lines) > 1:
2085            lines[1:] = [
2086                (" " * indents_spacing * num_indents) + line for line in lines[1:]
2087            ]
2088        return "".join(lines)
2089
2090    @staticmethod
2091    def _template_from_string(source):
2092        env = jinja2_env()
2093        if env is not None:
2094            env.filters["indent_except_first"] = KernelTemplate.indent_except_first
2095            from jinja2 import TemplateSyntaxError
2096
2097            class DetailedTemplateSyntaxError(TemplateSyntaxError):
2098                def __init__(self, original_error):
2099                    super().__init__(
2100                        original_error.message,
2101                        original_error.lineno,
2102                        original_error.name,
2103                        original_error.filename,
2104                    )
2105                    self.original_error = original_error
2106
2107                def __str__(self):
2108                    error_info = f"Error in template at line {self.lineno}\n"
2109                    error_info += f"Error message: {self.message}\n"
2110                    if hasattr(self.original_error, "source"):
2111                        lines = self.original_error.source.split("\n")
2112                        error_info += "Context:\n"
2113                        start = max(0, self.lineno - 2)
2114                        end = min(len(lines), self.lineno + 2)
2115                        for i in range(start, end):
2116                            if i == self.lineno - 1:
2117                                error_info += f"{i+1}: --> {lines[i]}\n"
2118                                if hasattr(self.original_error, "column"):
2119                                    error_info += (
2120                                        "     "
2121                                        + " " * (self.original_error.column - 1)
2122                                        + "^\n"
2123                                    )
2124                            else:
2125                                error_info += f"{i+1}:     {lines[i]}\n"
2126                    return error_info
2127
2128            try:
2129                return env.from_string(source)
2130            except TemplateSyntaxError as e:
2131                raise DetailedTemplateSyntaxError(e) from e
2132
2133        return None
2134
2135    @staticmethod
2136    def _fake_get_dtype(fake_out):
2137        _get_dtype_real = V.graph.get_dtype
2138
2139        def get_dtype(name):
2140            if name == fake_out.get_name():
2141                return fake_out.get_dtype()
2142            return _get_dtype_real(name)
2143
2144        return get_dtype
2145
2146    def __init__(self, name: str):
2147        self.name = name
2148
2149    def maybe_append_choice(self, choices, **kwargs):
2150        """
2151        Maybe generates a new ChoiceCaller and appends it into existing choices.
2152
2153        choices: A list of ChoiceCallers.
2154        kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
2155        """
2156
2157        try:
2158            choices.append(self.generate(**kwargs))
2159        except NotImplementedError as e:
2160            pass
2161
2162    def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
2163        """
2164        Generates a ChoiceCaller instance from the given arguments.
2165        """
2166
2167        raise NotImplementedError
2168