1# mypy: allow-untyped-defs 2import contextlib 3import copy 4import functools 5import math 6import sys 7from collections import namedtuple 8from typing import Any, Callable, Dict, List, Optional, Set, Tuple 9from unittest.mock import patch 10 11import sympy 12 13import torch 14from torch._prims_common import is_integer_dtype 15from torch.utils._sympy.symbol import symbol_is_type, SymT 16from torch.utils._sympy.value_ranges import ValueRanges 17 18from .. import ir 19from ..loop_body import LoopBody 20from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs 21from ..virtualized import ops, OpsValue, V 22from .common import ( 23 CSEVariable, 24 deduce_output_dtype_by_name, 25 ExprPrinter, 26 Kernel, 27 KernelArgs, 28 OptimizationContext, 29) 30 31 32DTYPE_TO_CPP = { 33 torch.float32: "float", 34 torch.float64: "double", 35 torch.float16: "half", 36 torch.int64: "int64_t", 37 torch.int32: "int32_t", 38 torch.int16: "int16_t", 39 torch.int8: "int8_t", 40 torch.uint64: "uint64_t", 41 torch.uint32: "uint32_t", 42 torch.uint16: "uint16_t", 43 torch.uint8: "uint8_t", 44 torch.bool: "bool", 45 torch.bfloat16: "bfloat16", 46 torch.complex64: "c10::complex<float>", 47 torch.float8_e4m3fn: "float8_e4m3fn", 48 torch.float8_e5m2: "float8_e5m2", 49} 50 51DTYPE_TO_ATEN = { 52 torch.float32: "at::kFloat", 53 torch.float64: "at::kDouble", 54 torch.float16: "at::kHalf", 55 torch.int64: "at::kLong", 56 torch.int32: "at::kInt", 57 torch.int16: "at::kShort", 58 torch.int8: "at::kChar", 59 torch.uint64: "at::kUInt64", 60 torch.uint32: "at::kUInt32", 61 torch.uint16: "at::kUInt16", 62 torch.uint8: "at::kByte", 63 torch.uint32: "at::kUInt32", 64 torch.uint64: "at::kUInt64", 65 torch.bool: "at::kBool", 66 torch.bfloat16: "at::kBFloat16", 67 torch.complex32: "at::kComplexHalf", 68 torch.complex64: "at::kComplexFloat", 69 torch.complex128: "at::kComplexDouble", 70 torch.float8_e4m3fn: "at::kFloat8_e4m3fn", 71 torch.float8_e5m2: "at::kFloat8_e5m2", 72 torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", 73 torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", 74} 75 76DEVICE_TO_ATEN = { 77 "cpu": "at::kCPU", 78 "cuda": "at::kCUDA", 79} 80 81LAYOUT_TO_ATEN = { 82 torch.strided: "at::kStrided", 83 torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined] 84} 85 86_IS_WINDOWS = sys.platform == "win32" 87 88INDEX_TYPE = "int64_t" 89 90GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) 91 92 93def get_promote_dtype(args): 94 return ( 95 functools.reduce( 96 torch.promote_types, # type: ignore[arg-type] 97 [n.dtype for n in args if isinstance(n, CppCSEVariable)], 98 ) 99 if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable)) 100 else None # not enough info to calculate the promote dtype 101 ) 102 103 104def promote_args(new_args): 105 def promote_arg(arg, promote_type): 106 if ( 107 isinstance(arg, CppCSEVariable) 108 and arg.dtype 109 and promote_type 110 and arg.dtype != promote_type 111 ): 112 arg = ops.to_dtype(arg, promote_type) 113 arg = arg.value if isinstance(arg, OpsValue) else arg 114 arg.dtype = promote_type 115 return arg 116 117 promote_type = get_promote_dtype(new_args) 118 promote_fn = functools.partial( 119 promote_arg, 120 promote_type=promote_type, 121 ) 122 if ( 123 all( 124 new_arg.dtype is not None 125 for new_arg in new_args 126 if isinstance(new_arg, CppCSEVariable) 127 ) 128 and promote_type 129 ): 130 new_args = list(map(promote_fn, new_args)) 131 return new_args 132 133 134def get_opt_ctx(node: torch.fx.Node) -> OptimizationContext: 135 return node.meta.get(OptimizationContext.key, None) 136 137 138def get_current_node_opt_ctx() -> OptimizationContext: 139 assert V.interpreter.current_node 140 return get_opt_ctx(V.interpreter.current_node) 141 142 143def deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs): 144 if ( 145 output_dtype := deduce_output_dtype_by_name( 146 name, 147 *args, 148 **kwargs, 149 ) 150 ) is not None: 151 return output_dtype 152 elif name == "masked": 153 # <TODO> Leslie: perhaps we can also deduce the masked dtype by 154 # inputs' CppCseVariable like other. Let's check it if any 155 # unexpected failures. 156 assert ( 157 hasattr(V.interpreter, "current_node") 158 and V.interpreter.current_node.target.startswith("masked_subblock") 159 and get_current_node_opt_ctx() is not None 160 ) 161 return get_current_node_opt_ctx().dtype 162 else: 163 # deduce output dtype by inputs' dtype 164 assert all( 165 arg.dtype is not None for arg in args if isinstance(arg, CppCSEVariable) 166 ) 167 return functools.reduce( 168 torch.promote_types, # type: ignore[arg-type] 169 [arg.dtype for arg in args if isinstance(arg, CppCSEVariable)], 170 ) 171 172 173class CppCSEVariable(CSEVariable): 174 def __init__(self, name, bounds: ValueRanges[Any]) -> None: 175 super().__init__(name, bounds) 176 self.is_vec = False 177 self.dtype: Optional[torch.dtype] = None 178 self.dependent_itervars: Set[sympy.Symbol] = set() 179 180 def __repr__(self) -> str: 181 return ( 182 f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, " 183 f"dependent_itervars: {self.dependent_itervars})" 184 ) 185 186 def update_on_args(self, name, args, kwargs): 187 if name == "load": 188 # args[2] is index 189 self._set_dependent_itervars(args[2]) 190 else: 191 # propagate relevant itervars and is_vec from args 192 self.dependent_itervars.update( 193 *[ 194 arg.dependent_itervars 195 for arg in args 196 if isinstance(arg, CppCSEVariable) 197 ] 198 ) 199 if name == "index_expr": 200 self._set_dependent_itervars(args[0]) 201 if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)): 202 self.is_vec = True 203 # NOTE [Deduce dtype of CppCSEVariable at runtime] 204 self.dtype = deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs) 205 assert self.dtype is not None 206 207 def _set_dependent_itervars(self, index: sympy.Expr): 208 """ 209 Set the relevant itervars for this variable based on the `index` expression. 210 This includes the itervars directly used in the `index` as well as relevant itervars 211 of other cse variables used in the `index`. 212 """ 213 for s in index.free_symbols: 214 if s in V.kernel.itervars: 215 self.dependent_itervars.add(s) # type: ignore[arg-type] 216 elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined] 217 self.dependent_itervars.update( 218 V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined] 219 ) 220 221 def depends_on(self, itervar: sympy.Symbol): 222 return itervar in self.dependent_itervars 223 224 225class CppPrinter(ExprPrinter): 226 def _print_Integer(self, expr): 227 return ( 228 f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L" 229 ) 230 231 def _print_Where(self, expr): 232 c = self.paren(self.doprint(expr.args[0])) 233 p = self.paren(self.doprint(expr.args[1])) 234 q = self.paren(self.doprint(expr.args[2])) 235 return f"{c} ? {p} : {q}" 236 237 def _print_ModularIndexing(self, expr): 238 x, div, mod = expr.args 239 x = self.paren(self.doprint(x)) 240 if div != 1: 241 div = self.paren(self.doprint(div)) 242 if expr.is_integer: 243 x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))" 244 else: 245 x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))" 246 mod = self.paren(self.doprint(mod)) 247 return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})" 248 249 def _print_FloorDiv(self, expr): 250 x, div = expr.args 251 x = self.paren(self.doprint(x)) 252 div = self.paren(self.doprint(div)) 253 if expr.is_integer: 254 return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))" 255 return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))" 256 257 def _print_floor(self, expr): 258 assert len(expr.args) == 1 259 r = f"std::floor({self._print(expr.args[0])})" 260 return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r 261 262 def _print_FloorToInt(self, expr): 263 assert len(expr.args) == 1 264 r = f"std::floor({self._print(expr.args[0])})" 265 return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r 266 267 def _print_TruncToInt(self, expr): 268 assert len(expr.args) == 1 269 r = f"std::trunc({self._print(expr.args[0])})" 270 return f"static_cast<{INDEX_TYPE}>({r})" 271 272 def _print_TruncToFloat(self, expr): 273 assert len(expr.args) == 1 274 return f"std::trunc({self._print(expr.args[0])})" 275 276 def _print_ToFloat(self, expr): 277 assert len(expr.args) == 1 278 return f"static_cast<double>({self._print(expr.args[0])})" 279 280 # TODO: This is wrong if one of the inputs is negative. This is hard to 281 # tickle though, as the inputs are typically positive (and if we can prove 282 # they are positive, we will have used Mod instead, for which this codegen 283 # is right). 284 def _print_PythonMod(self, expr): 285 return " % ".join(map(self.paren, map(self._print, expr.args))) 286 287 def _print_CMod(self, expr): 288 return " % ".join(map(self.paren, map(self._print, expr.args))) 289 290 def _print_IntTrueDiv(self, expr): 291 lhs, rhs = expr.args 292 # TODO: This is only accurate up to 2**53 293 return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})" 294 295 # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT 296 # use std::pow, that operates on floats 297 def _print_PowByNatural(self, expr): 298 raise NotImplementedError( 299 f"_print_PowByNatural not implemented for {type(self)}" 300 ) 301 302 def _print_FloatTrueDiv(self, expr): 303 lhs, rhs = expr.args 304 return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" 305 306 def _print_FloatPow(self, expr): 307 base, exp = expr.args 308 return f"std::pow({self._print(base)}, {self._print(exp)})" 309 310 def _print_Pow(self, expr): 311 # Uses float constants to perform FP div 312 base, exp = expr.args 313 base = self._print(base) 314 315 if exp == 0.5 or exp == -0.5: 316 return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" 317 if exp.is_integer: 318 exp = int(exp) 319 if exp > 0: 320 r = "*".join([self.paren(base)] * exp) 321 elif exp < 0: 322 r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp))) 323 else: # exp == 0 324 r = "1.0" 325 326 return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r 327 else: 328 # TODO: float vs double 329 return f"std::pow({base}, {float(exp)})" 330 331 def _print_Rational(self, expr): 332 # Uses float constants to perform FP div 333 if expr.q == 1: 334 r = f"{expr.p}" 335 else: 336 r = f"{expr.p}.0/{expr.q}.0" 337 return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r 338 339 def _print_ceiling(self, expr): 340 assert len(expr.args) == 1 341 r = f"std::ceil({self._print(expr.args[0])})" 342 return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r 343 344 def _print_CeilToInt(self, expr): 345 assert len(expr.args) == 1 346 r = f"std::ceil({self._print(expr.args[0])})" 347 return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r 348 349 def _print_Min(self, expr): 350 args = [self._print(a) for a in expr.args] 351 if len(args) == 2: 352 return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" 353 else: 354 # Initializer list overload 355 il = "{" + ", ".join(args) + "}" 356 return f"std::min({il})" 357 358 def _print_Max(self, expr): 359 args = [self._print(a) for a in expr.args] 360 if len(args) == 2: 361 return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" 362 else: 363 # Initializer list overload 364 il = "{" + ", ".join(args) + "}" 365 return f"std::max({il})" 366 367 def _print_Abs(self, expr): 368 assert len(expr.args) == 1 369 return f"std::abs({self._print(expr.args[0])})" 370 371 def _print_OpaqueUnaryFn_cos(self, expr): 372 assert len(expr.args) == 1 373 return f"std::cos({self._print(expr.args[0])})" 374 375 def _print_OpaqueUnaryFn_cosh(self, expr): 376 assert len(expr.args) == 1 377 return f"std::cosh({self._print(expr.args[0])})" 378 379 def _print_OpaqueUnaryFn_acos(self, expr): 380 assert len(expr.args) == 1 381 return f"std::acos({self._print(expr.args[0])})" 382 383 def _print_OpaqueUnaryFn_sin(self, expr): 384 assert len(expr.args) == 1 385 return f"std::sin({self._print(expr.args[0])})" 386 387 def _print_OpaqueUnaryFn_sinh(self, expr): 388 assert len(expr.args) == 1 389 return f"std::sinh({self._print(expr.args[0])})" 390 391 def _print_OpaqueUnaryFn_asin(self, expr): 392 assert len(expr.args) == 1 393 return f"std::asin({self._print(expr.args[0])})" 394 395 def _print_OpaqueUnaryFn_tan(self, expr): 396 assert len(expr.args) == 1 397 return f"std::tan({self._print(expr.args[0])})" 398 399 def _print_OpaqueUnaryFn_tanh(self, expr): 400 assert len(expr.args) == 1 401 return f"std::tanh({self._print(expr.args[0])})" 402 403 def _print_OpaqueUnaryFn_atan(self, expr): 404 assert len(expr.args) == 1 405 return f"std::atan({self._print(expr.args[0])})" 406 407 def _print_OpaqueUnaryFn_sqrt(self, expr): 408 return f"std::sqrt({self._print(expr.args[0])})" 409 410 def _print_RoundToInt(self, expr): 411 assert len(expr.args) == 1 412 # TODO: dispatch to llrint depending on index type 413 return f"std::lrint({self._print(expr.args[0])})" 414 415 def _print_RoundDecimal(self, expr): 416 assert len(expr.args) == 2 417 number, ndigits = expr.args 418 if number.is_integer: 419 # ndigits < 0 should have been filtered by the sympy function 420 assert ndigits < 0 421 raise ValueError( 422 f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." 423 ) 424 return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})" 425 426 def _print_BooleanTrue(self, expr): 427 return "true" 428 429 def _print_BooleanFalse(self, expr): 430 return "false" 431 432 433# A function to print, useful for printing sympy symbols. 434cexpr = CppPrinter().doprint 435 436 437def cexpr_index(index): 438 return f"static_cast<{INDEX_TYPE}>({cexpr(index)})" 439 440 441def value_to_cpp(value, cpp_type): 442 if value == float("-inf"): 443 return f"-std::numeric_limits<{cpp_type}>::infinity()" 444 elif value == float("inf"): 445 return f"std::numeric_limits<{cpp_type}>::infinity()" 446 elif isinstance(value, bool): 447 return f"static_cast<{cpp_type}>({str(value).lower()})" 448 elif math.isnan(value): 449 return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" 450 else: 451 return f"static_cast<{cpp_type}>({repr(value)})" 452 453 454def rewrite_index_for_function( 455 localize_buffer_handler: "LocalizeBufferHandler", 456 index: sympy.Expr, 457 global_buf_name: str, 458): 459 # Local buffer at the inner dimensions 460 snode = V.graph.scheduler.name_to_buf[global_buf_name].defining_op 461 local_buf = localize_buffer_handler.global_to_local[global_buf_name] 462 scheduler_nodes = snode.get_nodes() 463 _, (group, reduction_group) = max( 464 scheduler_nodes, key=lambda x: int(x.is_reduction()) 465 ).group 466 call_ranges = tuple(group) + tuple(reduction_group) 467 indices_to_keep = [ 468 f"x{len(call_ranges) - (idx + 1)}" 469 for idx in range(len(local_buf.get_layout().size)) 470 ] 471 sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined] 472 replacements = {} 473 for x in sorted_symbols: 474 if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined] 475 # Only keep index used by local buffer 476 replacements[x] = sympy.core.numbers.Zero() 477 index = sympy_subs(index, replacements) # type: ignore[arg-type] 478 return index 479 480 481def rewrite_index_for_nodes( 482 localize_buffer_handler: "LocalizeBufferHandler", 483 index: sympy.Expr, 484 global_buf_name: str, 485): 486 used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)} 487 index_vars = [] 488 local_buf = localize_buffer_handler.global_to_local[global_buf_name] 489 for i in range(len(local_buf.get_size())): 490 var = sympy_index_symbol_with_prefix(SymT.INDEX, i) 491 index_vars.append(var if var in used_vars else 0) 492 index = local_buf.layout.make_indexer()(index_vars) 493 return index 494 495 496class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined] 497 def __init__( 498 self, 499 inner, 500 global_to_local: Dict[str, ir.Buffer], 501 rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr], 502 ) -> None: 503 super().__init__(inner) 504 self.global_to_local = global_to_local 505 self.rewrite_index = rewrite_index 506 507 def localize(self, name: str, index: sympy.Expr): 508 if self.global_to_local and name in self.global_to_local: 509 assert self.rewrite_index is not None 510 index = self.rewrite_index(self, index, name) 511 name = self.global_to_local[name].get_name() 512 return name, index 513 514 def load(self, name: str, index: sympy.Expr): 515 return self._inner.load(*self.localize(name, index)) 516 517 def store(self, name, index, value, mode=None): 518 local_buffer_name, local_buffer_index = self.localize(name, index) 519 res = self._inner.store(local_buffer_name, local_buffer_index, value, mode) 520 if ( 521 self.global_to_local 522 and name in self.global_to_local 523 and isinstance(V.kernel, Kernel) 524 ): 525 # Remove name of local buffer from Kernel.store_buffer_names 526 # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store. 527 V.kernel.store_buffer_names.discard(local_buffer_name) 528 return res 529 530 def store_reduction(self, name, index, value): 531 return self._inner.store_reduction(*self.localize(name, index), value) 532 533 534class LocalBufferContext: 535 """ 536 This class creates a context that helps to generate code involving Inductor IR with 537 function local buffers. These buffers are constructed during the codegen process and 538 are used to store intermediate results such as local accumulators. We do not want to 539 add them to `V.graph` since they are not global and we do not want to add them as 540 function arguments either. So we patch the codegen processes under this scope to support 541 these buffers without exposure to the outside world. 542 """ 543 544 def __init__(self, kernel_args: KernelArgs) -> None: 545 self.kernel_args = kernel_args 546 self.exit_stack = contextlib.ExitStack() 547 # map local buffer name to local buffer 548 self.local_buffers: Dict[str, ir.Buffer] = {} 549 # map global buffer name to global buffer 550 self.global_buffers: Dict[str, ir.Buffer] = {} 551 # map global buffer name to local buffer 552 self.global_to_local: Dict[str, ir.Buffer] = {} 553 554 def __enter__(self): 555 self.exit_stack.__enter__() 556 original_get_dtype = V.graph.get_dtype 557 558 def get_dtype(name): 559 if name in self.local_buffers: 560 return self.local_buffers[name].get_dtype() 561 return original_get_dtype(name) 562 563 self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype)) 564 565 original_input = self.kernel_args.input 566 567 def input(name): 568 if name in self.local_buffers: 569 return name 570 return original_input(name) 571 572 self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input)) 573 574 original_output = self.kernel_args.output 575 576 def output(name): 577 if name in self.local_buffers: 578 return name 579 return original_output(name) 580 581 self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output)) 582 583 # Set current LocalBufferContext into V 584 self.exit_stack.enter_context(V.set_local_buffer_context(self)) 585 586 return self 587 588 def __exit__(self, exc_type, exc_val, exc_tb): 589 self.local_buffers.clear() 590 self.exit_stack.__exit__(exc_type, exc_val, exc_tb) 591 592 def add_local_buffer( 593 self, local_buffer: ir.Buffer, global_buffers: Optional[List[ir.Buffer]] = None 594 ): 595 assert local_buffer.get_name() not in self.local_buffers 596 self.local_buffers[local_buffer.get_name()] = local_buffer 597 if global_buffers: 598 for global_buffer in global_buffers: 599 global_buffer_name = global_buffer.get_name() 600 assert ( 601 global_buffer_name not in self.global_buffers 602 and global_buffer_name not in self.global_to_local 603 ) 604 self.global_buffers[global_buffer_name] = global_buffer 605 self.global_to_local[global_buffer_name] = local_buffer 606 V.graph.removed_buffers.add(global_buffer_name) 607 608 def localize_function( 609 self, 610 fn: Callable[..., Any], 611 rewrite_index: Callable[ 612 ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr 613 ] = rewrite_index_for_function, 614 ): 615 def inner(*args, **kwargs): 616 with V.set_ops_handler( 617 LocalizeBufferHandler( 618 V.get_ops_handler(), 619 global_to_local=self.global_to_local, 620 rewrite_index=rewrite_index, 621 ) 622 ): 623 return fn(*args, **kwargs) 624 625 return inner 626 627 def localize_nodes( 628 self, 629 nodes: List[ir.IRNode], 630 rewrite_index: Callable[ 631 ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr 632 ] = rewrite_index_for_nodes, 633 ) -> List[ir.IRNode]: 634 """ 635 Given `local_buf` and `global_buf` registered in current `LocalBufferContext` 636 though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf` 637 for the given `nodes` and returns a new list of IR nodes that work on `local_buf` 638 instead of `global_buf`, i.e., all the loads and stores are redirected to 639 `local_buf`. This helps the fused loops to work on smaller-sized local buffers 640 for better data locality. 641 642 The the data access of `local_buf` is assumed to be contiguous with the 643 same order as the `global_buf`. 644 """ 645 assert len(nodes) > 0 646 647 def wrap_inner_fn_for_node(node: ir.IRNode): 648 loops = node.data if isinstance(node, ir.ComputedBuffer) else node 649 assert isinstance(loops, ir.Loops) 650 new_loops = copy.copy(loops) 651 if isinstance(node, ir.ComputedBuffer): 652 new_node = ir.ComputedBuffer( 653 node.get_name(), node.get_layout(), new_loops 654 ) 655 else: 656 new_node = new_loops # type: ignore[assignment] 657 658 new_loops.inner_fn = self.localize_function( 659 new_loops.inner_fn, 660 rewrite_index, 661 ) 662 return new_node 663 664 return [wrap_inner_fn_for_node(node) for node in nodes] 665 666 667def unify_mask_base_type( 668 buffer: IndentedBuffer, 669 vars: Tuple[CSEVariable, ...], 670 dtype=torch.float, 671): 672 """ 673 Given list of cse variables, 674 Cast each to new mask base dtype and return casted cse variable. 675 """ 676 new_vars = ( 677 V.kernel.cse.generate( 678 buffer, 679 f"{V.kernel._get_mask_cast(var, dtype)}", 680 ) 681 for var in vars 682 ) 683 return new_vars 684 685 686def codegen_rand(offset, code, rand_function, dst_dtype=torch.float32): 687 assert is_integer_dtype(offset.dtype) 688 code.writeline("[&]()") 689 with code.indent(): 690 code.writeline( 691 f"{DTYPE_TO_CPP[offset.dtype]} offset[{V.kernel.tiling_factor}];" 692 ) 693 code.writeline(f"{DTYPE_TO_CPP[dst_dtype]} result[{V.kernel.tiling_factor}];") 694 code.writeline(f"{offset}.store(offset);") 695 code.writeline( 696 f"for( {DTYPE_TO_CPP[offset.dtype]} offset_idx = 0; offset_idx < {V.kernel.tiling_factor}; offset_idx++ )" 697 ) 698 with code.indent(): 699 code.writeline(rand_function) 700 num_vectors = V.kernel._get_num_vectors(dtype=dst_dtype) 701 if num_vectors == 1: 702 code.writeline( 703 f"return at::vec::Vectorized<{DTYPE_TO_CPP[dst_dtype]}>::loadu(result);" 704 ) 705 else: 706 code.writeline( 707 f"return at::vec::VectorizedN<{DTYPE_TO_CPP[dst_dtype]}, {num_vectors}>::loadu(result);" 708 ) 709 code.writeline("()") 710 return code 711 712 713def get_gemm_template_output_and_compute_dtype(input_dtype): 714 if input_dtype == torch.uint8: 715 return (torch.int32, torch.int32) 716 else: 717 return (torch.float32, torch.float32) 718 719 720def create_epilogue_with_attr(input_buffer, attr, **kwargs): 721 input_loader = input_buffer.make_loader() 722 dtype = input_buffer.get_dtype() 723 if attr == "relu": 724 725 def inner_fn(index): 726 input = input_loader(index) 727 zero = ops.constant(0, dtype) 728 return ops.maximum(input, zero) 729 730 elif attr == "gelu": 731 assert "algorithm" in kwargs 732 if kwargs["algorithm"] == "none": 733 734 def inner_fn(index): 735 input = input_loader(index) 736 if dtype != torch.float: 737 input = ops.to_dtype(input, torch.float) 738 half = ops.constant(0.5, torch.float) 739 one = ops.constant(1.0, torch.float) 740 const = ops.constant(0.7071067811865476, torch.float) 741 result = input * half * (ops.erf(input * const) + one) 742 if dtype != torch.float: 743 result = ops.to_dtype(result, dtype) 744 return result 745 746 else: 747 assert kwargs["algorithm"] == "tanh" 748 749 def inner_fn(index): 750 input = input_loader(index) 751 if dtype != torch.float: 752 input = ops.to_dtype(input, torch.float) 753 half = ops.constant(0.5, torch.float) 754 one = ops.constant(1.0, torch.float) 755 const1 = ops.constant(0.7978845608028654, torch.float) 756 const2 = ops.constant(0.044715, torch.float) 757 result = ( 758 half 759 * input 760 * ( 761 one 762 + ops.tanh(const1 * (input + const2 * input * input * input)) 763 ) 764 ) 765 if dtype != torch.float: 766 result = ops.to_dtype(result, dtype) 767 return result 768 769 elif attr == "swish": 770 771 def inner_fn(index): 772 input = input_loader(index) 773 result = input * ops.sigmoid(input) 774 return result 775 776 elif attr == "sigmoid": 777 778 def inner_fn(index): 779 return ops.sigmoid(input_loader(index)) 780 781 elif attr == "tanh": 782 783 def inner_fn(index): 784 return ops.tanh(input_loader(index)) 785 786 elif attr == "hardswish" or attr == "hardsigmoid": 787 788 def hardsigmoid_float(input): 789 zero = ops.constant(0, torch.float) 790 six = ops.constant(6, torch.float) 791 three = ops.constant(3, torch.float) 792 one_over_six = ops.constant(0.16666666666666666, torch.float) 793 max = ops.maximum(input + three, zero) 794 min = ops.minimum(max, six) 795 return min * one_over_six 796 797 def inner_fn(index): 798 input = input_loader(index) 799 if dtype != torch.float: 800 input = ops.to_dtype(input, torch.float) 801 result = hardsigmoid_float(input) 802 if attr == "hardswish": 803 result = input * result 804 if dtype != torch.float: 805 result = ops.to_dtype(result, dtype) 806 return result 807 808 elif attr == "leaky_relu": 809 assert "scalars" in kwargs 810 assert len(kwargs["scalars"]) == 1 811 negative_slope = kwargs["scalars"][0] 812 813 def inner_fn(index): 814 input = input_loader(index) 815 if dtype != torch.float: 816 input = ops.to_dtype(input, torch.float) 817 zero = ops.constant(0, torch.float) 818 result = ops.where( 819 input > zero, input, input * ops.constant(negative_slope, torch.float) 820 ) 821 if dtype != torch.float: 822 result = ops.to_dtype(result, dtype) 823 return result 824 825 elif attr == "hardtanh": 826 assert "scalars" in kwargs 827 assert len(kwargs["scalars"]) == 2 828 min_value = kwargs["scalars"][0] 829 max_value = kwargs["scalars"][1] 830 831 def inner_fn(index): 832 input = input_loader(index) 833 if dtype != torch.float: 834 input = ops.to_dtype(input, torch.float) 835 result = ops.minimum( 836 ops.maximum(input, ops.constant(min_value, torch.float)), 837 ops.constant(max_value, torch.float), 838 ) 839 if dtype != torch.float: 840 result = ops.to_dtype(result, dtype) 841 return result 842 843 elif attr in ["add", "sub", "mul"]: 844 assert "other" in kwargs 845 other = kwargs["other"] 846 num_input_dims = len(input_buffer.get_size()) 847 num_other_dims = len(other.get_size()) 848 dims_diff = num_input_dims - num_other_dims 849 other_loader = other.make_loader() 850 851 def inner_fn(index): 852 op = getattr(ops, attr) 853 if dims_diff != 0: 854 return op(input_loader(index), other_loader(index[dims_diff:])) 855 else: 856 return op(input_loader(index), other_loader(index)) 857 858 elif attr == "bias_add": 859 assert "other" in kwargs 860 assert "beta" in kwargs 861 assert "dtype" in kwargs 862 beta = kwargs["beta"] 863 other = kwargs["other"] 864 dtype = kwargs["dtype"] 865 bias_loader = other.make_loader() 866 867 def inner_fn(index): 868 bias = bias_loader(index) 869 input = input_loader(index) 870 if beta != 1: 871 result = ops.constant(beta, torch.float) * bias + input 872 else: 873 result = bias + input 874 return result 875 876 else: 877 raise ValueError(f"Unsupported epilogue attribute: {attr}") 878 return ir.Pointwise( 879 device=input_buffer.get_device(), 880 dtype=dtype, 881 inner_fn=inner_fn, 882 ranges=input_buffer.get_size(), 883 ) 884 885 886def _get_loop_body(fn_list): 887 if all(isinstance(fn, LoopBody) for fn in fn_list): 888 loop_bodies = fn_list 889 else: 890 if hasattr(fn_list[0], "original_fn"): 891 # For the case of local buffer, we wrap the fn with localize_function 892 assert all(hasattr(fn, "original_fn") for fn in fn_list) 893 assert all( 894 isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list 895 ) 896 loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list] 897 else: 898 assert all(isinstance(fn, functools.partial) for fn in fn_list) 899 assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list) 900 loop_bodies = [fn.args[0]._body for fn in fn_list] 901 assert loop_bodies is not None 902 return loop_bodies 903 904 905def _get_dtype_from_loopbodies(loop_bodies): 906 dtypes = set() 907 for loop_body in loop_bodies: 908 graphs = [loop_body.root_block.graph] + [ 909 body.graph for body in list(loop_body.subblocks.values()) 910 ] 911 for graph in graphs: 912 for node in graph.nodes: 913 if node.op != "call_method": 914 continue 915 dtypes.add(node.meta[OptimizationContext.key].dtype) 916 return dtypes 917