xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/cpp_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import copy
4import functools
5import math
6import sys
7from collections import namedtuple
8from typing import Any, Callable, Dict, List, Optional, Set, Tuple
9from unittest.mock import patch
10
11import sympy
12
13import torch
14from torch._prims_common import is_integer_dtype
15from torch.utils._sympy.symbol import symbol_is_type, SymT
16from torch.utils._sympy.value_ranges import ValueRanges
17
18from .. import ir
19from ..loop_body import LoopBody
20from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
21from ..virtualized import ops, OpsValue, V
22from .common import (
23    CSEVariable,
24    deduce_output_dtype_by_name,
25    ExprPrinter,
26    Kernel,
27    KernelArgs,
28    OptimizationContext,
29)
30
31
32DTYPE_TO_CPP = {
33    torch.float32: "float",
34    torch.float64: "double",
35    torch.float16: "half",
36    torch.int64: "int64_t",
37    torch.int32: "int32_t",
38    torch.int16: "int16_t",
39    torch.int8: "int8_t",
40    torch.uint64: "uint64_t",
41    torch.uint32: "uint32_t",
42    torch.uint16: "uint16_t",
43    torch.uint8: "uint8_t",
44    torch.bool: "bool",
45    torch.bfloat16: "bfloat16",
46    torch.complex64: "c10::complex<float>",
47    torch.float8_e4m3fn: "float8_e4m3fn",
48    torch.float8_e5m2: "float8_e5m2",
49}
50
51DTYPE_TO_ATEN = {
52    torch.float32: "at::kFloat",
53    torch.float64: "at::kDouble",
54    torch.float16: "at::kHalf",
55    torch.int64: "at::kLong",
56    torch.int32: "at::kInt",
57    torch.int16: "at::kShort",
58    torch.int8: "at::kChar",
59    torch.uint64: "at::kUInt64",
60    torch.uint32: "at::kUInt32",
61    torch.uint16: "at::kUInt16",
62    torch.uint8: "at::kByte",
63    torch.uint32: "at::kUInt32",
64    torch.uint64: "at::kUInt64",
65    torch.bool: "at::kBool",
66    torch.bfloat16: "at::kBFloat16",
67    torch.complex32: "at::kComplexHalf",
68    torch.complex64: "at::kComplexFloat",
69    torch.complex128: "at::kComplexDouble",
70    torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
71    torch.float8_e5m2: "at::kFloat8_e5m2",
72    torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz",
73    torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz",
74}
75
76DEVICE_TO_ATEN = {
77    "cpu": "at::kCPU",
78    "cuda": "at::kCUDA",
79}
80
81LAYOUT_TO_ATEN = {
82    torch.strided: "at::kStrided",
83    torch._mkldnn: "at::kMkldnn",  # type: ignore[attr-defined]
84}
85
86_IS_WINDOWS = sys.platform == "win32"
87
88INDEX_TYPE = "int64_t"
89
90GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
91
92
93def get_promote_dtype(args):
94    return (
95        functools.reduce(
96            torch.promote_types,  # type: ignore[arg-type]
97            [n.dtype for n in args if isinstance(n, CppCSEVariable)],
98        )
99        if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable))
100        else None  # not enough info to calculate the promote dtype
101    )
102
103
104def promote_args(new_args):
105    def promote_arg(arg, promote_type):
106        if (
107            isinstance(arg, CppCSEVariable)
108            and arg.dtype
109            and promote_type
110            and arg.dtype != promote_type
111        ):
112            arg = ops.to_dtype(arg, promote_type)
113            arg = arg.value if isinstance(arg, OpsValue) else arg
114            arg.dtype = promote_type
115        return arg
116
117    promote_type = get_promote_dtype(new_args)
118    promote_fn = functools.partial(
119        promote_arg,
120        promote_type=promote_type,
121    )
122    if (
123        all(
124            new_arg.dtype is not None
125            for new_arg in new_args
126            if isinstance(new_arg, CppCSEVariable)
127        )
128        and promote_type
129    ):
130        new_args = list(map(promote_fn, new_args))
131    return new_args
132
133
134def get_opt_ctx(node: torch.fx.Node) -> OptimizationContext:
135    return node.meta.get(OptimizationContext.key, None)
136
137
138def get_current_node_opt_ctx() -> OptimizationContext:
139    assert V.interpreter.current_node
140    return get_opt_ctx(V.interpreter.current_node)
141
142
143def deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs):
144    if (
145        output_dtype := deduce_output_dtype_by_name(
146            name,
147            *args,
148            **kwargs,
149        )
150    ) is not None:
151        return output_dtype
152    elif name == "masked":
153        # <TODO> Leslie: perhaps we can also deduce the masked dtype by
154        # inputs' CppCseVariable like other. Let's check it if any
155        # unexpected failures.
156        assert (
157            hasattr(V.interpreter, "current_node")
158            and V.interpreter.current_node.target.startswith("masked_subblock")
159            and get_current_node_opt_ctx() is not None
160        )
161        return get_current_node_opt_ctx().dtype
162    else:
163        # deduce output dtype by inputs' dtype
164        assert all(
165            arg.dtype is not None for arg in args if isinstance(arg, CppCSEVariable)
166        )
167        return functools.reduce(
168            torch.promote_types,  # type: ignore[arg-type]
169            [arg.dtype for arg in args if isinstance(arg, CppCSEVariable)],
170        )
171
172
173class CppCSEVariable(CSEVariable):
174    def __init__(self, name, bounds: ValueRanges[Any]) -> None:
175        super().__init__(name, bounds)
176        self.is_vec = False
177        self.dtype: Optional[torch.dtype] = None
178        self.dependent_itervars: Set[sympy.Symbol] = set()
179
180    def __repr__(self) -> str:
181        return (
182            f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, "
183            f"dependent_itervars: {self.dependent_itervars})"
184        )
185
186    def update_on_args(self, name, args, kwargs):
187        if name == "load":
188            # args[2] is index
189            self._set_dependent_itervars(args[2])
190        else:
191            # propagate relevant itervars and is_vec from args
192            self.dependent_itervars.update(
193                *[
194                    arg.dependent_itervars
195                    for arg in args
196                    if isinstance(arg, CppCSEVariable)
197                ]
198            )
199            if name == "index_expr":
200                self._set_dependent_itervars(args[0])
201            if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
202                self.is_vec = True
203        # NOTE [Deduce dtype of CppCSEVariable at runtime]
204        self.dtype = deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs)
205        assert self.dtype is not None
206
207    def _set_dependent_itervars(self, index: sympy.Expr):
208        """
209        Set the relevant itervars for this variable based on the `index` expression.
210        This includes the itervars directly used in the `index` as well as relevant itervars
211        of other cse variables used in the `index`.
212        """
213        for s in index.free_symbols:
214            if s in V.kernel.itervars:
215                self.dependent_itervars.add(s)  # type: ignore[arg-type]
216            elif s.name in V.kernel.cse.varname_map:  # type: ignore[attr-defined]
217                self.dependent_itervars.update(
218                    V.kernel.cse.varname_map[s.name].dependent_itervars  # type: ignore[attr-defined]
219                )
220
221    def depends_on(self, itervar: sympy.Symbol):
222        return itervar in self.dependent_itervars
223
224
225class CppPrinter(ExprPrinter):
226    def _print_Integer(self, expr):
227        return (
228            f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
229        )
230
231    def _print_Where(self, expr):
232        c = self.paren(self.doprint(expr.args[0]))
233        p = self.paren(self.doprint(expr.args[1]))
234        q = self.paren(self.doprint(expr.args[2]))
235        return f"{c} ? {p} : {q}"
236
237    def _print_ModularIndexing(self, expr):
238        x, div, mod = expr.args
239        x = self.paren(self.doprint(x))
240        if div != 1:
241            div = self.paren(self.doprint(div))
242            if expr.is_integer:
243                x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
244            else:
245                x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
246        mod = self.paren(self.doprint(mod))
247        return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})"
248
249    def _print_FloorDiv(self, expr):
250        x, div = expr.args
251        x = self.paren(self.doprint(x))
252        div = self.paren(self.doprint(div))
253        if expr.is_integer:
254            return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
255        return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
256
257    def _print_floor(self, expr):
258        assert len(expr.args) == 1
259        r = f"std::floor({self._print(expr.args[0])})"
260        return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
261
262    def _print_FloorToInt(self, expr):
263        assert len(expr.args) == 1
264        r = f"std::floor({self._print(expr.args[0])})"
265        return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
266
267    def _print_TruncToInt(self, expr):
268        assert len(expr.args) == 1
269        r = f"std::trunc({self._print(expr.args[0])})"
270        return f"static_cast<{INDEX_TYPE}>({r})"
271
272    def _print_TruncToFloat(self, expr):
273        assert len(expr.args) == 1
274        return f"std::trunc({self._print(expr.args[0])})"
275
276    def _print_ToFloat(self, expr):
277        assert len(expr.args) == 1
278        return f"static_cast<double>({self._print(expr.args[0])})"
279
280    # TODO: This is wrong if one of the inputs is negative.  This is hard to
281    # tickle though, as the inputs are typically positive (and if we can prove
282    # they are positive, we will have used Mod instead, for which this codegen
283    # is right).
284    def _print_PythonMod(self, expr):
285        return " % ".join(map(self.paren, map(self._print, expr.args)))
286
287    def _print_CMod(self, expr):
288        return " % ".join(map(self.paren, map(self._print, expr.args)))
289
290    def _print_IntTrueDiv(self, expr):
291        lhs, rhs = expr.args
292        # TODO: This is only accurate up to 2**53
293        return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
294
295    # TODO: PowByNatural: we need to implement our own int-int pow.  Do NOT
296    # use std::pow, that operates on floats
297    def _print_PowByNatural(self, expr):
298        raise NotImplementedError(
299            f"_print_PowByNatural not implemented for {type(self)}"
300        )
301
302    def _print_FloatTrueDiv(self, expr):
303        lhs, rhs = expr.args
304        return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
305
306    def _print_FloatPow(self, expr):
307        base, exp = expr.args
308        return f"std::pow({self._print(base)}, {self._print(exp)})"
309
310    def _print_Pow(self, expr):
311        # Uses float constants to perform FP div
312        base, exp = expr.args
313        base = self._print(base)
314
315        if exp == 0.5 or exp == -0.5:
316            return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
317        if exp.is_integer:
318            exp = int(exp)
319            if exp > 0:
320                r = "*".join([self.paren(base)] * exp)
321            elif exp < 0:
322                r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp)))
323            else:  # exp == 0
324                r = "1.0"
325
326            return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
327        else:
328            # TODO: float vs double
329            return f"std::pow({base}, {float(exp)})"
330
331    def _print_Rational(self, expr):
332        # Uses float constants to perform FP div
333        if expr.q == 1:
334            r = f"{expr.p}"
335        else:
336            r = f"{expr.p}.0/{expr.q}.0"
337        return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
338
339    def _print_ceiling(self, expr):
340        assert len(expr.args) == 1
341        r = f"std::ceil({self._print(expr.args[0])})"
342        return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
343
344    def _print_CeilToInt(self, expr):
345        assert len(expr.args) == 1
346        r = f"std::ceil({self._print(expr.args[0])})"
347        return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
348
349    def _print_Min(self, expr):
350        args = [self._print(a) for a in expr.args]
351        if len(args) == 2:
352            return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
353        else:
354            # Initializer list overload
355            il = "{" + ", ".join(args) + "}"
356            return f"std::min({il})"
357
358    def _print_Max(self, expr):
359        args = [self._print(a) for a in expr.args]
360        if len(args) == 2:
361            return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
362        else:
363            # Initializer list overload
364            il = "{" + ", ".join(args) + "}"
365            return f"std::max({il})"
366
367    def _print_Abs(self, expr):
368        assert len(expr.args) == 1
369        return f"std::abs({self._print(expr.args[0])})"
370
371    def _print_OpaqueUnaryFn_cos(self, expr):
372        assert len(expr.args) == 1
373        return f"std::cos({self._print(expr.args[0])})"
374
375    def _print_OpaqueUnaryFn_cosh(self, expr):
376        assert len(expr.args) == 1
377        return f"std::cosh({self._print(expr.args[0])})"
378
379    def _print_OpaqueUnaryFn_acos(self, expr):
380        assert len(expr.args) == 1
381        return f"std::acos({self._print(expr.args[0])})"
382
383    def _print_OpaqueUnaryFn_sin(self, expr):
384        assert len(expr.args) == 1
385        return f"std::sin({self._print(expr.args[0])})"
386
387    def _print_OpaqueUnaryFn_sinh(self, expr):
388        assert len(expr.args) == 1
389        return f"std::sinh({self._print(expr.args[0])})"
390
391    def _print_OpaqueUnaryFn_asin(self, expr):
392        assert len(expr.args) == 1
393        return f"std::asin({self._print(expr.args[0])})"
394
395    def _print_OpaqueUnaryFn_tan(self, expr):
396        assert len(expr.args) == 1
397        return f"std::tan({self._print(expr.args[0])})"
398
399    def _print_OpaqueUnaryFn_tanh(self, expr):
400        assert len(expr.args) == 1
401        return f"std::tanh({self._print(expr.args[0])})"
402
403    def _print_OpaqueUnaryFn_atan(self, expr):
404        assert len(expr.args) == 1
405        return f"std::atan({self._print(expr.args[0])})"
406
407    def _print_OpaqueUnaryFn_sqrt(self, expr):
408        return f"std::sqrt({self._print(expr.args[0])})"
409
410    def _print_RoundToInt(self, expr):
411        assert len(expr.args) == 1
412        # TODO: dispatch to llrint depending on index type
413        return f"std::lrint({self._print(expr.args[0])})"
414
415    def _print_RoundDecimal(self, expr):
416        assert len(expr.args) == 2
417        number, ndigits = expr.args
418        if number.is_integer:
419            # ndigits < 0 should have been filtered by the sympy function
420            assert ndigits < 0
421            raise ValueError(
422                f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
423            )
424        return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})"
425
426    def _print_BooleanTrue(self, expr):
427        return "true"
428
429    def _print_BooleanFalse(self, expr):
430        return "false"
431
432
433# A function to print, useful for printing sympy symbols.
434cexpr = CppPrinter().doprint
435
436
437def cexpr_index(index):
438    return f"static_cast<{INDEX_TYPE}>({cexpr(index)})"
439
440
441def value_to_cpp(value, cpp_type):
442    if value == float("-inf"):
443        return f"-std::numeric_limits<{cpp_type}>::infinity()"
444    elif value == float("inf"):
445        return f"std::numeric_limits<{cpp_type}>::infinity()"
446    elif isinstance(value, bool):
447        return f"static_cast<{cpp_type}>({str(value).lower()})"
448    elif math.isnan(value):
449        return f"std::numeric_limits<{cpp_type}>::quiet_NaN()"
450    else:
451        return f"static_cast<{cpp_type}>({repr(value)})"
452
453
454def rewrite_index_for_function(
455    localize_buffer_handler: "LocalizeBufferHandler",
456    index: sympy.Expr,
457    global_buf_name: str,
458):
459    # Local buffer at the inner dimensions
460    snode = V.graph.scheduler.name_to_buf[global_buf_name].defining_op
461    local_buf = localize_buffer_handler.global_to_local[global_buf_name]
462    scheduler_nodes = snode.get_nodes()
463    _, (group, reduction_group) = max(
464        scheduler_nodes, key=lambda x: int(x.is_reduction())
465    ).group
466    call_ranges = tuple(group) + tuple(reduction_group)
467    indices_to_keep = [
468        f"x{len(call_ranges) - (idx + 1)}"
469        for idx in range(len(local_buf.get_layout().size))
470    ]
471    sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)  # type: ignore[attr-defined]
472    replacements = {}
473    for x in sorted_symbols:
474        if x.name.startswith("x") and x.name not in indices_to_keep:  # type: ignore[attr-defined]
475            # Only keep index used by local buffer
476            replacements[x] = sympy.core.numbers.Zero()
477    index = sympy_subs(index, replacements)  # type: ignore[arg-type]
478    return index
479
480
481def rewrite_index_for_nodes(
482    localize_buffer_handler: "LocalizeBufferHandler",
483    index: sympy.Expr,
484    global_buf_name: str,
485):
486    used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)}
487    index_vars = []
488    local_buf = localize_buffer_handler.global_to_local[global_buf_name]
489    for i in range(len(local_buf.get_size())):
490        var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
491        index_vars.append(var if var in used_vars else 0)
492    index = local_buf.layout.make_indexer()(index_vars)
493    return index
494
495
496class LocalizeBufferHandler(V.WrapperHandler):  # type: ignore[name-defined]
497    def __init__(
498        self,
499        inner,
500        global_to_local: Dict[str, ir.Buffer],
501        rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr],
502    ) -> None:
503        super().__init__(inner)
504        self.global_to_local = global_to_local
505        self.rewrite_index = rewrite_index
506
507    def localize(self, name: str, index: sympy.Expr):
508        if self.global_to_local and name in self.global_to_local:
509            assert self.rewrite_index is not None
510            index = self.rewrite_index(self, index, name)
511            name = self.global_to_local[name].get_name()
512        return name, index
513
514    def load(self, name: str, index: sympy.Expr):
515        return self._inner.load(*self.localize(name, index))
516
517    def store(self, name, index, value, mode=None):
518        local_buffer_name, local_buffer_index = self.localize(name, index)
519        res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
520        if (
521            self.global_to_local
522            and name in self.global_to_local
523            and isinstance(V.kernel, Kernel)
524        ):
525            # Remove name of local buffer from Kernel.store_buffer_names
526            # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
527            V.kernel.store_buffer_names.discard(local_buffer_name)
528        return res
529
530    def store_reduction(self, name, index, value):
531        return self._inner.store_reduction(*self.localize(name, index), value)
532
533
534class LocalBufferContext:
535    """
536    This class creates a context that helps to generate code involving Inductor IR with
537    function local buffers. These buffers are constructed during the codegen process and
538    are used to store intermediate results such as local accumulators. We do not want to
539    add them to `V.graph` since they are not global and we do not want to add them as
540    function arguments either. So we patch the codegen processes under this scope to support
541    these buffers without exposure to the outside world.
542    """
543
544    def __init__(self, kernel_args: KernelArgs) -> None:
545        self.kernel_args = kernel_args
546        self.exit_stack = contextlib.ExitStack()
547        # map local buffer name to local buffer
548        self.local_buffers: Dict[str, ir.Buffer] = {}
549        # map global buffer name to global buffer
550        self.global_buffers: Dict[str, ir.Buffer] = {}
551        # map global buffer name to local buffer
552        self.global_to_local: Dict[str, ir.Buffer] = {}
553
554    def __enter__(self):
555        self.exit_stack.__enter__()
556        original_get_dtype = V.graph.get_dtype
557
558        def get_dtype(name):
559            if name in self.local_buffers:
560                return self.local_buffers[name].get_dtype()
561            return original_get_dtype(name)
562
563        self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
564
565        original_input = self.kernel_args.input
566
567        def input(name):
568            if name in self.local_buffers:
569                return name
570            return original_input(name)
571
572        self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
573
574        original_output = self.kernel_args.output
575
576        def output(name):
577            if name in self.local_buffers:
578                return name
579            return original_output(name)
580
581        self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
582
583        # Set current LocalBufferContext into V
584        self.exit_stack.enter_context(V.set_local_buffer_context(self))
585
586        return self
587
588    def __exit__(self, exc_type, exc_val, exc_tb):
589        self.local_buffers.clear()
590        self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
591
592    def add_local_buffer(
593        self, local_buffer: ir.Buffer, global_buffers: Optional[List[ir.Buffer]] = None
594    ):
595        assert local_buffer.get_name() not in self.local_buffers
596        self.local_buffers[local_buffer.get_name()] = local_buffer
597        if global_buffers:
598            for global_buffer in global_buffers:
599                global_buffer_name = global_buffer.get_name()
600                assert (
601                    global_buffer_name not in self.global_buffers
602                    and global_buffer_name not in self.global_to_local
603                )
604                self.global_buffers[global_buffer_name] = global_buffer
605                self.global_to_local[global_buffer_name] = local_buffer
606                V.graph.removed_buffers.add(global_buffer_name)
607
608    def localize_function(
609        self,
610        fn: Callable[..., Any],
611        rewrite_index: Callable[
612            ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
613        ] = rewrite_index_for_function,
614    ):
615        def inner(*args, **kwargs):
616            with V.set_ops_handler(
617                LocalizeBufferHandler(
618                    V.get_ops_handler(),
619                    global_to_local=self.global_to_local,
620                    rewrite_index=rewrite_index,
621                )
622            ):
623                return fn(*args, **kwargs)
624
625        return inner
626
627    def localize_nodes(
628        self,
629        nodes: List[ir.IRNode],
630        rewrite_index: Callable[
631            ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr
632        ] = rewrite_index_for_nodes,
633    ) -> List[ir.IRNode]:
634        """
635        Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
636        though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
637        for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
638        instead of `global_buf`, i.e., all the loads and stores are redirected to
639        `local_buf`. This helps the fused loops to work on smaller-sized local buffers
640        for better data locality.
641
642        The the data access of `local_buf` is assumed to be contiguous with the
643        same order as the `global_buf`.
644        """
645        assert len(nodes) > 0
646
647        def wrap_inner_fn_for_node(node: ir.IRNode):
648            loops = node.data if isinstance(node, ir.ComputedBuffer) else node
649            assert isinstance(loops, ir.Loops)
650            new_loops = copy.copy(loops)
651            if isinstance(node, ir.ComputedBuffer):
652                new_node = ir.ComputedBuffer(
653                    node.get_name(), node.get_layout(), new_loops
654                )
655            else:
656                new_node = new_loops  # type: ignore[assignment]
657
658            new_loops.inner_fn = self.localize_function(
659                new_loops.inner_fn,
660                rewrite_index,
661            )
662            return new_node
663
664        return [wrap_inner_fn_for_node(node) for node in nodes]
665
666
667def unify_mask_base_type(
668    buffer: IndentedBuffer,
669    vars: Tuple[CSEVariable, ...],
670    dtype=torch.float,
671):
672    """
673    Given list of cse variables,
674    Cast each to new mask base dtype and return casted cse variable.
675    """
676    new_vars = (
677        V.kernel.cse.generate(
678            buffer,
679            f"{V.kernel._get_mask_cast(var, dtype)}",
680        )
681        for var in vars
682    )
683    return new_vars
684
685
686def codegen_rand(offset, code, rand_function, dst_dtype=torch.float32):
687    assert is_integer_dtype(offset.dtype)
688    code.writeline("[&]()")
689    with code.indent():
690        code.writeline(
691            f"{DTYPE_TO_CPP[offset.dtype]} offset[{V.kernel.tiling_factor}];"
692        )
693        code.writeline(f"{DTYPE_TO_CPP[dst_dtype]} result[{V.kernel.tiling_factor}];")
694        code.writeline(f"{offset}.store(offset);")
695        code.writeline(
696            f"for( {DTYPE_TO_CPP[offset.dtype]} offset_idx = 0; offset_idx < {V.kernel.tiling_factor}; offset_idx++ )"
697        )
698        with code.indent():
699            code.writeline(rand_function)
700        num_vectors = V.kernel._get_num_vectors(dtype=dst_dtype)
701        if num_vectors == 1:
702            code.writeline(
703                f"return at::vec::Vectorized<{DTYPE_TO_CPP[dst_dtype]}>::loadu(result);"
704            )
705        else:
706            code.writeline(
707                f"return at::vec::VectorizedN<{DTYPE_TO_CPP[dst_dtype]}, {num_vectors}>::loadu(result);"
708            )
709    code.writeline("()")
710    return code
711
712
713def get_gemm_template_output_and_compute_dtype(input_dtype):
714    if input_dtype == torch.uint8:
715        return (torch.int32, torch.int32)
716    else:
717        return (torch.float32, torch.float32)
718
719
720def create_epilogue_with_attr(input_buffer, attr, **kwargs):
721    input_loader = input_buffer.make_loader()
722    dtype = input_buffer.get_dtype()
723    if attr == "relu":
724
725        def inner_fn(index):
726            input = input_loader(index)
727            zero = ops.constant(0, dtype)
728            return ops.maximum(input, zero)
729
730    elif attr == "gelu":
731        assert "algorithm" in kwargs
732        if kwargs["algorithm"] == "none":
733
734            def inner_fn(index):
735                input = input_loader(index)
736                if dtype != torch.float:
737                    input = ops.to_dtype(input, torch.float)
738                half = ops.constant(0.5, torch.float)
739                one = ops.constant(1.0, torch.float)
740                const = ops.constant(0.7071067811865476, torch.float)
741                result = input * half * (ops.erf(input * const) + one)
742                if dtype != torch.float:
743                    result = ops.to_dtype(result, dtype)
744                return result
745
746        else:
747            assert kwargs["algorithm"] == "tanh"
748
749            def inner_fn(index):
750                input = input_loader(index)
751                if dtype != torch.float:
752                    input = ops.to_dtype(input, torch.float)
753                half = ops.constant(0.5, torch.float)
754                one = ops.constant(1.0, torch.float)
755                const1 = ops.constant(0.7978845608028654, torch.float)
756                const2 = ops.constant(0.044715, torch.float)
757                result = (
758                    half
759                    * input
760                    * (
761                        one
762                        + ops.tanh(const1 * (input + const2 * input * input * input))
763                    )
764                )
765                if dtype != torch.float:
766                    result = ops.to_dtype(result, dtype)
767                return result
768
769    elif attr == "swish":
770
771        def inner_fn(index):
772            input = input_loader(index)
773            result = input * ops.sigmoid(input)
774            return result
775
776    elif attr == "sigmoid":
777
778        def inner_fn(index):
779            return ops.sigmoid(input_loader(index))
780
781    elif attr == "tanh":
782
783        def inner_fn(index):
784            return ops.tanh(input_loader(index))
785
786    elif attr == "hardswish" or attr == "hardsigmoid":
787
788        def hardsigmoid_float(input):
789            zero = ops.constant(0, torch.float)
790            six = ops.constant(6, torch.float)
791            three = ops.constant(3, torch.float)
792            one_over_six = ops.constant(0.16666666666666666, torch.float)
793            max = ops.maximum(input + three, zero)
794            min = ops.minimum(max, six)
795            return min * one_over_six
796
797        def inner_fn(index):
798            input = input_loader(index)
799            if dtype != torch.float:
800                input = ops.to_dtype(input, torch.float)
801            result = hardsigmoid_float(input)
802            if attr == "hardswish":
803                result = input * result
804            if dtype != torch.float:
805                result = ops.to_dtype(result, dtype)
806            return result
807
808    elif attr == "leaky_relu":
809        assert "scalars" in kwargs
810        assert len(kwargs["scalars"]) == 1
811        negative_slope = kwargs["scalars"][0]
812
813        def inner_fn(index):
814            input = input_loader(index)
815            if dtype != torch.float:
816                input = ops.to_dtype(input, torch.float)
817            zero = ops.constant(0, torch.float)
818            result = ops.where(
819                input > zero, input, input * ops.constant(negative_slope, torch.float)
820            )
821            if dtype != torch.float:
822                result = ops.to_dtype(result, dtype)
823            return result
824
825    elif attr == "hardtanh":
826        assert "scalars" in kwargs
827        assert len(kwargs["scalars"]) == 2
828        min_value = kwargs["scalars"][0]
829        max_value = kwargs["scalars"][1]
830
831        def inner_fn(index):
832            input = input_loader(index)
833            if dtype != torch.float:
834                input = ops.to_dtype(input, torch.float)
835            result = ops.minimum(
836                ops.maximum(input, ops.constant(min_value, torch.float)),
837                ops.constant(max_value, torch.float),
838            )
839            if dtype != torch.float:
840                result = ops.to_dtype(result, dtype)
841            return result
842
843    elif attr in ["add", "sub", "mul"]:
844        assert "other" in kwargs
845        other = kwargs["other"]
846        num_input_dims = len(input_buffer.get_size())
847        num_other_dims = len(other.get_size())
848        dims_diff = num_input_dims - num_other_dims
849        other_loader = other.make_loader()
850
851        def inner_fn(index):
852            op = getattr(ops, attr)
853            if dims_diff != 0:
854                return op(input_loader(index), other_loader(index[dims_diff:]))
855            else:
856                return op(input_loader(index), other_loader(index))
857
858    elif attr == "bias_add":
859        assert "other" in kwargs
860        assert "beta" in kwargs
861        assert "dtype" in kwargs
862        beta = kwargs["beta"]
863        other = kwargs["other"]
864        dtype = kwargs["dtype"]
865        bias_loader = other.make_loader()
866
867        def inner_fn(index):
868            bias = bias_loader(index)
869            input = input_loader(index)
870            if beta != 1:
871                result = ops.constant(beta, torch.float) * bias + input
872            else:
873                result = bias + input
874            return result
875
876    else:
877        raise ValueError(f"Unsupported epilogue attribute: {attr}")
878    return ir.Pointwise(
879        device=input_buffer.get_device(),
880        dtype=dtype,
881        inner_fn=inner_fn,
882        ranges=input_buffer.get_size(),
883    )
884
885
886def _get_loop_body(fn_list):
887    if all(isinstance(fn, LoopBody) for fn in fn_list):
888        loop_bodies = fn_list
889    else:
890        if hasattr(fn_list[0], "original_fn"):
891            # For the case of local buffer, we wrap the fn with localize_function
892            assert all(hasattr(fn, "original_fn") for fn in fn_list)
893            assert all(
894                isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list
895            )
896            loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list]
897        else:
898            assert all(isinstance(fn, functools.partial) for fn in fn_list)
899            assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list)
900            loop_bodies = [fn.args[0]._body for fn in fn_list]
901    assert loop_bodies is not None
902    return loop_bodies
903
904
905def _get_dtype_from_loopbodies(loop_bodies):
906    dtypes = set()
907    for loop_body in loop_bodies:
908        graphs = [loop_body.root_block.graph] + [
909            body.graph for body in list(loop_body.subblocks.values())
910        ]
911        for graph in graphs:
912            for node in graph.nodes:
913                if node.op != "call_method":
914                    continue
915                dtypes.add(node.meta[OptimizationContext.key].dtype)
916    return dtypes
917