xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/cpp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import dataclasses
4import functools
5import itertools
6import math
7import re
8import sys
9import warnings
10from copy import copy, deepcopy
11from enum import Enum
12from typing import cast, Dict, List, Optional, Sequence, Set, Tuple, Union
13
14import sympy
15
16import torch
17import torch.fx
18from torch._inductor import dependencies
19from torch._prims_common import is_float_dtype, is_integer_dtype
20from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
21from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
22
23from ..._dynamo.utils import counters
24from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics
25from ..loop_body import LoopBody
26from ..scheduler import (
27    BaseSchedulerNode,
28    BaseScheduling,
29    ForeachKernelSchedulerNode,
30    FusedSchedulerNode,
31    Scheduler,
32    SchedulerNode,
33)
34from ..utils import (
35    cache_on_self,
36    get_bounds_index_expr,
37    get_fused_kernel_name,
38    has_free_symbols,
39    is_welford_reduction,
40    parallel_num_threads,
41    Placeholder,
42    sympy_index_symbol,
43    sympy_index_symbol_with_prefix,
44    sympy_product,
45    sympy_subs,
46)
47from ..virtualized import NullKernelHandler, ops, OpsValue, V
48from .common import (
49    BackendFeature,
50    BracesBuffer,
51    CppWrapperKernelArgs,
52    CSE,
53    CSEVariable,
54    DataTypePropagation,
55    DeferredLine,
56    DTYPE_TO_COMPUTATION_DTYPE,
57    IndentedBuffer,
58    Kernel,
59    KernelArgs,
60    OpOverrides,
61    OptimizationContext,
62)
63from .cpp_utils import (
64    _get_dtype_from_loopbodies,
65    _get_loop_body,
66    cexpr,
67    cexpr_index,
68    codegen_rand,
69    CppCSEVariable,
70    DTYPE_TO_CPP,
71    INDEX_TYPE,
72    LocalBufferContext,
73    promote_args,
74    unify_mask_base_type,
75    value_to_cpp,
76)
77
78
79_IS_WINDOWS = sys.platform == "win32"
80
81
82def get_export_declaration():
83    return "__declspec(dllexport)" if _IS_WINDOWS else ""
84
85
86schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
87
88NATIVE_OMP_RTYPES = {"+", "*", "^", "||", "min", "max"}
89RTYPE_TO_CPP = {
90    "sum": "+",
91    "prod": "*",
92    "xor_sum": "^",
93    "min": "min",
94    "max": "max",
95    "argmin": "argmin",
96    "argmax": "argmax",
97    "any": "||",
98    "welford_reduce": "welford",
99    "welford_combine": "welford",
100}
101VECTORIZABLE_RTYPES = {
102    "max",
103    "min",
104    "sum",
105    "prod",
106    "xor_sum",
107    "welford_reduce",
108    "welford_combine",
109    "argmin",
110    "argmax",
111    "any",
112}
113
114PYTHON_TO_CPP = {
115    "Tensor": "at::Tensor",
116    "int": "long",
117    "float": "double",
118    "bool": "bool",
119    "str": "std::string",
120    "ScalarType": "c10::ScalarType",
121    "MemoryFormat": "at::MemoryFormat",
122    "Layout": "at::Layout",
123    "Device": "at::Device",
124    "number": "at::Scalar",
125}
126
127CONTAINER_PYTHON_TO_CPP = {
128    "List": "std::vector",
129    "Optional": "std::optional",
130}
131
132DTYPE_LOWP_FP = [
133    torch.bfloat16,
134    torch.float16,
135]
136
137VECTORIZABLE_DTYPES: List[torch.dtype] = [
138    torch.float64,
139    torch.float,
140    torch.bfloat16,
141    torch.float16,
142    torch.bool,
143    torch.uint8,
144    torch.int8,
145    torch.int32,
146    torch.int64,
147]
148
149MASKED_VECTORIZABLE_DTYPES: List[torch.dtype] = [
150    torch.float,
151    torch.bfloat16,
152    torch.float16,
153    torch.uint8,
154    torch.int8,
155]
156
157
158def reduction_init(reduction_type, dtype):
159    if dtype in DTYPE_LOWP_FP:
160        # Since load promotes all half-precision inputs to float, the initial
161        # constant for reduction must be promoted as well
162        dtype = torch.float32
163    if reduction_type in ("xor_sum", "sum", "any"):
164        return 0
165    if reduction_type == "prod":
166        return 1
167    if reduction_type in ("max", "argmax", "min", "argmin"):
168        cdtype = DTYPE_TO_CPP[dtype]
169        min_var = (
170            f"-std::numeric_limits<{cdtype}>::infinity()"
171            if is_float_dtype(dtype)
172            else f"std::numeric_limits<{cdtype}>::min()"
173        )
174        max_var = (
175            f"std::numeric_limits<{cdtype}>::infinity()"
176            if is_float_dtype(dtype)
177            else f"std::numeric_limits<{cdtype}>::max()"
178        )
179        init_var = min_var if reduction_type in ("max", "argmax") else max_var
180        return (
181            init_var
182            if reduction_type in ("max", "min")
183            else f"IndexValue<{cdtype}>{{0, {init_var}}}"
184        )
185    if is_welford_reduction(reduction_type):
186        return f"Welford<{DTYPE_TO_CPP[dtype]}>()"
187    raise AssertionError(reduction_type)
188
189
190def reduction_acc_type(reduction_type, dtype):
191    scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]]
192    if is_welford_reduction(reduction_type):
193        return f"Welford<{scalar_type}>"
194    if reduction_type in {"argmin", "argmax"}:
195        return f"IndexValue<{scalar_type}>"
196    return scalar_type
197
198
199def reduction_combine(
200    reduction_type,
201    var,
202    next_value,
203    index: Optional[sympy.Symbol] = None,
204    src_dtype=None,
205):
206    is_bool = src_dtype == torch.bool
207    if reduction_type == "sum":
208        conjunction = "|" if is_bool else "+"
209        return f"{var} {conjunction} {next_value}"
210    if reduction_type == "prod":
211        return f"{var} * {next_value}"
212    if reduction_type == "xor_sum":
213        return f"{var} ^ {next_value}"
214    if reduction_type == "any":
215        return f"{var} || {next_value}"
216    if reduction_type in ("min", "max"):
217        return f"{reduction_type}_propagate_nan({var}, {next_value})"
218    if reduction_type == "welford_reduce":
219        return f"welford_combine({var}, {next_value})"
220    if reduction_type == "welford_combine":
221        if isinstance(next_value, tuple):
222            mean, m2, weight = next_value
223        else:
224            mean, m2, weight = reduction_project(reduction_type, next_value)
225        return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
226    if reduction_type in ("argmin", "argmax"):
227        if index is not None:
228            return f"{reduction_type}_combine({var}, {next_value}, {index})"
229        else:
230            return f"{reduction_type}_combine({var}, {next_value})"
231    raise AssertionError(reduction_type)
232
233
234def reduction_project(reduction_type, acc):
235    if is_welford_reduction(reduction_type):
236        return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight"
237    elif reduction_type in {"argmin", "argmax"}:
238        return f"{acc}.index"
239    return acc
240
241
242@functools.lru_cache
243def stride_at(index: sympy.Expr, var: sympy.Symbol):
244    if not index.has(var):
245        # see test_torchinductor_dynamic_shapes.py::test_full_boolean_dynamic_shapes_cpu
246        # which has tmp0 = ops.index_expr(s0 >= 1024, torch.bool) and fails below calculation.
247        # in this case, there is no dependencies between index and var.
248        return sympy.Integer(0)
249    replacement = {var: var + 1}
250    new_index = sympy_subs(index, replacement)  # type: ignore[arg-type]
251    return sympy.simplify(new_index - index)
252
253
254@functools.lru_cache
255def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int):
256    """
257    Simplifies the index expression within the range of a vectorized loop.
258    Given a vectorized loop variable `var` in the range of a loop with `vec_length`,
259    this function transforms the `index` into an equivalent form. It handles
260    simplifications for cases where `var` can be expressed as `vec_length * a + b`,
261    where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences
262    of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations.
263
264    NOTE:
265    The simplified index expression is intended for analysis purposes only, not
266    for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables
267    which are not dependent on the loop variable `var` in the vectorized range. Check
268    https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details.
269
270    Examples:
271    1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then
272       `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable
273       when `div` is divisible by 16.
274    2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free
275       variable when `mod` is divisible by 16.
276    """
277
278    div_freevar_id = 0
279    mod_freevar_id = 0
280
281    def visit_indexing_div(divisor):
282        nonlocal div_freevar_id
283        result = FloorDiv(var, divisor)
284        if sympy.gcd(divisor, vec_length) == vec_length:
285            result = sympy.Symbol(f"{var}_div_c{div_freevar_id}")
286            div_freevar_id += 1
287        return result
288
289    def visit_modular_indexing(divisor, modulus):
290        nonlocal mod_freevar_id
291        result = ModularIndexing(var, divisor, modulus)
292        if sympy.gcd(divisor, vec_length) == vec_length:
293            result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
294            mod_freevar_id += 1
295        elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length:
296            result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
297            mod_freevar_id += 1
298        return result
299
300    original_index = index
301
302    div = sympy.Wild("divisor", integer=True)
303    if index.has(FloorDiv):
304        index = index.replace(FloorDiv(var, div), visit_indexing_div)
305
306    mod = sympy.Wild("modulus", integer=True)
307    if index.has(ModularIndexing):
308        index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing)
309
310    index = sympy.simplify(index)
311    if index != original_index:
312        return simplify_index_in_vec_range(index, var, vec_length)
313
314    return index
315
316
317@functools.lru_cache
318def stride_at_vec_range(
319    index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None
320):
321    if vec_length:
322        index = simplify_index_in_vec_range(index, var, vec_length)
323    return stride_at(index, var)
324
325
326class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
327    @classmethod
328    def fuse(  # type: ignore[override]
329        cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode, outer_loop_fusion_depth
330    ):
331        assert node1.scheduler is node2.scheduler
332        assert all(
333            type(node)
334            in (
335                OuterLoopFusedSchedulerNode,
336                SchedulerNode,
337                FusedSchedulerNode,
338            )
339            for node in (node1, node2)
340        )
341        if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)):
342            return cls(
343                node1.scheduler,
344                (
345                    list(node1.get_outer_nodes())
346                    if type(node1) is OuterLoopFusedSchedulerNode
347                    else [
348                        node1,
349                    ]
350                )
351                + (
352                    list(node2.get_outer_nodes())
353                    if type(node2) is OuterLoopFusedSchedulerNode
354                    else [
355                        node2,
356                    ]
357                ),
358                outer_loop_fusion_depth,
359            )
360        else:
361            return cls(node1.scheduler, [node1, node2], outer_loop_fusion_depth)  # type: ignore[list-item]
362
363    def __init__(
364        self,
365        scheduler: "Scheduler",
366        outer_fused_nodes: List[Union[FusedSchedulerNode, SchedulerNode]],
367        outer_loop_fusion_depth,
368    ):
369        self.outer_fused_nodes: List[
370            Union[FusedSchedulerNode, SchedulerNode]
371        ] = outer_fused_nodes
372        self.outer_loop_fusion_depth = outer_loop_fusion_depth
373        flatten_snodes = []
374        for _node in self.outer_fused_nodes:
375            assert isinstance(_node, (SchedulerNode, FusedSchedulerNode))
376            flatten_snodes.extend(list(_node.get_nodes()))
377        super().__init__(scheduler, flatten_snodes)  # type: ignore[arg-type]
378
379    def get_outer_nodes(self):
380        return self.outer_fused_nodes
381
382    def check_outer_fusion_loop_level_attr(
383        self, cpp_kernel_proxy_list, outer_loop_fusion_depth
384    ):
385        # This function ensures that the same tiling split is applied at each loop level within the outer loop fusion depth.
386        # In the fusion stage, we only examine nodes with same vars and reduce.
387        # However, for nodes with same vars and reduce, the loops may still have different tile splits.
388        # For example (test_expr_vec_non_contiguous in test_cpu_repro.py):
389        #   * buf0 tiling along the 2nd loop level, buf1 tiling along the 3rd loop level.
390        # If the check failed, we should fall back to standard loop codegen.
391        def _inner(
392            left_loop_level: LoopLevel,
393            right_loop_level: LoopLevel,
394            loop_fusion_depth: int,
395        ) -> bool:
396            # Check if same loop level attr
397            outer_loops_attr_compare_list = [
398                "var",
399                "size",
400                "offset",
401                "steps",
402            ]
403            if not (
404                all(
405                    getattr(left_loop_level, attr_compare)
406                    == getattr(right_loop_level, attr_compare)
407                    for attr_compare in outer_loops_attr_compare_list
408                )
409            ):
410                return False
411
412            assert loop_fusion_depth >= 1
413            if (loop_fusion_depth := loop_fusion_depth - 1) > 0:
414                # If the next loop level is expected to undergo outer loop fusion,
415                # there should be no kernel present at the current loop level.
416                assert (
417                    left_loop_level.kernel is None and right_loop_level.kernel is None
418                )
419                # Check next loop level attr
420                if any(
421                    # Assume no main/tail loop split at any outer loop fusion depth
422                    # Given no clear performance benefit for this complex case
423                    len(loop_level.inner) != 1
424                    for loop_level in [left_loop_level, right_loop_level]
425                ) or not _inner(
426                    left_loop_level.inner[0],
427                    right_loop_level.inner[0],
428                    loop_fusion_depth,
429                ):
430                    return False
431
432            return True
433
434        for idx in range(len(cpp_kernel_proxy_list) - 1):
435            left_loop_nest = cpp_kernel_proxy_list[idx].loop_nest
436            right_loop_nest = cpp_kernel_proxy_list[idx + 1].loop_nest
437            if any(
438                # Assume no main/tail loop split at any outer loop fusion depth
439                len(loop_nest.root) != 1
440                for loop_nest in [left_loop_nest, right_loop_nest]
441            ) or not _inner(
442                left_loop_nest.root[0], right_loop_nest.root[0], outer_loop_fusion_depth
443            ):
444                return False
445
446        return True
447
448    def merge_outer_fusion_kernels(
449        self,
450        cpp_kernel_proxy_list,
451    ):
452        loop_nest_list: List[LoopNestWithSplit] = [
453            kernel.loop_nest for kernel in cpp_kernel_proxy_list
454        ]
455        kernel_group = cpp_kernel_proxy_list[0].kernel_group
456
457        def _merge_outer_fusion_loop_levels(
458            loop_level_nested_list: List[List["LoopLevel"]],
459            outer_loop_fusion_depth,
460        ):
461            assert outer_loop_fusion_depth >= 1
462            # Assume no main/tail loop split at any outer loop fusion depth
463            assert all(
464                len(loop_level_list) == 1 for loop_level_list in loop_level_nested_list
465            )
466            if (outer_loop_fusion_depth := outer_loop_fusion_depth - 1) >= 1:
467                # Further merge the next loop level
468                next_loop_level_nested_list = [
469                    loop_level_list[0].inner
470                    for loop_level_list in loop_level_nested_list
471                ]
472                _merge_outer_fusion_loop_levels(
473                    next_loop_level_nested_list,
474                    outer_loop_fusion_depth,
475                )
476            else:
477                outer_loop_fused_kernel = OuterLoopFusedKernel(kernel_group)
478                loop_level_of_first_kernel = loop_level_nested_list[0][0]
479                for kernel_idx in range(len(loop_level_nested_list)):
480                    outer_loop_fused_kernel.inner.append(
481                        deepcopy(loop_level_nested_list[kernel_idx][0]),
482                    )
483                loop_level_of_first_kernel.inner = []
484                loop_level_of_first_kernel.kernel = outer_loop_fused_kernel
485
486        # Merge the List[LoopNestWithSplit] from cpp_kernel_proxy_list
487        # into cpp_kernel_proxy_list[0].loop_nest
488        _merge_outer_fusion_loop_levels(
489            [_loop_nest.root for _loop_nest in loop_nest_list],  # type: ignore[misc]
490            self.outer_loop_fusion_depth,
491        )
492        return cpp_kernel_proxy_list[0]
493
494
495class RecordOptimizationContext:
496    def __init__(self, func_name: str = ""):
497        self.func_name = func_name
498        self.current_node: Optional[torch.fx.Node] = None
499        self.opt_ctx: Optional[OptimizationContext] = None
500
501    def __enter__(self):
502        assert V.interpreter
503        assert V.interpreter.current_node
504
505        self.current_node = V.interpreter.current_node
506        assert self.current_node is not None
507        if OptimizationContext.key in self.current_node.meta:
508            self.opt_ctx = self.current_node.meta[OptimizationContext.key]
509        else:
510            self.opt_ctx = OptimizationContext()
511        assert self.opt_ctx is not None
512        self.opt_ctx.ops_name = self.func_name
513        return self
514
515    def __exit__(self, exc_type, exc_val, exc_tb):
516        assert self.current_node
517        assert self.opt_ctx
518        self.current_node.meta[OptimizationContext.key] = self.opt_ctx
519
520    def get_opt_ctx(self):
521        return self.opt_ctx
522
523    def get_fx_node(self):
524        assert self.current_node
525        return self.current_node
526
527
528class CppOverrides(OpOverrides):
529    """Map element-wise ops to C++"""
530
531    @staticmethod
532    def add(a, b):
533        return f"decltype({a})({a} + {b})"
534
535    @staticmethod
536    def sub(a, b):
537        return f"decltype({a})({a} - {b})"
538
539    @staticmethod
540    def mul(a, b):
541        return f"decltype({a})({a} * {b})"
542
543    @staticmethod
544    def to_dtype(x, dtype, src_dtype=None, use_compute_types=True):
545        assert isinstance(x, CppCSEVariable)
546        if src_dtype is None:
547            src_dtype = x.dtype
548        expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype)
549        csevar = V.kernel.cse.generate(V.kernel.compute, expr)
550        csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype})
551        if dtype in [torch.bfloat16, torch.float16] and src_dtype == torch.float:
552            """
553            https://github.com/pytorch/pytorch/issues/115260
554            For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is
555            in low-precision floating point data type. When the output of node1 also serves as the output of the
556            kernel, the result of nodes would be different from the case when output of node1 is not the output
557            of the kernel (where we don't need to insert `to_dtype` for legalization). To address the problem, on
558            storing the lowp node1 output, we also add the inverse dtype conversion to high precision data type
559            to the cse cache.
560
561            Example (pseudo code):
562                node1_output = ...
563                node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16)
564                store(buf, node1_output_lowp)
565                node2_input_lowp = load(buf)
566                node2_input = to_dtype(node2_input_lowp, dtype=torch.float)
567
568            Without cse cache trick:
569                node1_output = ...
570                node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16)
571                store(buf, node1_output_lowp)
572                node2_input_lowp = node_output_lowp # hit store cache
573                node2_input = to_dtype(node2_input_lowp, dtype=torch.float)
574
575            With cse cache trick:
576                node1_output = ...
577                node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16)
578                # also add `to_dtype(node1_input_lowp, dtype=torch.float)` -> `node1_output` to cse cache
579                store(buf, node1_output_lowp)
580                node2_input_lowp = node_output_lowp # hit store cache
581                node2_input = node1_output # hit cse cache
582            """
583            V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype)
584        return csevar
585
586    @staticmethod
587    def to_dtype_bitcast(x, dtype, src_dtype):
588        assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP"
589        if src_dtype in (torch.float16, torch.bfloat16):
590            # c10::bit_cast requires the source and target have the bitwidth.
591            # Because the input tensor's dtype could be promoted, e.g. from float16 to
592            # float, we have to cast the tensor to its original source dtype before
593            # invoking bit_cast. We also need to convert the bit-casted tensor
594            # back to float to make sure we keep using higher precision values
595            # for the rest of the computation.
596            cast_x = f"c10::convert<{DTYPE_TO_CPP[src_dtype]}>({x})"
597            cast_x = f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({cast_x})"
598            return f"c10::convert<{DTYPE_TO_CPP[torch.float32]}>({cast_x})"
599        else:
600            return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})"
601
602    @staticmethod
603    def abs(x):
604        return f"std::abs({x})"
605
606    @staticmethod
607    def sin(x):
608        return f"std::sin({x})"
609
610    @staticmethod
611    def cos(x):
612        return f"std::cos({x})"
613
614    @staticmethod
615    def neg(x):
616        return f"decltype({x})(-{x})"
617
618    @staticmethod
619    def exp(x):
620        # return f"Sleef_expf_u10({x})"
621        return f"std::exp({x})"
622
623    @staticmethod
624    def exp2(x):
625        return f"std::exp2({x})"
626
627    @staticmethod
628    def expm1(x):
629        return f"std::expm1({x})"
630
631    @staticmethod
632    def erf(x):
633        return f"std::erf({x})"
634
635    @staticmethod
636    def erfc(x):
637        return f"std::erfc({x})"
638
639    @staticmethod
640    def erfinv(x):
641        return f"calc_erfinv({x})"
642
643    @staticmethod
644    def sqrt(x):
645        return f"std::sqrt({x})"
646
647    @staticmethod
648    def rsqrt(x):
649        return f"1 / std::sqrt({x})"
650
651    @staticmethod
652    def log1p(x):
653        bug = config.cpp.inject_log1p_bug_TESTING_ONLY
654        if bug == "accuracy":
655            return f"{x} + decltype({x})(1)"
656        elif bug is None:
657            return f"std::log1p({x})"
658        else:
659            raise AssertionError(
660                f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}"
661            )
662
663    @staticmethod
664    def tan(x):
665        return f"std::tan({x})"
666
667    @staticmethod
668    def tanh(x):
669        return f"std::tanh({x})"
670
671    @staticmethod
672    def signbit(x):
673        """
674        On windows std::signbit only support float type.
675        Ref: https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/signbit?view=msvc-170
676        """
677        return (
678            f"std::signbit(static_cast<float>({x}))"
679            if _IS_WINDOWS
680            else f"std::signbit({x})"
681        )
682
683    @staticmethod
684    def pow(a, b):
685        return f"std::pow({a}, {b})"
686
687    @staticmethod
688    def log(x):
689        return f"std::log({x})"
690
691    @staticmethod
692    def round(x):
693        return f"std::nearbyint({x})"
694
695    @staticmethod
696    def floor(x):
697        return f"std::floor({x})"
698
699    @staticmethod
700    def floordiv(a, b):
701        # a and b are integer type
702        quot = f"{a} / {b}"
703        rem = f"{a} % {b}"
704        return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})"
705
706    @staticmethod
707    def ceil(x):
708        return f"std::ceil({x})"
709
710    @staticmethod
711    def trunc(x):
712        return f"std::trunc({x})"
713
714    @staticmethod
715    def truncdiv(a, b):
716        # a and b are integer type
717        return f"{a} / {b}"
718
719    @staticmethod
720    def fmod(a, b):
721        return f"std::fmod({a}, {b})"
722
723    @staticmethod
724    def isinf(x):
725        return f"std::isinf({x})"
726
727    @staticmethod
728    def isnan(x):
729        return f"std::isnan({x})"
730
731    @staticmethod
732    def lgamma(x):
733        return f"std::lgamma({x})"
734
735    @staticmethod
736    def acos(x):
737        return f"std::acos({x})"
738
739    @staticmethod
740    def acosh(x):
741        return f"std::acosh({x})"
742
743    @staticmethod
744    def cosh(x):
745        return f"std::cosh({x})"
746
747    @staticmethod
748    def sinh(x):
749        return f"std::sinh({x})"
750
751    @staticmethod
752    def asin(x):
753        return f"std::asin({x})"
754
755    @staticmethod
756    def asinh(x):
757        return f"std::asinh({x})"
758
759    @staticmethod
760    def atan2(x, y):
761        return f"std::atan2({x}, {y})"
762
763    @staticmethod
764    def atan(x):
765        return f"std::atan({x})"
766
767    @staticmethod
768    def atanh(x):
769        return f"std::atanh({x})"
770
771    @staticmethod
772    def copysign(x, y):
773        return f"std::copysign({x}, {y})"
774
775    @staticmethod
776    def frexp(x):
777        cache_keys = f"frexp({x})[0]", f"frexp({x})[1]"
778        if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys):
779            return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys)
780
781        code = BracesBuffer()
782        exponent = V.kernel.cse.newvar()
783        mantissa = V.kernel.cse.newvar()
784        code.writeline(f"int32_t {exponent};")
785        code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});")
786        V.kernel.compute.splice(code)
787        cse_vars = (mantissa, exponent)
788        for cache_key, cse_var in zip(cache_keys, cse_vars):
789            V.kernel.cse.cache[cache_key] = cse_var
790        return mantissa, exponent
791
792    @staticmethod
793    def hypot(x, y):
794        return f"std::hypot({x}, {y})"
795
796    @staticmethod
797    def log10(x):
798        return f"std::log10({x})"
799
800    @staticmethod
801    def log2(x):
802        return f"std::log2({x})"
803
804    @staticmethod
805    def nextafter(x, y):
806        return f"std::nextafter({x}, {y})"
807
808    @staticmethod
809    def relu(x):
810        bug = config.cpp.inject_relu_bug_TESTING_ONLY
811        if bug == "compile_error":
812            return "compile error!"
813        elif bug == "runtime_error":
814            return f"{x}; throw 1"
815        elif bug == "accuracy":
816            return f"{x} + decltype({x})(1)"
817        elif bug is None:
818            return f"std::max({x}, decltype({x})(0))"
819        else:
820            raise AssertionError(
821                f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}"
822            )
823
824    @staticmethod
825    def minimum(a, b):
826        return f"min_propagate_nan({a}, {b})"
827
828    @staticmethod
829    def maximum(a, b):
830        return f"max_propagate_nan({a}, {b})"
831
832    @staticmethod
833    def where(a, b, c):
834        return f"{a} ? {b} : {c}"
835
836    @staticmethod
837    def mod(a, b):
838        return f"mod({a}, {b})"
839
840    @staticmethod
841    def constant(val, dtype):
842        if dtype in DTYPE_LOWP_FP:
843            # Since load promotes all half-precision inputs to float, constants
844            # must be promoted as well
845            dtype = torch.float32
846        return value_to_cpp(val, DTYPE_TO_CPP[dtype])
847
848    @staticmethod
849    def index_expr(expr, dtype):
850        idx_str = cexpr(V.kernel.rename_indexing(expr))
851        var = V.kernel.cse.generate(
852            V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr)
853        )
854        return ops.to_dtype(var, dtype)
855
856    @staticmethod
857    def masked(mask, body, other):
858        code = BracesBuffer()
859
860        # Write masked operation into a lambda
861        body_var = V.kernel.cse.newvar()
862        code.writeline(f"auto {body_var} = [&]")
863        with V.kernel.swap_buffers(code), code.indent():
864            result = body()
865            code.writeline(f"return {result};")
866        code.writeline(";")
867        V.kernel.compute.splice(code)
868
869        # Use the lambda's return type as the type of other
870        other_code = value_to_cpp(other, f"decltype({body_var}())")
871        return f"{mask} ? {body_var}() : {other_code}"
872
873    @staticmethod
874    def logical_and(a, b):
875        return f"{a} && {b}"
876
877    @staticmethod
878    def logical_not(a):
879        return f"!{a}"
880
881    @staticmethod
882    def logical_or(a, b):
883        return f"{a} || {b}"
884
885    @staticmethod
886    def logical_xor(a, b):
887        return f"{a} != {b}"
888
889    @staticmethod
890    def bitwise_and(a, b):
891        return f"decltype({a})({a} & {b})"
892
893    @staticmethod
894    def bitwise_not(a):
895        return f"decltype({a})(~{a})"
896
897    @staticmethod
898    def bitwise_or(a, b):
899        return f"decltype({a})({a} | {b})"
900
901    @staticmethod
902    def bitwise_xor(a, b):
903        return f"decltype({a})({a} ^ {b})"
904
905    @staticmethod
906    def bitwise_left_shift(a, b):
907        return f"decltype({a})({a} << {b})"
908
909    @staticmethod
910    def bitwise_right_shift(a, b):
911        return f"decltype({a})({a} >> {b})"
912
913    @staticmethod
914    def rand(seed: sympy.Expr, offset: sympy.Expr):
915        return f"normalized_rand_cpu({seed}, {offset})"
916
917    @staticmethod
918    def randn(seed: sympy.Expr, offset: sympy.Expr):
919        return f"randn_cpu({seed}, {offset})"
920
921    @staticmethod
922    def randint64(seed: sympy.Expr, offset: sympy.Expr, low, high):
923        return f"randint64_cpu({seed}, {offset}, {low}, {high})"
924
925    @staticmethod
926    def sigmoid(x):
927        return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))"
928
929    @staticmethod
930    def sign(x):
931        code = BracesBuffer()
932        scalar_zero = f"decltype({x})(0)"
933        scalar_one = f"decltype({x})(1)"
934        code.writeline("[&]()")
935        with code.indent():
936            code.writeline(f"auto left = {x} > 0 ? {scalar_one} : {scalar_zero};")
937            code.writeline(f"auto right = {x} < 0 ? {scalar_one} : {scalar_zero};")
938            code.writeline("return left - right;")
939        code.writeline("()")
940        return code
941
942
943CppOverrides._initialize_pointwise_overrides("cpp")
944
945
946class CppVecOverrides(CppOverrides):
947    """Map element-wise ops to aten vectorization C++"""
948
949    def __new__(cls, *args, **kargs):
950        self = super().__new__(cls)
951
952        def wrap(func):
953            # `CppVecKernel` generates both scalar ops and vector ops according to
954            # whether the inputs are scalars or vectors while all ops in `CppVecOverrides`
955            # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in
956            # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to
957            # `CppOverrides` when all inputs are scalars.
958            #
959            # Notes on ops handled separately in their own functions:
960            # `ops.masked`:
961            #     needs recursive handling of masked body.
962            # `ops.index_expr`:
963            #     needs to further analyze the dependency of the index expression on
964            #     the tiling itervar.
965            def wrapper(*args, **kwargs):
966                scalars = [
967                    arg
968                    for arg in args
969                    if isinstance(arg, (int, sympy.Expr))
970                    or (isinstance(arg, CppCSEVariable) and not arg.is_vec)
971                ]
972                vectors = [
973                    arg
974                    for arg in args
975                    if isinstance(arg, CppCSEVariable) and arg.is_vec
976                ]
977                new_args = list(args)
978                if scalars and vectors:
979                    new_args = []
980                    for arg in args:
981                        if isinstance(arg, (int, sympy.Expr)):
982                            if isinstance(arg, sympy.Expr) and not arg.is_number:
983                                arg = ops.index_expr(arg, torch.int64)
984                            else:
985                                arg = ops.constant(arg, torch.int64)
986                            arg = arg.value if isinstance(arg, OpsValue) else arg
987                        new_args.append(arg)
988
989                # DType Promotion
990                if vectors:
991                    # We have saw several data type mismatch issues related with index_expr in
992                    # the lowering phase of torch.int8. torch.int32, torch.int64.
993                    # 1. int32 and int64 in test_torchinductor.py::test_max_pool2d_with_indices_backward3_cpu
994                    # 2. int8 and int32 in test_torchinductor.py::test_max_pool2d5_cpu
995                    # 3. int32 and fp32 in test_torchinductor_dynamic_shapes.py::test_avg_pool2d8_dynamic_shapes_cpu
996                    if len(new_args) == 2:
997                        new_args = promote_args(new_args)
998                    elif func == CppVecOverrides.where:
999                        new_args[1:] = promote_args(new_args[1:])
1000
1001                # Broadcast scalar args to vector
1002                if scalars and vectors:
1003                    assert isinstance(V.kernel, CppVecKernel)
1004                    new_args = [
1005                        V.kernel.broadcast(new_arg)
1006                        if (
1007                            isinstance(new_arg, CppCSEVariable)
1008                            and not new_arg.is_vec
1009                            and func
1010                            not in [
1011                                CppVecOverrides.rand,
1012                                CppVecOverrides.randn,
1013                                CppVecOverrides.randint64,
1014                            ]
1015                        )
1016                        else new_arg
1017                        for new_arg in new_args
1018                    ]
1019
1020                if vectors:
1021                    return func(*new_args, **kwargs)
1022                else:
1023                    # fallback to scalar ops
1024                    scalar_ops = super(CppVecOverrides, self)
1025                    scalar_func = getattr(
1026                        scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__)  # type: ignore[attr-defined]
1027                    )
1028                    assert scalar_func is not None
1029                    return scalar_func(*args, **kwargs)
1030
1031            return wrapper
1032
1033        for name, method in vars(CppVecOverrides).items():
1034            if getattr(method, "__class__", None) == staticmethod and name not in [
1035                "masked",
1036                "index_expr",
1037            ]:
1038                setattr(self, name, wrap(method.__func__))
1039
1040        return self
1041
1042    @staticmethod
1043    def add(a, b):
1044        return f"{a} + {b}"
1045
1046    @staticmethod
1047    def sub(a, b):
1048        return f"{a} - {b}"
1049
1050    @staticmethod
1051    def mul(a, b):
1052        return f"{a} * {b}"
1053
1054    @staticmethod
1055    def truediv(a, b):
1056        return f"{a} / {b}"
1057
1058    @staticmethod
1059    def abs(x):
1060        return f"{x}.abs()"
1061
1062    @staticmethod
1063    def sin(x):
1064        return f"{x}.sin()"
1065
1066    @staticmethod
1067    def cos(x):
1068        return f"{x}.cos()"
1069
1070    @staticmethod
1071    def exp(x):
1072        return f"{x}.exp()"
1073
1074    @staticmethod
1075    def exp2(x):
1076        return f"{x}.exp2()"
1077
1078    @staticmethod
1079    def expm1(x):
1080        # decompose for a better performance
1081        vec_one = f"decltype({x})(1)"
1082        return f"{x}.exp() - {vec_one}"
1083
1084    @staticmethod
1085    def erf(x):
1086        return f"{x}.erf()"
1087
1088    @staticmethod
1089    def erfc(x):
1090        return f"{x}.erfc()"
1091
1092    @staticmethod
1093    def erfinv(x):
1094        return f"{x}.erfinv()"
1095
1096    @staticmethod
1097    def sqrt(x):
1098        return f"{x}.sqrt()"
1099
1100    @staticmethod
1101    def eq(x, y):
1102        assert isinstance(V.kernel, CppVecKernel)
1103        assert isinstance(x, CppCSEVariable)
1104        assert x.dtype is not None
1105        return f"{V.kernel._get_mask_type(x.dtype)}({x} == {y})"
1106
1107    @staticmethod
1108    def ne(x, y):
1109        assert isinstance(V.kernel, CppVecKernel)
1110        assert isinstance(x, CppCSEVariable)
1111        if x.dtype == torch.bool:
1112            assert y.dtype == torch.bool
1113            x_cast, y_cast = unify_mask_base_type(V.kernel.compute, (x, y))
1114            return f"{x_cast} != {y_cast}"
1115        else:
1116            assert x.dtype is not None
1117            return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})"
1118
1119    @staticmethod
1120    def lt(x, y):
1121        assert isinstance(V.kernel, CppVecKernel)
1122        assert isinstance(x, CppCSEVariable)
1123        assert x.dtype is not None
1124        return f"{V.kernel._get_mask_type(x.dtype)}({x} < {y})"
1125
1126    @staticmethod
1127    def gt(x, y):
1128        assert isinstance(V.kernel, CppVecKernel)
1129        assert isinstance(x, CppCSEVariable)
1130        assert x.dtype is not None
1131        return f"{V.kernel._get_mask_type(x.dtype)}({x} > {y})"
1132
1133    @staticmethod
1134    def le(x, y):
1135        assert isinstance(V.kernel, CppVecKernel)
1136        assert isinstance(x, CppCSEVariable)
1137        assert x.dtype is not None
1138        return f"{V.kernel._get_mask_type(x.dtype)}({x} <= {y})"
1139
1140    @staticmethod
1141    def ge(x, y):
1142        assert isinstance(V.kernel, CppVecKernel)
1143        assert isinstance(x, CppCSEVariable)
1144        assert x.dtype is not None
1145        return f"{V.kernel._get_mask_type(x.dtype)}({x} >= {y})"
1146
1147    @staticmethod
1148    def and_(x, y):
1149        return f"{x} & {y}"
1150
1151    @staticmethod
1152    def rsqrt(x):
1153        return f"{x}.rsqrt()"
1154
1155    @staticmethod
1156    def pow(a, b):
1157        return f"{a}.pow({b})"
1158
1159    @staticmethod
1160    def log(x):
1161        return f"{x}.log()"
1162
1163    @staticmethod
1164    def round(x):
1165        return f"{x}.round()"
1166
1167    @staticmethod
1168    def floor(x):
1169        return f"{x}.floor()"
1170
1171    @staticmethod
1172    def ceil(x):
1173        return f"{x}.ceil()"
1174
1175    @staticmethod
1176    def trunc(x):
1177        return f"{x}.trunc()"
1178
1179    @staticmethod
1180    def fmod(a, b):
1181        return f"{a}.fmod({b})"
1182
1183    @staticmethod
1184    def lgamma(x):
1185        return f"{x}.lgamma()"
1186
1187    @staticmethod
1188    def logical_and(a, b):
1189        return f"{a} & {b}"
1190
1191    @staticmethod
1192    def logical_not(a):
1193        return f"~{a}"
1194
1195    @staticmethod
1196    def logical_or(a, b):
1197        return f"{a} | {b}"
1198
1199    @staticmethod
1200    def logical_xor(a, b):
1201        return f"{a} ^ {b}"
1202
1203    @staticmethod
1204    def bitwise_and(a, b):
1205        return f"{a} & {b}"
1206
1207    @staticmethod
1208    def bitwise_not(a):
1209        return f"~{a}"
1210
1211    @staticmethod
1212    def bitwise_or(a, b):
1213        return f"{a} | {b}"
1214
1215    @staticmethod
1216    def bitwise_xor(a, b):
1217        return f"{a} ^ {b}"
1218
1219    @staticmethod
1220    def bitwise_left_shift(a, b):
1221        return f"{a} << {b}"
1222
1223    @staticmethod
1224    def bitwise_right_shift(a, b):
1225        return f"{a} >> {b}"
1226
1227    @staticmethod
1228    def load_seed(name, offset):
1229        assert isinstance(V.kernel, CppVecKernel)
1230        return f"{V.kernel.load(name, offset)}"
1231
1232    @staticmethod
1233    def rand(seed, offset):
1234        assert isinstance(V.kernel, CppVecKernel)
1235        code = BracesBuffer()
1236        rand_function = (
1237            f"result[offset_idx] = normalized_rand_cpu({seed}, offset[offset_idx]);"
1238        )
1239        return codegen_rand(offset, code, rand_function)
1240
1241    @staticmethod
1242    def randn(seed, offset):
1243        assert isinstance(V.kernel, CppVecKernel)
1244        code = BracesBuffer()
1245        rand_function = f"result[offset_idx] = randn_cpu({seed}, offset[offset_idx]);"
1246        return codegen_rand(offset, code, rand_function)
1247
1248    @staticmethod
1249    def randint64(seed, offset, low, high):
1250        assert isinstance(V.kernel, CppVecKernel)
1251        code = BracesBuffer()
1252        rand_function = f"result[offset_idx] = randint64_cpu({seed}, offset[offset_idx], {low}, {high});"
1253        return codegen_rand(offset, code, rand_function, torch.int64)
1254
1255    @staticmethod
1256    def remainder(a, b):
1257        assert (
1258            a.dtype == b.dtype
1259        ), "remainder vec implementation expect the same inputs' dtype."
1260        return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}"
1261
1262    @staticmethod
1263    def tan(a):
1264        return f"{a}.tan()"
1265
1266    @staticmethod
1267    def tanh(a):
1268        vec_one = f"decltype({a})(1)"
1269        vec_two = f"decltype({a})(2)"
1270        vec_minus_two = f"decltype({a})(-2)"
1271        return f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}"
1272
1273    @staticmethod
1274    def reciprocal(a):
1275        return f"{a}.reciprocal()"
1276
1277    @staticmethod
1278    def atan(x):
1279        return f"{x}.atan()"
1280
1281    @staticmethod
1282    def acos(x):
1283        return f"{x}.acos()"
1284
1285    @staticmethod
1286    def asin(x):
1287        return f"{x}.asin()"
1288
1289    @staticmethod
1290    def cosh(x):
1291        return f"{x}.cosh()"
1292
1293    @staticmethod
1294    def sinh(x):
1295        return f"{x}.sinh()"
1296
1297    @staticmethod
1298    def log10(x):
1299        return f"{x}.log10()"
1300
1301    @staticmethod
1302    def log2(x):
1303        return f"{x}.log2()"
1304
1305    @staticmethod
1306    def nextafter(x, y):
1307        return f"{x}.nextafter({y})"
1308
1309    @staticmethod
1310    def copysign(a, b):
1311        return f"{a}.copysign({b})"
1312
1313    @staticmethod
1314    def atan2(a, b):
1315        return f"{a}.atan2({b})"
1316
1317    @staticmethod
1318    def hypot(a, b):
1319        return f"{a}.hypot({b})"
1320
1321    @staticmethod
1322    def atanh(x):
1323        # For real x, atanh(x) = 1/2 * log((1+x)/(1-x))
1324        vec_one = f"decltype({x})(1)"
1325        vec_one_half = f"decltype({x})(0.5)"
1326        return f"{vec_one_half} * (({vec_one} + {x})/({vec_one} - {x})).log()"
1327
1328    @staticmethod
1329    def asinh(x):
1330        # For real x, asinh(x) = log(x + sqrt(1 + x**2))
1331        vec_one = f"decltype({x})(1)"
1332        return f"({x} + ({vec_one} + {x}*{x}).sqrt()).log()"
1333
1334    @staticmethod
1335    def acosh(x):
1336        return f"{x}.acosh()"
1337
1338    @staticmethod
1339    def relu(x):
1340        bug = config.cpp.inject_relu_bug_TESTING_ONLY
1341        if bug == "compile_error":
1342            return "compile error!"
1343        elif bug == "runtime_error":
1344            return f"{x}; throw 1"
1345        elif bug == "accuracy":
1346            return f"{x} + decltype({x})(1)"
1347        elif bug is None:
1348            return f"at::vec::clamp_min({x}, decltype({x})(0))"
1349        else:
1350            raise AssertionError(
1351                f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}"
1352            )
1353
1354    # TODO: this seems to be dead
1355    @staticmethod
1356    def sigmoid(x):
1357        return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())"
1358
1359    @staticmethod
1360    def neg(x):
1361        return f"{x}.neg()"
1362
1363    @staticmethod
1364    def floordiv(a, b):
1365        if is_float_dtype(a.dtype):
1366            assert (
1367                a.dtype == b.dtype
1368            ), "div_floor_floating_vec implementation expect the same inputs' dtype."
1369            return f"div_floor_floating_vec({a}, {b})"
1370        else:
1371            assert all(is_integer_dtype(item.dtype) for item in [a, b])
1372            # a and b are integer type
1373            _t = f"decltype({a})"
1374            if V.kernel._get_raw_num_vectors(b.dtype) < 1:
1375                # Doing blend to set the remaining bits of b to non-zero
1376                b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})"
1377            quot = f"{a} / {b}"
1378            has_rem = f"({a} % {b} != {_t}(0))"
1379            is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))"
1380            return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})"
1381
1382    @staticmethod
1383    def truncdiv(a, b):
1384        # a and b are integer type
1385        if V.kernel._get_raw_num_vectors(b.dtype) < 1:
1386            # Doing blend to set the remaining bits of b to non-zero
1387            _t = f"decltype({b})"
1388            b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})"
1389        return f"{a} / {b}"
1390
1391    @staticmethod
1392    def minimum(a, b):
1393        if a.dtype == torch.bool:
1394            assert b.dtype == torch.bool
1395            a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b))
1396            return f"{a_cast} & {b_cast}"
1397        else:
1398            return f"at::vec::minimum({a}, {b})"
1399
1400    @staticmethod
1401    def maximum(a, b):
1402        if a.dtype == torch.bool:
1403            assert b.dtype == torch.bool
1404            a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b))
1405            return f"{a_cast} | {b_cast}"
1406        else:
1407            return f"at::vec::maximum({a}, {b})"
1408
1409    @staticmethod
1410    def square(a):
1411        return f"{a} * {a}"
1412
1413    @staticmethod
1414    def where(a, b, c):
1415        assert isinstance(V.kernel, CppVecKernel)
1416        if b.dtype == torch.bool:
1417            assert c.dtype == torch.bool
1418            blendv_a, blendv_b, blendv_c = unify_mask_base_type(
1419                V.kernel.compute, (a, b, c)
1420            )
1421            return f"decltype({blendv_b})::blendv({blendv_c}, {blendv_b}, {blendv_a})"
1422        else:
1423            return f"decltype({b})::blendv({c}, {b}, {V.kernel._get_mask_cast(a, b.dtype)})"
1424
1425    @staticmethod
1426    def sign(x):
1427        code = BracesBuffer()
1428        vec_zero = f"decltype({x})(0)"
1429        vec_one = f"decltype({x})(1)"
1430        blendv_l = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})"
1431        blendv_r = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})"
1432        code.writeline("[&]()")
1433        with code.indent():
1434            code.writeline(f"auto left = {blendv_l};")
1435            code.writeline(f"auto right = {blendv_r};")
1436            code.writeline("return left - right;")
1437        code.writeline("()")
1438        return code
1439
1440    @staticmethod
1441    def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True):
1442        assert dtype in [
1443            torch.bool,
1444            torch.float64,
1445            torch.float,
1446            torch.bfloat16,
1447            torch.float16,
1448            torch.uint8,
1449            torch.int8,
1450            torch.int32,
1451            torch.int64,
1452        ], f"{__name__} does not support {dtype}"
1453        assert isinstance(x, CppCSEVariable)
1454        src_dtype = x.dtype
1455        expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype)
1456        csevar = V.kernel.cse.generate(V.kernel.compute, expr)
1457        csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype})
1458        if dtype in [torch.bfloat16, torch.float16] and src_dtype == torch.float:
1459            V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype)
1460        return csevar
1461
1462    @staticmethod
1463    def log1p(x):
1464        bug = config.cpp.inject_log1p_bug_TESTING_ONLY
1465        if bug == "accuracy":
1466            return f"{x} + decltype({x})(1)"
1467        elif bug is None:
1468            return f"{x}.log1p()"
1469        else:
1470            raise AssertionError(
1471                f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}"
1472            )
1473
1474    @staticmethod
1475    def masked(mask, body, other):
1476        assert isinstance(V.kernel, CppVecKernel)
1477        code = BracesBuffer()
1478        var = V.kernel.cse.newvar()
1479        with V.kernel.masked(mask) as new_mask:
1480            code.writeline(f"auto {var} = [&]")
1481            with V.kernel.swap_buffers(code), code.indent():
1482                result = body()
1483                code.writeline(f"return {result};")
1484        code.writeline(";")
1485        V.kernel.compute.splice(code)
1486
1487        dtype = result.dtype
1488        body_code = f"{var}()"
1489        body_code_vec = (
1490            body_code
1491            if result.is_vec
1492            else f"{V.kernel._get_vec_type(dtype)}({body_code})"
1493        )
1494        other_code = value_to_cpp(other, DTYPE_TO_CPP[dtype])
1495        # loading bool as VecMask<float, N>
1496        other_code_vec = (
1497            f"{V.kernel._get_mask_type()}::from({other_code})"
1498            if dtype == torch.bool
1499            else f"{V.kernel._get_vec_type(dtype)}({other_code})"
1500        )
1501        assert isinstance(new_mask, CppCSEVariable), new_mask
1502        if new_mask.is_vec:
1503            code = BracesBuffer()
1504            code.writeline("[&]")
1505            with V.kernel.swap_buffers(code), code.indent():
1506                code.writeline(f"if ({new_mask}.all_zero())")
1507                with code.indent():
1508                    code.writeline(f"return {other_code_vec};")
1509                code.writeline("else")
1510                with code.indent():
1511                    # Create cse variable to reuse kernel.overrides.where
1512                    body_vec_var = V.kernel.cse.generate(
1513                        V.kernel.compute,
1514                        body_code_vec,
1515                    )
1516                    other_vec_var = V.kernel.cse.generate(
1517                        V.kernel.compute,
1518                        other_code_vec,
1519                    )
1520                    assert isinstance(body_vec_var, CppCSEVariable), body_vec_var
1521                    assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
1522                    body_vec_var.dtype = dtype
1523                    other_vec_var.dtype = dtype
1524                    code.writeline(
1525                        f"return {V.kernel.overrides.where(new_mask, body_vec_var, other_vec_var)};"
1526                    )
1527            code.writeline("()")
1528            csevar = V.kernel.cse.generate(
1529                V.kernel.compute,
1530                code,
1531            )
1532        elif result.is_vec:
1533            csevar = V.kernel.cse.generate(
1534                V.kernel.compute, f"{mask} ? {body_code_vec} : {other_code_vec}"
1535            )
1536        else:
1537            csevar = V.kernel.cse.generate(
1538                V.kernel.compute, f"{mask} ? {body_code} : {other_code}"
1539            )
1540        # `result` is explicitly added to the args for correct propagation
1541        # of relevant itervars and vectorization status.
1542        csevar.update_on_args("masked", (mask, body, other, result), {})
1543        return csevar
1544
1545    @staticmethod
1546    def index_expr(expr, dtype):
1547        assert isinstance(V.kernel, CppVecKernel)
1548        index = V.kernel.rename_indexing(expr)
1549        tiling_var = V.kernel.itervars[V.kernel.tiling_idx]
1550        stride = V.kernel._try_get_const_stride(index, tiling_var)
1551        if stride == 0:
1552            return CppOverrides.index_expr(expr, dtype)
1553        elif stride is not None:
1554            idx = V.kernel.cse.generate(
1555                V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr)
1556            )
1557            value = ops.to_dtype(idx, dtype)
1558            if isinstance(value, OpsValue):
1559                value = value.value
1560            csevar = V.kernel.arange(value, stride)
1561        else:
1562            csevar = V.kernel._load_or_store_non_contiguous(  # type: ignore[assignment]
1563                None, index, dtype, V.kernel.compute
1564            )
1565        csevar.update_on_args("index_expr", (expr, dtype), {})
1566        return csevar
1567
1568    @staticmethod
1569    def frexp(x):
1570        cache_keys = f"frexp({x})[0]", f"frexp({x})[1]"
1571        if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys):
1572            return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys)
1573
1574        cdtype = DTYPE_TO_CPP[x.dtype]
1575        size = V.kernel.tail_size if V.kernel.tail_size else V.kernel.tiling_factor
1576        code = BracesBuffer()
1577        exponent = V.kernel.cse.newvar()
1578        mantissa = V.kernel.cse.newvar()
1579        exponent.update_on_args("frexp", (x,), kwargs={})
1580        mantissa.update_on_args("frexp", (x,), kwargs={})
1581        n_vec = V.kernel._get_num_vectors(x.dtype)
1582        mantissa_t = (
1583            f"at::vec::Vectorized<{cdtype}>"
1584            if n_vec == 1
1585            else f"at::vec::VectorizedN<{cdtype}, {n_vec}>"
1586        )
1587        code.writeline(
1588            f"at::vec::Vectorized<int32_t> {exponent};"
1589            if n_vec == 1
1590            else f"at::vec::VectorizedN<int32_t, {n_vec}> {exponent};"
1591        )
1592        code.writeline(f"{mantissa_t} {mantissa};")
1593        code.writeline("[&]()")
1594        with code.indent():
1595            code.writeline(
1596                f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf;"
1597            )
1598            code.writeline(f"{x}.store(tmpbuf.data(), {cexpr_index(size)});")
1599            code.writeline(
1600                f"__at_align__ std::array<int32_t, {V.kernel.tiling_factor}> tmpbuf_exponent;"
1601            )
1602            code.writeline(
1603                f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf_mantissa;"
1604            )
1605            code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)")
1606            with code.indent():
1607                code.writeline(
1608                    "tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);"
1609                )
1610            code.writeline(
1611                f"{exponent} = at::vec::Vectorized<int32_t>::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});"
1612                if n_vec == 1
1613                else f"{exponent} = at::vec::VectorizedN<int32_t, {n_vec}>::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});"
1614            )
1615            code.writeline(
1616                f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {cexpr_index(size)});"
1617            )
1618        code.writeline("();")
1619        V.kernel.compute.splice(code)
1620        cse_vars = (mantissa, exponent)
1621        for cache_key, cse_var in zip(cache_keys, cse_vars):
1622            V.kernel.cse.cache[cache_key] = cse_var
1623        return mantissa, exponent
1624
1625    @classmethod
1626    def scalarize(cls, scalar_func):
1627        def inner(*args, **kwargs):
1628            assert not kwargs
1629            kernel = V.kernel
1630            assert isinstance(kernel, CppVecKernel)
1631            code = BracesBuffer()
1632            code.writeline("[&]()")
1633            vec_dtype = args[0].dtype
1634            n_vec = kernel._get_num_vectors(vec_dtype)
1635            size = kernel.tail_size if kernel.tail_size else kernel.tiling_factor
1636            scalar_args = []
1637            cdtype = DTYPE_TO_CPP[vec_dtype]
1638            output_mask = scalar_func.__name__ in (
1639                "isinf",
1640                "isnan",
1641                "signbit",
1642            )
1643            octype = "bool" if output_mask else cdtype
1644            octype = (
1645                DTYPE_TO_CPP[args[-2]]
1646                if (scalar_func.__name__ == "to_dtype_bitcast")
1647                else octype
1648            )
1649            with code.indent():
1650                for argidx, arg in enumerate(args):
1651                    if isinstance(arg, CppCSEVariable):
1652                        assert arg.is_vec
1653                        assert arg.dtype == vec_dtype
1654                        code.writeline(
1655                            f"__at_align__ std::array<{cdtype}, {kernel.tiling_factor}> tmpbuf{argidx};"
1656                        )
1657                        code.writeline(
1658                            f"{arg}.store(tmpbuf{argidx}.data(), {cexpr_index(size)});"
1659                        )
1660                        scalar_args.append(f"tmpbuf{argidx}[i]")
1661                    else:
1662                        scalar_args.append(arg)
1663                code.writeline(
1664                    f"__at_align__ std::array<{octype}, {kernel.tiling_factor}> tmpbuf_out;"
1665                )
1666                res = scalar_func(*scalar_args)
1667                code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)")
1668                with code.indent():
1669                    code.writeline(f"tmpbuf_out[i] = {res};")
1670                if output_mask:
1671                    assert not kernel.tail_size
1672                    load_args = "tmpbuf_out.data()"
1673                    load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from"
1674                else:
1675                    load_args = f"tmpbuf_out.data(), {cexpr_index(size)}"
1676                    if n_vec == 1:
1677                        load_fn = f"at::vec::Vectorized<{octype}>::loadu"
1678                    else:
1679                        load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu"
1680                code.writeline(f"return {load_fn}({load_args});")
1681            code.writeline("()")
1682            return code
1683
1684        return inner
1685
1686    @classmethod
1687    def _initialize_scalarize(cls):
1688        for name, method in vars(CppOverrides).items():
1689            if getattr(method, "__class__", None) == staticmethod and name not in vars(
1690                CppVecOverrides
1691            ):
1692                func = cls.scalarize(method.__func__)
1693                func.__name__ = name
1694                setattr(cls, name, staticmethod(func))
1695
1696
1697CppVecOverrides._initialize_pointwise_overrides("cppvec")
1698CppVecOverrides._initialize_scalarize()
1699
1700
1701class CppTile2DOverrides(CppVecOverrides):
1702    @staticmethod
1703    def index_expr(expr, dtype):
1704        assert isinstance(V.kernel, CppTile2DKernel)
1705        expr = V.kernel.transform_indexing(expr)
1706        return CppVecOverrides.index_expr(expr, dtype)
1707
1708
1709class CppKernel(Kernel):
1710    overrides = CppOverrides  # type: ignore[assignment]
1711    sexpr = cexpr
1712    newvar_prefix = "auto "
1713    suffix = ";"
1714
1715    def __init__(self, args, num_threads):
1716        super().__init__(args)
1717        self.call_ranges: Optional[Tuple[sympy.Expr, ...]] = None
1718        self.ranges: List[sympy.Expr] = []
1719        self.itervars: List[sympy.Symbol] = []
1720        self.reduction_depth = None
1721        self.reduction_prefix = IndentedBuffer()
1722        self.reduction_suffix = IndentedBuffer()
1723        self.parallel_reduction_prefix = IndentedBuffer()
1724        self.parallel_reduction_suffix = IndentedBuffer()
1725        self.local_reduction_init = IndentedBuffer()
1726        self.local_reduction_stores = IndentedBuffer()
1727        self.is_reduction = False
1728        self.non_parallel_reduction_prefix = IndentedBuffer()
1729        self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc")
1730        self.weight_recps_cse = CSE(
1731            self.newvar_prefix, self.suffix, name_prefix="wrecps"
1732        )
1733        self.preloads = IndentedBuffer()
1734        self.poststores = IndentedBuffer()
1735        self.num_threads = num_threads  # num_threads the kernel specialized for
1736        self.reduction_omp_dec: Dict[Tuple[str, str], str] = {}
1737
1738    def _gen_parallel_reduction_buffers(
1739        self,
1740        acc,
1741        acc_type,
1742        reduction_type,
1743        dtype,
1744        reduction_combine_fn=reduction_combine,
1745        reduction_init_fn=reduction_init,
1746        welford_weight_reciprocal_vec_fn=None,
1747    ):
1748        if config.cpp.dynamic_threads and not self.parallel_reduction_prefix:
1749            self.parallel_reduction_prefix.writeline(
1750                "int max_threads = omp_get_max_threads();"
1751            )
1752        acc_local = f"{acc}_local"
1753        num_threads = (
1754            "max_threads" if config.cpp.dynamic_threads else parallel_num_threads()
1755        )
1756        acc_per_thread_var_name = f"{acc}_arr"
1757        acc_per_thread = f"{acc_per_thread_var_name}[{num_threads}]"
1758        """
1759        MSVC don't support dynamic array(VLA). Please use std::unique_ptr to instead of it.
1760        Ref: https://stackoverflow.com/questions/56555406/creating-dynamic-sized-array-using-msvc-c-compiler
1761        MSVC is the only one compiler, which not support VLA. And MSVC can't get good inductor performance.
1762        So, we can use unique_ptr make it works on MSVC.
1763        For other compilers, we continue to use VLA to get best performence.
1764        """
1765        acc_per_thread_unique_ptr_decl = f"auto {acc_per_thread_var_name} = std::make_unique<{acc_type}[]>({num_threads})"
1766        acc_per_thread_vla_decl = f"{acc_per_thread_var_name}[{num_threads}]"
1767        acc_local_in_array = acc_per_thread.replace(f"[{num_threads}]", "[tid]")
1768        self.local_reduction_init.writeline(
1769            f"{acc_type} {acc_local} = {reduction_init_fn(reduction_type, dtype)};"
1770        )
1771        self.parallel_reduction_prefix.writeline(
1772            f"{acc_per_thread_unique_ptr_decl};"
1773            if cpp_builder.is_msvc_cl()
1774            else f"{acc_type} {acc_per_thread_vla_decl};"
1775        )
1776        self.parallel_reduction_prefix.writelines(
1777            [
1778                f"for (int tid = 0; tid < {num_threads}; tid++)",
1779                "{",
1780                f"    {acc_local_in_array} = {reduction_init_fn(reduction_type, dtype)};",
1781                "}",
1782            ],
1783        )
1784        self.local_reduction_stores.writelines(
1785            [
1786                f"{acc_local_in_array} = {acc_local};",
1787            ]
1788        )
1789        self.parallel_reduction_suffix.writelines(
1790            [
1791                f"for (int tid = 0; tid < {num_threads}; tid++)",
1792                "{",
1793                f"    {acc} = {reduction_combine_fn(reduction_type, acc, acc_local_in_array, src_dtype=dtype)};",
1794                "}",
1795            ],
1796        )
1797
1798    def get_reduction_var_pattern(self, line: str):
1799        return re.search("tmp_acc[0-9]+", line)
1800
1801    def update_stores_with_parallel_reduction(self):
1802        for i, line in enumerate(self.stores._lines):
1803            if isinstance(line, str):
1804                m = self.get_reduction_var_pattern(line)
1805                if m:
1806                    var_name = m.group(0)
1807                    self.stores._lines[i] = line.replace(var_name, f"{var_name}_local")
1808
1809    @contextlib.contextmanager
1810    def masked(self, mask):
1811        """Context manager to add an additional mask to loads and stores."""
1812        prior = self._load_mask
1813        if prior:
1814            mask = ops.and_(mask, prior)
1815            if isinstance(mask, OpsValue):
1816                mask = mask.value
1817                assert isinstance(mask, CppCSEVariable)
1818                # see NOTE [dtype of CppCSEVariable]
1819                # mask's dtype should be bool
1820                mask.dtype = torch.bool
1821
1822        self._load_mask = mask
1823        try:
1824            yield mask
1825        finally:
1826            self._load_mask = prior
1827
1828    def scale_index_with_offset(
1829        self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0
1830    ):
1831        var = self.itervars[itervar_idx]
1832        replacement = {var: var * scale + offset}
1833        new_index = sympy_subs(index, replacement)
1834        return new_index
1835
1836    def index_to_str(self, index: sympy.Expr) -> str:
1837        """
1838        Convert an index expr to a string that can be used in cpp code.
1839        e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel.
1840        """
1841        return cexpr(self.rename_indexing(index))
1842
1843    def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol):
1844        """
1845        Check if an index has free symbol CppCSEVariable that depends on `itervar`.
1846        """
1847        return any(
1848            self.cse.varname_map[s.name].depends_on(itervar)  # type: ignore[attr-defined]
1849            for s in index.free_symbols
1850            if s.name in self.cse.varname_map  # type: ignore[attr-defined]
1851            and isinstance(self.cse.varname_map[s.name], CppCSEVariable)  # type: ignore[attr-defined]
1852        )
1853
1854    def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol):
1855        return itervar in index.free_symbols or self.index_indirect_depends_on(
1856            index, itervar
1857        )
1858
1859    def var_ranges(self):
1860        return dict(zip(self.itervars, self.ranges))
1861
1862    def check_bounds(
1863        self,
1864        expr: sympy.Expr,
1865        size: sympy.Expr,
1866        lower: bool,
1867        upper: bool,
1868    ):
1869        if not (lower or upper):
1870            return
1871
1872        indirect = free_symbol_is_type(expr, SymT.TMP)
1873        if indirect:
1874            # indexing in compute
1875            csevar = ops.index_expr(expr, torch.int64).value
1876            buffer = V.kernel.compute
1877        else:
1878            # indexing in loads
1879            prior_compute = V.kernel.compute
1880            try:
1881                V.kernel.compute = self.loads
1882                csevar = ops.index_expr(expr, torch.int64).value
1883            finally:
1884                V.kernel.compute = prior_compute
1885            buffer = self.loads
1886
1887        size_str = V.kernel.sexpr(self.rename_indexing(size)) if upper else None
1888
1889        line = self.indirect_assert(
1890            csevar, "0" if lower else None, size_str, self._load_mask
1891        )
1892        self.cse.generate(buffer, line, assignment=False)
1893
1894    def load(self, name: str, index: sympy.Expr):
1895        var = self.args.input(name)
1896        index = self.rename_indexing(index)
1897        line = f"{var}[{cexpr_index(index)}]"
1898        csevar = self.cse.generate(self.loads, line)
1899        csevar.update_on_args("load", (self, name, index), {})
1900        return csevar
1901
1902    def store(self, name, index, value, mode=None):
1903        assert "buf" in name
1904        var = self.args.output(name)
1905        index = self.rename_indexing(index)
1906        if mode is None:
1907            line = f"{var}[{cexpr_index(index)}] = {value};"
1908        elif mode == "atomic_add":
1909            if not config.cpp.dynamic_threads and self.num_threads == 1:
1910                line = f"{var}[{cexpr_index(index)}] += {value};"
1911            else:
1912                dtype = V.graph.get_dtype(name)
1913                # mirroring static_cast<float>(...) in load:
1914                value = f"static_cast<{DTYPE_TO_CPP[dtype]}>({value})"
1915                line = f"atomic_add(&{var}[{cexpr_index(index)}], {value});"
1916        else:
1917            raise NotImplementedError(f"store mode={mode}")
1918        self.stores.writeline(DeferredLine(name, line))
1919
1920    def reduction(self, dtype, src_dtype, reduction_type, value):
1921        argmax_or_argmin = reduction_type in {"argmax", "argmin"}
1922        reduction_key = src_dtype, reduction_type, value
1923        if reduction_key in self.reduction_cse.reduction_cache:
1924            return self.reduction_cse.reduction_cache[reduction_key]
1925
1926        acc = self.reduction_cse.generate(
1927            self.loads, f"reduction {reduction_key}", write=False
1928        )
1929        self.is_reduction = True
1930        init_dtype = src_dtype if argmax_or_argmin else dtype
1931        acc_type = reduction_acc_type(reduction_type, init_dtype)
1932        self.reduction_prefix.writeline(
1933            f"{acc_type} {acc} = {reduction_init(reduction_type, init_dtype)};"
1934        )
1935        assert self.reduction_depth is not None
1936        index = self.itervars[self.reduction_depth]
1937        for i in range(self.reduction_depth + 1, len(self.itervars)):
1938            index = index * self.ranges[i] + self.itervars[i]
1939        self.stores.writeline(
1940            f"{acc} = {reduction_combine(reduction_type, acc, value, index)};"
1941        )
1942        self._gen_parallel_reduction_buffers(acc, acc_type, reduction_type, init_dtype)
1943        result = reduction_project(reduction_type, acc)
1944        self.reduction_cse.reduction_cache[reduction_key] = result
1945        return result
1946
1947    def store_reduction(self, name, index, value):
1948        index = self.rename_indexing(index)
1949        var = self.args.output(name)
1950        self.reduction_suffix.writeline(
1951            DeferredLine(name, f"{var}[{cexpr_index(index)}] = {value};")
1952        )
1953
1954    def set_ranges(self, lengths, reduction_lengths):
1955        if self.call_ranges:
1956            assert self.call_ranges == tuple(lengths) + tuple(
1957                reduction_lengths
1958            ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
1959            assert self.reduction_depth == len(lengths)
1960        else:
1961            self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
1962            self.ranges = [self.rename_indexing(x) for x in self.call_ranges]
1963            self.itervars = [
1964                sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
1965                for n in range(len(self.ranges))
1966            ]
1967            self.reduction_depth = len(lengths)
1968        return (
1969            self.itervars[: self.reduction_depth],
1970            self.itervars[self.reduction_depth :],
1971        )
1972
1973    def size_hint(self):
1974        return V.graph.sizevars.size_hint(
1975            sympy_product(self.call_ranges), fallback=8192
1976        )
1977
1978    def codegen_loops_impl(self, loop_nest, code, worksharing):
1979        threads = parallel_num_threads()
1980        assert self.call_ranges is not None
1981        kernels = loop_nest.get_kernels()
1982        has_outer_loop_kernel = any(
1983            isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels
1984        )
1985        if has_outer_loop_kernel:
1986            assert len(kernels) == 1
1987            assert isinstance(kernels[0], OuterLoopFusedKernel)
1988            par_depth = kernels[0].decide_parallel_depth(
1989                loop_nest.max_parallel_depth(), threads
1990            )
1991        else:
1992            par_depth = self.decide_parallel_depth(
1993                loop_nest.max_parallel_depth(), threads
1994            )
1995
1996        with contextlib.ExitStack() as stack:
1997            if par_depth:
1998                if loop_nest.is_reduction_only():
1999                    # need to close the worksharing scope to define reduction vars outside it
2000                    worksharing.close()
2001                else:
2002                    worksharing.parallel(threads)
2003                loop_nest.mark_parallel(par_depth)
2004            elif threads > 1:
2005                if worksharing.single():
2006                    stack.enter_context(code.indent())
2007
2008            def gen_loop_kernel(loop: LoopLevel):
2009                def is_parallel_reduction(loop):
2010                    root = loop.get_root()
2011                    return root.is_reduction and root.parallel
2012
2013                kernels = loop.get_kernels()
2014                assert len(kernels) == 1
2015                if not isinstance(
2016                    kernels[0], OuterLoopFusedKernel
2017                ) and is_parallel_reduction(loop):
2018                    kernels[0].update_stores_with_parallel_reduction()
2019                gen_kernel(kernels[0])
2020
2021            def gen_kernel(kernel):
2022                if isinstance(kernel, OuterLoopFusedKernel):
2023                    for loop in kernel.inner:
2024                        if loop.inner:
2025                            gen_loops(loop.inner, loop.is_reduction)
2026                        else:
2027                            with contextlib.ExitStack() as stack:
2028                                # If there is any kernel existing at the final outer loop fusion level,
2029                                # the kernel code should be placed within its respective indent to prevent
2030                                # the duplication of variable definitions.
2031                                stack.enter_context(code.indent())
2032                                gen_loop_kernel(loop)
2033                else:
2034                    with contextlib.ExitStack() as stack:
2035                        assert kernel
2036                        if hasattr(kernel, "codegen_inner_loops"):
2037                            code.splice(kernel.preloads)
2038                            kernel.codegen_inner_loops(code)
2039                            stack.enter_context(code.indent())
2040                        code.splice(kernel.loads)
2041                        code.splice(kernel.compute)
2042                        code.splice(kernel.stores)
2043                    if hasattr(kernel, "codegen_inner_loops"):
2044                        code.splice(kernel.poststores)
2045
2046            def get_reduction_code_buffer(loops, buffer="prefix"):
2047                assert buffer in ("prefix", "suffix", "local")
2048                for loop in loops:
2049                    for kernel in loop.get_kernels():
2050                        if buffer == "local":
2051                            return (
2052                                kernel.local_reduction_init,
2053                                kernel.local_reduction_stores,
2054                            )
2055                        elif buffer == "suffix":
2056                            suffix = kernel.reduction_suffix
2057                            if loop.parallel:
2058                                suffix = kernel.parallel_reduction_suffix + suffix
2059                            return suffix
2060                        else:
2061                            prefix = kernel.reduction_prefix
2062                            if loop.parallel:
2063                                prefix = prefix + kernel.parallel_reduction_prefix
2064                            else:
2065                                prefix = prefix + kernel.non_parallel_reduction_prefix
2066                            return prefix
2067
2068            def gen_loops(loops: List[LoopLevel], in_reduction=False):
2069                with contextlib.ExitStack() as stack_outer:
2070                    local_reduction_init = local_reduction_stores = None
2071                    if loops:
2072                        loop = loops[0]
2073                        if loop.is_reduction and not in_reduction:
2074                            reduction_prefix = get_reduction_code_buffer(loops)
2075                            if reduction_prefix:
2076                                stack_outer.enter_context(code.indent())
2077                            code.splice(reduction_prefix)
2078                        if loop_nest.is_reduction_only() and loop.parallel:
2079                            (
2080                                local_reduction_init,
2081                                local_reduction_stores,
2082                            ) = get_reduction_code_buffer(loops, "local")
2083                            worksharing.parallel(threads)
2084                            if local_reduction_init:
2085                                assert local_reduction_stores
2086                                code.splice(local_reduction_init)
2087
2088                    for loop in loops:
2089                        gen_loop(loop)
2090
2091                    if loops:
2092                        loop = loops[0]
2093                        if loop_nest.is_reduction_only() and loop.parallel:
2094                            if local_reduction_stores:
2095                                code.splice(local_reduction_stores)
2096                            worksharing.close()
2097                        if loop.is_reduction and not in_reduction:
2098                            code.splice(get_reduction_code_buffer(loops, "suffix"))
2099
2100            def gen_loop(loop: LoopLevel):
2101                with contextlib.ExitStack() as stack:
2102                    loop_lines = loop.lines()
2103                    if loop_lines is None:
2104                        return
2105                    code.writelines(loop_lines)
2106                    stack.enter_context(code.indent())
2107                    # generate inner loops or loop body
2108                    if loop.inner:
2109                        gen_loops(loop.inner, loop.is_reduction)
2110                    else:
2111                        gen_loop_kernel(loop)
2112
2113            stack.enter_context(code.indent())
2114            if loop_nest.root:
2115                if (
2116                    has_outer_loop_kernel
2117                    and isinstance(V.local_buffer_context, LocalBufferContext)
2118                    and V.local_buffer_context.local_buffers
2119                ):
2120                    # Allocate local buffer
2121                    local_buffers = V.local_buffer_context.local_buffers
2122                    for local_buffer in local_buffers.values():
2123                        # For dynamic size, rename s to ks
2124                        local_buf_size = sympy_product(
2125                            [
2126                                self.rename_indexing(size_val)
2127                                for size_val in local_buffer.get_layout().size
2128                            ]
2129                        )
2130                        local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype]
2131                        allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})"
2132                        local_buffer_name = local_buffer.get_name()
2133                        code.splice(
2134                            f"std::unique_ptr<{local_buf_dtype} []> buf_{local_buffer_name} = {allocate};"
2135                        )
2136                        code.splice(
2137                            f"{local_buf_dtype}* {local_buffer_name} = buf_{local_buffer_name}.get();"
2138                        )
2139                gen_loops(loop_nest.root)
2140            else:
2141                gen_kernel(loop_nest.kernel)
2142
2143    def codegen_loops(self, code, worksharing):
2144        loop_nest = LoopNestWithSplit.build(self)
2145        self.codegen_loops_impl(loop_nest, code, worksharing)
2146
2147    @property
2148    def assert_function(self) -> str:
2149        if V.graph.aot_mode:
2150            # TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models
2151            # compared with JIT Inductor which uses TORCH_CHECK
2152            return "AOTI_TORCH_CHECK"
2153        else:
2154            return "TORCH_CHECK"
2155
2156    def decide_parallel_depth(self, max_parallel_depth, threads):
2157        assert self.call_ranges is not None
2158        ranges = self.call_ranges[:max_parallel_depth]
2159        seq = self.size_hint()
2160        par = 1
2161        depth = 0
2162        for expr in ranges:
2163            hint = V.graph.sizevars.size_hint(expr, fallback=8192)
2164            if par >= 2 * threads or par == threads:
2165                break
2166            if seq // threads < config.cpp.min_chunk_size:
2167                # not enough work
2168                break
2169            depth += 1
2170            par *= hint
2171            seq /= hint
2172        # if we assume thread number is dynamic, make sure we
2173        # have at least one parallel scope and let OMP runtime
2174        # to manage the serial vs. parallel.
2175        if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0:
2176            depth = 1
2177        return depth
2178
2179    @contextlib.contextmanager
2180    def write_to_suffix(self):
2181        prior = (self.loads, self.compute, self.stores, self.cse)
2182        self.loads = IndentedBuffer()
2183        self.compute = IndentedBuffer()
2184        self.stores = IndentedBuffer()
2185        self.cse = self.cse.clone()
2186        yield
2187        self.reduction_suffix.splice(self.loads)
2188        self.reduction_suffix.splice(self.compute)
2189        self.reduction_suffix.splice(self.stores)
2190        (self.loads, self.compute, self.stores, self.cse) = prior
2191
2192    def create_cse_var(self, *args, **kwargs):
2193        return CppCSEVariable(*args, **kwargs)
2194
2195    def get_to_dtype_expr(self, src, dtype, src_dtype):
2196        return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({src})"
2197
2198    def cache_dtype_convert(self, dst, dst_dtype, src, src_dtype):
2199        expr = self.get_to_dtype_expr(src, dst_dtype, src_dtype)
2200        self.cse.cache[expr] = dst
2201
2202
2203class CppVecKernel(CppKernel):
2204    overrides = CppVecOverrides  # type: ignore[assignment]
2205
2206    def __init__(
2207        self,
2208        args,
2209        num_threads,
2210        tiling_factor,
2211        tiling_idx,
2212        tail_size=None,
2213    ):
2214        super().__init__(args, num_threads)
2215        self.vec_isa = cpu_vec_isa.pick_vec_isa()
2216        assert self.vec_isa
2217        assert tiling_factor > 0, "Expect pass in Non-Zero tiling_factor explicitly"
2218        self.tiling_factor = tiling_factor
2219        self.tiling_idx = tiling_idx
2220        self.tail_size = tail_size
2221        self.num_elems = tail_size if tail_size else tiling_factor
2222
2223    def _try_get_const_stride(self, index: sympy.Expr, itervar: sympy.Symbol):
2224        if self.index_indirect_depends_on(index, itervar):
2225            return None
2226        for indirect_var in (
2227            self.cse.varname_map[s.name]  # type: ignore[attr-defined]
2228            for s in index.free_symbols
2229            if symbol_is_type(s, SymT.TMP)
2230        ):
2231            assert isinstance(indirect_var, CppCSEVariable)
2232            if indirect_var.is_vec:
2233                return None
2234        stride = stride_at_vec_range(index, itervar, self.tiling_factor)
2235        return stride if stride.is_number else None
2236
2237    def _get_num_vectors(self, dtype: torch.dtype) -> int:
2238        num_vectors = math.ceil(
2239            self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width()
2240        )
2241        assert num_vectors >= 1
2242        return num_vectors
2243
2244    def _get_raw_num_vectors(self, dtype: torch.dtype) -> float:
2245        # This utility function is used to check if the vector lanes has been
2246        # fully utilized. For example, uint8 will only use 1/4 of the vector lanes.
2247        return self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width()
2248
2249    def _get_vec_type(self, dtype: torch.dtype) -> str:
2250        num_vectors = self._get_num_vectors(dtype)
2251        if num_vectors == 1:
2252            return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>"
2253        else:
2254            return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>"
2255
2256    def _get_mask_type(self, dtype: torch.dtype = torch.float) -> str:
2257        if dtype == torch.bool:
2258            return ""
2259        num_vectors = self._get_num_vectors(dtype)
2260        return f"at::vec::VecMask<{DTYPE_TO_CPP[dtype]},{num_vectors}>"
2261
2262    def _get_mask_cast(self, mask: CppCSEVariable, dtype: torch.dtype) -> str:
2263        assert mask.dtype == torch.bool, repr(mask)
2264        num_vectors = self._get_num_vectors(dtype)
2265        return f"{mask}.template cast<{DTYPE_TO_CPP[dtype]},{num_vectors}>()"
2266
2267    def get_reduction_var_pattern(self, line: str):
2268        return re.search("tmp_acc[0-9]+_vec", line)
2269
2270    def _get_vec_load_line(
2271        self,
2272        var: str,
2273        index: sympy.Expr,
2274        dtype: torch.dtype,
2275        load_mask: Optional[CppCSEVariable] = None,
2276    ):
2277        """
2278        Get a load line str that loads a vector from `var` at `index` of type `dtype`.
2279        If `load_mask` is not None, we do a masked load accordingly.
2280        Notes on the `dtype`:
2281        1. We always load `self.tiling_factor` number of elements regardless of the `dtype`.
2282           It means we load half of the vector lanes for 16-bit data types and quarter of the
2283           vector lanes for 8-bit data types.
2284        2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors.
2285        """
2286        cpp_type = DTYPE_TO_CPP[dtype]
2287        num_vectors = self._get_num_vectors(dtype)
2288        load_mask_str = None
2289        if load_mask:
2290            if not load_mask.is_vec:
2291                # TODO: avoid hard-code torch.float
2292                load_mask_str = f"{self._get_mask_type(torch.float)}::from({load_mask})"
2293            else:
2294                load_mask_str = f"{self._get_mask_cast(load_mask, torch.float)}"
2295        loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var
2296        if dtype == torch.bool:
2297            # TODO: should we consider load mask here?
2298            line = f"{self._get_mask_type()}::from({loadbuf})"
2299        else:
2300            line = (
2301                f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})"
2302                if load_mask_str
2303                else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {cexpr_index(self.num_elems)})"
2304            )
2305        return line
2306
2307    def _load_or_store_non_contiguous(
2308        self,
2309        var: Optional[str],
2310        index: sympy.Expr,
2311        dtype: torch.dtype,
2312        buffer: Optional[IndentedBuffer] = None,
2313        store_value: Optional[Union[str, CppCSEVariable]] = None,
2314        accu_store: bool = False,
2315    ) -> Optional[CppCSEVariable]:
2316        """
2317        Load or store a vector in a non-contiguous way. The vector is initialized from an array that is
2318        filled in an inner loop over the tiling factor.
2319        :param var: buffer to load from or store to, i.e. `var[transformed(index)]`. If None, we load the index
2320                    as index expression, i.e. `transformed(index)`.
2321        :param index: index into the `var` or the index expression by its own if `var` is None.
2322                      The `index` could contain indirect indexing or the tiling itervar. When used in
2323                      the inner loop, the index is transformed as follows:
2324                      1. the index is linearized along the tiling dim.
2325                      2. the indirect indexing vector variables are transformed into arrays over the tiling dim.
2326        :param dtype: data type of `var` or `index` if `var` is None.
2327        :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`.
2328        :param store_value: the value to store. If None, we load the vector.
2329        :param accu_store: whether accumulate the store_value to store_ptr. If True, a store_value should be provided
2330        :return: a CppCSEVariable that represents the loaded vector or None if it is a store.
2331        """
2332        assert not store_value or var is not None, "store var must be provided"
2333        if accu_store:
2334            assert store_value
2335        if buffer is None:
2336            buffer = self.loads
2337
2338        def get_result_size(dtype: torch.dtype) -> int:
2339            if dtype.itemsize < 4:
2340                return self.num_elems * (4 // dtype.itemsize)
2341            else:
2342                return self.num_elems
2343
2344        def get_tiling_size(dtype: torch.dtype) -> int:
2345            if dtype.itemsize < 4:
2346                return self.tiling_factor * (4 // dtype.itemsize)
2347            else:
2348                return self.tiling_factor
2349
2350        def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
2351            assert vec_var.is_vec
2352            code = BracesBuffer()
2353            code.writeline("[&]")
2354            with code.indent():
2355                vec_dtype = vec_var.dtype
2356                assert vec_dtype is not None
2357                if vec_dtype == torch.bool:
2358                    vec_dtype = torch.float
2359                result_size = get_result_size(vec_dtype)
2360                tiling_size = get_tiling_size(vec_dtype)
2361                code.writeline(
2362                    f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {tiling_size}> tmpbuf;"
2363                )
2364                line = f"{vec_var}.store(tmpbuf.data(), {cexpr_index(result_size)});"
2365                code.writeline(line)
2366                code.writeline("return tmpbuf;")
2367            code.writeline("()")
2368            csevar = self.cse.generate(buffer, code)
2369            assert isinstance(csevar, CppCSEVariable)
2370            return csevar
2371
2372        code = BracesBuffer()
2373        code.writeline("[&]")
2374        with code.indent():
2375            result_size = get_result_size(dtype)
2376            tiling_size = get_tiling_size(dtype)
2377            result_declare = (
2378                f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {tiling_size}> tmpbuf;"
2379            )
2380            code.writeline(result_declare)
2381            if store_value:
2382                code.writeline(
2383                    f"{store_value}.store(tmpbuf.data(), {cexpr_index(result_size)});"
2384                )
2385            itervar_inner = sympy_index_symbol(
2386                f"{self.itervars[self.tiling_idx]}_inner"
2387            )
2388            replacements = {}
2389            for indirect_var in (
2390                self.cse.varname_map[s.name]  # type: ignore[attr-defined]
2391                for s in index.free_symbols
2392                if symbol_is_type(s, SymT.TMP)
2393            ):
2394                assert isinstance(indirect_var, CppCSEVariable)
2395                if indirect_var.is_vec:
2396                    array_var = vec_to_array(indirect_var)
2397                    replacements[indirect_var] = f"{array_var}[{itervar_inner}]"
2398            index = self.scale_index_with_offset(
2399                index, itervar_idx=self.tiling_idx, offset=itervar_inner
2400            )
2401            load_mask = None
2402            if self._load_mask is not None:
2403                assert not store_value, "unexpected store with load mask"
2404                assert isinstance(self._load_mask, CppCSEVariable), self._load_mask
2405                if self._load_mask.is_vec:
2406                    load_mask = f"{self._load_mask}.is_masked({itervar_inner})"
2407                else:
2408                    load_mask = f"{self._load_mask} != 0"
2409            if cpp_builder.is_gcc():
2410                code.writeline(f"#pragma GCC unroll {self.tiling_factor}")
2411            else:
2412                code.writeline(f"#pragma unroll {self.tiling_factor}")
2413            code.writeline(
2414                f"for (long {itervar_inner} = 0; "
2415                + f"{itervar_inner} < {cexpr_index(self.num_elems)}; "
2416                + f"{itervar_inner}++)"
2417            )
2418            with code.indent(), contextlib.ExitStack() as stack:
2419                index_c = cexpr_index(index)
2420                for indirect_var in replacements:
2421                    index_c = re.sub(
2422                        r"\b" + f"{indirect_var}" + r"\b",
2423                        replacements[indirect_var],
2424                        index_c,
2425                    )
2426                rhs = f"{var}[{index_c}]" if var is not None else f"{index_c}"
2427                if load_mask:
2428                    code.writeline(f"if ({load_mask})")
2429                    stack.enter_context(code.indent())
2430                if store_value:
2431                    conjunction = "+=" if accu_store else "="
2432                    code.writeline(f"{rhs} {conjunction} tmpbuf[{itervar_inner}];")
2433                else:
2434                    code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};")
2435            if not store_value:
2436                load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype)  # type: ignore[arg-type]
2437                code.writeline(f"return {load_line};")
2438        code.writeline("()")
2439        if store_value:
2440            code.writeline(";")
2441            buffer.splice(code)
2442            return None
2443        else:
2444            csevar = self.cse.generate(buffer, code)
2445            assert isinstance(csevar, CppCSEVariable)
2446            csevar.is_vec = True
2447            return csevar
2448
2449    def load(self, name: str, index: sympy.Expr):
2450        var = self.args.input(name)
2451        index = self.rename_indexing(index)
2452        dtype = V.graph.get_dtype(name)
2453        tiling_var = self.itervars[self.tiling_idx]
2454        stride = self._try_get_const_stride(index, tiling_var)
2455        if stride == 0:
2456            # load scalar and lazily broadcast it on demand
2457            return super().load(name, index)
2458        elif stride == 1:
2459            # load contiguously
2460            line = self._get_vec_load_line(var, index, dtype, self._load_mask)
2461            csevar = self.cse.generate(self.loads, line)  # type: ignore[assignment]
2462        else:
2463            csevar = self._load_or_store_non_contiguous(var, index, dtype)  # type: ignore[assignment]
2464        assert isinstance(csevar, CppCSEVariable)
2465        csevar.update_on_args("load", (self, name, index), {})
2466        csevar.is_vec = True
2467        return csevar
2468
2469    def _get_store_line(
2470        self,
2471        value: Union[str, CppCSEVariable],
2472        var: str,
2473        index: sympy.Expr,
2474        dtype: torch.dtype,
2475        accu_store: bool = False,
2476    ):
2477        """
2478        Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles
2479        both contiguous and non-contiguous store cases.
2480        :param value: Vectorized type templaterized on `dtype`.
2481        :param var: buffer to store into.
2482        :index: index into the `var`.
2483        """
2484        # when value's type is str (e.g., welford reduction), caller should make sure
2485        # it is a vector
2486        assert isinstance(value, str) or (
2487            isinstance(value, CppCSEVariable) and value.is_vec
2488        ), value
2489        tiling_var = self.itervars[self.tiling_idx]
2490        var_expr = f"{var} + {cexpr_index(index)}"
2491        stride = self._try_get_const_stride(index, tiling_var)
2492        code = IndentedBuffer()
2493        if stride == 1:
2494            if dtype == torch.float and self.tail_size is None:
2495                code.writeline(f"{value}.store({var_expr});")
2496            else:
2497                code.writeline(
2498                    f"{value}.store({var_expr}, {cexpr_index(self.num_elems)});"
2499                )
2500        else:
2501            self._load_or_store_non_contiguous(
2502                var, index, dtype, buffer=code, store_value=value, accu_store=accu_store
2503            )
2504        return code
2505
2506    def store(self, name, index, value, mode=None):
2507        assert "buf" in name
2508        assert isinstance(value, CppCSEVariable), value
2509        if not value.is_vec:
2510            # this happens when we store a scalar into a vectorized buffer like "fill"
2511            value = self.broadcast(value)
2512        var = self.args.output(name)
2513        index = self.rename_indexing(index)
2514        dtype = V.graph.get_dtype(name)
2515        if mode is None:
2516            code = self._get_store_line(value, var, index, dtype)
2517            self.stores.splice(code.map(lambda x: DeferredLine(name, x)))
2518        elif mode == "atomic_add":
2519            if not config.cpp.dynamic_threads and self.num_threads == 1:
2520                code = self._get_store_line(
2521                    f"{value}",
2522                    var,
2523                    index,
2524                    dtype,
2525                    accu_store=True,
2526                )
2527                self.stores.splice(code.map(lambda x: DeferredLine(name, x)))
2528            else:
2529                n_src = self._get_num_vectors(dtype)
2530                n_idx = self._get_num_vectors(torch.int64)
2531                cdtype = DTYPE_TO_CPP[dtype]
2532                index = ops.index_expr(index, torch.int64).value
2533                assert index.is_vec
2534                line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});"
2535                self.stores.writeline(DeferredLine(name, line))
2536        else:
2537            raise NotImplementedError(f"store mode={mode}")
2538
2539    def reduction(self, dtype, src_dtype, reduction_type, value):
2540        assert reduction_type in VECTORIZABLE_RTYPES
2541        argmax_or_argmin = reduction_type in {"argmax", "argmin"}
2542        horizontal_reduction = self.tiling_idx >= self.reduction_depth
2543        init_dtype = src_dtype if argmax_or_argmin else dtype
2544        assert isinstance(value, CppCSEVariable), value
2545
2546        if not value.is_vec:
2547            value = self.broadcast(value)
2548
2549        reduction_key = src_dtype, reduction_type, value
2550        if reduction_key in self.reduction_cse.reduction_cache:
2551            return self.reduction_cse.reduction_cache[reduction_key]
2552
2553        vec_ns = "at::vec"
2554        vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>"
2555        acc_type = reduction_acc_type(reduction_type, init_dtype)
2556        acc_type_vec = self.reduction_acc_type_vec(reduction_type, init_dtype)
2557
2558        acc = self.reduction_cse.generate(
2559            self.loads, f"reduction {reduction_key}", write=False
2560        )
2561        acc_vec = f"{acc}_vec"
2562        self.is_reduction = True
2563        self.reduction_prefix.writeline(
2564            f"{acc_type} {acc} = {reduction_init(reduction_type, init_dtype)};"
2565        )
2566        self.reduction_prefix.writeline(
2567            f"{acc_type_vec} {acc_vec} = {self.reduction_init_vec(reduction_type, init_dtype)};"
2568        )
2569        if reduction_type == "welford_reduce":
2570            # save the reciprocal of weights for welford reduce
2571            assert self.reduction_depth is not None
2572            # use masked acc_vec for tail vec kernel
2573            self.reduction_prefix.writeline(
2574                f"{acc_type_vec} masked_{acc_vec} = {self.reduction_init_vec(reduction_type, dtype)};"
2575            )
2576            reduction_size = functools.reduce(
2577                lambda x, y: x * y, self.ranges[self.reduction_depth :]
2578            )
2579            reduction_factor = (
2580                self.tiling_factor if self.tiling_idx >= self.reduction_depth else 1
2581            )
2582            self.weight_recp_vec_range = FloorDiv(reduction_size, reduction_factor)
2583            if self.weight_recp_vec_range not in self.weight_recps_cse.reduction_cache:
2584                self.weight_recps_val = self.weight_recps_cse.generate(
2585                    self.compute, f"reduction {self.weight_recp_vec_range}", write=False
2586                )
2587                self.weight_recps_cse.reduction_cache[
2588                    self.weight_recp_vec_range
2589                ] = self.weight_recps_val
2590                self.non_parallel_reduction_prefix.writeline(
2591                    self.welford_weight_reciprocal_vec(dtype)
2592                )
2593                # generate weight_recps for parallel reduction
2594                num_threads = (
2595                    "max_threads"
2596                    if config.cpp.dynamic_threads
2597                    else parallel_num_threads()
2598                )
2599                self.local_reduction_init.writeline(
2600                    self.welford_weight_reciprocal_vec(dtype, num_threads)
2601                )
2602            else:
2603                self.weight_recps_val = self.weight_recps_cse.reduction_cache[
2604                    self.weight_recp_vec_range
2605                ]
2606            # use masked acc_vec for tail vec kernel
2607            acc_vec_ = f"masked_{acc_vec}" if self.tail_size else acc_vec
2608            self.stores.writeline(
2609                f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, True)};"
2610            )
2611        else:
2612            assert self.reduction_depth is not None
2613            index = self.itervars[self.reduction_depth]
2614            for i in range(self.reduction_depth + 1, len(self.itervars)):
2615                index = index * self.ranges[i] + self.itervars[i]
2616            combine = self.reduction_combine_vec(
2617                reduction_type,
2618                acc_vec,
2619                value,
2620                index=index,
2621                horizontal_reduction=horizontal_reduction,
2622                src_dtype=src_dtype,
2623            )
2624            self.stores.writeline(f"{acc_vec} = {combine};")
2625        self._gen_parallel_reduction_buffers(
2626            acc,
2627            acc_type,
2628            reduction_type,
2629            init_dtype,
2630        )
2631        self._gen_parallel_reduction_buffers(
2632            acc_vec,
2633            acc_type_vec,
2634            reduction_type,
2635            init_dtype,
2636            reduction_combine_fn=self.reduction_combine_vec,
2637            reduction_init_fn=self.reduction_init_vec,
2638        )
2639        if reduction_type == "welford_reduce":
2640            # use masked acc_vec for tail vec kernel
2641            self._gen_parallel_reduction_buffers(
2642                f"masked_{acc_vec}",
2643                acc_type_vec,
2644                reduction_type,
2645                dtype,
2646                reduction_combine_fn=self.reduction_combine_vec,
2647                reduction_init_fn=self.reduction_init_vec,
2648            )
2649        tmpvar: Union[str, CSEVariable]
2650        is_bool = dtype == torch.bool
2651        if horizontal_reduction:
2652            # Horizontal reduction
2653            if is_welford_reduction(reduction_type):
2654                assert self._get_num_vectors(dtype) in [
2655                    1,
2656                    2,
2657                ], "Welford reduction does not support VectorizedN (N>2)"
2658                next_value = f"welford_vec_reduce_all({acc_vec})"
2659                masked_next_value = f"welford_vec_reduce_all(masked_{acc_vec})"
2660                self.reduction_suffix.writeline(
2661                    f"{acc} = {reduction_combine(reduction_type, acc, masked_next_value)};"
2662                )
2663            elif argmax_or_argmin:
2664                next_value = f"{reduction_type}_vec_reduce_all({acc_vec})"
2665            elif is_bool:
2666                if reduction_type in (
2667                    "any",
2668                    "sum",
2669                    "max",
2670                ):
2671                    next_value = f"!{acc_vec}.all_zero()"
2672                else:
2673                    assert reduction_type == "min"
2674                    next_value = f"{acc_vec}.all_masked()"
2675            else:
2676                reduce_all_body = (
2677                    "{ return "
2678                    + self.reduction_combine_vec(reduction_type, "x", "y")
2679                    + "; }"
2680                )
2681                is_bool = dtype == torch.bool
2682                # we are using at::vec::VecMask<float, N> for bool
2683                vec_dtype = torch.float if is_bool else dtype
2684                vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>"
2685                vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>"
2686                next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})"
2687
2688            self.reduction_suffix.writeline(
2689                f"{acc} = {reduction_combine(reduction_type, acc, next_value, src_dtype=src_dtype)};"
2690            )
2691            tmpvar = acc
2692        else:
2693            tmpvar = acc_vec
2694            if is_welford_reduction(reduction_type):
2695                masked_tmpvar = f"masked_{tmpvar}"
2696                self.reduction_suffix.writeline(
2697                    f"{tmpvar} = {reduction_combine(reduction_type, tmpvar, masked_tmpvar)};"
2698                )
2699
2700        result = reduction_project(reduction_type, tmpvar)
2701        self.reduction_cse.reduction_cache[reduction_key] = result
2702        return result
2703
2704    def store_reduction(self, name, index, value):
2705        index = self.rename_indexing(index)
2706        var = self.args.output(name)
2707        out_dtype = V.graph.get_dtype(name)
2708        dtype = (
2709            (out_dtype if out_dtype == torch.double else torch.float)
2710            if out_dtype.is_floating_point
2711            else torch.int64
2712        )
2713        out_num_vectors = V.kernel._get_num_vectors(out_dtype)
2714        src_num_vectors = V.kernel._get_num_vectors(dtype)
2715        code = IndentedBuffer()
2716        if self.tiling_idx >= self.reduction_depth:
2717            # Horizontal reduction
2718            code.writeline(
2719                f"{var}[{cexpr_index(index)}] = static_cast<{DTYPE_TO_CPP[out_dtype]}>({value});"
2720            )
2721        else:
2722            # Vertical reduction
2723            if out_dtype != dtype:
2724                converted_value = f"{DTYPE_TO_CPP[out_dtype]}_{value}"
2725                if out_dtype == torch.bool:
2726                    convert = f"{value}.template cast<bool,{self._get_num_vectors(torch.bool)}>()"
2727                else:
2728                    if src_num_vectors == out_num_vectors == 1:
2729                        convert = (
2730                            f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})"
2731                        )
2732                    else:
2733                        convert = (
2734                            f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]},"
2735                            f"{out_num_vectors},{DTYPE_TO_CPP[dtype]},{src_num_vectors}>({value})"
2736                        )
2737                code.writeline(f"auto {converted_value} = {convert};")
2738                value = converted_value
2739            code.splice(self._get_store_line(value, var, index, out_dtype))
2740        self.reduction_suffix.splice(code.map(lambda x: DeferredLine(name, x)))
2741
2742    def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable:
2743        assert not scalar_var.is_vec
2744        if scalar_var.dtype == torch.bool:
2745            vec_var = self.cse.generate(
2746                self.compute, f"{self._get_mask_type()}::from({scalar_var.name})"
2747            )
2748        else:
2749            assert scalar_var.dtype is not None
2750            vec_var = self.cse.generate(
2751                self.compute,
2752                f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})",
2753            )
2754        assert isinstance(vec_var, CppCSEVariable)
2755        vec_var.dtype = scalar_var.dtype
2756        vec_var.dependent_itervars = scalar_var.dependent_itervars
2757        vec_var.is_vec = True
2758        return vec_var
2759
2760    def arange(self, index: CppCSEVariable, stride: sympy.Symbol) -> CppCSEVariable:
2761        assert not index.is_vec
2762        assert index.dtype is not None
2763        csevar = self.cse.generate(
2764            self.compute,
2765            f"{self._get_vec_type(index.dtype)}::arange({index}, {stride})",
2766        )
2767        assert isinstance(csevar, CppCSEVariable)
2768        csevar.dtype = index.dtype
2769        csevar.is_vec = True
2770        return csevar
2771
2772    def reduction_init_vec(self, reduction_type, dtype):
2773        scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype]
2774        vec_type = self._get_vec_type(scalar_type)
2775
2776        if is_welford_reduction(reduction_type):
2777            return f"Welford<{vec_type}>()"
2778
2779        if reduction_type in {"argmin", "argmax"}:
2780            cdtype = DTYPE_TO_CPP[scalar_type]
2781            acc_type = self.reduction_acc_type_vec(reduction_type, dtype)
2782            if reduction_type == "argmin":
2783                val = (
2784                    f"std::numeric_limits<{cdtype}>::infinity()"
2785                    if is_float_dtype(dtype)
2786                    else f"std::numeric_limits<{cdtype}>::max()"
2787                )
2788            else:
2789                val = (
2790                    f"-std::numeric_limits<{cdtype}>::infinity()"
2791                    if is_float_dtype(dtype)
2792                    else f"std::numeric_limits<{cdtype}>::min()"
2793                )
2794            return f"{acc_type}({val})"
2795
2796        if reduction_type == "any":
2797            return f"{self._get_mask_type()}::from(0)"
2798
2799        scalar_init = reduction_init(reduction_type, dtype)
2800        vec_init = f"{vec_type}({scalar_init})"
2801        if dtype == torch.bool:
2802            assert reduction_type in ("min", "max", "sum")
2803            return f"{self._get_mask_type()}::from({scalar_init})"
2804        return vec_init
2805
2806    def reduction_acc_type_vec(self, reduction_type, dtype):
2807        scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype]
2808        vec_type = self._get_vec_type(scalar_type)
2809        if is_welford_reduction(reduction_type):
2810            return f"Welford<{vec_type}>"
2811        if reduction_type in {"argmin", "argmax"}:
2812            n_src = self._get_num_vectors(scalar_type)
2813            n_idx = self._get_num_vectors(torch.int64)
2814            return f"IndexValueVec<{DTYPE_TO_CPP[scalar_type]}, {n_src}, {n_idx}>"
2815        if dtype == torch.bool:
2816            assert reduction_type in ("min", "max", "any", "sum")
2817            return f"{self._get_mask_type()}"
2818        return vec_type
2819
2820    def welford_weight_reciprocal_vec(self, dtype, num_threads=None):
2821        vec_num_range_thread = (
2822            CeilDiv(self.weight_recp_vec_range, num_threads)
2823            if num_threads
2824            else self.weight_recp_vec_range
2825        )
2826        vec_num_range_thread_expr = cexpr_index(vec_num_range_thread)
2827        return (
2828            f"static WeightRecp<{self._get_vec_type(dtype)}> {self.weight_recps_val}"
2829            f"("
2830            f"{vec_num_range_thread_expr}"
2831            f");"
2832        )
2833
2834    def reduction_combine_vec(
2835        self,
2836        reduction_type,
2837        var,
2838        next_value,
2839        use_weight_recps=False,
2840        index: Optional[sympy.Symbol] = None,
2841        horizontal_reduction: Optional[bool] = None,
2842        src_dtype: Optional[torch.dtype] = torch.float32,
2843    ):
2844        is_bool = src_dtype == torch.bool
2845        if reduction_type == "max":
2846            if self.tail_size:
2847                return f"max_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
2848            else:
2849                return (
2850                    f"{var} | {next_value}"
2851                    if is_bool
2852                    else f"at::vec::maximum({var}, {next_value})"
2853                )
2854        elif reduction_type == "min":
2855            if self.tail_size:
2856                return f"min_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
2857            else:
2858                return (
2859                    f"{var} & {next_value}"
2860                    if is_bool
2861                    else f"at::vec::minimum({var}, {next_value})"
2862                )
2863        elif reduction_type == "sum":
2864            if self.tail_size:
2865                return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
2866            else:
2867                conjunction = "|" if is_bool else "+"
2868                return f"{var} {conjunction} {next_value}"
2869        elif reduction_type == "prod":
2870            if self.tail_size:
2871                return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
2872            else:
2873                return f"{var} * {next_value}"
2874        elif reduction_type == "xor_sum":
2875            if self.tail_size:
2876                return f"xor_sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})"
2877            else:
2878                return f"{var} ^ {next_value}"
2879        elif reduction_type == "welford_reduce":
2880            if use_weight_recps:
2881                if self.tail_size:
2882                    return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{self.weight_recps_val})"
2883                else:
2884                    return f"welford_combine({var}, {next_value}, &{self.weight_recps_val})"
2885            else:
2886                if self.tail_size:
2887                    return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})"
2888                else:
2889                    return f"welford_combine({var}, {next_value})"
2890        elif reduction_type == "welford_combine":
2891            if isinstance(next_value, tuple):
2892                # When reading a value from Inductor IR we have a tuple of variable names
2893                mean, m2, weight = next_value
2894            else:
2895                # When combining intermediate accumulators we have a Welford<T> struct
2896                mean, m2, weight = reduction_project(reduction_type, next_value)
2897            if self.tail_size:
2898                return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {cexpr_index(self.tail_size)})"
2899            else:
2900                return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
2901        elif reduction_type in ("argmin", "argmax"):
2902            assert src_dtype is not None
2903            cdtype = DTYPE_TO_CPP[src_dtype]
2904            n_src = self._get_num_vectors(src_dtype)
2905            n_idx = self._get_num_vectors(torch.int64)
2906            t_extra = ""
2907            arg_extra = ""
2908            if index is not None:
2909                assert horizontal_reduction is not None
2910                t_extra = f", {str(horizontal_reduction).lower()}"
2911                arg_extra = f", {index}"
2912            if self.tail_size:
2913                return (
2914                    f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>"
2915                    f"({var}, {next_value}{arg_extra}, {cexpr_index(self.tail_size)})"
2916                )
2917            else:
2918                return f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>({var}, {next_value}{arg_extra})"
2919        elif reduction_type == "any":
2920            return f"{var} | {next_value}"
2921        else:
2922            raise NotImplementedError
2923
2924    def indirect_assert(self, var, lower, upper, mask=None):
2925        assert isinstance(var, CppCSEVariable)
2926        assert var.dtype is not None
2927        if not var.is_vec:
2928            if isinstance(mask, CppCSEVariable) and mask.is_vec:
2929                mask = f"({mask}).all_masked()"
2930            return super().indirect_assert(var, lower, upper, mask)
2931        lower_scalar = lower
2932        upper_scalar = upper
2933        if lower:
2934            lower = f"{self._get_vec_type(var.dtype)}({lower})"
2935        if upper:
2936            upper = f"{self._get_vec_type(var.dtype)}({upper})"
2937        if lower and upper:
2938            cond = f"({lower} <= {var}) & ({var} < {upper})"
2939            cond_print = f"{lower_scalar} <= {var} < {upper_scalar}"
2940        elif lower:
2941            cond = f"{lower} <= {var}"
2942            cond_print = f"{lower_scalar} <= {var}"
2943        else:
2944            assert upper
2945            cond = f"{var} < {upper}"
2946            cond_print = f"{var} < {upper_scalar}"
2947        cond = f"{self._get_mask_type(var.dtype)}({cond})"
2948        if mask:
2949            if not mask.is_vec:
2950                mask = f"{self._get_mask_type(var.dtype)}({mask})"
2951            # We need not check when the mask is False
2952            cond = f"({cond}) | ~({mask})"
2953        if self.tail_size:
2954            cond = (
2955                f"{self._get_mask_type(var.dtype)}::set({self._get_mask_type(var.dtype)}::from(1)"
2956                f", ({cond}), {cexpr_index(self.tail_size)})"
2957            )
2958        cond = f"({cond}).all_masked()"
2959        return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
2960
2961    def get_to_dtype_expr(self, src, dtype, src_dtype):
2962        assert isinstance(src, CppCSEVariable)
2963        if not src.is_vec:
2964            return super().get_to_dtype_expr(src, dtype, src_dtype)
2965        src_cpp_type = DTYPE_TO_CPP[src_dtype]
2966        src_num_vectors = self._get_num_vectors(src_dtype)
2967        dst_cpp_type = DTYPE_TO_CPP[dtype]
2968        dst_num_vectors = self._get_num_vectors(dtype)
2969        expr = f"({src})"
2970        if src_dtype != torch.bool and dtype == torch.bool:
2971            expr = f"{self._get_mask_type(src_dtype)}::from<{src_cpp_type},{src_num_vectors}>({src})"
2972        elif src_dtype == torch.bool and dtype != torch.bool:
2973            expr = f"{src}.to<{dst_cpp_type},{dst_num_vectors}>()"
2974        elif src_dtype != dtype:
2975            if src_num_vectors == dst_num_vectors == 1:
2976                expr = f"at::vec::convert<{dst_cpp_type}>({src})"
2977            else:
2978                expr = f"at::vec::convert<{dst_cpp_type},{dst_num_vectors},{src_cpp_type},{src_num_vectors}>({src})"
2979        return expr
2980
2981
2982class CppTile2DKernel(CppVecKernel):
2983    """
2984    A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on
2985    the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data
2986    tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the
2987    tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization
2988    logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load
2989    and store are generated into kernel.preloads and kernel.poststores buffers.
2990
2991    The loop structure looks like below:
2992    for ...
2993      for i_outer ...
2994        for ...
2995          for inner_most ...
2996            // generated by CppTile2DKernel
2997            float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads
2998            float tmp1[16*16]; // into kernel.preloads
2999            for i_inner ... { // the kernel inner loop
3000              vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores
3001            }
3002            at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores
3003          for inner_most ... (tail)
3004            // generated by CppVecKernel
3005            ...
3006      for i_outer ... (tail)
3007        for ...
3008          for ...
3009            // generated by CppKernel
3010            ...
3011    """
3012
3013    overrides = CppTile2DOverrides  # type: ignore[assignment]
3014
3015    def __init__(
3016        self,
3017        args,
3018        num_threads,
3019        tiling_factor,
3020        tiling_indices,
3021        inner_tail_size=None,
3022        outer_tail_size=None,
3023    ):
3024        super().__init__(
3025            args,
3026            num_threads,
3027            tiling_factor,
3028            tiling_indices[1],
3029            inner_tail_size,
3030        )
3031        self.tiling_indices = tiling_indices
3032        self.inner_tail_size = inner_tail_size
3033        self.outer_tail_size = outer_tail_size
3034        self.inner_num_elems = inner_tail_size if inner_tail_size else tiling_factor
3035        self.outer_num_elems = outer_tail_size if outer_tail_size else tiling_factor
3036        self.inner_is_tiling_idx = True
3037
3038    def inner_itervar(self):
3039        return sympy_index_symbol(f"{self.itervars[self.outer_idx]}_inner")
3040
3041    def need_vec_transpose(self, index):
3042        outer_var = self.itervars[self.outer_idx]
3043        inner_var = self.itervars[self.tiling_idx]
3044        outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor)
3045        inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor)
3046        return (
3047            self._load_mask is None  # TODO: support transposition with mask
3048            and outer_stride == 1
3049            and index.has(inner_var)
3050            and not inner_stride.has(inner_var)
3051            and not inner_stride.has(outer_var)
3052        )
3053
3054    def gen_transposed_tile_load_store(self, name, var, index, is_store):
3055        # transposed tile load/store outside the kernel inner loop
3056        dtype = V.graph.get_dtype(name)
3057        factor = self.tiling_factor
3058        src = f"{var} + {cexpr_index(index)}"
3059        dst = "__place_holder__"
3060        ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}"
3061        ld_dst = f"{cexpr_index(self.num_elems)}"
3062        if is_store:
3063            src, dst = dst, src
3064            ld_src, ld_dst = ld_dst, ld_src
3065
3066        need_define = True
3067        if self.inner_is_tiling_idx ^ is_store:
3068            M, N = self.inner_num_elems, self.outer_num_elems
3069        else:
3070            M, N = (
3071                self.outer_num_elems,
3072                self.inner_num_elems,
3073            )
3074        if (isinstance(M, sympy.Expr) and not M.is_number) or (
3075            isinstance(N, sympy.Expr) and not N.is_number
3076        ):
3077            load_or_store = (
3078                f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]}>"
3079                f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});"
3080            )
3081        else:
3082            load_or_store = (
3083                f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)}>"
3084                f"({src}, {ld_src}, {dst}, {ld_dst});"
3085            )
3086        if is_store:
3087            tile_var = self.cse.newvar()
3088        elif load_or_store not in self.cse.cache:
3089            tile_var = self.cse.generate(self.preloads, load_or_store, write=False)
3090        else:
3091            need_define = False
3092            tile_var = self.cse.cache[load_or_store]
3093
3094        if need_define:
3095            define_line = f"alignas({factor}) {DTYPE_TO_CPP[dtype]} {tile_var}[{factor}*{factor}];"
3096            self.preloads.writeline(define_line)
3097
3098        load_or_store = load_or_store.replace("__place_holder__", str(tile_var))
3099        if is_store:
3100            self.poststores.writeline(DeferredLine(name, load_or_store))
3101        else:
3102            self.preloads.writeline(load_or_store)
3103
3104        return tile_var
3105
3106    def load(self, name: str, index: sympy.Expr):
3107        var = self.args.input(name)
3108        index = self.rename_indexing(index)
3109
3110        inner = self.inner_itervar()
3111        if self.need_vec_transpose(index):
3112            tile_var = self.gen_transposed_tile_load_store(
3113                name, var, index, is_store=False
3114            )
3115            # vector load inside the kernel inner loop
3116            loadbuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}"
3117            dtype = V.graph.get_dtype(name)
3118            line = self._get_vec_load_line(loadbuf, 0, dtype)  # type: ignore[arg-type]
3119            csevar = self.cse.generate(self.loads, line)
3120            csevar.update_on_args("load", (self, name, index), {})
3121            assert isinstance(csevar, CppCSEVariable)
3122            csevar.is_vec = True
3123            return csevar
3124        else:
3125            new_index = self.transform_indexing(index)
3126            return super().load(name, new_index)
3127
3128    def store(self, name, index, value, mode=None):
3129        assert "buf" in name
3130        var = self.args.output(name)
3131
3132        inner = self.inner_itervar()
3133        index = self.rename_indexing(index)
3134        assert mode is None
3135        if self.need_vec_transpose(index):
3136            tile_var = self.gen_transposed_tile_load_store(
3137                name, var, index, is_store=True
3138            )
3139            # vector store inside the kernel inner loop
3140            storebuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}"
3141            if self.tail_size or V.graph.get_dtype(name) in DTYPE_LOWP_FP + [
3142                torch.uint8,
3143                torch.int8,
3144            ]:
3145                line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});"
3146            else:
3147                line = f"{value}.store({storebuf});"
3148            self.stores.writeline(DeferredLine(name, line))
3149        else:
3150            new_index = self.transform_indexing(index)
3151            super().store(name, new_index, value, mode)
3152
3153    def codegen_inner_loops(self, code):
3154        inner = self.inner_itervar()
3155        if self.inner_is_tiling_idx:
3156            code.writeline(
3157                f"for (long {inner} = 0; {inner} < {cexpr_index(self.outer_num_elems)}; {inner}++)"
3158            )
3159        else:
3160            code.writeline(
3161                f"for (long {inner} = 0; {inner} < {cexpr_index(self.inner_num_elems)}; {inner}++)"
3162            )
3163
3164    def set_ranges(self, group, reduction_group):
3165        vars = super().set_ranges(group, reduction_group)
3166        # do vertical reduction as the tail loop
3167        self.outer_idx, self.tiling_idx = (
3168            self.tiling_indices
3169            if self.tiling_indices[1] < self.reduction_depth
3170            else reversed(self.tiling_indices)
3171        )
3172        if self.tiling_idx == self.tiling_indices[0]:
3173            self.tail_size = self.outer_tail_size
3174            self.num_elems = self.outer_num_elems
3175            self.inner_is_tiling_idx = False
3176        else:
3177            self.tail_size = self.inner_tail_size
3178            self.num_elems = self.inner_num_elems
3179            self.inner_is_tiling_idx = True
3180        return vars
3181
3182    def transform_indexing(self, index: sympy.Expr) -> sympy.Expr:
3183        return self.scale_index_with_offset(
3184            index,
3185            itervar_idx=self.outer_idx,
3186            offset=self.inner_itervar(),
3187        )
3188
3189
3190def get_loop_body_lowp_fp(_body: LoopBody) -> Tuple[Optional[torch.dtype], bool]:
3191    """
3192    Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes
3193    and if all the nodes can codegen with this data type without converting to float.
3194    Otherwise returns None and True.
3195    """
3196    sub_blocks = [_body.root_block] + list(_body.subblocks.values())
3197
3198    _lowp_fp_type: Optional[torch.dtype] = None
3199    _use_fp32 = False
3200    for sub_block in sub_blocks:
3201        for _node in sub_block.graph.nodes:
3202            if _node.op == "placeholder" or _node.target in (
3203                "get_index",
3204                "index_expr",
3205            ):
3206                continue
3207
3208            # Fast path if all operations can support bf16/fp16 without converting to fp32
3209            if _node.target not in [
3210                "load",
3211                "store",
3212                "abs",
3213                "neg",
3214                "output",
3215            ]:
3216                _use_fp32 = True
3217
3218            if hasattr(_node, "meta") and _node.meta:
3219                assert OptimizationContext.key in _node.meta
3220                opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key]
3221                if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP:
3222                    _use_fp32 = True
3223                elif _lowp_fp_type is not None:
3224                    if _lowp_fp_type != opt_ctx.dtype:
3225                        warnings.warn("bf16 and fp16 are mixed in the scheduler node.")
3226                else:
3227                    _lowp_fp_type = opt_ctx.dtype
3228            else:
3229                _use_fp32 = True
3230
3231    return _lowp_fp_type, _use_fp32
3232
3233
3234class TilingSelect:
3235    """
3236    Implement the heuristic to select the tiling factors and tiling indices.
3237    In the future, we can implement advanced heuristic in a subclass.
3238    """
3239
3240    def __init__(self):
3241        super().__init__()
3242
3243    def select_tiling(
3244        self,
3245        fn_list,
3246        var_sizes_list,
3247    ) -> Tuple[List[int], List[int]]:
3248        # TODO(jgong5): support alternative tiling factors and data types
3249        loop_bodies = _get_loop_body(fn_list)
3250        all_dtypes = _get_dtype_from_loopbodies(loop_bodies)
3251        assert all_dtypes
3252        if any(dtype not in VECTORIZABLE_DTYPES for dtype in all_dtypes):
3253            return [], []
3254        dtype = torch.float
3255        _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0])[0]
3256        if _lowp_fp_dtype and all(
3257            (get_loop_body_lowp_fp(loop_body)[0] == _lowp_fp_dtype)
3258            for loop_body in loop_bodies[1:]
3259        ):
3260            dtype = _lowp_fp_dtype
3261
3262        tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype)
3263        tiling_indices = self._select_tiling_indices(
3264            fn_list, var_sizes_list, tiling_factor
3265        )
3266
3267        if tiling_indices:
3268            group, reduction_group = max(
3269                var_sizes_list, key=lambda sizes: len(sizes[1])
3270            )
3271            call_ranges = tuple(group) + tuple(reduction_group)
3272
3273            if config.cpp.enable_tiling_heuristics:
3274
3275                def _try_get_stride(
3276                    index,
3277                    itervars,
3278                    tiling_factor,
3279                    tiling_indices,
3280                ):
3281                    itervar = itervars[tiling_indices[0]]
3282                    stride = stride_at_vec_range(index, itervar, tiling_factor)
3283                    return stride if stride.is_number else None
3284
3285                def _update_negative_op_count(
3286                    node_name, non_contig_indexing_op_counter
3287                ):
3288                    if node_name not in non_contig_indexing_op_counter:
3289                        non_contig_indexing_op_counter[node_name] = 1
3290                    else:
3291                        non_contig_indexing_op_counter[node_name] += 1
3292
3293                def _is_valid_indices(
3294                    itervars,
3295                    tiling_indices,
3296                ):
3297                    return (
3298                        len(tiling_indices) == 1
3299                        and len(itervars) > 0
3300                        and (
3301                            tiling_indices[0]
3302                            if tiling_indices[0] >= 0
3303                            else tiling_indices[0] + len(itervars)
3304                        )
3305                        < len(itervars)
3306                    )
3307
3308                itervars = [
3309                    sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
3310                    for n in range(len(call_ranges))
3311                ]
3312                reduction_depth = len(group)
3313                vars, reduction_vars = (
3314                    itervars[:reduction_depth],
3315                    itervars[reduction_depth:],
3316                )
3317                op_counter: Dict[str, int] = {}
3318                # ops may cause overhead with vectorization, like non-contiguous
3319                # index_expr, load, store
3320                non_contig_indexing_op_counter: Dict[str, int] = {}
3321                for _body in loop_bodies:
3322                    sub_blocks = [_body.root_block] + list(_body.subblocks.values())
3323                    for sub_block in sub_blocks:
3324                        for _node in sub_block.graph.nodes:
3325                            if _node.target in ["index_expr", "load", "store"]:
3326                                # get the index and replace prefix from z to x
3327                                arg_idx = 1 if _node.target == "index_expr" else 2
3328                                index = sub_block.body.indexing_from_args(
3329                                    (vars, reduction_vars)
3330                                )[_node.args[arg_idx].args[0]]
3331                                if _is_valid_indices(itervars, tiling_indices):
3332                                    stride = _try_get_stride(
3333                                        index, itervars, tiling_factor, tiling_indices
3334                                    )
3335                                    if (
3336                                        stride is None
3337                                        if _node.target == "index_expr"
3338                                        else stride not in [0, 1]
3339                                    ):
3340                                        _update_negative_op_count(
3341                                            _node.target, non_contig_indexing_op_counter
3342                                        )
3343                            if isinstance(_node.target, str) and not (
3344                                _node.target.startswith("masked_subblock")
3345                                or _node.target
3346                                in ["ops", "output", "constant", "get_index"]
3347                            ):
3348                                if _node.target not in op_counter:
3349                                    op_counter[_node.target] = 1
3350                                else:
3351                                    op_counter[_node.target] += 1
3352
3353                op_num = sum(op_counter.values())
3354                non_contig_indexing_op_num = sum(
3355                    non_contig_indexing_op_counter.values()
3356                )
3357                threshold = 0.08
3358                if op_num > 0 and non_contig_indexing_op_num / op_num >= threshold:
3359                    # Too many non-contiguous load/store/index_expr which hurts the
3360                    # vectorization performance. Disable vectorization when exceeding
3361                    # the threshold.
3362                    return [], []
3363
3364                if (
3365                    not reduction_group
3366                    and group
3367                    and len(tiling_indices) == 1
3368                    and not has_free_symbols(
3369                        [
3370                            group[tiling_indices[0]],
3371                        ]
3372                    )
3373                    and group[tiling_indices[0]] < tiling_factor / 2
3374                ):
3375                    # For case of Multi Thread AMP Static shape of pyhpc_isoneutral_mixing,
3376                    # the inner loop range doesn't have enough elements to do vectorization
3377                    # explicitly and found that `#pragma GCC ivdep` has better performance than
3378                    # `#pragma omp simd simdlen(8)`. Disable vectorization for this case.
3379                    # <TODO> Leslie: maybe we can always disable vectorization when loop range is less
3380                    # than tiling factor and enable `#pragma omp simd simdlen(8)` for scalar kernel
3381                    # when needed.
3382                    return [], []
3383
3384            if dtype in DTYPE_LOWP_FP:
3385                # For lower precision data type, if the call_range is not long enough,
3386                # use tiling_factor // 2 for better performance
3387                factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype)
3388                for tiling_indice in tiling_indices:
3389                    if tiling_indice < 0:
3390                        tiling_indice = tiling_indice + len(call_ranges)
3391                    if tiling_indice < 0 or tiling_indice >= len(call_ranges):
3392                        continue
3393                    if has_free_symbols(call_ranges):
3394                        call_range = V.graph.sizevars.size_hint(
3395                            call_ranges[tiling_indice], fallback=0
3396                        )
3397                        if call_range < factor_lowp:
3398                            V.graph.sizevars.guard_lt(call_range, factor_lowp)
3399                            tiling_factor = factor_lowp // 2
3400                            break
3401                    elif call_ranges[tiling_indice] < factor_lowp:
3402                        tiling_factor = factor_lowp // 2
3403                        break
3404
3405            if len(tiling_indices) == 1:
3406                return [tiling_factor], tiling_indices
3407            if len(tiling_indices) == 2:
3408                return [tiling_factor, tiling_factor], tiling_indices
3409        return [], []
3410
3411    def _select_tiling_indices(
3412        self,
3413        fn_list,
3414        var_sizes_list,
3415        tiling_factor,
3416    ):
3417        all_index = []
3418        for fn, var_sizes in zip(fn_list, var_sizes_list):
3419            rw = dependencies.extract_read_writes(fn, *var_sizes)
3420            all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)]
3421        contig_vars = set()
3422        contig_vars_list = []
3423        non_contig_stride_const = set()
3424        non_contig_stride_other = set()
3425        for index in all_index:
3426            for var in index.free_symbols:
3427                if not re.search(r"^d\d+$", var.name):
3428                    continue
3429                stride = stride_at_vec_range(index, var, tiling_factor)
3430                if stride == 0:
3431                    continue
3432                elif stride == 1:
3433                    contig_vars.add(int(var.name[1:]))
3434                    contig_vars_list.append(int(var.name[1:]))
3435                elif all(symbol_is_type(s, SymT.SIZE) for s in stride.free_symbols):
3436                    non_contig_stride_const.add(int(var.name[1:]))
3437                else:
3438                    non_contig_stride_other.add(int(var.name[1:]))
3439        contig_only = contig_vars - non_contig_stride_const - non_contig_stride_other
3440        group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1]))
3441        num_itervars = len(group) + len(reduction_group)
3442        if len(contig_vars) == 0:
3443            # no contiguous vars
3444            return [num_itervars - 1]
3445        if contig_only:
3446            return sorted(contig_only)[-1:]
3447        contig_and_const_stride = (
3448            contig_vars & non_contig_stride_const
3449        ) - non_contig_stride_other
3450        contig_vars_sorted = sorted(contig_vars)
3451        if (
3452            len(contig_vars_sorted) == 2
3453            and contig_vars_sorted[-1] in contig_and_const_stride
3454            and contig_vars_sorted[-1] == num_itervars - 1
3455        ):
3456            return contig_vars_sorted
3457        return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:]
3458
3459
3460class CppKernelProxy(CppKernel):
3461    def __init__(self, kernel_group):
3462        super().__init__(kernel_group.args, kernel_group.ws.num_threads)
3463        self.kernel_group = kernel_group
3464        self.loop_nest = None
3465        self.call_ranges = None
3466        self.picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa()
3467
3468    def data_type_propagation(self, nodes):
3469        for _node in nodes:
3470            assert isinstance(_node, SchedulerNode)
3471            DataTypePropagation.propagate_scheduler_node(_node)
3472
3473    # Check if all the nodes of a given fx graph can support BF16/FP16
3474    def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode):
3475        if not isinstance(scheduler_node._body, LoopBody):
3476            return True
3477        # Propagate the dtype to check if all the fx node is bf16/fp16
3478        DataTypePropagation.propagate_scheduler_node(scheduler_node)
3479        return (
3480            get_loop_body_lowp_fp(scheduler_node._body)[0] is not None
3481            and not get_loop_body_lowp_fp(scheduler_node._body)[1]
3482        )
3483
3484    def legalize_lowp_fp_dtype_loopbody(self, loop_body: LoopBody):
3485        def add_to_dtype(sub_graph: torch.fx.Graph):
3486            def is_lowp_fp_load(node: torch.fx.Node):
3487                if node.target not in ["load"]:
3488                    return False
3489                assert len(node.args) == 3
3490                load_dtype = V.graph.get_dtype(node.args[1])  # type: ignore[arg-type]
3491                return load_dtype in DTYPE_LOWP_FP
3492
3493            def is_lowp_fp_store(node: torch.fx.Node):
3494                if node.target != "store":
3495                    return False
3496                _, store_var, _, _, _ = node.args
3497                store_dtype = V.graph.get_dtype(store_var)  # type: ignore[arg-type]
3498                return store_dtype in DTYPE_LOWP_FP
3499
3500            sub_graph_nodes = list(sub_graph.nodes)
3501            to_lowp_fp_legalized_nodes = []
3502            for _node in sub_graph_nodes:
3503                if is_lowp_fp_load(_node):
3504                    # No need to promote to float if all users are direct stores
3505                    if all(user.target == "store" for user in _node.users):
3506                        continue
3507                    ops = _node.args[0]
3508                    with sub_graph.inserting_after(_node):
3509                        to_type_node = sub_graph.call_method(
3510                            "to_dtype", args=(ops, _node, torch.float)
3511                        )
3512                        to_type_node_args = to_type_node.args
3513                        _node.replace_all_uses_with(to_type_node)
3514                        to_type_node.args = to_type_node_args
3515                        metrics.cpp_to_dtype_count += 1
3516                elif is_lowp_fp_store(_node):
3517                    ops, name, _, value_var, _ = _node.args
3518                    # No need to promote to float if it is a user of a load which are all directly stored
3519                    if value_var.target == "load" and all(
3520                        user.target == "store" for user in value_var.users
3521                    ):
3522                        continue
3523                    dtype = V.graph.get_dtype(name)
3524                    with sub_graph.inserting_before(_node):
3525                        to_type_node = sub_graph.call_method(
3526                            "to_dtype", args=(ops, value_var, dtype)
3527                        )
3528                        _node.replace_input_with(value_var, to_type_node)
3529                        metrics.cpp_to_dtype_count += 1
3530                elif _node.target == "reduction":
3531                    (
3532                        ops,
3533                        dtype,
3534                        src_dtype,
3535                        reduction_type,
3536                        value,
3537                    ) = _node.args
3538                    if src_dtype in DTYPE_LOWP_FP:
3539                        # Since we always convert the load/store value to float if the tensor is bfloat16/float16.
3540                        # Therefore, the reduction should never work with bfloat16/float16 value. Hence, we update
3541                        # the bfloat16/float16 reduction by
3542                        #     1) updating the src_dtype to float
3543                        # and 2) updating the dtype to float if it is bfloat16/float16.
3544                        assert dtype in [
3545                            torch.float,
3546                            torch.bfloat16,
3547                            torch.float16,
3548                            torch.int64,
3549                        ]
3550                        _node.args = (
3551                            ops,
3552                            torch.float if dtype in DTYPE_LOWP_FP else dtype,
3553                            torch.float,
3554                            reduction_type,
3555                            value,
3556                        )
3557                elif _node.target == "to_dtype" and _node.args[-1] in DTYPE_LOWP_FP:
3558                    (ops, x, _) = _node.args
3559                    # The legalization always loads the BF16/FP16 tensor as FP32 for computation
3560                    # and converts back to BF16/FP16 after the computation.
3561                    # Hence, there should be no computation w/ BF16/FP16.
3562                    # Therefore, we update the to_dtype by replacing the bf16/fp16 dtype with fp32.
3563                    # Save the legalized to_dtype node for the elimination(eliminate_to_dtype step):
3564                    #  1) Eliminate the redundant to_dtype node if we have a pattern as follows:
3565                    #     graph():
3566                    #       %lowp_fp_legalized = call_method[target=to_dtype](args = (%ops, %input, torch.float))
3567                    #       %to_dtype2 = call_method[target=to_dtype](args = (%ops, %lowp_fp_legalized, torch.bfloat16/float16))
3568                    # Regarding the first to_dtype, it is redundant because
3569                    # the second to_type also converts to the torch.bfloat16/torch.float16.
3570                    # Hence, we remove the first to_type.
3571                    to_lowp_fp_legalized_nodes.append(_node)
3572                    _node.args = (ops, x, torch.float)
3573                else:
3574                    pass
3575
3576            def eliminate_to_dtype(sub_graph: torch.fx.Graph):
3577                def _eliminate_duplicate_to_node(sub_graph: torch.fx.Graph):
3578                    # Eliminate the redundant to_dtype node. Let's consider a pattern as follows:
3579                    #   graph():
3580                    #     %to_dtype1 = call_method[target=to_dtype](args = (%ops, %input, torch.float), kwargs = {})
3581                    #     %to_dtype2 = call_method[target=to_dtype](args = (%ops, %to_dtype1, torch.float), kwargs = {})
3582                    # Regarding the first to_dtype, it is redundant because the second to_type also converts to the
3583                    # torch.float. Hence, we remove the first to_type
3584                    def _used_by_to(to_node: torch.fx.Node):
3585                        return all(usr.target == "to_dtype" for usr in to_node.users)
3586
3587                    all_to_nodes = [
3588                        node for node in sub_graph.nodes if node.target == "to_dtype"
3589                    ]
3590                    all_to_nodes_and_users = [
3591                        {node: node.users} for node in all_to_nodes if _used_by_to(node)
3592                    ]
3593                    for node_users in all_to_nodes_and_users:
3594                        for node, users in node_users.items():
3595                            if node in sub_graph.nodes and (
3596                                all(usr.args[-1] == node.args[-1] for usr in users)
3597                                or (
3598                                    node in to_lowp_fp_legalized_nodes
3599                                    and all(
3600                                        usr.args[-1] in DTYPE_LOWP_FP for usr in users
3601                                    )
3602                                )
3603                            ):
3604                                val_node = node.all_input_nodes[-1]
3605                                node.replace_all_uses_with(val_node)
3606                                sub_graph.erase_node(node)
3607
3608                    # For debug mode, the graph of LoopBody will attach a new GraphModule as
3609                    # owning_module for debugging while the release mode will not. The lint will
3610                    # check whether the graph has owning_module to decide if it needs to check
3611                    # call_module. LoopBody might contain get_index as a module call. But it
3612                    # is just a function. Hence, it cannot pass the lint check for debug mode.
3613                    # We bypass the check if the owning_module is None. Eventually, we should call
3614                    # get_index via call_function but not call_module.
3615                    if sub_graph.owning_module is None:
3616                        sub_graph.lint()
3617
3618                _eliminate_duplicate_to_node(sub_graph)
3619
3620            eliminate_to_dtype(sub_graph)
3621
3622        sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values())
3623        for sub_block in sub_blocks:
3624            add_to_dtype(sub_block.graph)
3625
3626    def legalize_lowp_fp_dtype(self, nodes):
3627        if all(
3628            isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node)
3629            for _node in nodes
3630        ):
3631            # Mark the load node to load bf16/fp16
3632            for _node in nodes:
3633                sub_blocks = [_node._body.root_block] + list(
3634                    _node._body.subblocks.values()
3635                )
3636                for sub_block in sub_blocks:
3637                    for fx_node in sub_block.graph.nodes:
3638                        if fx_node.target in ["load", "store"]:
3639                            assert fx_node.meta
3640                            assert OptimizationContext.key in fx_node.meta
3641                            opt_ctx: OptimizationContext = fx_node.meta[
3642                                OptimizationContext.key
3643                            ]
3644                            assert opt_ctx.dtype in DTYPE_LOWP_FP
3645
3646            # Bypass the legalization as the kernel can run with bf16/fp16 directly
3647            return
3648
3649        for _node in nodes:
3650            assert isinstance(_node, SchedulerNode)
3651            assert isinstance(_node._body, LoopBody)
3652            body: LoopBody = _node._body
3653            if not body.is_memory_copy():
3654                self.legalize_lowp_fp_dtype_loopbody(body)
3655
3656    def codegen_functions(self, fn_list, var_sizes_list):
3657        assert len(fn_list) == len(var_sizes_list)
3658        kernel_group = self.kernel_group
3659        group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1]))
3660
3661        self.set_ranges(group, reduction_group)
3662
3663        def codegen_kernel(cls, *args):
3664            with kernel_group.new_kernel(cls, *args) as kernel:
3665                # Ugly hack to maintain the metrics kernel count since
3666                # we only count in CppKernelProxy, not those contained in it
3667                metrics.generated_kernel_count -= 1
3668
3669                run(kernel)
3670                return kernel
3671
3672        def run(kernel):
3673            vars, reduction_vars = kernel.set_ranges(group, reduction_group)
3674            in_suffix = False
3675            for fn, var_sizes in zip(fn_list, var_sizes_list):
3676                if var_sizes in [
3677                    (group, reduction_group),
3678                    (tuple(itertools.chain(group, reduction_group)), ()),
3679                ]:
3680                    assert not in_suffix
3681                    fn(vars, reduction_vars)
3682                else:
3683                    in_suffix = True
3684                    assert var_sizes == (
3685                        group,
3686                        (),
3687                    ), f"unexpected group: {var_sizes} != {group}, {reduction_group}"
3688                    # we can fuse in some extra pointwise into the suffix
3689                    with kernel.write_to_suffix():
3690                        fn(vars, ())
3691
3692        scalar_kernel = codegen_kernel(CppKernel)
3693        V.graph.removed_buffers |= scalar_kernel.removed_buffers
3694        V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove
3695        self.loop_nest = LoopNestWithSplit.build(scalar_kernel)
3696
3697        if not self.picked_vec_isa:
3698            return
3699
3700        if not self.itervars:
3701            # not a loop
3702            return
3703
3704        # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args.
3705        # But the generated scalar kernel has updated these global contexts. Hence, the other kernels
3706        # should not do this again to avoid context conflict. By now, we only control the
3707        # config.inplace_buffers. In the future, we could maintain more contexts.
3708        with torch._inductor.config.patch(inplace_buffers=False):
3709            tiling_select = TilingSelect()
3710            tiling_factors, tiling_indices = tiling_select.select_tiling(
3711                fn_list, var_sizes_list
3712            )
3713            assert len(tiling_factors) == len(tiling_indices)
3714            # <TODO> This should be removed after full support for vectorization is implemented.
3715            could_masked_vec = True
3716            all_dtypes = _get_dtype_from_loopbodies(_get_loop_body(fn_list))
3717            if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes):
3718                # can be removed after masked vectorizable dtype are same with vectorizable dtype
3719                could_masked_vec = False
3720
3721            if len(tiling_indices) == 1:
3722                vec_kernel = codegen_kernel(
3723                    CppVecKernel, tiling_factors[0], tiling_indices[0]
3724                )
3725                metrics.generated_cpp_vec_kernel_count += 1
3726                main_loop, tail_loop = self.loop_nest.split_with_tiling(
3727                    tiling_indices[0], factor=tiling_factors[0]
3728                )
3729                main_loop.set_kernel(vec_kernel)
3730                main_loop.simd_vec = True
3731                if config.cpp.enable_loop_tail_vec and could_masked_vec:
3732                    tail_loop.steps = tail_loop.size - tail_loop.offset
3733                    masked_vec_kernel = codegen_kernel(
3734                        CppVecKernel,
3735                        tiling_factors[0],
3736                        tiling_indices[0],
3737                        tail_loop.steps,
3738                    )
3739                    tail_loop.set_kernel(masked_vec_kernel)
3740                    tail_loop.simd_vec = True
3741                else:
3742                    tail_loop.set_kernel(scalar_kernel)
3743                    tail_loop.simd_omp = True
3744                # We chop the loop into two cubes by the nelements - main loop and tail loop.
3745                # Regarding the main loop, it is straightforward that it could be vectorized with
3746                # nelements. But for the tail loop, it still could be vectorized. For example,
3747                # if the nelements is 8(256bits), then the tail loop still could be vectorized
3748                # as 4(128bits).
3749                tail_loop.simd_nelements = tiling_factors[0] // 2
3750            elif len(tiling_indices) == 2:
3751                assert (
3752                    tiling_indices[1] == len(self.itervars) - 1
3753                    and tiling_factors[0] == tiling_factors[1]
3754                )
3755
3756                metrics.generated_cpp_vec_kernel_count += 2
3757                outer_main_loop, outer_tail_loop = self.loop_nest.split_with_tiling(
3758                    tiling_indices[0], factor=tiling_factors[0]
3759                )
3760                (
3761                    inner_main_loop,
3762                    inner_tail_loop,
3763                ) = outer_main_loop.split_with_tiling(
3764                    tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0]
3765                )
3766                tile2d_kernel = codegen_kernel(
3767                    CppTile2DKernel, tiling_factors[0], tiling_indices
3768                )
3769                inner_main_loop.set_kernel(tile2d_kernel)
3770
3771                if config.cpp.enable_loop_tail_vec and could_masked_vec:
3772                    (
3773                        inner_main_loop_of_outer_tail_loop,
3774                        inner_tail_loop_of_outer_tail_loop,
3775                    ) = outer_tail_loop.split_with_tiling(
3776                        tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0]
3777                    )
3778
3779                    for tail_loop in (
3780                        inner_tail_loop,
3781                        outer_tail_loop,
3782                        inner_tail_loop_of_outer_tail_loop,
3783                    ):
3784                        tail_loop.steps = tail_loop.size - tail_loop.offset
3785
3786                    for tail_loop, inner_tail_size, outer_tail_size in (
3787                        (inner_tail_loop, inner_tail_loop.steps, None),
3788                        (
3789                            inner_main_loop_of_outer_tail_loop,
3790                            None,
3791                            outer_tail_loop.steps,
3792                        ),
3793                        (
3794                            inner_tail_loop_of_outer_tail_loop,
3795                            inner_tail_loop_of_outer_tail_loop.steps,
3796                            outer_tail_loop.steps,
3797                        ),
3798                    ):
3799                        masked_tile2d_kernel = codegen_kernel(
3800                            CppTile2DKernel,
3801                            tiling_factors[0],
3802                            tiling_indices,
3803                            inner_tail_size,
3804                            outer_tail_size,
3805                        )
3806                        tail_loop.set_kernel(masked_tile2d_kernel)
3807                else:
3808                    vec_kernel = codegen_kernel(
3809                        CppVecKernel, tiling_factors[0], tiling_indices[0]
3810                    )
3811                    inner_tail_loop.set_kernel(vec_kernel)
3812
3813                    outer_tail_loop.set_kernel(scalar_kernel)
3814
3815    def codegen_loop_bodies(self, loop_bodies, var_sizes_list):
3816        for body in loop_bodies:
3817            self.legalize_lowp_fp_dtype_loopbody(body)
3818            DataTypePropagation.propagate_loopbody(body)
3819        self.codegen_functions(loop_bodies, var_sizes_list)
3820
3821    def codegen_nodes(self, nodes: List[SchedulerNode]):
3822        # Legalize BF16 node by adding to_dtype explicitly
3823        self.legalize_lowp_fp_dtype(nodes)
3824        self.data_type_propagation(nodes)
3825        assert len(nodes) >= 1
3826
3827        def fn(node, *index_vars):
3828            node.decide_inplace_update()
3829            node.mark_run()
3830            if isinstance(V.kernel, NullKernelHandler):
3831                return node._body(*index_vars)
3832            else:
3833                return node.codegen(index_vars)
3834
3835        fn_list = [functools.partial(fn, node) for node in nodes]
3836
3837        if (
3838            isinstance(V.local_buffer_context, LocalBufferContext)
3839            and V.local_buffer_context.local_buffers
3840        ):
3841
3842            def wrap_fn(fn):
3843                wrapped_fn = V.local_buffer_context.localize_function(
3844                    fn,
3845                )
3846                wrapped_fn.original_fn = fn
3847                return wrapped_fn
3848
3849            fn_list = [wrap_fn(fn) for fn in fn_list]
3850
3851        var_sizes_list = [node.group[1] for node in nodes]
3852        self.codegen_functions(fn_list, var_sizes_list)
3853
3854    def codegen_loops(self, code, worksharing):
3855        self.codegen_loops_impl(self.loop_nest, code, worksharing)
3856
3857
3858class OuterLoopFusedKernel(CppKernel):
3859    def __init__(self, kernel_group):
3860        super().__init__(kernel_group.args, kernel_group.ws.num_threads)
3861        self.inner: List[LoopLevel] = []
3862
3863    def decide_parallel_depth(self, max_parallel_depth, threads) -> int:
3864        kernels_parallel_depth = []
3865        nested_kernels: List[List[CppKernel]] = [
3866            loop.get_kernels() for loop in self.inner
3867        ]
3868        for kernels in nested_kernels:
3869            # For any ScalarKernel, VecKernel, or Tile2DKernel,
3870            # they should all have the same call_ranges
3871            call_ranges = kernels[0].call_ranges
3872            assert call_ranges is not None
3873            assert all(kernel.call_ranges == call_ranges for kernel in kernels)
3874            kernels_parallel_depth.append(
3875                kernels[0].decide_parallel_depth(len(call_ranges), threads)
3876            )
3877        return min(
3878            max_parallel_depth,
3879            max(kernels_parallel_depth),
3880        )
3881
3882
3883class ReasonFusedNodes(Enum):
3884    SAME_VARS_REDUCE = "same_vars_reduce"
3885    COMPATIBLE_REDUCTION = "compatible_reduction"
3886    COMPATIBLE_RANGES_NO_REDUCTION = "compatible_ranges_no_reduction"
3887
3888
3889class CppScheduling(BaseScheduling):
3890    # ctypes limits the number of args to 1024, refer to:
3891    # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237
3892    # We set a conservative threshold here.
3893    MAX_FUSED_KERNEL_ARGS_NUM = 500
3894    backend_features = dict.fromkeys(
3895        [
3896            BackendFeature.INPLACE_BUFFERS,
3897            BackendFeature.REDUCE_TO_SINGLE_ELEMENT,
3898        ]
3899    )
3900
3901    @classmethod
3902    def get_backend_features(cls, device: torch.device):
3903        return cls.backend_features
3904
3905    def __init__(self, scheduler):
3906        super().__init__()
3907        self.scheduler = scheduler
3908        if scheduler:
3909            self.reset_kernel_group()
3910        self._ready_to_flush = False
3911
3912    def _set_flush_status(self, status: bool):
3913        self._ready_to_flush = status
3914
3915    def group_fn(self, sizes):
3916        return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes)
3917
3918    def reset_kernel_group(self):
3919        from .cpp_wrapper_cpu import CppWrapperCpu
3920
3921        self.kernel_group: Union[CppWrapperKernelGroup, KernelGroup]
3922        if isinstance(V.graph.wrapper_code, CppWrapperCpu):
3923            self.kernel_group = CppWrapperKernelGroup()
3924        else:
3925            self.kernel_group = KernelGroup()
3926
3927    def fuse(self, node1, node2):
3928        if node1.is_foreach() or node2.is_foreach():
3929            return ForeachKernelSchedulerNode.fuse(node1, node2)
3930        elif node1.is_template():
3931            assert not node2.is_template()
3932            return FusedSchedulerNode.fuse(node1, node2)
3933        else:
3934            if (
3935                self._why_fuse_nodes(node1, node2)
3936                == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION
3937            ):
3938                assert isinstance(node1, (SchedulerNode, FusedSchedulerNode))
3939                assert isinstance(node2, (SchedulerNode, FusedSchedulerNode))
3940
3941                _, (vars1, reduce1) = node1.group
3942                _, (vars2, reduce2) = node2.group
3943                assert reduce1 == () and reduce2 == (), (reduce1, reduce2)
3944
3945                def get_indexing_ranges_exprs(node):
3946                    if isinstance(node, FusedSchedulerNode):
3947                        assert len(node.snodes) > 0, node.snodes
3948                        var_ranges = None
3949                        indexing_exprs = set()
3950                        for snode in node.snodes:
3951                            v, exprs = get_indexing_ranges_exprs(snode)
3952                            if var_ranges is None:
3953                                var_ranges = v
3954                            assert var_ranges == v, (var_ranges, v, node.snodes)
3955                            indexing_exprs.update(exprs)
3956                        return var_ranges, list(indexing_exprs)
3957                    else:
3958                        assert isinstance(node, SchedulerNode)
3959                        comp_buffer = node.node
3960                        assert isinstance(comp_buffer, ir.ComputedBuffer)
3961                        _, body, _ = comp_buffer.get_default_sizes_body()
3962                        return body.var_ranges, list(body.indexing_exprs.values())
3963
3964                node_to_recomp = node1 if len(vars1) < len(vars2) else node2
3965                assert isinstance(node_to_recomp, SchedulerNode)
3966
3967                ref_node = node2 if len(vars1) < len(vars2) else node1
3968
3969                extra_indexing_constraints = get_indexing_ranges_exprs(ref_node)
3970
3971                node_to_recomp.recompute_size_and_body(
3972                    extra_indexing_constraints=extra_indexing_constraints
3973                )
3974
3975                _, (vars1, _) = node1.group
3976                _, (vars2, _) = node2.group
3977                assert vars1 == vars2, (vars1, vars2)
3978                return FusedSchedulerNode.fuse(node1, node2)
3979            elif self.can_fuse_vertical_outer_loop(node1, node2):
3980                return OuterLoopFusedSchedulerNode.fuse(
3981                    node1, node2, self._get_outer_loop_fusion_depth(node1, node2)
3982                )
3983            else:
3984                return FusedSchedulerNode.fuse(node1, node2)
3985
3986    def _why_fuse_nodes(self, node1, node2) -> Optional[ReasonFusedNodes]:
3987        _, (vars1, reduce1) = node1.group
3988        _, (vars2, reduce2) = node2.group
3989
3990        if vars1 == vars2 and reduce1 == reduce2:
3991            return ReasonFusedNodes.SAME_VARS_REDUCE
3992        if reduce1 == () and vars1 == vars2 + reduce2:
3993            return ReasonFusedNodes.COMPATIBLE_REDUCTION
3994        if self._can_fuse_nodes_with_compatible_ranges(node1, node2):
3995            return ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION
3996        # TODO(jansel): allow fusion pointwise (vars1, ()) suffix?
3997        return None
3998
3999    def _can_fuse_nodes_with_compatible_ranges(self, node1, node2):
4000        # Here we try to fuse SchedulerNode/FusedSchedulerNode with compatible ranges
4001        # e.g. (s0, s1, s2) and (s0 * s1 * s2)
4002        _, (vars1, reduce1) = node1.group
4003        _, (vars2, reduce2) = node2.group
4004
4005        c1 = reduce1 == () and reduce2 == ()
4006        c2 = math.prod(vars1) == math.prod(vars2)
4007        c3 = len(vars1) == 1 or len(vars2) == 1
4008        if not (c1 and c2 and c3):
4009            return False
4010
4011        node_to_recomp = node1 if len(vars1) < len(vars2) else node2
4012        ref_node = node2 if len(vars1) < len(vars2) else node1
4013
4014        # We can not recompute sizes and body for nodes other than SchedulerNode
4015        # TODO: we can extend fusion support with compatible ranges for FusedSchedulerNode
4016        if isinstance(node_to_recomp, FusedSchedulerNode):
4017            return False
4018
4019        # It may happen that node1 and node2 compatible number of elements
4020        # but different original ranges, for example:
4021        # {d0: s0, d1: s1, d2: s2} vs {d0: s0*s1*s2}
4022        # See https://github.com/pytorch/pytorch/pull/120077/files#r1500427848 for more details
4023        # TODO: we can fix if it allows us to CSE at least one of the variables
4024
4025        assert isinstance(node_to_recomp, SchedulerNode)
4026        if isinstance(node_to_recomp.node, ir.TemplateBuffer):
4027            return False
4028        assert isinstance(node_to_recomp.node, ir.ComputedBuffer)
4029        # node.data.get_size() is a cheaper version of node.get_read_writes().var_ranges
4030        # but without variable name
4031        ranges2 = node_to_recomp.node.data.get_size()
4032        ranges1 = None
4033        if isinstance(ref_node, FusedSchedulerNode):
4034            ranges_set = set()
4035            for snode in ref_node.snodes:
4036                if isinstance(snode.node, ir.TemplateBuffer):
4037                    break
4038                assert isinstance(snode.node, ir.ComputedBuffer)
4039                ranges_set.add(tuple(snode.node.data.get_size()))
4040
4041            if len(ranges_set) != 1:
4042                return False
4043
4044            ranges1 = list(next(iter(ranges_set)))
4045        else:
4046            assert isinstance(ref_node, SchedulerNode)
4047            assert isinstance(ref_node.node, ir.ComputedBuffer)
4048            ranges1 = ref_node.node.data.get_size()
4049
4050        if ranges1 != ranges2:
4051            return False
4052
4053        return True
4054
4055    def _can_fuse_horizontal_impl(self, node1, node2):
4056        assert isinstance(node1, (FusedSchedulerNode, SchedulerNode))
4057        assert isinstance(node2, (FusedSchedulerNode, SchedulerNode))
4058        if any(
4059            isinstance(node, OuterLoopFusedSchedulerNode) for node in (node1, node2)
4060        ):
4061            return False
4062        return self._why_fuse_nodes(node1, node2) is not None
4063
4064    def can_fuse_horizontal(self, node1, node2):
4065        if node1.is_template() or node2.is_template():
4066            return False
4067        if (
4068            len(node1.get_nodes()) + len(node2.get_nodes())
4069            > config.cpp.max_horizontal_fusion_size
4070        ):
4071            return False
4072
4073        return self._can_fuse_horizontal_impl(node1, node2)
4074
4075    def _get_outer_loop_fusion_depth(self, node1, node2):
4076        DISABLE_OUTER_LOOP_FUSION = 0
4077        if not all(
4078            type(node)
4079            in (OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode)
4080            for node in (node1, node2)
4081        ):
4082            return DISABLE_OUTER_LOOP_FUSION
4083
4084        _node1 = (
4085            node1.get_outer_nodes()[-1]
4086            if isinstance(node1, OuterLoopFusedSchedulerNode)
4087            else node1
4088        )
4089        assert isinstance(_node1, (FusedSchedulerNode, SchedulerNode))
4090        _node2 = (
4091            node2.get_outer_nodes()[0]
4092            if isinstance(node2, OuterLoopFusedSchedulerNode)
4093            else node2
4094        )
4095        assert isinstance(_node2, (FusedSchedulerNode, SchedulerNode))
4096
4097        _, (vars1, reduce1) = _node1.group
4098        _, (vars2, reduce2) = _node2.group
4099        if vars1 == () and vars2 == () and reduce1 != () and reduce2 != ():
4100            # Reduction only
4101            return DISABLE_OUTER_LOOP_FUSION
4102        if all(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)):
4103            return (
4104                node1.outer_loop_fusion_depth
4105                if node1.outer_loop_fusion_depth == node2.outer_loop_fusion_depth
4106                else DISABLE_OUTER_LOOP_FUSION
4107            )
4108        outer_loop_fusion_depth = min(len(vars1), len(vars2))
4109        if (
4110            outer_loop_fusion_depth >= 1
4111            and vars1[:outer_loop_fusion_depth] == vars2[:outer_loop_fusion_depth]
4112        ):
4113            if any(
4114                type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)
4115            ):
4116                _compare_node = (
4117                    node1 if type(node1) is OuterLoopFusedSchedulerNode else node2
4118                )
4119                if _compare_node.outer_loop_fusion_depth == outer_loop_fusion_depth:
4120                    # Same outer loop fusion depth as prev nodes in OuterLoopFusedSchedulerNode
4121                    return outer_loop_fusion_depth
4122                else:
4123                    return DISABLE_OUTER_LOOP_FUSION
4124            else:
4125                # First 2 nodes to generate OuterLoopFusedSchedulerNode
4126                return outer_loop_fusion_depth
4127        return DISABLE_OUTER_LOOP_FUSION
4128
4129    def can_fuse_vertical_outer_loop(self, node1, node2):
4130        return (
4131            not node1.is_template()
4132            and not node2.is_template()
4133            and node1.get_operation_names() & node2.ancestors
4134            and not (
4135                self._can_fuse_horizontal_impl(node1, node2)
4136                and not node1.is_reduction()
4137            )
4138            and self._get_outer_loop_fusion_depth(node1, node2) >= 1
4139        )
4140
4141    def get_fusion_pair_priority(self, node1, node2):
4142        if self.can_fuse_vertical_outer_loop(node1, node2):
4143            # Outer loop fusion with lower priority
4144            return 1
4145        else:
4146            return 0
4147
4148    def can_fuse_vertical(self, node1, node2):
4149        if node2.is_template():
4150            # TODO(jgong5): support pre-op fusion with template
4151            return False
4152        if node1.is_template():
4153            return not node2.is_reduction()
4154        return (
4155            self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
4156        ) or self.can_fuse_vertical_outer_loop(node1, node2)
4157
4158    def try_loop_split(self, nodes: List[SchedulerNode]):
4159        """
4160        Apply loop split optimization.
4161        When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop
4162        to avoid non-contiguous loads, subject to the following conditions:
4163            1. No reduction and no mudular index for all nodes.
4164            2. Only one node's one indexing_exprs contains a division, according to this indexing_exprs,
4165               we can get the dimension that needs to be split, and the split dimension is contiguous
4166               in all other indexing_exprs.
4167
4168        For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs:
4169        {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2},
4170        we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to
4171        {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to
4172        {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}.
4173        """
4174
4175        # No reduction and no mudular
4176        if any(
4177            len(node.group[1][1]) != 0
4178            or any(
4179                expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values()
4180            )
4181            for node in nodes
4182        ):
4183            return nodes
4184
4185        split_var = None
4186        split_number = None
4187        divide_index_name = None
4188        num_div = 0
4189        match_div = False
4190        matched_node = None
4191
4192        for node in nodes:
4193            assert isinstance(node.node, ir.ComputedBuffer)
4194            _, original_body, _ = node.node.get_default_sizes_body()
4195            for name, expr in original_body.indexing_exprs.items():
4196                num_div += expr.count(FloorDiv)
4197                if num_div > 1:
4198                    return nodes
4199                if expr.count(FloorDiv) == 1:
4200                    div_expr = expr.find(FloorDiv).pop()
4201                    split_var = div_expr.args[0]
4202                    split_number = div_expr.args[1]
4203                    divide_index_name = name
4204                    if (
4205                        isinstance(split_number, sympy.core.numbers.Integer)
4206                        and isinstance(split_var, sympy.core.symbol.Symbol)
4207                        and split_var in original_body.iter_vars
4208                        and divide_index_name is not None
4209                        and all(
4210                            stride_at_vec_range(expr, split_var) == 1
4211                            for name, expr in original_body.indexing_exprs.items()
4212                            if name != divide_index_name
4213                        )
4214                    ):
4215                        match_div = True
4216                        matched_node = node
4217
4218        # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs.
4219        if not match_div:
4220            return nodes
4221
4222        extra_indexing_constraints = None
4223
4224        def loop_split(sizes, body, vars):
4225            index_size, reduce_size = sizes
4226            index_vars, reduce_vars = vars
4227            split_idx = index_vars.index(split_var)
4228            new_index_size = index_size.copy()
4229            new_index_size[split_idx] = index_size[split_idx] // split_number
4230            new_index_size.insert(split_idx + 1, split_number)
4231            (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze(
4232                new_index_size, reduce_size, prefix="y"
4233            )
4234            iter_vars = new_index_vars.copy()
4235            divisor_var = iter_vars.pop(split_idx + 1)
4236            iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var
4237            body = ir.LoopBody(
4238                body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars
4239            )
4240            nonlocal extra_indexing_constraints
4241            if not extra_indexing_constraints:
4242                extra_indexing_constraints = (
4243                    body.var_ranges,
4244                    list(body.indexing_exprs.values()),
4245                )
4246            return (
4247                (new_index_size, reduce_size),
4248                body,
4249                (new_index_vars, reduce_vars),
4250            )
4251
4252        # Here decide the final loop order
4253        for node in nodes:
4254            if node == matched_node:
4255                node.recompute_size_and_body(recompute_sizes_body_func=loop_split)
4256        for node in nodes:
4257            if node != matched_node:
4258                node.recompute_size_and_body(
4259                    extra_indexing_constraints=extra_indexing_constraints,
4260                    recompute_sizes_body_func=loop_split,
4261                )
4262
4263        return nodes
4264
4265    def codegen_outer_loop_node(
4266        self,
4267        node: OuterLoopFusedSchedulerNode,
4268    ):
4269        """
4270        Generate the code for the outer loop fused scheduler node.
4271        1. Codegen with fused outer loop: depends on the analysis of
4272            the outer loop fused scheduler node, with or without the local buffer.
4273        2. If failed, fallback to standard codegen.
4274        """
4275        kernel_group = self.kernel_group
4276        generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count
4277        cpp_kernel_proxy_list: List[CppKernelProxy] = []
4278        nodes_list: List[List[SchedulerNode]] = []
4279        assert isinstance(node, OuterLoopFusedSchedulerNode)
4280
4281        def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode):
4282            """
4283            Codegen code with fused outer loop and local Buffer.
4284            """
4285            assert isinstance(node, OuterLoopFusedSchedulerNode)
4286            cpp_kernel_proxy_list.clear()
4287            nodes_list.clear()
4288
4289            def get_call_ranges(node: BaseSchedulerNode):
4290                assert isinstance(node, (SchedulerNode, FusedSchedulerNode))
4291                nodes: List[SchedulerNode] = node.get_nodes()  # type: ignore[assignment]
4292                _, (group, reduction_group) = max(
4293                    nodes, key=lambda x: int(x.is_reduction())
4294                ).group
4295                call_ranges = tuple(group) + tuple(reduction_group)
4296                return call_ranges
4297
4298            local_buffers: List[ir.Buffer] = []
4299            # Map local buffer name to a list of global buffers
4300            local_to_global_buffers: Dict[str, List[ir.Buffer]] = {}
4301            if all(
4302                len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1
4303                for _node in node.get_outer_nodes()
4304            ):
4305                # Ref to the typical case of local buffer
4306                # in https://github.com/pytorch/pytorch/blob/
4307                # 1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159
4308                # where the buffer is with size of last dim and contiguous.
4309                # Only support this typical case at first.
4310                visited_scheduler_nodes: Set[str] = set()
4311                for scheduler_node in node.get_nodes():
4312                    # all users inside same OuterLoopFusedSchedulerNode
4313                    assert isinstance(scheduler_node, SchedulerNode)
4314                    visited_scheduler_nodes.add(scheduler_node.get_name())
4315                    if (
4316                        scheduler_node.is_reduction()
4317                        or len(scheduler_node.get_outputs()) != 1
4318                    ):
4319                        continue
4320
4321                    scheduler_buffer = scheduler_node.get_outputs()[0]
4322                    if all(
4323                        user.node in node.get_nodes() for user in scheduler_buffer.users
4324                    ):
4325                        global_buffer = scheduler_buffer.node
4326                        assert isinstance(global_buffer, ir.ComputedBuffer)
4327                        global_buffer_layout = global_buffer.get_layout()
4328                        size_offset = node.outer_loop_fusion_depth - len(
4329                            get_call_ranges(scheduler_node)
4330                        )
4331
4332                        def is_all_write_read_contiguous():
4333                            contiguous_index_expr = 0
4334                            stride = 1
4335                            for var, range in reversed(
4336                                scheduler_node._body.var_ranges.items()
4337                            ):
4338                                contiguous_index_expr += stride * var
4339                                stride *= range
4340                            write_index_expr = scheduler_node._body.get_write_expr(
4341                                scheduler_buffer.get_name()
4342                            )
4343
4344                            def is_contiguous_index(x):
4345                                return x == contiguous_index_expr
4346
4347                            return is_contiguous_index(write_index_expr) and all(
4348                                isinstance(user.node, SchedulerNode)
4349                                and is_contiguous_index(
4350                                    user.node._body.get_read_expr(
4351                                        scheduler_buffer.get_name()
4352                                    ),
4353                                )
4354                                for user in scheduler_buffer.users
4355                            )
4356
4357                        if not (
4358                            global_buffer_layout.is_contiguous()
4359                            and is_all_write_read_contiguous()
4360                        ):
4361                            continue
4362                        # Local Buffer is a view of global buffer
4363                        local_buffer_layout = ir.FixedLayout(
4364                            global_buffer_layout.device,
4365                            global_buffer_layout.dtype,
4366                            global_buffer_layout.size[size_offset:],
4367                            global_buffer_layout.stride[size_offset:],
4368                        )
4369
4370                        def try_share_local_buffer(local_buffer_layout, local_buffers):
4371                            for local_buf in local_buffers:
4372                                if local_buffer_layout == local_buf.layout and all(
4373                                    all(
4374                                        user.node.get_name() in visited_scheduler_nodes
4375                                        for user in V.graph.scheduler.name_to_buf[
4376                                            global_buffer.name
4377                                        ].users
4378                                    )
4379                                    for global_buffer in local_to_global_buffers[
4380                                        local_buf.name
4381                                    ]
4382                                    if global_buffer.name is not None
4383                                ):
4384                                    return local_buf
4385                            return None
4386
4387                        local_buf_prefix = "local_buffer_data"
4388                        # Share existing local buffer
4389                        local_buffer_used = try_share_local_buffer(
4390                            local_buffer_layout, local_buffers
4391                        )
4392                        if not local_buffer_used:
4393                            # Create new local buffer
4394                            local_buffer_used = ir.Buffer(
4395                                f"{local_buf_prefix}_{len(local_buffers)}",
4396                                local_buffer_layout,
4397                            )
4398                            local_buffers.append(local_buffer_used)
4399                            local_to_global_buffers[local_buffer_used.name] = []
4400                        local_to_global_buffers[local_buffer_used.name].append(
4401                            global_buffer,
4402                        )
4403
4404            with LocalBufferContext(kernel_group.args) as scope:
4405                if len(local_buffers) > 0:
4406                    for local_buffer in local_buffers:
4407                        assert local_buffer.name is not None
4408                        scope.add_local_buffer(
4409                            local_buffer, local_to_global_buffers[local_buffer.name]
4410                        )
4411                for _node in node.get_outer_nodes():
4412                    assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
4413                    cpp_kernel_proxy = CppKernelProxy(kernel_group)
4414                    cpp_kernel_proxy.codegen_nodes(_node.get_nodes())  # type: ignore[arg-type]
4415                    cpp_kernel_proxy_list.append(cpp_kernel_proxy)
4416                    nodes_list.append(_node.get_nodes())  # type: ignore[arg-type]
4417
4418                if not node.check_outer_fusion_loop_level_attr(
4419                    cpp_kernel_proxy_list, node.outer_loop_fusion_depth
4420                ):
4421                    return False
4422                metrics.cpp_outer_loop_fused_inner_counts.append(
4423                    metrics.CppOuterLoopFusedCount(
4424                        len(cpp_kernel_proxy_list),
4425                        local_buffer_number=len(scope.local_buffers),
4426                    )
4427                )
4428                outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels(
4429                    cpp_kernel_proxy_list,
4430                )
4431                kernel_group.finalize_kernel(
4432                    outer_fusion_cpp_kernel_proxy,
4433                    [_node for _nodes in nodes_list for _node in _nodes],
4434                )
4435
4436            return True
4437
4438        if not try_outer_loop_fusion_with_local_buf(node):
4439            # Reset generated_cpp_vec_kernel_count to codegen again
4440            metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count
4441            cpp_kernel_proxy_list.clear()
4442            nodes_list.clear()
4443            # Similar as comment in
4444            # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272
4445            # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args.
4446            with torch._inductor.config.patch(inplace_buffers=False):
4447                for _node in node.get_outer_nodes():
4448                    assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
4449                    _nodes: List[SchedulerNode] = _node.get_nodes()  # type: ignore[assignment]
4450                    cpp_kernel_proxy = CppKernelProxy(kernel_group)
4451                    cpp_kernel_proxy.codegen_nodes(_nodes)
4452                    kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes)
4453
4454    def codegen_node(
4455        self,
4456        node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode],
4457    ):
4458        """
4459        Turn an set of pre-fused nodes into a C++ kernel.
4460        """
4461        kernel_group = self.kernel_group
4462
4463        if isinstance(node, OuterLoopFusedSchedulerNode):
4464            self.codegen_outer_loop_node(node)
4465        else:
4466            nodes: List[SchedulerNode] = node.get_nodes()  # type: ignore[assignment]
4467            nodes = self.try_loop_split(nodes)
4468            cpp_kernel_proxy = CppKernelProxy(kernel_group)
4469            cpp_kernel_proxy.codegen_nodes(nodes)
4470            kernel_group.finalize_kernel(cpp_kernel_proxy, nodes)
4471
4472        args_num = self._get_scheduled_num_args()
4473        if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM:
4474            self._set_flush_status(True)
4475
4476    def is_cpp_template(self, node: BaseSchedulerNode) -> bool:
4477        return isinstance(node, SchedulerNode) and isinstance(
4478            node.node, ir.CppTemplateBuffer
4479        )
4480
4481    def codegen_template(
4482        self,
4483        template_node: BaseSchedulerNode,
4484        epilogue_nodes: Sequence[BaseSchedulerNode],
4485    ):
4486        """
4487        Codegen a CPP template, possibly with fused epilogues
4488        """
4489        counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
4490        assert self.is_cpp_template(
4491            template_node
4492        ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
4493        template_node = cast(SchedulerNode, template_node)
4494        _, (_, rnumel) = template_node.group
4495        assert rnumel == ()
4496        ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node)
4497        epilogue_ir_nodes: List[Optional[ir.Operation]] = [
4498            n.node for n in epilogue_nodes
4499        ]
4500        assert all(
4501            isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
4502        ), "Epilogue nodes must all be instances of ir.ComputedBuffer"
4503
4504        def template_buffer_has_other_users(
4505            template_buffer, outputs_by_name, epilogue_nodes
4506        ):
4507            assert template_buffer.get_name() in outputs_by_name
4508            users = outputs_by_name[template_buffer.get_name()].users
4509            return not all(
4510                isinstance(user.node, BaseSchedulerNode)
4511                and user.node.node in epilogue_nodes
4512                for user in users
4513            )
4514
4515        flag_template_buffer_has_other_users = template_buffer_has_other_users(
4516            ctb, template_node.outputs_by_name, epilogue_ir_nodes
4517        )
4518        kernel, render = ctb.make_kernel_render(
4519            ctb,
4520            flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
4521            epilogue_nodes=epilogue_ir_nodes,
4522        )
4523        with kernel:
4524            for node in [template_node, *epilogue_nodes]:
4525                node.mark_run()  # type: ignore[attr-defined]
4526            src_code = render()
4527
4528        with V.set_kernel_handler(kernel):
4529            node_schedule = [template_node, *epilogue_nodes]
4530            kernel_name = self.define_kernel(src_code, node_schedule, kernel.args)
4531        kernel.call_kernel(kernel_name, ctb)
4532        V.graph.removed_buffers |= kernel.removed_buffers
4533        self.scheduler.free_buffers()
4534
4535    def _get_scheduled_num_args(self):
4536        return self.kernel_group.get_num_args()
4537
4538    def ready_to_flush(self):
4539        return self._ready_to_flush
4540
4541    def codegen_sync(self):
4542        pass
4543
4544    def define_kernel(self, src_code, nodes, kernel_args=None):
4545        wrapper = V.graph.wrapper_code
4546        fused_name = (
4547            get_fused_kernel_name(nodes, config.cpp.descriptive_names)
4548            if config.cpp.descriptive_names
4549            else ""
4550        )
4551        kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
4552        kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
4553        src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name)
4554        src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
4555        # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
4556        # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
4557        src_code = src_code.replace("#pragma CMT", "//")
4558
4559        compile_wrapper = IndentedBuffer()
4560        args = self.kernel_group.args if kernel_args is None else kernel_args
4561        _, _, arg_types = args.cpp_argdefs()
4562        if not V.graph.cpp_wrapper:
4563            compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''")
4564        compile_wrapper.splice(src_code, strip=True)
4565        if not V.graph.cpp_wrapper:
4566            compile_wrapper.writeline("''')")
4567        wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), cuda=False)
4568        return kernel_name
4569
4570    def flush(self):
4571        src_code = self.kernel_group.codegen_group()
4572        if src_code:
4573            kernel_name = self.define_kernel(
4574                src_code, self.kernel_group.scheduled_nodes
4575            )
4576            self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name)
4577        self.reset_kernel_group()
4578        self._set_flush_status(False)
4579
4580
4581class KernelGroup:
4582    def __init__(self):
4583        super().__init__()
4584        self.args = KernelArgs()
4585        self.loops_code = BracesBuffer()
4586        self.ws = WorkSharing(self.loops_code)
4587        self.stack = contextlib.ExitStack()
4588        self.stack.enter_context(self.ws)
4589        self.scheduled_nodes = []
4590
4591    def new_kernel(self, cls, *args):
4592        return cls(self.args, parallel_num_threads(), *args)
4593
4594    def finalize_kernel(self, new_kernel, nodes):
4595        self.scheduled_nodes += nodes
4596        code = self.loops_code
4597        ws = self.ws
4598        new_kernel.codegen_loops(code, ws)
4599
4600    def get_num_args(self):
4601        arg_defs, call_args, arg_types = self.args.cpp_argdefs()
4602        args_num = len(arg_defs)
4603        return args_num
4604
4605    def codegen_group(self, name=None) -> str:
4606        self.stack.close()
4607        if not self.scheduled_nodes:
4608            return ""
4609        code = BracesBuffer()
4610        # 1. Include header files
4611        # TODO: support kernel profile on other platforms
4612        enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
4613            "linux",
4614            "win32",
4615        ]
4616        if enable_kernel_profile:
4617            code.writelines(["#include <ATen/record_function.h>"])
4618        code.writeline(codecache.cpp_prefix())
4619
4620        # 2. Function definition
4621        kernel_decl_name = str(Placeholder.KERNEL_NAME) if name is None else name
4622        kernel_name = str(Placeholder.DESCRIPTIVE_NAME) if name is None else name
4623        arg_defs, _, _ = self.args.cpp_argdefs()
4624        arg_defs = ",\n".ljust(25).join(arg_defs)
4625        func_export_decl = get_export_declaration()
4626        code.writeline(
4627            f'extern "C" {func_export_decl} void {kernel_decl_name}({arg_defs})'
4628        )
4629
4630        # 3. Function body
4631        with code.indent():
4632            if enable_kernel_profile:
4633                graph_id = V.graph.graph_id
4634                prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
4635                code.writelines(
4636                    [
4637                        f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
4638                    ]
4639                )
4640            for old, new in self.args.aliases():
4641                code.writeline(f"auto {old} = {new};")
4642            code.splice(self.loops_code)
4643        return code.getvalue()
4644
4645    def call_kernel(self, wrapper, kernel_name):
4646        _, call_args, arg_types = self.args.cpp_argdefs()
4647        wrapper.generate_kernel_call(
4648            kernel_name, call_args, cuda=False, arg_types=arg_types
4649        )
4650
4651
4652class CppWrapperKernelGroup(KernelGroup):
4653    def __init__(self):
4654        super().__init__()
4655        self.args = CppWrapperKernelArgs()
4656
4657
4658class WorkSharing:
4659    def __init__(self, code):
4660        self.code = code
4661        self.in_parallel = False
4662        self.num_threads = None
4663        self.stack = contextlib.ExitStack()
4664
4665    def parallel(self, threads):
4666        if self.in_parallel and threads != self.num_threads:
4667            # wrong number of threads
4668            self.close()
4669        if not self.in_parallel:
4670            self.num_threads = threads
4671            self.in_parallel = True
4672            if config.cpp.dynamic_threads:
4673                self.code.writeline("#pragma omp parallel")
4674            else:
4675                self.code.writeline(f"#pragma omp parallel num_threads({threads})")
4676            self.stack.enter_context(self.code.indent())
4677            self.code.writeline(
4678                "int tid = omp_get_thread_num();",
4679            )
4680
4681    def single(self):
4682        if self.in_parallel:
4683            self.code.writeline("#pragma omp single")
4684        return self.in_parallel
4685
4686    def close(self):
4687        self.stack.close()
4688        self.in_parallel = False
4689
4690    def __enter__(self):
4691        self.stack.__enter__()
4692        return self
4693
4694    def __exit__(self, exc_type, exc_val, exc_tb):
4695        self.stack.__exit__(exc_type, exc_val, exc_tb)
4696
4697
4698@dataclasses.dataclass
4699class LoopLevel:
4700    var: Optional[sympy.Expr] = None
4701    size: Optional[sympy.Expr] = None
4702    offset: sympy.Expr = sympy.Integer(0)
4703    steps: sympy.Expr = sympy.Integer(1)
4704    parallel: int = 0
4705    simd_omp: bool = False
4706    simd_vec: bool = False
4707    collapsed: bool = False
4708    is_reduction: bool = False
4709    parent: Optional["LoopLevel"] = None
4710    # the next inner level of the loop, empty if it is inner-most
4711    # contains >1 LoopLevel if the inner level of loop is split
4712    inner: List["LoopLevel"] = dataclasses.field(default_factory=list)
4713    # kernel assigned to this loop level, only valid when it is a leaf
4714    kernel: Optional[CppKernel] = None
4715
4716    def __post_init__(self):
4717        # Regarding the C++/OpenMP backend, `cpu_vec_isa.pick_vec_isa()` to check
4718        # vectorization ISA is a time-consuming and one-shot operation. It leads
4719        # to taking a longer time to import `codegen.cpp` package because the
4720        # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while
4721        # the decorator will invoke `cpu_vec_isa.pick_vec_isa()` to initialize the
4722        # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation
4723        # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to
4724        # `__post_init__`
4725        picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa()
4726        self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0
4727
4728    def get_kernels(self) -> List[CppKernel]:
4729        """Get all kernel objects under this loop level"""
4730        if self.kernel:
4731            return [self.kernel]
4732        kernels = []
4733        for loop in self.inner:
4734            kernels += loop.get_kernels()
4735        return kernels
4736
4737    def get_root(self):
4738        """Get all kernel objects under this loop level"""
4739        root = self
4740        while root.parent:
4741            root = root.parent
4742        return root
4743
4744    def set_kernel(self, kernel: CppKernel):
4745        """
4746        Set the kernel under this loop level. No split is allowed under
4747        this loop level.
4748        """
4749        if not self.inner:
4750            self.kernel = kernel
4751            loop: Optional[LoopLevel] = self
4752            assert loop is not None
4753            return
4754        assert len(self.inner) == 1
4755        self.inner[0].set_kernel(kernel)
4756
4757    def get_loops_at(self, depth) -> List["LoopLevel"]:
4758        if depth == 0:
4759            return [self]
4760        else:
4761            loops = []
4762            for loop in self.inner:
4763                loops += loop.get_loops_at(depth - 1)
4764            return loops
4765
4766    def split_with_tiling(self, depth, factor):
4767        def clone_inner():
4768            inner = []
4769            if self.inner:
4770                for loop in self.inner:
4771                    inner.append(loop.clone())
4772            return inner
4773
4774        def do_split_with_tiling():
4775            sympy_factor = sympy.Integer(factor)
4776
4777            offset = FloorDiv(self.size, sympy_factor) * sympy_factor
4778            main_loop = LoopLevel(self.var, offset)
4779            main_loop.steps = sympy_factor
4780            main_loop.parallel = self.parallel
4781            main_loop.collapsed = False
4782            main_loop.is_reduction = self.is_reduction
4783            main_loop.inner = clone_inner()
4784            if main_loop.inner:
4785                for loop in main_loop.inner:
4786                    loop.parent = main_loop
4787
4788            tail_loop = LoopLevel(self.var, self.size)
4789            tail_loop.offset = offset
4790            tail_loop.parallel = self.parallel
4791            tail_loop.collapsed = False
4792            tail_loop.is_reduction = self.is_reduction
4793            tail_loop.inner = clone_inner()
4794            if tail_loop.inner:
4795                for loop in tail_loop.inner:
4796                    loop.parent = tail_loop
4797
4798            return main_loop, tail_loop
4799
4800        if depth == 0:
4801            main_loop, tail_loop = do_split_with_tiling()
4802            parent = self.parent
4803            if parent:
4804                parent.inner = [main_loop, tail_loop]
4805                main_loop.parent = parent
4806                tail_loop.parent = parent
4807            return main_loop, tail_loop
4808        else:
4809            assert len(self.inner) == 1
4810            return self.inner[0].split_with_tiling(depth - 1, factor)
4811
4812    def clone(self):
4813        loop = copy(self)
4814        loop.inner = []
4815        if self.inner:
4816            for inner_loop in self.inner:
4817                inner_loop_clone = inner_loop.clone()
4818                inner_loop_clone.parent = loop
4819                loop.inner.append(inner_loop_clone)
4820        loop.kernel = deepcopy(self.kernel)
4821        return loop
4822
4823    def lines(self):
4824        offset_expr = cexpr_index(self.offset)
4825        size_expr = cexpr_index(self.size)
4826        if config.cpp.no_redundant_loops and offset_expr == size_expr:
4827            return None
4828        simd = (
4829            f"simd simdlen({self.simd_nelements}) "
4830            if self.simd_omp and self.simd_nelements > 1
4831            else ""
4832        )
4833        if self.parallel:
4834            # TODO(jansel): look into chunk size and other schedules
4835            line1 = "#pragma omp for"
4836            if self.parallel > 1:
4837                line1 += f" collapse({self.parallel})"
4838            if self.simd_omp:
4839                line1 = line1.replace(" for ", f" for {simd}")
4840        elif self.simd_vec:
4841            line1 = ""
4842        elif self.simd_omp:
4843            line1 = f"#pragma omp {simd}"
4844        elif not self.is_reduction and cpp_builder.is_gcc():
4845            line1 = "#pragma GCC ivdep"
4846        else:
4847            line1 = ""
4848        offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}"
4849        size_str = f"{self.var}<{size_expr}"
4850        if self.steps.is_number:
4851            steps_str = f"{self.var}+={cexpr_index(self.steps)}"
4852        else:
4853            # If the step size is 0, change it to 1 because a step size of 0
4854            # will cause floating point exception (core dump) during parallelization.
4855            steps_str = (
4856                f"{self.var}+=({cexpr_index(self.steps)} == 0 ? "
4857                f"1 : {cexpr_index(self.steps)})"
4858            )
4859        line2 = f"for({offset_str}; {size_str}; {steps_str})"
4860        if self.collapsed or not line1:
4861            return [line2]
4862        return [line1, line2]
4863
4864
4865@dataclasses.dataclass
4866class LoopNestWithSplit:
4867    """
4868    A loop-nest like structure but with some loop level split along
4869    the loop range into the main tiling loop and the tail. It is built
4870    with the `build` method as a loop nest and then split with
4871    `split_with_tiling` at some depth.
4872
4873    A typical case is for vectorization where we typically split at the inner-most
4874    loop level. A more complicated case is 2D tiling where we split at
4875    both inner-most and outer levels.
4876    """
4877
4878    root: Optional[List[LoopLevel]] = None
4879    kernel: Optional[CppKernel] = None
4880
4881    @staticmethod
4882    def build(kernel: CppKernel):
4883        """Build a LoopNest with the given `kernel` as the leaf"""
4884        itervars = kernel.itervars
4885        ranges = kernel.ranges
4886        reduction_depth = kernel.reduction_depth
4887        assert reduction_depth is not None
4888
4889        root: List[LoopLevel] = []
4890        levels: List[LoopLevel] = root
4891        loop: Optional[LoopLevel] = None
4892        for loop_idx, (var, size) in enumerate(zip(itervars, ranges)):
4893            loop = LoopLevel(var, size, parent=loop)
4894            if loop_idx >= reduction_depth:
4895                loop.is_reduction = kernel.is_reduction
4896            levels.append(loop)
4897            levels = loop.inner
4898        loop_nest = LoopNestWithSplit(root)
4899        if loop:
4900            loop.kernel = kernel
4901        else:
4902            loop_nest.kernel = kernel
4903        return loop_nest
4904
4905    def __bool__(self):
4906        return bool(self.root)
4907
4908    def get_loops_at(self, depth) -> List[LoopLevel]:
4909        """Get all the loop levels at the given `depth` (most outer loop has depth 0)"""
4910        loops: List[LoopLevel] = []
4911        assert self.root is not None
4912        for loop in self.root:
4913            loops += loop.get_loops_at(depth)
4914        return loops
4915
4916    @cache_on_self
4917    def max_parallel_depth(self):
4918        """
4919        Maximal allowed depth for parallelism:
4920        1) Levels without splitting and
4921        2) All reduction or non-reduction levels
4922        When the loop is split at the top level, the max depth is 1.
4923        """
4924        max_depth = 0
4925        assert self.root is not None
4926        loops = self.root
4927        if len(loops) > 1:
4928            return 1
4929        is_reduction = loops[0].is_reduction if loops else False
4930        while len(loops) == 1 and loops[0].is_reduction == is_reduction:
4931            max_depth += 1
4932            loops = loops[0].inner
4933        return max_depth
4934
4935    def is_reduction_only(self):
4936        """
4937        Whether all the loops are for reduction. Reduction loops
4938        are always the inner most ones.
4939        """
4940        return (
4941            self.root is not None and len(self.root) > 0 and self.root[0].is_reduction
4942        )
4943
4944    def mark_parallel(self, par_depth):
4945        assert (
4946            par_depth <= self.max_parallel_depth()
4947        ), "Parallel depth cannot exceed the maximal allowed parallel depth"
4948        assert self.root is not None
4949        loops = self.root
4950        for loop in loops:
4951            loop.parallel = par_depth
4952        for i in range(1, par_depth):
4953            loops = loops[0].inner
4954            loops[0].collapsed = True
4955
4956    def split_with_tiling(self, depth, factor):
4957        """
4958        Split the loop into main and tail loops at given `depth` so that the range
4959        of the main loop has range `floor_div(range, factor) * factor` and
4960        the tail loop handles the remainder. The main loop is tiled
4961        according to the `factor`.
4962        """
4963        loops = self.get_loops_at(depth)
4964        assert len(loops) == 1
4965        split_loops = loops[0].split_with_tiling(0, factor)
4966        if depth == 0:
4967            self.root = split_loops
4968        return split_loops
4969
4970    def get_kernels(self) -> List[CppKernel]:
4971        """Get all kernel objects under this loop nest"""
4972        if self.kernel:
4973            return [self.kernel]
4974        kernels: List[CppKernel] = []
4975        assert self.root is not None
4976        for loop in self.root:
4977            kernels += loop.get_kernels()
4978        return kernels
4979