xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/halide.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import dataclasses
5import functools
6import itertools
7import logging
8import re
9from collections import defaultdict
10from math import inf
11from typing import (
12    Any,
13    Callable,
14    Dict,
15    List,
16    Optional,
17    Sequence,
18    Tuple,
19    TYPE_CHECKING,
20    Union,
21)
22
23import sympy
24
25import torch
26import torch._logging
27
28from ..._prims_common import is_integer_dtype
29from ...utils._sympy.functions import FloorDiv, ModularIndexing
30from ...utils._sympy.symbol import symbol_is_type, SymT
31from ...utils._sympy.value_ranges import ValueRanges
32from .. import config, ir
33from ..codecache import HalideCodeCache
34from ..ir import get_reduction_combine_fn
35from ..metrics import is_metric_table_enabled, log_kernel_metadata
36from ..ops_handler import AddParenHandler, MockHandler
37from ..runtime.hints import HalideInputSpec, HalideMeta, ReductionHint
38from ..utils import (
39    get_bounds_index_expr,
40    get_kernel_metadata,
41    parallel_num_threads,
42    sympy_index_symbol,
43    sympy_subs,
44)
45from ..virtualized import _ops as ops, OpsHandler, V
46from .common import (
47    BackendFeature,
48    CSEVariable,
49    DeferredLine,
50    IndentedBuffer,
51    OpOverrides,
52    PythonPrinter,
53    SizeArg,
54    TensorArg,
55)
56from .cpp import DTYPE_TO_CPP
57from .cpp_utils import cexpr
58from .simd import constant_repr, SIMDKernel, SIMDScheduling
59
60
61if TYPE_CHECKING:
62    from torch.utils._ordered_set import OrderedSet
63
64    from ..ops_handler import ReductionType, StoreMode
65
66log = logging.getLogger(__name__)
67
68
69def halide_constant(val):
70    if isinstance(val, int) and not (-2147483648 <= val <= 2147483647):
71        info = torch.iinfo(torch.int64)
72        if val == info.min:
73            return "hl.Int(64).min()"
74        if val == info.max:
75            return "hl.Int(64).max()"
76        return f"hl.i64({val!r})"
77    if isinstance(val, float):
78        return f"hl.f64({constant_repr(val)})"
79    return repr(val)
80
81
82class Unsupported(RuntimeError):
83    def __init__(self, thing) -> None:
84        super().__init__(f"halide backend does not support: {thing}")
85
86
87class HalidePrinter(PythonPrinter):
88    @staticmethod
89    def cast_index(expr):
90        return f"hl.cast({V.kernel.index_dtype}, {expr})"
91
92    @staticmethod
93    def cast_float(expr):
94        return f"hl.cast(hl.Float(32), {expr})"
95
96    def _print_Float(self, expr):
97        return f"hl.f32({expr})"
98
99    def _print_ToFloat(self, expr):
100        assert len(expr.args) == 1
101        return f"hl.f32({self._print(expr.args[0])})"
102
103    def _print_floor(self, expr):
104        assert len(expr.args) == 1
105        return self.cast_index(f"hl.floor({self._print(expr.args[0])})")
106
107    def _print_Trunc(self, expr):
108        assert len(expr.args) == 1
109        return self.cast_index(f"hl.trunc({self._print(expr.args[0])})")
110
111    _print_TruncToInt = _print_Trunc
112
113    def _print_ceiling(self, expr):
114        assert len(expr.args) == 1
115        return self.cast_index(f"hl.ceil({self._print(expr.args[0])})")
116
117    def _helper_sqrt(self, expr):
118        return f"hl.sqrt({self.cast_float(self._print(expr))})"
119
120    def _print_Where(self, expr):
121        c = self.doprint(expr.args[0])
122        p = self.doprint(expr.args[1])
123        q = self.doprint(expr.args[2])
124        return f"hl.select({c}, {p}, {q})"
125
126    def _print_Min(self, expr):
127        if len(expr.args) == 1:
128            return self._print(expr.args[0])
129
130        mid = len(expr.args) // 2
131        a = self._print(sympy.Min(*expr.args[:mid]))
132        b = self._print(sympy.Min(*expr.args[mid:]))
133        return f"hl.min({a}, {b})"
134
135    def _print_Max(self, expr):
136        if len(expr.args) == 1:
137            return self._print(expr.args[0])
138
139        mid = len(expr.args) // 2
140        a = self._print(sympy.Max(*expr.args[:mid]))
141        b = self._print(sympy.Max(*expr.args[mid:]))
142
143        return f"hl.max({a}, {b})"
144
145    def _print_Abs(self, expr):
146        assert len(expr.args) == 1
147        return self.cast_index(f"hl.abs({self._print(expr.args[0])})")
148
149    def _print_OpaqueUnaryFn_cos(self, expr):
150        assert len(expr.args) == 1
151        return f"hl.cos(({self._print(expr.args[0])})"
152
153    def _print_OpaqueUnaryFn_cosh(self, expr):
154        assert len(expr.args) == 1
155        return f"hl.cosh(({self._print(expr.args[0])})"
156
157    def _print_OpaqueUnaryFn_acos(self, expr):
158        assert len(expr.args) == 1
159        return f"hl.acos(({self._print(expr.args[0])})"
160
161    def _print_OpaqueUnaryFn_sin(self, expr):
162        assert len(expr.args) == 1
163        return f"hl.sin(({self._print(expr.args[0])})"
164
165    def _print_OpaqueUnaryFn_sinh(self, expr):
166        assert len(expr.args) == 1
167        return f"hl.sinh(({self._print(expr.args[0])})"
168
169    def _print_OpaqueUnaryFn_asin(self, expr):
170        assert len(expr.args) == 1
171        return f"hl.asin(({self._print(expr.args[0])})"
172
173    def _print_OpaqueUnaryFn_tan(self, expr):
174        assert len(expr.args) == 1
175        return f"hl.tan(({self._print(expr.args[0])})"
176
177    def _print_OpaqueUnaryFn_tanh(self, expr):
178        assert len(expr.args) == 1
179        return f"hl.tanh(({self._print(expr.args[0])})"
180
181    def _print_OpaqueUnaryFn_atan(self, expr):
182        assert len(expr.args) == 1
183        return f"hl.atan(({self._print(expr.args[0])})"
184
185    def _print_FloorDiv(self, expr):
186        if expr.is_integer:
187            return super()._print_FloorDiv(expr)
188
189        x, div = expr.args
190        x = self.cast_float(self.paren(self.doprint(x)))
191        div = self.cast_float(self.paren(self.doprint(div)))
192        return self.cast_index(f"hl.floor({x} / {div})")
193
194    def _print_Round(self, expr):
195        assert len(expr.args) == 1
196        return self.cast_index(f"hl.round({self._print(expr.args[0])})")
197
198    _print_RoundToInt = _print_Round
199
200    def _print_IntTrueDiv(self, expr):
201        a, b = expr.args
202        # force a cast to float
203        return f"({a}) / ({b}+hl.f32(0))"
204
205    def _print_RoundDecimal(self, expr):
206        val, n = expr.args
207        val = self._print(val)
208        n = int(n)
209        return f"hl.f32({10.**(-n)!r})*hl.round(({val})*hl.f32({10.**n!r}))"
210
211
212texpr = HalidePrinter().doprint
213pexpr = PythonPrinter().doprint
214
215
216_halide_type = {
217    torch.bool: "hl.Bool()",
218    torch.bfloat16: "hl.BFloat(16)",
219    torch.float16: "hl.Float(16)",
220    torch.float32: "hl.Float(32)",
221    torch.float64: "hl.Float(64)",
222    torch.int8: "hl.Int(8)",
223    torch.int16: "hl.Int(16)",
224    torch.int32: "hl.Int(32)",
225    torch.int64: "hl.Int(64)",
226    torch.uint8: "hl.UInt(8)",
227    torch.uint16: "hl.UInt(16)",
228    torch.uint32: "hl.UInt(32)",
229    torch.uint64: "hl.UInt(64)",
230}
231
232
233def halide_type(dtype):
234    return _halide_type[dtype]
235
236
237def halide_acc_type(dtype):
238    if is_integer_dtype(dtype) and dtype.is_signed and dtype != torch.int64:
239        dtype = torch.int32
240    if dtype in (torch.float16, torch.bfloat16):
241        dtype = torch.float32
242    return halide_type(dtype)
243
244
245class HalideOverrides(OpOverrides):
246    @staticmethod
247    def to_dtype(
248        x,
249        dtype: torch.dtype,
250        src_dtype: Optional[torch.dtype] = None,
251        use_compute_types=True,
252    ):
253        if dtype == torch.bool:
254            return f"({x} != 0)"
255        return f"hl.cast({halide_type(dtype)}, {x})"
256
257    @staticmethod
258    def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype):
259        if src_dtype in (torch.float16, torch.bfloat16):
260            x = f"hl.cast({halide_type(src_dtype)}, {x})"  # body compute is upcast to fp32
261        line = f"hl.reinterpret({halide_type(dtype)}, {x})"
262        if dtype in (torch.float16, torch.bfloat16):
263            line = f"hl.cast(hl.Float(32), {line})"
264        return line
265
266    @classmethod
267    def constant(cls, value, dtype):
268        return cls.to_dtype(halide_constant(value), dtype)
269
270    @staticmethod
271    def abs(x):
272        return f"hl.abs({x})"
273
274    @staticmethod
275    def exp(x):
276        if not hasattr(x, "name"):
277            return f"hl.exp({x})"
278        return f"hl.fast_exp(hl.cast(hl.Float(32), {x})) if {x.name}.type().bits() <= 32 else hl.exp({x})"
279
280    @staticmethod
281    def libdevice_exp(x):
282        return f"hl.exp({x})"  # higher precision that ops.exp
283
284    @staticmethod
285    def sqrt(x):
286        return f"hl.sqrt({x})"
287
288    @staticmethod
289    def minimum(a, b):
290        # return f"hl.min({a}, {b})"  <== handles nan wrong
291        if not hasattr(a, "name"):
292            return f"hl.min({a}, {b})"
293        b = f"hl.cast({a.name}.type(), {b})"
294        return f"hl.select(({a}<{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.min({a}, {b})"
295
296    @staticmethod
297    def maximum(a, b):
298        # return f"hl.max({a}, {b})"  <== handles nan wrong
299        if not hasattr(a, "name"):
300            return f"hl.max({a}, {b})"
301        b = f"hl.cast({a.name}.type(), {b})"
302        return f"hl.select(({a}>{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.max({a}, {b})"
303
304    @staticmethod
305    def where(a, b, c):
306        if hasattr(b, "name"):
307            c = f"hl.cast({b.name}.type(), {c})"
308        return f"hl.select({a}, {b}, {c})"
309
310    @staticmethod
311    def cos(x):
312        return f"hl.cos({x})"
313
314    @staticmethod
315    def sin(x):
316        return f"hl.sin({x})"
317
318    @staticmethod
319    def lgamma(x):
320        raise Unsupported("lgamma")
321
322    @staticmethod
323    def erf(x):
324        return f"hl.erf({x})"
325
326    @staticmethod
327    def cosh(x):
328        return f"hl.cosh({x})"
329
330    @staticmethod
331    def sinh(x):
332        return f"hl.sinh({x})"
333
334    @staticmethod
335    def acos(x):
336        return f"hl.acos({x})"
337
338    @staticmethod
339    def acosh(x):
340        return f"hl.acosh({x})"
341
342    @staticmethod
343    def asin(x):
344        return f"hl.asin({x})"
345
346    @staticmethod
347    def asinh(x):
348        return f"hl.asinh({x})"
349
350    @staticmethod
351    def atan2(x, y):
352        return f"hl.atan2({x}, {y})"
353
354    @staticmethod
355    def atan(x):
356        return f"hl.atan({x})"
357
358    @staticmethod
359    def atanh(x):
360        return f"hl.atanh({x})"
361
362    @staticmethod
363    def copysign(x, y):
364        raise Unsupported("copysign")
365
366    @staticmethod
367    def erfinv(x):
368        raise Unsupported("erfinv")
369
370    @staticmethod
371    def hypot(x, y):
372        return f"hl.hypot({x}, {y})"
373
374    @staticmethod
375    def nextafter(x, y):
376        raise Unsupported("nextafter")
377
378    @staticmethod
379    def logical_and(a, b):
380        return f"{a} & {b}"
381
382    @staticmethod
383    def logical_not(a):
384        return f"{a} == 0"
385
386    @staticmethod
387    def logical_or(a, b):
388        return f"{a} | {b}"
389
390    @staticmethod
391    def logical_xor(a, b):
392        return f"({a} ^ {b})"
393
394    @staticmethod
395    def bitwise_and(a, b):
396        return f"{a} & {b}"
397
398    @staticmethod
399    def bitwise_not(a):
400        return f"~{a}"
401
402    @staticmethod
403    def bitwise_or(a, b):
404        return f"{a} | {b}"
405
406    @staticmethod
407    def bitwise_xor(a, b):
408        return f"{a} ^ {b}"
409
410    @staticmethod
411    def bitwise_left_shift(a, b):
412        return f"{a} << {b}"
413
414    @staticmethod
415    def bitwise_right_shift(a, b):
416        return f"{a} >> {b}"
417
418    @staticmethod
419    def rand(seed, offset):
420        return f"halide_helpers.rand({seed}, {offset})"
421
422    @staticmethod
423    def randn(seed, offset):
424        return f"halide_helpers.randn({seed}, {offset})"
425
426    @staticmethod
427    def randint64(seed, offset, low, high):
428        return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})"
429
430    @staticmethod
431    def load_seed(name, offset):
432        return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}"
433
434    @staticmethod
435    def rsqrt(x):
436        # return f"hl.fast_inverse_sqrt({x})"  <== accuracy issues
437        return f"1./hl.sqrt({x})"
438
439    @staticmethod
440    def tan(x):
441        return f"hl.tan({x})"
442
443    @staticmethod
444    def tanh(x):
445        return f"hl.tanh({x})"
446
447    @staticmethod
448    def signbit(x):
449        return f"(hl.reinterpret(hl.UInt(32), hl.cast(hl.Float(32), {x})) >> 31) != 0"
450
451    @staticmethod
452    def fmod(a, b):
453        # TODO(jansel): find a better way to do this, builtin % has wrong sign
454        return f"{a} - hl.trunc({a}/{b})*{b}"
455
456    @staticmethod
457    def pow(a, b):
458        return f"hl.pow({a}, {b})"  # hl.fast_pow fails accuracy
459
460    @staticmethod
461    def log(x):
462        return f"hl.log({x})"  # hl.fast_log fails accuracy
463
464    @staticmethod
465    def isinf(x):
466        # workaround https://github.com/halide/Halide/issues/8309
467        return f"hl.is_inf(hl.cast(hl.Float(32), {x}))"
468
469    @staticmethod
470    def isnan(x):
471        # workaround https://github.com/halide/Halide/issues/8309
472        return f"hl.is_nan(hl.cast(hl.Float(32), {x}))"
473
474    @staticmethod
475    def round(x):
476        return f"hl.round({x})"
477
478    @staticmethod
479    def floor(x):
480        return f"hl.floor({x})"
481
482    @staticmethod
483    def int_truediv(a, b):
484        return f"({a}) / ({b} + hl.f32(0))"
485
486    @staticmethod
487    def floordiv(a, b):
488        # TODO(jansel): find a better ways to do this, the select-based trick from triton.py didn't work
489        return (
490            f"hl.floor(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})"
491        )
492
493    @classmethod
494    def sign(cls, x):
495        left = ops.to_dtype(ops.lt("0", x), torch.int8)
496        right = ops.to_dtype(ops.lt(x, "0"), torch.int8)
497        sub = ops.sub(left, right)
498        return f"hl.cast({x.name}.type(), {sub})"
499
500    @staticmethod
501    def trunc(x):
502        return f"hl.trunc({x})"
503
504    @staticmethod
505    def truncdiv(a, b):
506        # this causes crashes with floating point exception, see test_div_zero_dim_cpu
507        # return f"hl.div_round_to_zero({a}, {b})"
508        return (
509            f"hl.trunc(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})"
510        )
511
512    @staticmethod
513    def ceil(x):
514        return f"hl.ceil({x})"
515
516    @staticmethod
517    def relu(x):
518        return f"hl.max({x}, 0)"
519
520    @classmethod
521    def index_expr(cls, expr, dtype):
522        index = V.kernel.prepare_indexing(expr)
523        var = V.kernel.genfunc(
524            V.kernel.index_to_str(index),
525            V.kernel.used_dims_from_index(index),
526            bounds=get_bounds_index_expr(expr),
527        )
528        if dtype not in {torch.int32, torch.int64}:
529            return ops.to_dtype(var, dtype)
530        return var
531
532    @classmethod
533    def indirect_indexing(cls, index_var, size, check=True, wrap_neg=True):
534        # TODO(jansel): Halide only supports 32-bit indexing, we should error on overflow
535        index_var = ops.to_dtype(index_var, torch.int32)
536        index_var = ops.halide_clamp(index_var, size, check)
537        index_var.indirect_indexing_size = size
538        return sympy_index_symbol(str(index_var))
539
540    @classmethod
541    def halide_clamp(cls, value, size, check):
542        end = V.kernel.kexpr(V.kernel.rename_indexing(size) - 1)
543        if not isinstance(size, (int, sympy.Integer)):
544            end = f"hl.cast({value.name}.type(), {end})"
545        # Skip unsafe_promise_clamped to workaround: https://github.com/halide/Halide/issues/8261#issuecomment-2148835692
546        # return f"hl.unsafe_promise_clamped({value}, 0, {end})"
547        return f"hl.clamp({value}, 0, {end})"
548
549    @staticmethod
550    def masked(mask, body, other):
551        with V.kernel.mask_loads(mask, other) as new_mask:
552            result = body()
553
554        if result.bounds.is_bool:
555            other = bool(other)
556
557        # Take dtype from result to prevent accidental promotion
558        other = V.kernel.genfunc(
559            f"hl.cast({result.name}.type(), {halide_constant(other)})",
560            [],
561            bounds=ValueRanges.wrap(other),
562        )
563        # TODO(jansel): look into removing the where in the same places triton does
564        return ops.where(new_mask, result, other)
565
566
567# Use mypy to check protocol implemented correctly
568def _typecheck_HalideOverrides(h: HalideOverrides) -> OpsHandler[str]:
569    return h
570
571
572class HalideCSEVariable(CSEVariable):
573    undefined_re = re.compile(r"\b(tmp\d+)\[\?\]")
574
575    def __init__(self, name, bounds: ValueRanges[Any]) -> None:
576        super().__init__(name, bounds)
577        self.used_dims: Optional[List[sympy.Symbol]] = None
578
579    def update_on_args(self, name, args, kwargs):
580        used = set(self.used_dims or ())
581        for arg in itertools.chain(args, kwargs.values()):
582            if isinstance(arg, HalideCSEVariable):
583                assert arg.used_dims is not None, (name, arg, args)
584                used.update(arg.used_dims)
585        self.used_dims = V.kernel.sort_used_dims(used)
586
587    def index_str(self, dims):
588        if len(dims) == 0:
589            return f"{self.name}[()]"
590        # Reversed since Halide is column major
591        return f"{self.name}[{', '.join(map(str, dims))}]"
592
593    def __str__(self) -> str:
594        if self.used_dims is None:
595            # This will get recomputed and replaced in codegen_kernel()
596            return f"{self.name}[?]"
597        return self.index_str(self.used_dims)
598
599    def subs_str(self, replacements):
600        assert self.used_dims is not None and all(
601            isinstance(x, sympy.Expr) for x in self.used_dims
602        )
603        return self.index_str([replacements.get(n, n) for n in self.used_dims])
604
605
606@dataclasses.dataclass
607class DimensionInfo:
608    expr: Optional[sympy.Expr]
609    size: sympy.Expr
610    stride: sympy.Expr
611
612    def __init__(self, expr, size, stride) -> None:
613        super().__init__()
614        if V.graph.sizevars.statically_known_lt(stride, 0):
615            stride = -stride
616            expr = -expr
617        self.expr = expr
618        self.size = size
619        self.stride = stride
620
621    def index_str(self, replacements=None, zero_vars=False):
622        assert self.expr is not None
623        expr = self.expr
624        if zero_vars and expr == 0:
625            return "hl.Var()"
626        if replacements:
627            replacements = {**replacements}
628            for sym in expr.free_symbols:
629                if symbol_is_type(sym, SymT.TMP):
630                    assert isinstance(sym, sympy.Symbol)
631                    var = V.kernel.lookup_cse_var(sym.name)
632                    assert isinstance(var, HalideCSEVariable)
633                    replacements[sym] = sympy_index_symbol(var.subs_str(replacements))
634            expr = sympy_subs(expr, replacements)
635        return V.kernel.index_to_str(expr)
636
637
638def eq(left, right):
639    if V.graph.sizevars.statically_known_equals(left, right):
640        return True
641    try:
642        a = V.graph.sizevars.size_hint(left)
643        b = V.graph.sizevars.size_hint(right)
644    except TypeError:  # unbacked symints
645        return False
646    if a == b:
647        V.graph.sizevars.guard_equals(left, right)
648    return a == b
649
650
651def lt(left, right):
652    if V.graph.sizevars.statically_known_lt(left, right):
653        return True
654    try:
655        a = V.graph.sizevars.size_hint(left)
656        b = V.graph.sizevars.size_hint(right)
657    except TypeError:  # unbacked symints
658        gcd = sympy.gcd(left, right)
659        if gcd == left:
660            return left != right
661        return False
662    if a < b:
663        V.graph.sizevars.guard_lt(left, right)
664    return a < b
665
666
667class HalideKernel(SIMDKernel):
668    overrides = HalideOverrides  # type: ignore[assignment]
669    kexpr: Callable[[sympy.Expr], str] = texpr
670
671    def __init__(
672        self,
673        *groups,
674        index_dtype: str,
675        mutations: Optional[OrderedSet[str]] = None,
676        pid_cache=None,
677        reduction_hint=ReductionHint.DEFAULT,
678        override_persistent_reduction=None,
679    ) -> None:
680        super().__init__(
681            *groups,
682            index_dtype=index_dtype,
683            mutations=mutations,
684            reduction_hint=reduction_hint,
685            pid_cache=pid_cache,
686            override_persistent_reduction=override_persistent_reduction,
687        )
688        # For halide, we just write directly to the body
689        self.compute = self.body
690        self.loads = self.body
691        self.stores = self.body
692        self.indexing_code_dom = IndentedBuffer()
693        self.needs_dom_indexing = self.inside_reduction
694        self.has_reduction = self.inside_reduction
695        self.buffer_dimensions: Dict[str, List[DimensionInfo]] = {}
696        self.buffer_offsets: Dict[str, sympy.Expr] = {}
697        # {h0: size1, h1: size2, ...}
698        self.halide_vars: Dict[sympy.Symbol, sympy.Expr] = {}
699        # {x0: h0, x1: h1+10*h2, ...}
700        self.index_replacements: Dict[sympy.Expr, sympy.Expr] = {}
701        # {h1: hr1, ...}
702        self.reduction_renames: Dict[sympy.Symbol, sympy.Symbol] = {}
703        # {"i": {h0: hi0}, "o": ...}
704        self.dom_renames: Dict[str, Dict[sympy.Symbol, sympy.Symbol]] = {}
705        # {"in_ptr0": ["in_ptr0_view0"], ...}
706        self.buffer_aliases: Dict[str, List[str]] = defaultdict(list)
707        self.has_indirect_indexing = False
708
709    def create_cse_var(self, name, bounds=None):
710        self.body.writeline(f"{name} = hl.Func({name!r})")
711        return HalideCSEVariable(name, bounds)
712
713    def finalize_indexing(self, indices: Sequence[sympy.Expr]):
714        """
715        Hook called right before codegen with every index that will be
716        used in the fused kernel.
717
718        This populates self.halide_vars/index_replacements/reduction_renames which is an alternate indexing
719        scheme that avoids using divide and modulus.  Instead of xindex/yindex/rindex
720        we base indexing on a larger number of vars whose product combines to those.
721
722        This function populates self.halide_vars, self.index_replacements, and self.reduction_renames
723        """
724        assert not (
725            self.index_replacements or self.halide_vars or self.reduction_renames
726        )
727        size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf)  # type: ignore[arg-type]
728        indices = dict.fromkeys(map(super().prepare_indexing, indices))
729        all_used_symbols = set()
730        sym_to_node = {
731            n.symbol(): n
732            for n in itertools.chain.from_iterable(
733                [tree.nodes.values() for tree in self.range_trees]
734            )
735        }
736
737        def simplify(expr):
738            return sympy.simplify(
739                V.graph.sizevars.remove_precomputed_replacements(expr)
740            )
741
742        def visit_modular_indexing(base, divisor, modulus):
743            if base in sym_to_node:
744                node = sym_to_node[base]
745                all_used_symbols.add(
746                    node.root.lookup(
747                        node.divisor * divisor,
748                        V.graph.sizevars.evaluate_min(
749                            modulus, FloorDiv(node.length, divisor)
750                        ),
751                    ).symbol()
752                )
753
754        def visit_floor_div(base, divisor):
755            if base in sym_to_node:
756                node = sym_to_node[base]
757                all_used_symbols.add(
758                    node.root.lookup(
759                        node.divisor * divisor,
760                        FloorDiv(node.length, divisor),
761                    ).symbol()
762                )
763
764        # first figure out all_used_symbols to do dead symbol elimination
765        for index in indices:
766            if index.has(ModularIndexing):
767                index.replace(
768                    ModularIndexing(
769                        sympy.Wild("base"),
770                        sympy.Wild("divisor"),
771                        sympy.Wild("modulus"),
772                    ),
773                    visit_modular_indexing,
774                )
775            if index.has(FloorDiv):
776                index.replace(
777                    FloorDiv(
778                        sympy.Wild("base"),
779                        sympy.Wild("divisor"),
780                    ),
781                    visit_floor_div,
782                )
783            all_used_symbols.update(super().prepare_indexing(index).free_symbols)
784
785        self.has_indirect_indexing = any(
786            symbol_is_type(sym, SymT.INDIRECT) for sym in all_used_symbols
787        )
788
789        had_fallback = False
790        for tree in reversed(self.range_trees):
791            nodes = [n for n in tree.nodes.values() if n.symbol() in all_used_symbols]
792            nodes.sort(key=lambda n: size_hint(n.divisor))
793            if not nodes:
794                nodes.append(tree.lookup(1, tree.numel))
795            handled_count = 0
796            divisor = sympy.Integer(1)
797            added_sym_size = []
798            # decide on a minimal set of symbols and put them in self.halide_vars
799            while handled_count < len(nodes) and not eq(tree.numel, divisor):
800                sizes_to_add = [
801                    simplify(n.length) for n in nodes if eq(n.divisor, divisor)
802                ]
803                handled_count += len(sizes_to_add)
804                assert sizes_to_add, nodes
805                end = divisor * functools.reduce(
806                    V.graph.sizevars.evaluate_max, sizes_to_add
807                )
808                sizes_to_add.extend(
809                    [
810                        simplify(n.divisor / divisor)
811                        for n in nodes
812                        if lt(divisor, n.divisor) and lt(n.divisor, end)
813                    ]
814                )
815                while sizes_to_add:
816                    next_size = functools.reduce(sympy.gcd, sizes_to_add)
817                    if eq(next_size, 1):
818                        # sizes share no common factors, e.g [2, 21, 42, 441, 889056]
819                        # TODO(jansel): we should just prevent fusion in cases that hit this
820                        next_size = simplify(tree.numel / divisor)
821                        assert not eq(next_size, 1)
822                        sizes_to_add = []
823                        handled_count = len(nodes)
824                        had_fallback = True
825                    sym = sympy_index_symbol(f"h{len(self.halide_vars)}")
826                    if tree.prefix == "r":
827                        self.reduction_renames[sym] = sympy_index_symbol(
828                            f"hr{len(self.halide_vars)}"
829                        )
830                    self.halide_vars[sym] = next_size
831                    added_sym_size.append((sym, next_size))
832                    divisor *= next_size
833                    new_sizes = [n.length for n in nodes if eq(n.divisor, divisor)]
834                    handled_count += len(new_sizes)
835                    prior_len = len(sizes_to_add)
836                    sizes_to_add = [
837                        sympy.simplify(s / next_size)
838                        for s in sizes_to_add
839                        if not eq(s, next_size)
840                    ]
841                    assert len(sizes_to_add) < prior_len or prior_len == 0
842                    sizes_to_add.extend(new_sizes)
843
844            # create a mapping to the new set of symbols in self.index_replacements
845            for node in nodes:
846                try:
847                    idx = 0
848                    divisor = 1
849                    while not eq(node.divisor, divisor):
850                        sym, size = added_sym_size[idx]
851                        idx += 1
852                        divisor *= size
853                    length = 1
854                    expr = sympy.Integer(0)
855                    while not eq(node.length, length):
856                        sym, size = added_sym_size[idx]
857                        idx += 1
858                        expr += length * sym
859                        length *= size
860                    self.index_replacements[node.symbol()] = expr
861                except IndexError:
862                    assert had_fallback
863                    full_index = sympy.Integer(0)
864                    stride = sympy.Integer(1)
865                    for sym, size in added_sym_size:
866                        full_index += stride * sym
867                        stride *= size
868                    self.index_replacements[
869                        node.symbol()
870                    ] = V.graph.sizevars.simplify_with_ranges(
871                        ModularIndexing(full_index, node.divisor, node.length),
872                        self.halide_vars,  # type: ignore[arg-type]
873                    )
874
875        # codegen the variable definitions
876        for sym in self.halide_vars:
877            self.indexing_code.writeline(f"{sym} = hl.Var({sym.name!r})")
878        if self.reduction_renames:
879            self.codegen_rdom(
880                "rdom",
881                {rv: self.halide_vars[v] for v, rv in self.reduction_renames.items()},
882            )
883
884    def setup_dom_indexing(self):
885        """RDom based indexing uses explicit iteration ranges for Func updates"""
886        prefix = "i" if self.inside_reduction else "o"
887        if prefix in self.dom_renames:
888            return self.dom_renames[prefix]
889
890        renames = {}
891        for var in self.halide_vars.keys():
892            if not self.inside_reduction and var in self.reduction_renames:
893                continue
894            m = re.match(r"^h(\d+)$", var.name)
895            assert m
896            renames[var] = sympy_index_symbol(f"h{prefix}{m.group(1)}")
897
898        self.codegen_rdom(
899            f"{prefix}dom", {rv: self.halide_vars[v] for v, rv in renames.items()}
900        )
901
902        self.dom_renames[prefix] = renames
903        return renames
904
905    def codegen_rdom(self, name, vars):
906        rsizes = [
907            f"hl.Range(0, {self.kexpr(self.rename_indexing(size))})"
908            for size in vars.values()
909        ]
910        self.indexing_code.writeline(f"{name} = hl.RDom([{', '.join(rsizes)}])")
911        for i, rsym in enumerate(vars.keys()):
912            self.indexing_code.writeline(f"{rsym} = {name}[{i}]")
913
914    def prepare_indexing(
915        self,
916        index: sympy.Expr,
917    ):
918        index = super().prepare_indexing(index)
919        index = sympy_subs(index, self.index_replacements)
920        return V.graph.sizevars.simplify_with_ranges(index, self.halide_vars)  # type: ignore[arg-type]
921
922    def sym_size(self, sym):
923        """The size of an index symbol"""
924        if symbol_is_type(sym, SymT.TMP):
925            return self.lookup_cse_var(sym.name).indirect_indexing_size
926        return self.halide_vars[sym]
927
928    def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool):
929        """Convert address-based indexing into dimensions using self.halide_vars"""
930        symbols = []
931        for sym in sorted(index.free_symbols, key=lambda x: x.name):  # type: ignore[attr-defined]
932            if symbol_is_type(sym, (SymT.HALIDE, SymT.TMP)):
933                symbols.append(sym)
934            else:
935                assert symbol_is_type(
936                    sym,
937                    (
938                        SymT.UNBACKED_INT,
939                        SymT.SIZE,
940                        SymT.PRECOMPUTED_SIZE,
941                    ),
942                ), sym
943
944        # group the expression by variables used
945        offset = sympy.Integer(0)
946        split_expr = {s: sympy.Integer(0) for s in symbols}
947        split_failed: List[Tuple[List[sympy.Symbol], sympy.Expr]] = []
948        index = sympy.expand(self.rename_indexing(index))
949        for part in index.args if isinstance(index, sympy.Add) else [index]:
950            part_vars = [v for v in part.free_symbols if v in split_expr]
951            if len(part_vars) == 0:
952                offset += part
953            elif len(part_vars) == 1:
954                split_expr[part_vars[0]] += part
955            else:
956                new_split_failed = []
957                for i in range(len(split_failed)):
958                    assert split_failed[i] is not None
959                    other_vars, other_part = split_failed[i]
960                    if set(other_vars) & set(part_vars):
961                        part_vars.extend([v for v in other_vars if v not in part_vars])
962                        part += other_part
963                    else:
964                        new_split_failed.append((other_vars, other_part))
965                split_failed = [*new_split_failed, (part_vars, part)]
966
967        def expr_to_dimension(expr, syms):
968            expr = sympy.factor(expr)
969            if len(syms) == 1:
970                stride_wild = sympy.Wild("wild", exclude=symbols)
971                m = expr.match(stride_wild * syms[0])
972                if m:
973                    return DimensionInfo(
974                        syms[0], self.sym_size(syms[0]), m[stride_wild]
975                    )
976            assert not is_store, expr
977            length = sympy.simplify(
978                sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1
979            )
980            stride = sympy.Integer(1)
981            if isinstance(expr, sympy.Mul):
982                for term in expr.args:
983                    if isinstance(term, sympy.Integer):
984                        stride *= term
985                        expr = sympy.simplify(expr / term)
986                        length = sympy.simplify(sympy.ceiling(length / term))
987            return DimensionInfo(expr, length, stride)
988
989        # try to turn each group into a strided access
990        dims = []
991        for syms, expr in split_failed:
992            for v in syms:
993                expr += split_expr.pop(v)
994            dims.append(expr_to_dimension(expr, syms))
995        for sym, expr in split_expr.items():
996            dims.append(expr_to_dimension(expr, [sym]))
997        dims.sort(key=lambda d: V.graph.sizevars.size_hint(d.stride, fallback=inf))  # type: ignore[arg-type]
998
999        if not dims:  # scalar load/store
1000            if self.has_indirect_indexing:
1001                # workaround https://github.com/halide/Halide/issues/8338
1002                dims.append(DimensionInfo(sympy.Integer(0), 1, 1))
1003        elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1):
1004            # Halide assumes dimension 0 is stride == 1, so add a dummy dimension
1005            dims.insert(
1006                0, DimensionInfo(sympy.Integer(0), 1 if is_store else dims[0].stride, 1)
1007            )
1008
1009        if dims and not is_store:
1010            if var in self.buffer_offsets and V.graph.sizevars.statically_known_geq(
1011                offset, self.buffer_offsets[var]
1012            ):
1013                # reuse the existing offset to avoid needing an input alias
1014                self.apply_offset_to_dimension(dims, offset - self.buffer_offsets[var])
1015                offset = self.buffer_offsets[var]
1016            elif V.graph.sizevars.statically_known_gt(
1017                offset, 0
1018            ):  # TODO(jansel): negative offsets
1019                # roll the offset into the dimensions for cleaner indexing
1020                self.apply_offset_to_dimension(dims, offset)
1021                offset = 0
1022
1023        orig_var = var
1024        for i in itertools.count():
1025            if self.install_dims(var, dims, offset, is_store):
1026                return var, dims
1027            assert not is_store
1028            var = f"{orig_var}_view{i}"
1029            if var not in self.buffer_aliases[orig_var]:
1030                self.buffer_aliases[orig_var].append(var)
1031
1032    def install_dims(self, var, dims, offset, is_store):
1033        """Try to set self.buffer_dimensions[var], return True on success"""
1034        if var not in self.buffer_dimensions:
1035            self.buffer_dimensions[var] = dims
1036            self.buffer_offsets[var] = offset
1037            return True
1038        if self.buffer_offsets[var] != offset or len(
1039            self.buffer_dimensions[var]
1040        ) != len(dims):
1041            return False
1042        if is_store:
1043            return self.buffer_dimensions[var] == dims
1044        for old, new in zip(self.buffer_dimensions[var], dims):
1045            if old.stride != new.stride:
1046                return False
1047            if old.size != new.size or old.expr != new.expr:
1048                old.size = V.graph.sizevars.evaluate_max(old.size, new.size)
1049                old.expr = None
1050        return True
1051
1052    def apply_offset_to_dimension(self, dims, offset):
1053        if offset == 0:
1054            return
1055        for i in reversed(range(len(dims))):
1056            if dims[i].stride == 1 or V.graph.sizevars.statically_known_geq(
1057                offset, dims[i].stride
1058            ):
1059                part = FloorDiv(offset, dims[i].stride)
1060                offset -= part * dims[i].stride
1061                dims[i].expr += part
1062        assert offset == 0
1063
1064    def used_dims_from_index(self, index: sympy.Expr):
1065        """Detect which range trees are used to populate HalideCSEVariable.used_dims"""
1066        used_dims = set()
1067        for sym in index.free_symbols:
1068            assert isinstance(sym, sympy.Symbol)
1069            if symbol_is_type(sym, SymT.TMP):
1070                # indirect indexing
1071                cse_var = self.lookup_cse_var(sym.name)
1072                assert (
1073                    isinstance(cse_var, HalideCSEVariable)
1074                    and cse_var.used_dims is not None
1075                )
1076                used_dims.update(cse_var.used_dims)
1077            elif symbol_is_type(sym, SymT.HALIDE):
1078                used_dims.add(sym)
1079            elif symbol_is_type(
1080                sym, (SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, SymT.INDEX)
1081            ):
1082                pass
1083            else:
1084                raise NotImplementedError(f"unhandled symbol {sym}")
1085        return self.sort_used_dims(used_dims)
1086
1087    def sort_used_dims(self, used_dims):
1088        assert all(isinstance(x, sympy.Expr) for x in used_dims)
1089        ordered = [
1090            sym
1091            for sym in itertools.chain(
1092                self.halide_vars, self.reduction_renames.values()
1093            )
1094            if sym in used_dims
1095        ]
1096        assert len(ordered) == len(used_dims)
1097        return ordered
1098
1099    def make_index_str(self, dims, replacements=None, zero_vars=False):
1100        index_str = ", ".join(d.index_str(replacements, zero_vars) for d in dims)
1101        if len(dims) == 0:
1102            index_str = "()"
1103        elif len(dims) == 1:
1104            # workaround for https://github.com/halide/Halide/issues/8299
1105            index_str = f"{index_str},"
1106        return index_str
1107
1108    def load(self, name: str, index: sympy.Expr):
1109        """Codegen a load from an InputBuffer"""
1110        var = self.args.input(name)
1111        index = self.prepare_indexing(index)
1112        var, dims = self.indexing_to_dimensions(var, index, False)
1113        line = f"{var}[{self.make_index_str(dims)}]"
1114        dtype = V.graph.get_dtype(name)
1115        if dtype in (torch.float16, torch.bfloat16):
1116            dtype = torch.float32
1117            line = f"hl.cast(hl.Float(32), {line})"
1118
1119        if self._load_mask:
1120            assert (
1121                isinstance(self._load_mask, HalideCSEVariable)
1122                and self._load_mask.used_dims is not None
1123            )
1124            used_dims = {*self.used_dims_from_index(index), *self._load_mask.used_dims}
1125            result = self.newfunc(self.sort_used_dims(used_dims))
1126            if result.used_dims:
1127                self.body.writeline(f"{result.name}_mask = hl.RDom([hl.Range(0, 1)])")
1128                self.body.writeline(f"{result.name}_mask.where({self._load_mask})")
1129                other = self.kexpr(self._load_other or 0)  # type: ignore[arg-type]
1130                self.body.writeline(
1131                    f"{result} = hl.cast({halide_type(dtype)}, {other})"
1132                )
1133                self.body.writeline(
1134                    f"{result} = {line} + hl.cast({halide_type(dtype)}, {result.name}_mask)"
1135                )
1136            else:
1137                # scalar case
1138                self.body.writeline(
1139                    f"{result} = hl.select({self._load_mask}, {line}, hl.cast({halide_type(dtype)}, 0))"
1140                )
1141            return result
1142        else:
1143            return self.genfunc(line, self.used_dims_from_index(index))
1144
1145    def lookup_cse_var(self, name: str):
1146        return self.cse.varname_map[re.sub(r"\[.*", "", name)]
1147
1148    def store(
1149        self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
1150    ) -> None:
1151        """Codegen a store to an OutputBuffer"""
1152        assert isinstance(value, HalideCSEVariable)
1153        var = self.args.output(name)
1154        index = self.prepare_indexing(index)
1155        var, dims = self.indexing_to_dimensions(var, index, True)
1156        if self.is_indirect_indexing(index) or mode is not None:
1157            replacements = self.setup_dom_indexing()
1158            index_str = self.make_index_str(dims, replacements)
1159            value_str = value.subs_str(replacements)
1160            undef_dims = (", ".join(["hl.Var()"] * len(dims))) or "()"
1161            self.body.writeline(
1162                DeferredLine(name, f"{var}[{undef_dims}] = hl.undef({var}.type())")
1163            )
1164        else:
1165            index_str = self.make_index_str(dims, zero_vars=True)
1166            value_str = str(value)
1167
1168        dtype = V.graph.get_dtype(name)
1169        if mode is None:
1170            line = f"{var}[{index_str}] = hl.cast({halide_type(dtype)}, {value_str})"
1171        elif mode == "atomic_add":
1172            line = f"{var}[{index_str}] += hl.cast({halide_type(dtype)}, {value_str})"
1173        else:
1174            raise NotImplementedError(f"store mode={mode}")
1175        self.body.writeline(DeferredLine(name, line))
1176
1177    def reduction(
1178        self,
1179        dtype: torch.dtype,
1180        src_dtype: torch.dtype,
1181        reduction_type: ReductionType,
1182        value: Union[CSEVariable, Tuple[CSEVariable, ...]],
1183    ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
1184        """Codegen a reduction operation"""
1185        assert self.inside_reduction
1186        assert not self._load_mask
1187        cache_key = (src_dtype, reduction_type, value)
1188        if cache_key in self.cse.reduction_cache:
1189            return self.cse.reduction_cache[cache_key]
1190
1191        if isinstance(value, tuple):
1192            assert reduction_type == "welford_combine"
1193            self.cse.reduction_cache[
1194                cache_key
1195            ] = result_tuple = self.welford_combine_impl(*value)
1196            return result_tuple
1197
1198        assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
1199        reduction_vars = {*self.reduction_renames}
1200        result_var = self.newfunc(
1201            [v for v in value.used_dims if v not in reduction_vars]
1202        )
1203        if reduction_vars - {*value.used_dims}:
1204            value = self.genfunc(
1205                f"{value}", self.sort_used_dims({*value.used_dims, *reduction_vars})
1206            )
1207        value_str = value.subs_str(self.reduction_renames)
1208        default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
1209        acc_type = halide_acc_type(dtype)
1210
1211        if reduction_type in ("argmax", "argmin"):
1212            index = f"{result_var.name}_{reduction_type}"
1213            self.body.writeline(f"{index} = hl.{reduction_type}(rdom, {value_str})")
1214            # turn the N-D argmax index into a 1-D one
1215            parts = []
1216            stride = 1
1217            for i, sym in enumerate(self.reduction_renames):
1218                parts.append(f"{index}[{i}]")
1219                if stride != 1:
1220                    parts[-1] += f"*{stride}"
1221                stride *= self.halide_vars[sym]
1222            self.body.writeline(f"{result_var} = {' + '.join(parts)}")
1223        elif reduction_type == "welford_reduce":
1224            # TODO(jansel): implement welford_reduce without fallback
1225            result_var = self.welford_reduce_fallback(dtype, value)
1226        else:
1227            combine_fn = get_reduction_combine_fn(reduction_type, acc_type)
1228            with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))):
1229                combine_str = combine_fn(result_var, value_str)  # type: ignore[arg-type]
1230            default_str = f"hl.cast({acc_type}, {halide_constant(default)})"
1231            self.body.writeline(f"{result_var} = {default_str}")
1232            self.body.writeline(f"{result_var} = {combine_str}")
1233
1234        self.cse.reduction_cache[cache_key] = result_var
1235        return result_var
1236
1237    def welford_combine_impl(self, mean, m2, weight):
1238        assert isinstance(mean, HalideCSEVariable) and mean.used_dims is not None
1239        assert isinstance(m2, HalideCSEVariable) and m2.used_dims is not None
1240        assert isinstance(weight, HalideCSEVariable) and weight.used_dims is not None
1241        used_dims = {*mean.used_dims, *m2.used_dims, *weight.used_dims} or {
1242            *self.halide_vars
1243        }
1244        used_dims -= {*self.reduction_renames}
1245        result_var = self.newfunc(self.sort_used_dims(used_dims))
1246        default = [f"hl.cast({x.name}.type(), 0)" for x in (mean, m2, weight)]
1247        pfx = result_var.name
1248        self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(default)}])")
1249        self.body.writeline(f"{pfx}_mean_1 = {result_var}[0]")
1250        self.body.writeline(f"{pfx}_m2_1 = {result_var}[1]")
1251        self.body.writeline(f"{pfx}_weight_1 = {result_var}[2]")
1252        self.body.writeline(f"{pfx}_mean_2 = {mean.subs_str(self.reduction_renames)}")
1253        self.body.writeline(f"{pfx}_m2_2 = {m2.subs_str(self.reduction_renames)}")
1254        self.body.writeline(
1255            f"{pfx}_weight_2 = {weight.subs_str(self.reduction_renames)}"
1256        )
1257        self.body.writeline(f"{pfx}_delta = {pfx}_mean_2 - {pfx}_mean_1")
1258        self.body.writeline(f"{pfx}_new_weight = {pfx}_weight_1 + {pfx}_weight_2")
1259        self.body.writeline(
1260            f"{pfx}_w2_over_w = hl.select({pfx}_new_weight == 0.0, 0.0, {pfx}_weight_2 / {pfx}_new_weight)"
1261        )
1262        update = [
1263            f"{pfx}_mean_1 + {pfx}_delta * {pfx}_w2_over_w",
1264            f"{pfx}_m2_1 + {pfx}_m2_2 + {pfx}_delta * {pfx}_delta * {pfx}_weight_1 * {pfx}_w2_over_w",
1265            f"{pfx}_new_weight",
1266        ]
1267        self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(update)}])")
1268
1269        unpacked = []
1270        for i in range(3):
1271            unpacked.append(self.newfunc(result_var.used_dims))
1272            self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]")
1273        return tuple(unpacked)
1274
1275    def scan(
1276        self,
1277        dtypes: Tuple[torch.dtype, ...],
1278        combine_fn: Callable[
1279            [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
1280        ],
1281        values_orig: Tuple[CSEVariable, ...],
1282    ) -> Tuple[CSEVariable, ...]:
1283        assert self.inside_reduction
1284        assert len(dtypes) == len(values_orig)
1285        values: List[HalideCSEVariable] = []
1286        all_used_dims = set()
1287        for value in values_orig:
1288            assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
1289            if set(value.used_dims) & set(self.reduction_renames):
1290                values.append(value)
1291            else:
1292                values.append(
1293                    self.genfunc(
1294                        f"{value}", [*value.used_dims, [*self.reduction_renames][:1]]
1295                    )
1296                )
1297            all_used_dims.update(value.used_dims)
1298        result_var = self.newfunc(self.sort_used_dims(all_used_dims))
1299        assert result_var.used_dims and set(result_var.used_dims) & set(
1300            self.reduction_renames
1301        )
1302        initial = [
1303            f"hl.cast({halide_acc_type(dtype)}, {value})"
1304            for dtype, value in zip(dtypes, values)
1305        ]
1306
1307        length = self.kexpr(self.rename_indexing(self.range_trees[-1].numel))
1308        scan_dom = f"{result_var.name}_rdom"
1309        scan = f"{scan_dom}.x"
1310        self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])")
1311
1312        assert (
1313            len(self.reduction_renames) == 1
1314        ), "multi-dimensional scan not implemented"
1315        (scan_var,) = [*self.reduction_renames]  # type: ignore[misc]
1316        scan_renames_cur = {scan_var: sympy_index_symbol(scan)}
1317        scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1}
1318
1319        if len(values) == 1:
1320
1321            def maybe_tuple(x):
1322                return x[0]
1323
1324            read_left = [result_var.subs_str(scan_renames_pri)]
1325            read_right = [result_var.subs_str(scan_renames_cur)]
1326        else:
1327
1328            def maybe_tuple(x):
1329                return f"hl.Tuple([{', '.join(x)}])"
1330
1331            read_left = [
1332                result_var.subs_str(scan_renames_pri) + f"[{i}]"
1333                for i in range(len(values))
1334            ]
1335            read_right = [
1336                result_var.subs_str(scan_renames_cur) + f"[{i}]"
1337                for i in range(len(values))
1338            ]
1339
1340        self.body.writeline(f"{result_var} = {maybe_tuple(initial)}")
1341
1342        # Disable CSE for update fn
1343        with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))):
1344            combine_str = combine_fn(read_left, read_right)  # type: ignore[arg-type]
1345        self.body.writeline(
1346            f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}"
1347        )
1348
1349        if len(values) == 1:
1350            return (result_var,)
1351
1352        unpack_vars = [self.newfunc(self.sort_used_dims(all_used_dims)) for _ in values]
1353        for i, v in enumerate(unpack_vars):
1354            self.body.writeline(f"{v} = {result_var}[{i}]")
1355        return tuple(unpack_vars)
1356
1357    def genfunc(
1358        self, line, used_dims, *, bounds=ValueRanges.unknown()
1359    ) -> HalideCSEVariable:
1360        var = self.cse.generate(self.body, line, bounds=bounds)
1361        assert isinstance(var, HalideCSEVariable)
1362        var.used_dims = used_dims
1363        return var
1364
1365    def newfunc(self, used_dims) -> HalideCSEVariable:
1366        var = self.cse.newvar()
1367        assert isinstance(var, HalideCSEVariable)
1368        var.used_dims = used_dims
1369        return var
1370
1371    def halide_buffer_numel(self, name: str):
1372        """
1373        We map all tensors to 1D buffers in Halide since Halide has trouble representing some strides that PyTorch
1374        supports.  If there are gaps in the underlying layout the numel we pass to Halide includes the gaps while
1375        PyTorch's numel excludes them.
1376        """
1377        return V.graph.get_buffer(name).get_layout().storage_size()
1378
1379    def halide_argdefs(self):
1380        """
1381        Halide requires scalar inputs before outputs, so need to reorder args.
1382        """
1383
1384        def arg_order(arg_tuple):
1385            call_str, arg = arg_tuple
1386            if isinstance(arg, SizeArg):
1387                return 1  # this would normally be at the end, move it to middle
1388            elif "out_ptr" in arg.name:
1389                return 2
1390            else:
1391                assert "in_ptr" in arg.name
1392                return 0
1393
1394        result = []
1395        _, a, b, _ = self.args.python_argdefs()
1396        for call_str, arg in sorted(zip(a, b), key=arg_order):
1397            result.append((call_str, arg))
1398            if isinstance(arg, TensorArg):
1399                assert arg.offset == 0 and arg.alias_of is None
1400                for alias in self.buffer_aliases.get(arg.name, ()):
1401                    result.append(
1402                        (
1403                            None,
1404                            TensorArg(
1405                                alias,
1406                                arg.buffer,
1407                                arg.dtype,
1408                                arg.offset,
1409                                alias_of=arg.name,
1410                            ),
1411                        )
1412                    )
1413        return result
1414
1415    def halide_kernel_meta(self) -> HalideMeta:
1416        """Compute metadata required by codecache.py"""
1417        argtypes = []
1418        for _, arg in self.halide_argdefs():
1419            if isinstance(arg, SizeArg):
1420                shape = None
1421                stride = None
1422                offset = None
1423                dtype = "long"
1424            else:
1425                shape = [
1426                    cexpr(self.rename_indexing(x.size))
1427                    for x in self.buffer_dimensions[arg.name]
1428                ]
1429                stride = [
1430                    cexpr(self.rename_indexing(x.stride))
1431                    for x in self.buffer_dimensions[arg.name]
1432                ]
1433                assert len(shape) == len(stride)
1434                offset = cexpr(self.buffer_offsets[arg.name])
1435                dtype = f"{DTYPE_TO_CPP[arg.dtype]}*"
1436            argtypes.append(
1437                HalideInputSpec(
1438                    dtype,
1439                    arg.name,
1440                    shape=shape,
1441                    stride=stride,
1442                    offset=offset,
1443                    alias_of=arg.alias_of,
1444                )
1445            )
1446
1447        current_device = V.graph.scheduler.get_current_device_or_throw()
1448        if current_device.type == "cpu":
1449            target = [config.halide.cpu_target]
1450            schduler = config.halide.scheduler_cpu
1451            scheduler_flags = {
1452                "parallelism": parallel_num_threads(),
1453            }
1454            cuda_device = None
1455        else:
1456            assert current_device.type == "cuda", "only cpu/cuda supported"
1457            assert current_device.index <= 0, "only default device supported"
1458            target = [config.halide.gpu_target]
1459            schduler = config.halide.scheduler_cuda
1460            capability = torch.cuda.get_device_properties(current_device)
1461            if "cuda_capability" not in target[0]:
1462                for major, minor in [(8, 6), (8, 0), (7, 5), (7, 0), (6, 1)]:
1463                    if capability.major >= major and capability.minor >= minor:
1464                        target.append(f"cuda_capability_{major}{minor}")
1465                        break
1466            target.append("user_context")
1467            scheduler_flags = {
1468                "parallelism": capability.multi_processor_count,
1469                # TODO(jansel): explore other flags, see:
1470                # grep parser.parse ~/Halide/src/autoschedulers/anderson2021/AutoSchedule.cpp
1471            }
1472            cuda_device = max(0, current_device.index)
1473
1474        # strict_float is requires for correctness
1475        target.append("strict_float")
1476
1477        # without this we will initialize cuda once per kernel and hit errors
1478        target.append("no_runtime")
1479
1480        if not config.halide.asserts:
1481            target.append("no_asserts")
1482
1483        if config.halide.debug:
1484            target.append("debug")
1485
1486        if "64" in self.index_dtype:
1487            # TODO(jansel): it is unclear if this does anything, since input sizes are still int32
1488            target.append("large_buffers")
1489
1490        return HalideMeta(
1491            argtypes,
1492            target="-".join(target),
1493            scheduler=schduler,
1494            scheduler_flags=scheduler_flags,
1495            cuda_device=cuda_device,
1496        )
1497
1498    def codegen_kernel(self, name=None):
1499        """Called at the end to generate a final kernel string"""
1500        if self.args.inplace_buffers:
1501            raise Unsupported("inplace_buffers")
1502        meta = self.halide_kernel_meta()  # ensure needed args are added early
1503        code = IndentedBuffer()
1504        code.splice(
1505            """
1506            import halide as hl
1507            from torch._inductor.runtime import halide_helpers
1508            from math import inf, nan
1509
1510            @hl.generator(name="kernel")
1511            class Kernel:
1512        """,
1513            strip=True,
1514        )
1515        code.do_indent()
1516        for _, arg in self.halide_argdefs():
1517            if isinstance(arg, SizeArg):
1518                code.writeline(f"{arg.name} = hl.InputScalar({self.index_dtype})")
1519            else:
1520                assert arg.buffer, arg
1521                argcls = "hl.OutputBuffer" if "out" in arg.name else "hl.InputBuffer"
1522                argtype = halide_type(arg.dtype)
1523                ndim = len(self.buffer_dimensions[arg.name])
1524                code.writeline(f"{arg.name} = {argcls}({argtype}, {ndim})")
1525        code.splice(
1526            """
1527            def generate(g):
1528        """
1529        )
1530        code.do_indent()
1531        for _, arg in self.halide_argdefs():
1532            code.writeline(f"{arg.name} = g.{arg.name}")
1533        for old, new in self.args.aliases():
1534            code.writeline(f"{old} = {new}")
1535        code.splice(self.indexing_code)
1536
1537        def update_index(m):
1538            var = self.cse.varname_map[m.group(1)]
1539            assert var.used_dims is not None, var
1540            return str(var)
1541
1542        for line in self.body._lines:
1543            if isinstance(line, str):
1544                # fill in missing indices
1545                line = HalideCSEVariable.undefined_re.sub(update_index, line)
1546            code.writeline(line)
1547        code.writeline("")
1548        code.writeline("assert g.using_autoscheduler()")
1549
1550        for _, arg in self.halide_argdefs():
1551            # fallback=1 below because halide requires buffers to be at least as large as the estimates
1552            # This causes crashes if our estimate is greater than the vector length
1553            # https://github.com/halide/Halide/issues/3103
1554            if isinstance(arg, SizeArg):
1555                hint = V.graph.sizevars.size_hint(arg.expr, fallback=1)
1556                code.writeline(f"{arg.name}.set_estimate({hint})")
1557            else:
1558                dims = self.buffer_dimensions[arg.name]
1559                range_hints = []
1560                for i, dim in enumerate(dims):
1561                    hint = self._autoscheduler_workarounds(
1562                        V.graph.sizevars.size_hint(dim.size, fallback=1), dims
1563                    )
1564                    range_hints.append(f"hl.Range(0, {hint})")
1565                    if "out" not in arg.name:
1566                        code.writeline(f"{arg.name}.dim({i}).set_min(0)")
1567                        try:
1568                            code.writeline(
1569                                f"{arg.name}.dim({i}).set_stride({int(dim.stride)})"
1570                            )
1571                        except TypeError:
1572                            pass  # not integer
1573                        try:
1574                            code.writeline(
1575                                f"{arg.name}.dim({i}).set_extent({int(dim.size)})"
1576                            )
1577                        except TypeError:
1578                            pass  # not integer
1579                code.writeline(f"{arg.name}.set_estimates([{', '.join(range_hints)}])")
1580
1581        code.do_unindent(2)
1582        code.splice(
1583            """
1584            if __name__ == "__main__":
1585                hl.main()
1586            """.rstrip(),
1587        )
1588        if meta.scheduler:
1589            code.splice(
1590                f"""
1591                else:
1592                    hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r})
1593                    target = hl.Target({meta.target!r})
1594                    autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r})
1595                    with hl.GeneratorContext(target, autoscheduler):
1596                        gen = Kernel()
1597                        pipeline = gen._build_pipeline()
1598                        # gen.compile_to_callable() does not run the autoscheduler
1599                        pipeline.apply_autoscheduler(target, autoscheduler)
1600                        kernel = pipeline.compile_to_callable([
1601                                gen._get_input_parameter(a.name)._to_argument()
1602                                for a in gen._get_arginfos()
1603                                if a.dir == hl.ArgInfoDirection.Input
1604                            ], target)
1605                """,
1606                strip=True,
1607            )
1608        else:
1609            code.splice(
1610                f"""
1611                  else:
1612                      with hl.GeneratorContext(hl.Target({meta.target!r})):
1613                          kernel = Kernel().compile_to_callable()
1614                  """,
1615                strip=True,
1616            )
1617        return code.getvalue()
1618
1619    @staticmethod
1620    def _autoscheduler_workarounds(n, dims):
1621        if (
1622            len(dims) == 1
1623            and config.halide.scheduler_cuda == "Anderson2021"
1624            and V.graph.scheduler.get_current_device_or_throw().type == "cuda"
1625        ):
1626            # workaround https://github.com/halide/Halide/issues/8246
1627            n = max(2, n)
1628        return n
1629
1630    def call_kernel(self, name: str, node=None):
1631        """Codegen a call to this kernel"""
1632        wrapper = V.graph.wrapper_code
1633        call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None]
1634        current_device = V.graph.scheduler.get_current_device_or_throw()
1635        if current_device.type == "cuda":
1636            stream_name = wrapper.write_get_raw_stream(current_device.index, V.graph)
1637            call_args.append(stream_name)
1638        wrapper.generate_kernel_call(
1639            name,
1640            call_args,
1641            cuda=False,  # grid/stream is handled internally in halide
1642        )
1643
1644    def generate_assert(self, check):
1645        return False  # TODO(jansel): support asserts
1646
1647    def check_bounds(
1648        self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
1649    ):
1650        pass  # TODO(jansel): support asserts
1651
1652
1653class HalideScheduling(SIMDScheduling):
1654    int32_type = "hl.Int(32)"
1655    # TODO(jansel): Halide doesn't actually support 64 bit indexing...
1656    int64_type = "hl.Int(64)"
1657    kernel_type = HalideKernel  # type: ignore[arg-type]
1658
1659    @classmethod
1660    def get_backend_features(cls, device: torch.device):
1661        result = dict.fromkeys(
1662            [
1663                BackendFeature.TUPLE_REDUCTION,
1664                BackendFeature.PREFER_STORE_LOOP_ORDER,
1665                BackendFeature.REDUCE_TO_SINGLE_ELEMENT,
1666            ]
1667        )
1668        if config.halide.scan_kernels:
1669            result[BackendFeature.SCAN] = None
1670        return result
1671
1672    def define_kernel(self, src_code, node_schedule, kernel):
1673        """Codegen kernel definition to go in output wrapper code"""
1674        wrapper = V.graph.wrapper_code
1675        if src_code in wrapper.src_to_kernel:
1676            kernel_name = wrapper.src_to_kernel[src_code]
1677        else:
1678            kernel_name = f"halide_kernel_{wrapper.next_kernel_suffix()}"
1679            wrapper.src_to_kernel[src_code] = kernel_name
1680            wrapper.add_import_once(
1681                "from torch._inductor.runtime.hints import HalideMeta, HalideInputSpec"
1682            )
1683
1684            compile_wrapper = IndentedBuffer()
1685            compile_wrapper.writeline(
1686                f"async_compile.halide({kernel.halide_kernel_meta()!r}, '''"
1687            )
1688            compile_wrapper.splice(src_code, strip=True)
1689            compile_wrapper.writeline("''')")
1690
1691            origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
1692            metadata_comment = f"{origins}\n{detailed_origins}"
1693            wrapper.define_kernel(
1694                kernel_name, compile_wrapper.getvalue(), metadata_comment
1695            )
1696            if is_metric_table_enabled("kernel_metadata"):
1697                log_kernel_metadata(kernel_name, "", src_code)
1698
1699        return kernel_name
1700