1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import dataclasses 5import functools 6import itertools 7import logging 8import re 9from collections import defaultdict 10from math import inf 11from typing import ( 12 Any, 13 Callable, 14 Dict, 15 List, 16 Optional, 17 Sequence, 18 Tuple, 19 TYPE_CHECKING, 20 Union, 21) 22 23import sympy 24 25import torch 26import torch._logging 27 28from ..._prims_common import is_integer_dtype 29from ...utils._sympy.functions import FloorDiv, ModularIndexing 30from ...utils._sympy.symbol import symbol_is_type, SymT 31from ...utils._sympy.value_ranges import ValueRanges 32from .. import config, ir 33from ..codecache import HalideCodeCache 34from ..ir import get_reduction_combine_fn 35from ..metrics import is_metric_table_enabled, log_kernel_metadata 36from ..ops_handler import AddParenHandler, MockHandler 37from ..runtime.hints import HalideInputSpec, HalideMeta, ReductionHint 38from ..utils import ( 39 get_bounds_index_expr, 40 get_kernel_metadata, 41 parallel_num_threads, 42 sympy_index_symbol, 43 sympy_subs, 44) 45from ..virtualized import _ops as ops, OpsHandler, V 46from .common import ( 47 BackendFeature, 48 CSEVariable, 49 DeferredLine, 50 IndentedBuffer, 51 OpOverrides, 52 PythonPrinter, 53 SizeArg, 54 TensorArg, 55) 56from .cpp import DTYPE_TO_CPP 57from .cpp_utils import cexpr 58from .simd import constant_repr, SIMDKernel, SIMDScheduling 59 60 61if TYPE_CHECKING: 62 from torch.utils._ordered_set import OrderedSet 63 64 from ..ops_handler import ReductionType, StoreMode 65 66log = logging.getLogger(__name__) 67 68 69def halide_constant(val): 70 if isinstance(val, int) and not (-2147483648 <= val <= 2147483647): 71 info = torch.iinfo(torch.int64) 72 if val == info.min: 73 return "hl.Int(64).min()" 74 if val == info.max: 75 return "hl.Int(64).max()" 76 return f"hl.i64({val!r})" 77 if isinstance(val, float): 78 return f"hl.f64({constant_repr(val)})" 79 return repr(val) 80 81 82class Unsupported(RuntimeError): 83 def __init__(self, thing) -> None: 84 super().__init__(f"halide backend does not support: {thing}") 85 86 87class HalidePrinter(PythonPrinter): 88 @staticmethod 89 def cast_index(expr): 90 return f"hl.cast({V.kernel.index_dtype}, {expr})" 91 92 @staticmethod 93 def cast_float(expr): 94 return f"hl.cast(hl.Float(32), {expr})" 95 96 def _print_Float(self, expr): 97 return f"hl.f32({expr})" 98 99 def _print_ToFloat(self, expr): 100 assert len(expr.args) == 1 101 return f"hl.f32({self._print(expr.args[0])})" 102 103 def _print_floor(self, expr): 104 assert len(expr.args) == 1 105 return self.cast_index(f"hl.floor({self._print(expr.args[0])})") 106 107 def _print_Trunc(self, expr): 108 assert len(expr.args) == 1 109 return self.cast_index(f"hl.trunc({self._print(expr.args[0])})") 110 111 _print_TruncToInt = _print_Trunc 112 113 def _print_ceiling(self, expr): 114 assert len(expr.args) == 1 115 return self.cast_index(f"hl.ceil({self._print(expr.args[0])})") 116 117 def _helper_sqrt(self, expr): 118 return f"hl.sqrt({self.cast_float(self._print(expr))})" 119 120 def _print_Where(self, expr): 121 c = self.doprint(expr.args[0]) 122 p = self.doprint(expr.args[1]) 123 q = self.doprint(expr.args[2]) 124 return f"hl.select({c}, {p}, {q})" 125 126 def _print_Min(self, expr): 127 if len(expr.args) == 1: 128 return self._print(expr.args[0]) 129 130 mid = len(expr.args) // 2 131 a = self._print(sympy.Min(*expr.args[:mid])) 132 b = self._print(sympy.Min(*expr.args[mid:])) 133 return f"hl.min({a}, {b})" 134 135 def _print_Max(self, expr): 136 if len(expr.args) == 1: 137 return self._print(expr.args[0]) 138 139 mid = len(expr.args) // 2 140 a = self._print(sympy.Max(*expr.args[:mid])) 141 b = self._print(sympy.Max(*expr.args[mid:])) 142 143 return f"hl.max({a}, {b})" 144 145 def _print_Abs(self, expr): 146 assert len(expr.args) == 1 147 return self.cast_index(f"hl.abs({self._print(expr.args[0])})") 148 149 def _print_OpaqueUnaryFn_cos(self, expr): 150 assert len(expr.args) == 1 151 return f"hl.cos(({self._print(expr.args[0])})" 152 153 def _print_OpaqueUnaryFn_cosh(self, expr): 154 assert len(expr.args) == 1 155 return f"hl.cosh(({self._print(expr.args[0])})" 156 157 def _print_OpaqueUnaryFn_acos(self, expr): 158 assert len(expr.args) == 1 159 return f"hl.acos(({self._print(expr.args[0])})" 160 161 def _print_OpaqueUnaryFn_sin(self, expr): 162 assert len(expr.args) == 1 163 return f"hl.sin(({self._print(expr.args[0])})" 164 165 def _print_OpaqueUnaryFn_sinh(self, expr): 166 assert len(expr.args) == 1 167 return f"hl.sinh(({self._print(expr.args[0])})" 168 169 def _print_OpaqueUnaryFn_asin(self, expr): 170 assert len(expr.args) == 1 171 return f"hl.asin(({self._print(expr.args[0])})" 172 173 def _print_OpaqueUnaryFn_tan(self, expr): 174 assert len(expr.args) == 1 175 return f"hl.tan(({self._print(expr.args[0])})" 176 177 def _print_OpaqueUnaryFn_tanh(self, expr): 178 assert len(expr.args) == 1 179 return f"hl.tanh(({self._print(expr.args[0])})" 180 181 def _print_OpaqueUnaryFn_atan(self, expr): 182 assert len(expr.args) == 1 183 return f"hl.atan(({self._print(expr.args[0])})" 184 185 def _print_FloorDiv(self, expr): 186 if expr.is_integer: 187 return super()._print_FloorDiv(expr) 188 189 x, div = expr.args 190 x = self.cast_float(self.paren(self.doprint(x))) 191 div = self.cast_float(self.paren(self.doprint(div))) 192 return self.cast_index(f"hl.floor({x} / {div})") 193 194 def _print_Round(self, expr): 195 assert len(expr.args) == 1 196 return self.cast_index(f"hl.round({self._print(expr.args[0])})") 197 198 _print_RoundToInt = _print_Round 199 200 def _print_IntTrueDiv(self, expr): 201 a, b = expr.args 202 # force a cast to float 203 return f"({a}) / ({b}+hl.f32(0))" 204 205 def _print_RoundDecimal(self, expr): 206 val, n = expr.args 207 val = self._print(val) 208 n = int(n) 209 return f"hl.f32({10.**(-n)!r})*hl.round(({val})*hl.f32({10.**n!r}))" 210 211 212texpr = HalidePrinter().doprint 213pexpr = PythonPrinter().doprint 214 215 216_halide_type = { 217 torch.bool: "hl.Bool()", 218 torch.bfloat16: "hl.BFloat(16)", 219 torch.float16: "hl.Float(16)", 220 torch.float32: "hl.Float(32)", 221 torch.float64: "hl.Float(64)", 222 torch.int8: "hl.Int(8)", 223 torch.int16: "hl.Int(16)", 224 torch.int32: "hl.Int(32)", 225 torch.int64: "hl.Int(64)", 226 torch.uint8: "hl.UInt(8)", 227 torch.uint16: "hl.UInt(16)", 228 torch.uint32: "hl.UInt(32)", 229 torch.uint64: "hl.UInt(64)", 230} 231 232 233def halide_type(dtype): 234 return _halide_type[dtype] 235 236 237def halide_acc_type(dtype): 238 if is_integer_dtype(dtype) and dtype.is_signed and dtype != torch.int64: 239 dtype = torch.int32 240 if dtype in (torch.float16, torch.bfloat16): 241 dtype = torch.float32 242 return halide_type(dtype) 243 244 245class HalideOverrides(OpOverrides): 246 @staticmethod 247 def to_dtype( 248 x, 249 dtype: torch.dtype, 250 src_dtype: Optional[torch.dtype] = None, 251 use_compute_types=True, 252 ): 253 if dtype == torch.bool: 254 return f"({x} != 0)" 255 return f"hl.cast({halide_type(dtype)}, {x})" 256 257 @staticmethod 258 def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): 259 if src_dtype in (torch.float16, torch.bfloat16): 260 x = f"hl.cast({halide_type(src_dtype)}, {x})" # body compute is upcast to fp32 261 line = f"hl.reinterpret({halide_type(dtype)}, {x})" 262 if dtype in (torch.float16, torch.bfloat16): 263 line = f"hl.cast(hl.Float(32), {line})" 264 return line 265 266 @classmethod 267 def constant(cls, value, dtype): 268 return cls.to_dtype(halide_constant(value), dtype) 269 270 @staticmethod 271 def abs(x): 272 return f"hl.abs({x})" 273 274 @staticmethod 275 def exp(x): 276 if not hasattr(x, "name"): 277 return f"hl.exp({x})" 278 return f"hl.fast_exp(hl.cast(hl.Float(32), {x})) if {x.name}.type().bits() <= 32 else hl.exp({x})" 279 280 @staticmethod 281 def libdevice_exp(x): 282 return f"hl.exp({x})" # higher precision that ops.exp 283 284 @staticmethod 285 def sqrt(x): 286 return f"hl.sqrt({x})" 287 288 @staticmethod 289 def minimum(a, b): 290 # return f"hl.min({a}, {b})" <== handles nan wrong 291 if not hasattr(a, "name"): 292 return f"hl.min({a}, {b})" 293 b = f"hl.cast({a.name}.type(), {b})" 294 return f"hl.select(({a}<{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.min({a}, {b})" 295 296 @staticmethod 297 def maximum(a, b): 298 # return f"hl.max({a}, {b})" <== handles nan wrong 299 if not hasattr(a, "name"): 300 return f"hl.max({a}, {b})" 301 b = f"hl.cast({a.name}.type(), {b})" 302 return f"hl.select(({a}>{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.max({a}, {b})" 303 304 @staticmethod 305 def where(a, b, c): 306 if hasattr(b, "name"): 307 c = f"hl.cast({b.name}.type(), {c})" 308 return f"hl.select({a}, {b}, {c})" 309 310 @staticmethod 311 def cos(x): 312 return f"hl.cos({x})" 313 314 @staticmethod 315 def sin(x): 316 return f"hl.sin({x})" 317 318 @staticmethod 319 def lgamma(x): 320 raise Unsupported("lgamma") 321 322 @staticmethod 323 def erf(x): 324 return f"hl.erf({x})" 325 326 @staticmethod 327 def cosh(x): 328 return f"hl.cosh({x})" 329 330 @staticmethod 331 def sinh(x): 332 return f"hl.sinh({x})" 333 334 @staticmethod 335 def acos(x): 336 return f"hl.acos({x})" 337 338 @staticmethod 339 def acosh(x): 340 return f"hl.acosh({x})" 341 342 @staticmethod 343 def asin(x): 344 return f"hl.asin({x})" 345 346 @staticmethod 347 def asinh(x): 348 return f"hl.asinh({x})" 349 350 @staticmethod 351 def atan2(x, y): 352 return f"hl.atan2({x}, {y})" 353 354 @staticmethod 355 def atan(x): 356 return f"hl.atan({x})" 357 358 @staticmethod 359 def atanh(x): 360 return f"hl.atanh({x})" 361 362 @staticmethod 363 def copysign(x, y): 364 raise Unsupported("copysign") 365 366 @staticmethod 367 def erfinv(x): 368 raise Unsupported("erfinv") 369 370 @staticmethod 371 def hypot(x, y): 372 return f"hl.hypot({x}, {y})" 373 374 @staticmethod 375 def nextafter(x, y): 376 raise Unsupported("nextafter") 377 378 @staticmethod 379 def logical_and(a, b): 380 return f"{a} & {b}" 381 382 @staticmethod 383 def logical_not(a): 384 return f"{a} == 0" 385 386 @staticmethod 387 def logical_or(a, b): 388 return f"{a} | {b}" 389 390 @staticmethod 391 def logical_xor(a, b): 392 return f"({a} ^ {b})" 393 394 @staticmethod 395 def bitwise_and(a, b): 396 return f"{a} & {b}" 397 398 @staticmethod 399 def bitwise_not(a): 400 return f"~{a}" 401 402 @staticmethod 403 def bitwise_or(a, b): 404 return f"{a} | {b}" 405 406 @staticmethod 407 def bitwise_xor(a, b): 408 return f"{a} ^ {b}" 409 410 @staticmethod 411 def bitwise_left_shift(a, b): 412 return f"{a} << {b}" 413 414 @staticmethod 415 def bitwise_right_shift(a, b): 416 return f"{a} >> {b}" 417 418 @staticmethod 419 def rand(seed, offset): 420 return f"halide_helpers.rand({seed}, {offset})" 421 422 @staticmethod 423 def randn(seed, offset): 424 return f"halide_helpers.randn({seed}, {offset})" 425 426 @staticmethod 427 def randint64(seed, offset, low, high): 428 return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})" 429 430 @staticmethod 431 def load_seed(name, offset): 432 return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}" 433 434 @staticmethod 435 def rsqrt(x): 436 # return f"hl.fast_inverse_sqrt({x})" <== accuracy issues 437 return f"1./hl.sqrt({x})" 438 439 @staticmethod 440 def tan(x): 441 return f"hl.tan({x})" 442 443 @staticmethod 444 def tanh(x): 445 return f"hl.tanh({x})" 446 447 @staticmethod 448 def signbit(x): 449 return f"(hl.reinterpret(hl.UInt(32), hl.cast(hl.Float(32), {x})) >> 31) != 0" 450 451 @staticmethod 452 def fmod(a, b): 453 # TODO(jansel): find a better way to do this, builtin % has wrong sign 454 return f"{a} - hl.trunc({a}/{b})*{b}" 455 456 @staticmethod 457 def pow(a, b): 458 return f"hl.pow({a}, {b})" # hl.fast_pow fails accuracy 459 460 @staticmethod 461 def log(x): 462 return f"hl.log({x})" # hl.fast_log fails accuracy 463 464 @staticmethod 465 def isinf(x): 466 # workaround https://github.com/halide/Halide/issues/8309 467 return f"hl.is_inf(hl.cast(hl.Float(32), {x}))" 468 469 @staticmethod 470 def isnan(x): 471 # workaround https://github.com/halide/Halide/issues/8309 472 return f"hl.is_nan(hl.cast(hl.Float(32), {x}))" 473 474 @staticmethod 475 def round(x): 476 return f"hl.round({x})" 477 478 @staticmethod 479 def floor(x): 480 return f"hl.floor({x})" 481 482 @staticmethod 483 def int_truediv(a, b): 484 return f"({a}) / ({b} + hl.f32(0))" 485 486 @staticmethod 487 def floordiv(a, b): 488 # TODO(jansel): find a better ways to do this, the select-based trick from triton.py didn't work 489 return ( 490 f"hl.floor(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})" 491 ) 492 493 @classmethod 494 def sign(cls, x): 495 left = ops.to_dtype(ops.lt("0", x), torch.int8) 496 right = ops.to_dtype(ops.lt(x, "0"), torch.int8) 497 sub = ops.sub(left, right) 498 return f"hl.cast({x.name}.type(), {sub})" 499 500 @staticmethod 501 def trunc(x): 502 return f"hl.trunc({x})" 503 504 @staticmethod 505 def truncdiv(a, b): 506 # this causes crashes with floating point exception, see test_div_zero_dim_cpu 507 # return f"hl.div_round_to_zero({a}, {b})" 508 return ( 509 f"hl.trunc(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})" 510 ) 511 512 @staticmethod 513 def ceil(x): 514 return f"hl.ceil({x})" 515 516 @staticmethod 517 def relu(x): 518 return f"hl.max({x}, 0)" 519 520 @classmethod 521 def index_expr(cls, expr, dtype): 522 index = V.kernel.prepare_indexing(expr) 523 var = V.kernel.genfunc( 524 V.kernel.index_to_str(index), 525 V.kernel.used_dims_from_index(index), 526 bounds=get_bounds_index_expr(expr), 527 ) 528 if dtype not in {torch.int32, torch.int64}: 529 return ops.to_dtype(var, dtype) 530 return var 531 532 @classmethod 533 def indirect_indexing(cls, index_var, size, check=True, wrap_neg=True): 534 # TODO(jansel): Halide only supports 32-bit indexing, we should error on overflow 535 index_var = ops.to_dtype(index_var, torch.int32) 536 index_var = ops.halide_clamp(index_var, size, check) 537 index_var.indirect_indexing_size = size 538 return sympy_index_symbol(str(index_var)) 539 540 @classmethod 541 def halide_clamp(cls, value, size, check): 542 end = V.kernel.kexpr(V.kernel.rename_indexing(size) - 1) 543 if not isinstance(size, (int, sympy.Integer)): 544 end = f"hl.cast({value.name}.type(), {end})" 545 # Skip unsafe_promise_clamped to workaround: https://github.com/halide/Halide/issues/8261#issuecomment-2148835692 546 # return f"hl.unsafe_promise_clamped({value}, 0, {end})" 547 return f"hl.clamp({value}, 0, {end})" 548 549 @staticmethod 550 def masked(mask, body, other): 551 with V.kernel.mask_loads(mask, other) as new_mask: 552 result = body() 553 554 if result.bounds.is_bool: 555 other = bool(other) 556 557 # Take dtype from result to prevent accidental promotion 558 other = V.kernel.genfunc( 559 f"hl.cast({result.name}.type(), {halide_constant(other)})", 560 [], 561 bounds=ValueRanges.wrap(other), 562 ) 563 # TODO(jansel): look into removing the where in the same places triton does 564 return ops.where(new_mask, result, other) 565 566 567# Use mypy to check protocol implemented correctly 568def _typecheck_HalideOverrides(h: HalideOverrides) -> OpsHandler[str]: 569 return h 570 571 572class HalideCSEVariable(CSEVariable): 573 undefined_re = re.compile(r"\b(tmp\d+)\[\?\]") 574 575 def __init__(self, name, bounds: ValueRanges[Any]) -> None: 576 super().__init__(name, bounds) 577 self.used_dims: Optional[List[sympy.Symbol]] = None 578 579 def update_on_args(self, name, args, kwargs): 580 used = set(self.used_dims or ()) 581 for arg in itertools.chain(args, kwargs.values()): 582 if isinstance(arg, HalideCSEVariable): 583 assert arg.used_dims is not None, (name, arg, args) 584 used.update(arg.used_dims) 585 self.used_dims = V.kernel.sort_used_dims(used) 586 587 def index_str(self, dims): 588 if len(dims) == 0: 589 return f"{self.name}[()]" 590 # Reversed since Halide is column major 591 return f"{self.name}[{', '.join(map(str, dims))}]" 592 593 def __str__(self) -> str: 594 if self.used_dims is None: 595 # This will get recomputed and replaced in codegen_kernel() 596 return f"{self.name}[?]" 597 return self.index_str(self.used_dims) 598 599 def subs_str(self, replacements): 600 assert self.used_dims is not None and all( 601 isinstance(x, sympy.Expr) for x in self.used_dims 602 ) 603 return self.index_str([replacements.get(n, n) for n in self.used_dims]) 604 605 606@dataclasses.dataclass 607class DimensionInfo: 608 expr: Optional[sympy.Expr] 609 size: sympy.Expr 610 stride: sympy.Expr 611 612 def __init__(self, expr, size, stride) -> None: 613 super().__init__() 614 if V.graph.sizevars.statically_known_lt(stride, 0): 615 stride = -stride 616 expr = -expr 617 self.expr = expr 618 self.size = size 619 self.stride = stride 620 621 def index_str(self, replacements=None, zero_vars=False): 622 assert self.expr is not None 623 expr = self.expr 624 if zero_vars and expr == 0: 625 return "hl.Var()" 626 if replacements: 627 replacements = {**replacements} 628 for sym in expr.free_symbols: 629 if symbol_is_type(sym, SymT.TMP): 630 assert isinstance(sym, sympy.Symbol) 631 var = V.kernel.lookup_cse_var(sym.name) 632 assert isinstance(var, HalideCSEVariable) 633 replacements[sym] = sympy_index_symbol(var.subs_str(replacements)) 634 expr = sympy_subs(expr, replacements) 635 return V.kernel.index_to_str(expr) 636 637 638def eq(left, right): 639 if V.graph.sizevars.statically_known_equals(left, right): 640 return True 641 try: 642 a = V.graph.sizevars.size_hint(left) 643 b = V.graph.sizevars.size_hint(right) 644 except TypeError: # unbacked symints 645 return False 646 if a == b: 647 V.graph.sizevars.guard_equals(left, right) 648 return a == b 649 650 651def lt(left, right): 652 if V.graph.sizevars.statically_known_lt(left, right): 653 return True 654 try: 655 a = V.graph.sizevars.size_hint(left) 656 b = V.graph.sizevars.size_hint(right) 657 except TypeError: # unbacked symints 658 gcd = sympy.gcd(left, right) 659 if gcd == left: 660 return left != right 661 return False 662 if a < b: 663 V.graph.sizevars.guard_lt(left, right) 664 return a < b 665 666 667class HalideKernel(SIMDKernel): 668 overrides = HalideOverrides # type: ignore[assignment] 669 kexpr: Callable[[sympy.Expr], str] = texpr 670 671 def __init__( 672 self, 673 *groups, 674 index_dtype: str, 675 mutations: Optional[OrderedSet[str]] = None, 676 pid_cache=None, 677 reduction_hint=ReductionHint.DEFAULT, 678 override_persistent_reduction=None, 679 ) -> None: 680 super().__init__( 681 *groups, 682 index_dtype=index_dtype, 683 mutations=mutations, 684 reduction_hint=reduction_hint, 685 pid_cache=pid_cache, 686 override_persistent_reduction=override_persistent_reduction, 687 ) 688 # For halide, we just write directly to the body 689 self.compute = self.body 690 self.loads = self.body 691 self.stores = self.body 692 self.indexing_code_dom = IndentedBuffer() 693 self.needs_dom_indexing = self.inside_reduction 694 self.has_reduction = self.inside_reduction 695 self.buffer_dimensions: Dict[str, List[DimensionInfo]] = {} 696 self.buffer_offsets: Dict[str, sympy.Expr] = {} 697 # {h0: size1, h1: size2, ...} 698 self.halide_vars: Dict[sympy.Symbol, sympy.Expr] = {} 699 # {x0: h0, x1: h1+10*h2, ...} 700 self.index_replacements: Dict[sympy.Expr, sympy.Expr] = {} 701 # {h1: hr1, ...} 702 self.reduction_renames: Dict[sympy.Symbol, sympy.Symbol] = {} 703 # {"i": {h0: hi0}, "o": ...} 704 self.dom_renames: Dict[str, Dict[sympy.Symbol, sympy.Symbol]] = {} 705 # {"in_ptr0": ["in_ptr0_view0"], ...} 706 self.buffer_aliases: Dict[str, List[str]] = defaultdict(list) 707 self.has_indirect_indexing = False 708 709 def create_cse_var(self, name, bounds=None): 710 self.body.writeline(f"{name} = hl.Func({name!r})") 711 return HalideCSEVariable(name, bounds) 712 713 def finalize_indexing(self, indices: Sequence[sympy.Expr]): 714 """ 715 Hook called right before codegen with every index that will be 716 used in the fused kernel. 717 718 This populates self.halide_vars/index_replacements/reduction_renames which is an alternate indexing 719 scheme that avoids using divide and modulus. Instead of xindex/yindex/rindex 720 we base indexing on a larger number of vars whose product combines to those. 721 722 This function populates self.halide_vars, self.index_replacements, and self.reduction_renames 723 """ 724 assert not ( 725 self.index_replacements or self.halide_vars or self.reduction_renames 726 ) 727 size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type] 728 indices = dict.fromkeys(map(super().prepare_indexing, indices)) 729 all_used_symbols = set() 730 sym_to_node = { 731 n.symbol(): n 732 for n in itertools.chain.from_iterable( 733 [tree.nodes.values() for tree in self.range_trees] 734 ) 735 } 736 737 def simplify(expr): 738 return sympy.simplify( 739 V.graph.sizevars.remove_precomputed_replacements(expr) 740 ) 741 742 def visit_modular_indexing(base, divisor, modulus): 743 if base in sym_to_node: 744 node = sym_to_node[base] 745 all_used_symbols.add( 746 node.root.lookup( 747 node.divisor * divisor, 748 V.graph.sizevars.evaluate_min( 749 modulus, FloorDiv(node.length, divisor) 750 ), 751 ).symbol() 752 ) 753 754 def visit_floor_div(base, divisor): 755 if base in sym_to_node: 756 node = sym_to_node[base] 757 all_used_symbols.add( 758 node.root.lookup( 759 node.divisor * divisor, 760 FloorDiv(node.length, divisor), 761 ).symbol() 762 ) 763 764 # first figure out all_used_symbols to do dead symbol elimination 765 for index in indices: 766 if index.has(ModularIndexing): 767 index.replace( 768 ModularIndexing( 769 sympy.Wild("base"), 770 sympy.Wild("divisor"), 771 sympy.Wild("modulus"), 772 ), 773 visit_modular_indexing, 774 ) 775 if index.has(FloorDiv): 776 index.replace( 777 FloorDiv( 778 sympy.Wild("base"), 779 sympy.Wild("divisor"), 780 ), 781 visit_floor_div, 782 ) 783 all_used_symbols.update(super().prepare_indexing(index).free_symbols) 784 785 self.has_indirect_indexing = any( 786 symbol_is_type(sym, SymT.INDIRECT) for sym in all_used_symbols 787 ) 788 789 had_fallback = False 790 for tree in reversed(self.range_trees): 791 nodes = [n for n in tree.nodes.values() if n.symbol() in all_used_symbols] 792 nodes.sort(key=lambda n: size_hint(n.divisor)) 793 if not nodes: 794 nodes.append(tree.lookup(1, tree.numel)) 795 handled_count = 0 796 divisor = sympy.Integer(1) 797 added_sym_size = [] 798 # decide on a minimal set of symbols and put them in self.halide_vars 799 while handled_count < len(nodes) and not eq(tree.numel, divisor): 800 sizes_to_add = [ 801 simplify(n.length) for n in nodes if eq(n.divisor, divisor) 802 ] 803 handled_count += len(sizes_to_add) 804 assert sizes_to_add, nodes 805 end = divisor * functools.reduce( 806 V.graph.sizevars.evaluate_max, sizes_to_add 807 ) 808 sizes_to_add.extend( 809 [ 810 simplify(n.divisor / divisor) 811 for n in nodes 812 if lt(divisor, n.divisor) and lt(n.divisor, end) 813 ] 814 ) 815 while sizes_to_add: 816 next_size = functools.reduce(sympy.gcd, sizes_to_add) 817 if eq(next_size, 1): 818 # sizes share no common factors, e.g [2, 21, 42, 441, 889056] 819 # TODO(jansel): we should just prevent fusion in cases that hit this 820 next_size = simplify(tree.numel / divisor) 821 assert not eq(next_size, 1) 822 sizes_to_add = [] 823 handled_count = len(nodes) 824 had_fallback = True 825 sym = sympy_index_symbol(f"h{len(self.halide_vars)}") 826 if tree.prefix == "r": 827 self.reduction_renames[sym] = sympy_index_symbol( 828 f"hr{len(self.halide_vars)}" 829 ) 830 self.halide_vars[sym] = next_size 831 added_sym_size.append((sym, next_size)) 832 divisor *= next_size 833 new_sizes = [n.length for n in nodes if eq(n.divisor, divisor)] 834 handled_count += len(new_sizes) 835 prior_len = len(sizes_to_add) 836 sizes_to_add = [ 837 sympy.simplify(s / next_size) 838 for s in sizes_to_add 839 if not eq(s, next_size) 840 ] 841 assert len(sizes_to_add) < prior_len or prior_len == 0 842 sizes_to_add.extend(new_sizes) 843 844 # create a mapping to the new set of symbols in self.index_replacements 845 for node in nodes: 846 try: 847 idx = 0 848 divisor = 1 849 while not eq(node.divisor, divisor): 850 sym, size = added_sym_size[idx] 851 idx += 1 852 divisor *= size 853 length = 1 854 expr = sympy.Integer(0) 855 while not eq(node.length, length): 856 sym, size = added_sym_size[idx] 857 idx += 1 858 expr += length * sym 859 length *= size 860 self.index_replacements[node.symbol()] = expr 861 except IndexError: 862 assert had_fallback 863 full_index = sympy.Integer(0) 864 stride = sympy.Integer(1) 865 for sym, size in added_sym_size: 866 full_index += stride * sym 867 stride *= size 868 self.index_replacements[ 869 node.symbol() 870 ] = V.graph.sizevars.simplify_with_ranges( 871 ModularIndexing(full_index, node.divisor, node.length), 872 self.halide_vars, # type: ignore[arg-type] 873 ) 874 875 # codegen the variable definitions 876 for sym in self.halide_vars: 877 self.indexing_code.writeline(f"{sym} = hl.Var({sym.name!r})") 878 if self.reduction_renames: 879 self.codegen_rdom( 880 "rdom", 881 {rv: self.halide_vars[v] for v, rv in self.reduction_renames.items()}, 882 ) 883 884 def setup_dom_indexing(self): 885 """RDom based indexing uses explicit iteration ranges for Func updates""" 886 prefix = "i" if self.inside_reduction else "o" 887 if prefix in self.dom_renames: 888 return self.dom_renames[prefix] 889 890 renames = {} 891 for var in self.halide_vars.keys(): 892 if not self.inside_reduction and var in self.reduction_renames: 893 continue 894 m = re.match(r"^h(\d+)$", var.name) 895 assert m 896 renames[var] = sympy_index_symbol(f"h{prefix}{m.group(1)}") 897 898 self.codegen_rdom( 899 f"{prefix}dom", {rv: self.halide_vars[v] for v, rv in renames.items()} 900 ) 901 902 self.dom_renames[prefix] = renames 903 return renames 904 905 def codegen_rdom(self, name, vars): 906 rsizes = [ 907 f"hl.Range(0, {self.kexpr(self.rename_indexing(size))})" 908 for size in vars.values() 909 ] 910 self.indexing_code.writeline(f"{name} = hl.RDom([{', '.join(rsizes)}])") 911 for i, rsym in enumerate(vars.keys()): 912 self.indexing_code.writeline(f"{rsym} = {name}[{i}]") 913 914 def prepare_indexing( 915 self, 916 index: sympy.Expr, 917 ): 918 index = super().prepare_indexing(index) 919 index = sympy_subs(index, self.index_replacements) 920 return V.graph.sizevars.simplify_with_ranges(index, self.halide_vars) # type: ignore[arg-type] 921 922 def sym_size(self, sym): 923 """The size of an index symbol""" 924 if symbol_is_type(sym, SymT.TMP): 925 return self.lookup_cse_var(sym.name).indirect_indexing_size 926 return self.halide_vars[sym] 927 928 def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool): 929 """Convert address-based indexing into dimensions using self.halide_vars""" 930 symbols = [] 931 for sym in sorted(index.free_symbols, key=lambda x: x.name): # type: ignore[attr-defined] 932 if symbol_is_type(sym, (SymT.HALIDE, SymT.TMP)): 933 symbols.append(sym) 934 else: 935 assert symbol_is_type( 936 sym, 937 ( 938 SymT.UNBACKED_INT, 939 SymT.SIZE, 940 SymT.PRECOMPUTED_SIZE, 941 ), 942 ), sym 943 944 # group the expression by variables used 945 offset = sympy.Integer(0) 946 split_expr = {s: sympy.Integer(0) for s in symbols} 947 split_failed: List[Tuple[List[sympy.Symbol], sympy.Expr]] = [] 948 index = sympy.expand(self.rename_indexing(index)) 949 for part in index.args if isinstance(index, sympy.Add) else [index]: 950 part_vars = [v for v in part.free_symbols if v in split_expr] 951 if len(part_vars) == 0: 952 offset += part 953 elif len(part_vars) == 1: 954 split_expr[part_vars[0]] += part 955 else: 956 new_split_failed = [] 957 for i in range(len(split_failed)): 958 assert split_failed[i] is not None 959 other_vars, other_part = split_failed[i] 960 if set(other_vars) & set(part_vars): 961 part_vars.extend([v for v in other_vars if v not in part_vars]) 962 part += other_part 963 else: 964 new_split_failed.append((other_vars, other_part)) 965 split_failed = [*new_split_failed, (part_vars, part)] 966 967 def expr_to_dimension(expr, syms): 968 expr = sympy.factor(expr) 969 if len(syms) == 1: 970 stride_wild = sympy.Wild("wild", exclude=symbols) 971 m = expr.match(stride_wild * syms[0]) 972 if m: 973 return DimensionInfo( 974 syms[0], self.sym_size(syms[0]), m[stride_wild] 975 ) 976 assert not is_store, expr 977 length = sympy.simplify( 978 sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1 979 ) 980 stride = sympy.Integer(1) 981 if isinstance(expr, sympy.Mul): 982 for term in expr.args: 983 if isinstance(term, sympy.Integer): 984 stride *= term 985 expr = sympy.simplify(expr / term) 986 length = sympy.simplify(sympy.ceiling(length / term)) 987 return DimensionInfo(expr, length, stride) 988 989 # try to turn each group into a strided access 990 dims = [] 991 for syms, expr in split_failed: 992 for v in syms: 993 expr += split_expr.pop(v) 994 dims.append(expr_to_dimension(expr, syms)) 995 for sym, expr in split_expr.items(): 996 dims.append(expr_to_dimension(expr, [sym])) 997 dims.sort(key=lambda d: V.graph.sizevars.size_hint(d.stride, fallback=inf)) # type: ignore[arg-type] 998 999 if not dims: # scalar load/store 1000 if self.has_indirect_indexing: 1001 # workaround https://github.com/halide/Halide/issues/8338 1002 dims.append(DimensionInfo(sympy.Integer(0), 1, 1)) 1003 elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1): 1004 # Halide assumes dimension 0 is stride == 1, so add a dummy dimension 1005 dims.insert( 1006 0, DimensionInfo(sympy.Integer(0), 1 if is_store else dims[0].stride, 1) 1007 ) 1008 1009 if dims and not is_store: 1010 if var in self.buffer_offsets and V.graph.sizevars.statically_known_geq( 1011 offset, self.buffer_offsets[var] 1012 ): 1013 # reuse the existing offset to avoid needing an input alias 1014 self.apply_offset_to_dimension(dims, offset - self.buffer_offsets[var]) 1015 offset = self.buffer_offsets[var] 1016 elif V.graph.sizevars.statically_known_gt( 1017 offset, 0 1018 ): # TODO(jansel): negative offsets 1019 # roll the offset into the dimensions for cleaner indexing 1020 self.apply_offset_to_dimension(dims, offset) 1021 offset = 0 1022 1023 orig_var = var 1024 for i in itertools.count(): 1025 if self.install_dims(var, dims, offset, is_store): 1026 return var, dims 1027 assert not is_store 1028 var = f"{orig_var}_view{i}" 1029 if var not in self.buffer_aliases[orig_var]: 1030 self.buffer_aliases[orig_var].append(var) 1031 1032 def install_dims(self, var, dims, offset, is_store): 1033 """Try to set self.buffer_dimensions[var], return True on success""" 1034 if var not in self.buffer_dimensions: 1035 self.buffer_dimensions[var] = dims 1036 self.buffer_offsets[var] = offset 1037 return True 1038 if self.buffer_offsets[var] != offset or len( 1039 self.buffer_dimensions[var] 1040 ) != len(dims): 1041 return False 1042 if is_store: 1043 return self.buffer_dimensions[var] == dims 1044 for old, new in zip(self.buffer_dimensions[var], dims): 1045 if old.stride != new.stride: 1046 return False 1047 if old.size != new.size or old.expr != new.expr: 1048 old.size = V.graph.sizevars.evaluate_max(old.size, new.size) 1049 old.expr = None 1050 return True 1051 1052 def apply_offset_to_dimension(self, dims, offset): 1053 if offset == 0: 1054 return 1055 for i in reversed(range(len(dims))): 1056 if dims[i].stride == 1 or V.graph.sizevars.statically_known_geq( 1057 offset, dims[i].stride 1058 ): 1059 part = FloorDiv(offset, dims[i].stride) 1060 offset -= part * dims[i].stride 1061 dims[i].expr += part 1062 assert offset == 0 1063 1064 def used_dims_from_index(self, index: sympy.Expr): 1065 """Detect which range trees are used to populate HalideCSEVariable.used_dims""" 1066 used_dims = set() 1067 for sym in index.free_symbols: 1068 assert isinstance(sym, sympy.Symbol) 1069 if symbol_is_type(sym, SymT.TMP): 1070 # indirect indexing 1071 cse_var = self.lookup_cse_var(sym.name) 1072 assert ( 1073 isinstance(cse_var, HalideCSEVariable) 1074 and cse_var.used_dims is not None 1075 ) 1076 used_dims.update(cse_var.used_dims) 1077 elif symbol_is_type(sym, SymT.HALIDE): 1078 used_dims.add(sym) 1079 elif symbol_is_type( 1080 sym, (SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, SymT.INDEX) 1081 ): 1082 pass 1083 else: 1084 raise NotImplementedError(f"unhandled symbol {sym}") 1085 return self.sort_used_dims(used_dims) 1086 1087 def sort_used_dims(self, used_dims): 1088 assert all(isinstance(x, sympy.Expr) for x in used_dims) 1089 ordered = [ 1090 sym 1091 for sym in itertools.chain( 1092 self.halide_vars, self.reduction_renames.values() 1093 ) 1094 if sym in used_dims 1095 ] 1096 assert len(ordered) == len(used_dims) 1097 return ordered 1098 1099 def make_index_str(self, dims, replacements=None, zero_vars=False): 1100 index_str = ", ".join(d.index_str(replacements, zero_vars) for d in dims) 1101 if len(dims) == 0: 1102 index_str = "()" 1103 elif len(dims) == 1: 1104 # workaround for https://github.com/halide/Halide/issues/8299 1105 index_str = f"{index_str}," 1106 return index_str 1107 1108 def load(self, name: str, index: sympy.Expr): 1109 """Codegen a load from an InputBuffer""" 1110 var = self.args.input(name) 1111 index = self.prepare_indexing(index) 1112 var, dims = self.indexing_to_dimensions(var, index, False) 1113 line = f"{var}[{self.make_index_str(dims)}]" 1114 dtype = V.graph.get_dtype(name) 1115 if dtype in (torch.float16, torch.bfloat16): 1116 dtype = torch.float32 1117 line = f"hl.cast(hl.Float(32), {line})" 1118 1119 if self._load_mask: 1120 assert ( 1121 isinstance(self._load_mask, HalideCSEVariable) 1122 and self._load_mask.used_dims is not None 1123 ) 1124 used_dims = {*self.used_dims_from_index(index), *self._load_mask.used_dims} 1125 result = self.newfunc(self.sort_used_dims(used_dims)) 1126 if result.used_dims: 1127 self.body.writeline(f"{result.name}_mask = hl.RDom([hl.Range(0, 1)])") 1128 self.body.writeline(f"{result.name}_mask.where({self._load_mask})") 1129 other = self.kexpr(self._load_other or 0) # type: ignore[arg-type] 1130 self.body.writeline( 1131 f"{result} = hl.cast({halide_type(dtype)}, {other})" 1132 ) 1133 self.body.writeline( 1134 f"{result} = {line} + hl.cast({halide_type(dtype)}, {result.name}_mask)" 1135 ) 1136 else: 1137 # scalar case 1138 self.body.writeline( 1139 f"{result} = hl.select({self._load_mask}, {line}, hl.cast({halide_type(dtype)}, 0))" 1140 ) 1141 return result 1142 else: 1143 return self.genfunc(line, self.used_dims_from_index(index)) 1144 1145 def lookup_cse_var(self, name: str): 1146 return self.cse.varname_map[re.sub(r"\[.*", "", name)] 1147 1148 def store( 1149 self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None 1150 ) -> None: 1151 """Codegen a store to an OutputBuffer""" 1152 assert isinstance(value, HalideCSEVariable) 1153 var = self.args.output(name) 1154 index = self.prepare_indexing(index) 1155 var, dims = self.indexing_to_dimensions(var, index, True) 1156 if self.is_indirect_indexing(index) or mode is not None: 1157 replacements = self.setup_dom_indexing() 1158 index_str = self.make_index_str(dims, replacements) 1159 value_str = value.subs_str(replacements) 1160 undef_dims = (", ".join(["hl.Var()"] * len(dims))) or "()" 1161 self.body.writeline( 1162 DeferredLine(name, f"{var}[{undef_dims}] = hl.undef({var}.type())") 1163 ) 1164 else: 1165 index_str = self.make_index_str(dims, zero_vars=True) 1166 value_str = str(value) 1167 1168 dtype = V.graph.get_dtype(name) 1169 if mode is None: 1170 line = f"{var}[{index_str}] = hl.cast({halide_type(dtype)}, {value_str})" 1171 elif mode == "atomic_add": 1172 line = f"{var}[{index_str}] += hl.cast({halide_type(dtype)}, {value_str})" 1173 else: 1174 raise NotImplementedError(f"store mode={mode}") 1175 self.body.writeline(DeferredLine(name, line)) 1176 1177 def reduction( 1178 self, 1179 dtype: torch.dtype, 1180 src_dtype: torch.dtype, 1181 reduction_type: ReductionType, 1182 value: Union[CSEVariable, Tuple[CSEVariable, ...]], 1183 ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: 1184 """Codegen a reduction operation""" 1185 assert self.inside_reduction 1186 assert not self._load_mask 1187 cache_key = (src_dtype, reduction_type, value) 1188 if cache_key in self.cse.reduction_cache: 1189 return self.cse.reduction_cache[cache_key] 1190 1191 if isinstance(value, tuple): 1192 assert reduction_type == "welford_combine" 1193 self.cse.reduction_cache[ 1194 cache_key 1195 ] = result_tuple = self.welford_combine_impl(*value) 1196 return result_tuple 1197 1198 assert isinstance(value, HalideCSEVariable) and value.used_dims is not None 1199 reduction_vars = {*self.reduction_renames} 1200 result_var = self.newfunc( 1201 [v for v in value.used_dims if v not in reduction_vars] 1202 ) 1203 if reduction_vars - {*value.used_dims}: 1204 value = self.genfunc( 1205 f"{value}", self.sort_used_dims({*value.used_dims, *reduction_vars}) 1206 ) 1207 value_str = value.subs_str(self.reduction_renames) 1208 default = ir.Reduction.default_accumulator(reduction_type, src_dtype) 1209 acc_type = halide_acc_type(dtype) 1210 1211 if reduction_type in ("argmax", "argmin"): 1212 index = f"{result_var.name}_{reduction_type}" 1213 self.body.writeline(f"{index} = hl.{reduction_type}(rdom, {value_str})") 1214 # turn the N-D argmax index into a 1-D one 1215 parts = [] 1216 stride = 1 1217 for i, sym in enumerate(self.reduction_renames): 1218 parts.append(f"{index}[{i}]") 1219 if stride != 1: 1220 parts[-1] += f"*{stride}" 1221 stride *= self.halide_vars[sym] 1222 self.body.writeline(f"{result_var} = {' + '.join(parts)}") 1223 elif reduction_type == "welford_reduce": 1224 # TODO(jansel): implement welford_reduce without fallback 1225 result_var = self.welford_reduce_fallback(dtype, value) 1226 else: 1227 combine_fn = get_reduction_combine_fn(reduction_type, acc_type) 1228 with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))): 1229 combine_str = combine_fn(result_var, value_str) # type: ignore[arg-type] 1230 default_str = f"hl.cast({acc_type}, {halide_constant(default)})" 1231 self.body.writeline(f"{result_var} = {default_str}") 1232 self.body.writeline(f"{result_var} = {combine_str}") 1233 1234 self.cse.reduction_cache[cache_key] = result_var 1235 return result_var 1236 1237 def welford_combine_impl(self, mean, m2, weight): 1238 assert isinstance(mean, HalideCSEVariable) and mean.used_dims is not None 1239 assert isinstance(m2, HalideCSEVariable) and m2.used_dims is not None 1240 assert isinstance(weight, HalideCSEVariable) and weight.used_dims is not None 1241 used_dims = {*mean.used_dims, *m2.used_dims, *weight.used_dims} or { 1242 *self.halide_vars 1243 } 1244 used_dims -= {*self.reduction_renames} 1245 result_var = self.newfunc(self.sort_used_dims(used_dims)) 1246 default = [f"hl.cast({x.name}.type(), 0)" for x in (mean, m2, weight)] 1247 pfx = result_var.name 1248 self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(default)}])") 1249 self.body.writeline(f"{pfx}_mean_1 = {result_var}[0]") 1250 self.body.writeline(f"{pfx}_m2_1 = {result_var}[1]") 1251 self.body.writeline(f"{pfx}_weight_1 = {result_var}[2]") 1252 self.body.writeline(f"{pfx}_mean_2 = {mean.subs_str(self.reduction_renames)}") 1253 self.body.writeline(f"{pfx}_m2_2 = {m2.subs_str(self.reduction_renames)}") 1254 self.body.writeline( 1255 f"{pfx}_weight_2 = {weight.subs_str(self.reduction_renames)}" 1256 ) 1257 self.body.writeline(f"{pfx}_delta = {pfx}_mean_2 - {pfx}_mean_1") 1258 self.body.writeline(f"{pfx}_new_weight = {pfx}_weight_1 + {pfx}_weight_2") 1259 self.body.writeline( 1260 f"{pfx}_w2_over_w = hl.select({pfx}_new_weight == 0.0, 0.0, {pfx}_weight_2 / {pfx}_new_weight)" 1261 ) 1262 update = [ 1263 f"{pfx}_mean_1 + {pfx}_delta * {pfx}_w2_over_w", 1264 f"{pfx}_m2_1 + {pfx}_m2_2 + {pfx}_delta * {pfx}_delta * {pfx}_weight_1 * {pfx}_w2_over_w", 1265 f"{pfx}_new_weight", 1266 ] 1267 self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(update)}])") 1268 1269 unpacked = [] 1270 for i in range(3): 1271 unpacked.append(self.newfunc(result_var.used_dims)) 1272 self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]") 1273 return tuple(unpacked) 1274 1275 def scan( 1276 self, 1277 dtypes: Tuple[torch.dtype, ...], 1278 combine_fn: Callable[ 1279 [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...] 1280 ], 1281 values_orig: Tuple[CSEVariable, ...], 1282 ) -> Tuple[CSEVariable, ...]: 1283 assert self.inside_reduction 1284 assert len(dtypes) == len(values_orig) 1285 values: List[HalideCSEVariable] = [] 1286 all_used_dims = set() 1287 for value in values_orig: 1288 assert isinstance(value, HalideCSEVariable) and value.used_dims is not None 1289 if set(value.used_dims) & set(self.reduction_renames): 1290 values.append(value) 1291 else: 1292 values.append( 1293 self.genfunc( 1294 f"{value}", [*value.used_dims, [*self.reduction_renames][:1]] 1295 ) 1296 ) 1297 all_used_dims.update(value.used_dims) 1298 result_var = self.newfunc(self.sort_used_dims(all_used_dims)) 1299 assert result_var.used_dims and set(result_var.used_dims) & set( 1300 self.reduction_renames 1301 ) 1302 initial = [ 1303 f"hl.cast({halide_acc_type(dtype)}, {value})" 1304 for dtype, value in zip(dtypes, values) 1305 ] 1306 1307 length = self.kexpr(self.rename_indexing(self.range_trees[-1].numel)) 1308 scan_dom = f"{result_var.name}_rdom" 1309 scan = f"{scan_dom}.x" 1310 self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])") 1311 1312 assert ( 1313 len(self.reduction_renames) == 1 1314 ), "multi-dimensional scan not implemented" 1315 (scan_var,) = [*self.reduction_renames] # type: ignore[misc] 1316 scan_renames_cur = {scan_var: sympy_index_symbol(scan)} 1317 scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1} 1318 1319 if len(values) == 1: 1320 1321 def maybe_tuple(x): 1322 return x[0] 1323 1324 read_left = [result_var.subs_str(scan_renames_pri)] 1325 read_right = [result_var.subs_str(scan_renames_cur)] 1326 else: 1327 1328 def maybe_tuple(x): 1329 return f"hl.Tuple([{', '.join(x)}])" 1330 1331 read_left = [ 1332 result_var.subs_str(scan_renames_pri) + f"[{i}]" 1333 for i in range(len(values)) 1334 ] 1335 read_right = [ 1336 result_var.subs_str(scan_renames_cur) + f"[{i}]" 1337 for i in range(len(values)) 1338 ] 1339 1340 self.body.writeline(f"{result_var} = {maybe_tuple(initial)}") 1341 1342 # Disable CSE for update fn 1343 with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))): 1344 combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type] 1345 self.body.writeline( 1346 f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}" 1347 ) 1348 1349 if len(values) == 1: 1350 return (result_var,) 1351 1352 unpack_vars = [self.newfunc(self.sort_used_dims(all_used_dims)) for _ in values] 1353 for i, v in enumerate(unpack_vars): 1354 self.body.writeline(f"{v} = {result_var}[{i}]") 1355 return tuple(unpack_vars) 1356 1357 def genfunc( 1358 self, line, used_dims, *, bounds=ValueRanges.unknown() 1359 ) -> HalideCSEVariable: 1360 var = self.cse.generate(self.body, line, bounds=bounds) 1361 assert isinstance(var, HalideCSEVariable) 1362 var.used_dims = used_dims 1363 return var 1364 1365 def newfunc(self, used_dims) -> HalideCSEVariable: 1366 var = self.cse.newvar() 1367 assert isinstance(var, HalideCSEVariable) 1368 var.used_dims = used_dims 1369 return var 1370 1371 def halide_buffer_numel(self, name: str): 1372 """ 1373 We map all tensors to 1D buffers in Halide since Halide has trouble representing some strides that PyTorch 1374 supports. If there are gaps in the underlying layout the numel we pass to Halide includes the gaps while 1375 PyTorch's numel excludes them. 1376 """ 1377 return V.graph.get_buffer(name).get_layout().storage_size() 1378 1379 def halide_argdefs(self): 1380 """ 1381 Halide requires scalar inputs before outputs, so need to reorder args. 1382 """ 1383 1384 def arg_order(arg_tuple): 1385 call_str, arg = arg_tuple 1386 if isinstance(arg, SizeArg): 1387 return 1 # this would normally be at the end, move it to middle 1388 elif "out_ptr" in arg.name: 1389 return 2 1390 else: 1391 assert "in_ptr" in arg.name 1392 return 0 1393 1394 result = [] 1395 _, a, b, _ = self.args.python_argdefs() 1396 for call_str, arg in sorted(zip(a, b), key=arg_order): 1397 result.append((call_str, arg)) 1398 if isinstance(arg, TensorArg): 1399 assert arg.offset == 0 and arg.alias_of is None 1400 for alias in self.buffer_aliases.get(arg.name, ()): 1401 result.append( 1402 ( 1403 None, 1404 TensorArg( 1405 alias, 1406 arg.buffer, 1407 arg.dtype, 1408 arg.offset, 1409 alias_of=arg.name, 1410 ), 1411 ) 1412 ) 1413 return result 1414 1415 def halide_kernel_meta(self) -> HalideMeta: 1416 """Compute metadata required by codecache.py""" 1417 argtypes = [] 1418 for _, arg in self.halide_argdefs(): 1419 if isinstance(arg, SizeArg): 1420 shape = None 1421 stride = None 1422 offset = None 1423 dtype = "long" 1424 else: 1425 shape = [ 1426 cexpr(self.rename_indexing(x.size)) 1427 for x in self.buffer_dimensions[arg.name] 1428 ] 1429 stride = [ 1430 cexpr(self.rename_indexing(x.stride)) 1431 for x in self.buffer_dimensions[arg.name] 1432 ] 1433 assert len(shape) == len(stride) 1434 offset = cexpr(self.buffer_offsets[arg.name]) 1435 dtype = f"{DTYPE_TO_CPP[arg.dtype]}*" 1436 argtypes.append( 1437 HalideInputSpec( 1438 dtype, 1439 arg.name, 1440 shape=shape, 1441 stride=stride, 1442 offset=offset, 1443 alias_of=arg.alias_of, 1444 ) 1445 ) 1446 1447 current_device = V.graph.scheduler.get_current_device_or_throw() 1448 if current_device.type == "cpu": 1449 target = [config.halide.cpu_target] 1450 schduler = config.halide.scheduler_cpu 1451 scheduler_flags = { 1452 "parallelism": parallel_num_threads(), 1453 } 1454 cuda_device = None 1455 else: 1456 assert current_device.type == "cuda", "only cpu/cuda supported" 1457 assert current_device.index <= 0, "only default device supported" 1458 target = [config.halide.gpu_target] 1459 schduler = config.halide.scheduler_cuda 1460 capability = torch.cuda.get_device_properties(current_device) 1461 if "cuda_capability" not in target[0]: 1462 for major, minor in [(8, 6), (8, 0), (7, 5), (7, 0), (6, 1)]: 1463 if capability.major >= major and capability.minor >= minor: 1464 target.append(f"cuda_capability_{major}{minor}") 1465 break 1466 target.append("user_context") 1467 scheduler_flags = { 1468 "parallelism": capability.multi_processor_count, 1469 # TODO(jansel): explore other flags, see: 1470 # grep parser.parse ~/Halide/src/autoschedulers/anderson2021/AutoSchedule.cpp 1471 } 1472 cuda_device = max(0, current_device.index) 1473 1474 # strict_float is requires for correctness 1475 target.append("strict_float") 1476 1477 # without this we will initialize cuda once per kernel and hit errors 1478 target.append("no_runtime") 1479 1480 if not config.halide.asserts: 1481 target.append("no_asserts") 1482 1483 if config.halide.debug: 1484 target.append("debug") 1485 1486 if "64" in self.index_dtype: 1487 # TODO(jansel): it is unclear if this does anything, since input sizes are still int32 1488 target.append("large_buffers") 1489 1490 return HalideMeta( 1491 argtypes, 1492 target="-".join(target), 1493 scheduler=schduler, 1494 scheduler_flags=scheduler_flags, 1495 cuda_device=cuda_device, 1496 ) 1497 1498 def codegen_kernel(self, name=None): 1499 """Called at the end to generate a final kernel string""" 1500 if self.args.inplace_buffers: 1501 raise Unsupported("inplace_buffers") 1502 meta = self.halide_kernel_meta() # ensure needed args are added early 1503 code = IndentedBuffer() 1504 code.splice( 1505 """ 1506 import halide as hl 1507 from torch._inductor.runtime import halide_helpers 1508 from math import inf, nan 1509 1510 @hl.generator(name="kernel") 1511 class Kernel: 1512 """, 1513 strip=True, 1514 ) 1515 code.do_indent() 1516 for _, arg in self.halide_argdefs(): 1517 if isinstance(arg, SizeArg): 1518 code.writeline(f"{arg.name} = hl.InputScalar({self.index_dtype})") 1519 else: 1520 assert arg.buffer, arg 1521 argcls = "hl.OutputBuffer" if "out" in arg.name else "hl.InputBuffer" 1522 argtype = halide_type(arg.dtype) 1523 ndim = len(self.buffer_dimensions[arg.name]) 1524 code.writeline(f"{arg.name} = {argcls}({argtype}, {ndim})") 1525 code.splice( 1526 """ 1527 def generate(g): 1528 """ 1529 ) 1530 code.do_indent() 1531 for _, arg in self.halide_argdefs(): 1532 code.writeline(f"{arg.name} = g.{arg.name}") 1533 for old, new in self.args.aliases(): 1534 code.writeline(f"{old} = {new}") 1535 code.splice(self.indexing_code) 1536 1537 def update_index(m): 1538 var = self.cse.varname_map[m.group(1)] 1539 assert var.used_dims is not None, var 1540 return str(var) 1541 1542 for line in self.body._lines: 1543 if isinstance(line, str): 1544 # fill in missing indices 1545 line = HalideCSEVariable.undefined_re.sub(update_index, line) 1546 code.writeline(line) 1547 code.writeline("") 1548 code.writeline("assert g.using_autoscheduler()") 1549 1550 for _, arg in self.halide_argdefs(): 1551 # fallback=1 below because halide requires buffers to be at least as large as the estimates 1552 # This causes crashes if our estimate is greater than the vector length 1553 # https://github.com/halide/Halide/issues/3103 1554 if isinstance(arg, SizeArg): 1555 hint = V.graph.sizevars.size_hint(arg.expr, fallback=1) 1556 code.writeline(f"{arg.name}.set_estimate({hint})") 1557 else: 1558 dims = self.buffer_dimensions[arg.name] 1559 range_hints = [] 1560 for i, dim in enumerate(dims): 1561 hint = self._autoscheduler_workarounds( 1562 V.graph.sizevars.size_hint(dim.size, fallback=1), dims 1563 ) 1564 range_hints.append(f"hl.Range(0, {hint})") 1565 if "out" not in arg.name: 1566 code.writeline(f"{arg.name}.dim({i}).set_min(0)") 1567 try: 1568 code.writeline( 1569 f"{arg.name}.dim({i}).set_stride({int(dim.stride)})" 1570 ) 1571 except TypeError: 1572 pass # not integer 1573 try: 1574 code.writeline( 1575 f"{arg.name}.dim({i}).set_extent({int(dim.size)})" 1576 ) 1577 except TypeError: 1578 pass # not integer 1579 code.writeline(f"{arg.name}.set_estimates([{', '.join(range_hints)}])") 1580 1581 code.do_unindent(2) 1582 code.splice( 1583 """ 1584 if __name__ == "__main__": 1585 hl.main() 1586 """.rstrip(), 1587 ) 1588 if meta.scheduler: 1589 code.splice( 1590 f""" 1591 else: 1592 hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r}) 1593 target = hl.Target({meta.target!r}) 1594 autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r}) 1595 with hl.GeneratorContext(target, autoscheduler): 1596 gen = Kernel() 1597 pipeline = gen._build_pipeline() 1598 # gen.compile_to_callable() does not run the autoscheduler 1599 pipeline.apply_autoscheduler(target, autoscheduler) 1600 kernel = pipeline.compile_to_callable([ 1601 gen._get_input_parameter(a.name)._to_argument() 1602 for a in gen._get_arginfos() 1603 if a.dir == hl.ArgInfoDirection.Input 1604 ], target) 1605 """, 1606 strip=True, 1607 ) 1608 else: 1609 code.splice( 1610 f""" 1611 else: 1612 with hl.GeneratorContext(hl.Target({meta.target!r})): 1613 kernel = Kernel().compile_to_callable() 1614 """, 1615 strip=True, 1616 ) 1617 return code.getvalue() 1618 1619 @staticmethod 1620 def _autoscheduler_workarounds(n, dims): 1621 if ( 1622 len(dims) == 1 1623 and config.halide.scheduler_cuda == "Anderson2021" 1624 and V.graph.scheduler.get_current_device_or_throw().type == "cuda" 1625 ): 1626 # workaround https://github.com/halide/Halide/issues/8246 1627 n = max(2, n) 1628 return n 1629 1630 def call_kernel(self, name: str, node=None): 1631 """Codegen a call to this kernel""" 1632 wrapper = V.graph.wrapper_code 1633 call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None] 1634 current_device = V.graph.scheduler.get_current_device_or_throw() 1635 if current_device.type == "cuda": 1636 stream_name = wrapper.write_get_raw_stream(current_device.index, V.graph) 1637 call_args.append(stream_name) 1638 wrapper.generate_kernel_call( 1639 name, 1640 call_args, 1641 cuda=False, # grid/stream is handled internally in halide 1642 ) 1643 1644 def generate_assert(self, check): 1645 return False # TODO(jansel): support asserts 1646 1647 def check_bounds( 1648 self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool 1649 ): 1650 pass # TODO(jansel): support asserts 1651 1652 1653class HalideScheduling(SIMDScheduling): 1654 int32_type = "hl.Int(32)" 1655 # TODO(jansel): Halide doesn't actually support 64 bit indexing... 1656 int64_type = "hl.Int(64)" 1657 kernel_type = HalideKernel # type: ignore[arg-type] 1658 1659 @classmethod 1660 def get_backend_features(cls, device: torch.device): 1661 result = dict.fromkeys( 1662 [ 1663 BackendFeature.TUPLE_REDUCTION, 1664 BackendFeature.PREFER_STORE_LOOP_ORDER, 1665 BackendFeature.REDUCE_TO_SINGLE_ELEMENT, 1666 ] 1667 ) 1668 if config.halide.scan_kernels: 1669 result[BackendFeature.SCAN] = None 1670 return result 1671 1672 def define_kernel(self, src_code, node_schedule, kernel): 1673 """Codegen kernel definition to go in output wrapper code""" 1674 wrapper = V.graph.wrapper_code 1675 if src_code in wrapper.src_to_kernel: 1676 kernel_name = wrapper.src_to_kernel[src_code] 1677 else: 1678 kernel_name = f"halide_kernel_{wrapper.next_kernel_suffix()}" 1679 wrapper.src_to_kernel[src_code] = kernel_name 1680 wrapper.add_import_once( 1681 "from torch._inductor.runtime.hints import HalideMeta, HalideInputSpec" 1682 ) 1683 1684 compile_wrapper = IndentedBuffer() 1685 compile_wrapper.writeline( 1686 f"async_compile.halide({kernel.halide_kernel_meta()!r}, '''" 1687 ) 1688 compile_wrapper.splice(src_code, strip=True) 1689 compile_wrapper.writeline("''')") 1690 1691 origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) 1692 metadata_comment = f"{origins}\n{detailed_origins}" 1693 wrapper.define_kernel( 1694 kernel_name, compile_wrapper.getvalue(), metadata_comment 1695 ) 1696 if is_metric_table_enabled("kernel_metadata"): 1697 log_kernel_metadata(kernel_name, "", src_code) 1698 1699 return kernel_name 1700