1# mypy: allow-untyped-defs 2import contextlib 3import dataclasses 4import functools 5import itertools 6import logging 7import math 8import operator 9import re 10from enum import auto, Enum 11from itertools import chain 12from typing import ( 13 Any, 14 Callable, 15 ClassVar, 16 Dict, 17 List, 18 NamedTuple, 19 Optional, 20 Tuple, 21 Union, 22) 23 24import sympy 25from sympy.printing.printer import Printer 26 27import torch 28import torch.fx 29from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND 30from torch.utils import _pytree as pytree 31from torch.utils._ordered_set import OrderedSet 32from torch.utils._sympy.numbers import int_oo 33from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT 34from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges 35 36from .. import config, metrics 37from ..utils import ( 38 DeferredLineBase, 39 generate_assert, 40 IndentedBuffer, 41 sympy_dot, 42 sympy_subs, 43 unique, 44) 45from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V 46 47 48schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") 49 50 51def data_type_logger(msg): 52 if schedule_log.isEnabledFor(logging.DEBUG): 53 schedule_log.debug("Data type propagation: %s", msg) 54 55 56@dataclasses.dataclass 57class WorkspaceArg: 58 """A temporary buffer used for a single kernel, then discarded. 59 60 Not registered as a traditional buffer since there are no users, 61 so it would be dead code eliminated. 62 """ 63 64 nbytes: sympy.Expr 65 zero_fill: bool 66 67 68@dataclasses.dataclass 69class TensorArg: 70 name: str 71 buffer: str 72 dtype: torch.dtype 73 offset: sympy.Expr = sympy.Integer(0) # c++ only 74 alias_of: Optional[str] = None # halide only 75 76 77@dataclasses.dataclass 78class SizeArg: 79 name: str 80 expr: sympy.Expr 81 82 @property 83 def alias_of(self): 84 return None 85 86 87@dataclasses.dataclass 88class DeviceCodegen: 89 scheduling: Any 90 wrapper_codegen: type 91 cpp_wrapper_codegen: type = type(None) 92 93 94KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg] 95 96device_codegens: Dict[str, DeviceCodegen] = {} 97 98 99class DeviceOpOverrides: 100 def import_get_raw_stream_as(self, name): 101 raise NotImplementedError 102 103 def set_device(self, device_idx): 104 raise NotImplementedError 105 106 def synchronize(self): 107 raise NotImplementedError 108 109 def device_guard(self, device_idx): 110 raise NotImplementedError 111 112 113device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {} 114 115 116# The code generated by Inductor consists of two main parts: kernel code and wrapper code. 117# For any new backend looking to integrate with Inductor, customization of these two main 118# parts are necessary to generate its specific code. 119# 120# Kernel code generation is determined by different Scheduling. Consequently, a new 121# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, 122# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. 123# 124# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code 125# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen, 126# and override specific member functions to create backend-specific Python wrapper code. 127# 128# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part 129# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces 130# provide flexibility to the backend. A backend can choose to implement these classes from scratch, 131# or reuse them by extending and overriding as necessary. And Inductor provides the registration API, 132# register_backend_for_device, to equip a new backend at runtime. 133# 134# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. 135# This backend can be used as a reference: 136# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 137def register_backend_for_device( 138 device: str, 139 device_scheduling: Any, 140 device_wrapper_codegen: type, 141 device_cpp_wrapper_codegen: type = type(None), 142): 143 device_codegens[device] = DeviceCodegen( 144 device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen 145 ) 146 147 148class BackendFeature(Enum): 149 FOREACH = auto() 150 BUCKETIZE = auto() 151 INPLACE_BUFFERS = auto() 152 MASKED_SCATTER_WITH_INDEX = auto() 153 SCAN = auto() 154 SORT = auto() 155 TUPLE_REDUCTION = auto() 156 PREFER_STORE_LOOP_ORDER = auto() 157 TRITON_TEMPLATES = auto() 158 REDUCE_TO_SINGLE_ELEMENT = auto() 159 160 161def get_backend_features(device: Union[torch.device, str]): 162 init_backend_registration() 163 if isinstance(device, torch.device): 164 device_type = device.type 165 else: 166 assert isinstance(device, str) 167 device_type = device 168 device = torch.device(device_type) 169 scheduling = get_scheduling_for_device(device_type) 170 return scheduling(None).get_backend_features(device) 171 172 173def has_backend_feature(device, feature): 174 """See also V.graph.has_feature""" 175 assert isinstance(feature, BackendFeature) 176 return feature in get_backend_features(device) 177 178 179def get_scheduling_for_device(device: str): 180 return device_codegens[device].scheduling if device in device_codegens else None 181 182 183def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): 184 if device in device_codegens: 185 wrapper_codegen_obj: DeviceCodegen = device_codegens[device] 186 return ( 187 wrapper_codegen_obj.cpp_wrapper_codegen 188 if cpp_wrapper 189 else wrapper_codegen_obj.wrapper_codegen 190 ) 191 else: 192 return None 193 194 195@functools.lru_cache(None) 196def init_backend_registration(): 197 from .cpp import CppScheduling 198 from .cpp_wrapper_cpu import CppWrapperCpu 199 from .cpp_wrapper_cuda import CppWrapperCuda 200 from .cuda_combined_scheduling import CUDACombinedScheduling 201 from .halide import HalideScheduling 202 from .triton import TritonScheduling 203 from .wrapper import WrapperCodeGen 204 205 if get_scheduling_for_device("cpu") is None: 206 cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling} 207 register_backend_for_device( 208 "cpu", 209 lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs), 210 WrapperCodeGen, 211 CppWrapperCpu, 212 ) 213 214 if get_scheduling_for_device("cuda") is None: 215 # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation 216 cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling} 217 register_backend_for_device( 218 "cuda", 219 lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs), 220 WrapperCodeGen, 221 CppWrapperCuda, 222 ) 223 224 if get_scheduling_for_device("xpu") is None: 225 register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen) 226 227 private_backend = torch._C._get_privateuse1_backend_name() 228 if ( 229 private_backend != "privateuseone" 230 and get_scheduling_for_device(private_backend) is None 231 ): 232 from torch.utils.backend_registration import _get_custom_mod_func 233 234 try: 235 device_scheduling = _get_custom_mod_func("Scheduling") 236 wrapper_codegen = _get_custom_mod_func("WrapperCodeGen") 237 cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodeGen") 238 if device_scheduling and wrapper_codegen and cpp_wrapper_codegen: 239 register_backend_for_device( 240 private_backend, 241 device_scheduling, 242 wrapper_codegen, 243 cpp_wrapper_codegen, 244 ) 245 except RuntimeError: 246 pass 247 248 249def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes): 250 from ..ir import FlexibleLayout 251 252 # added contiguous index prevents reordering 253 return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] 254 255 256def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides): 257 device_op_overrides_dict[device] = device_op_overrides 258 259 260def get_device_op_overrides(device: str): 261 assert isinstance(device, str) 262 263 if not device_op_overrides_dict.keys(): 264 from .cuda import device_op_overrides # noqa: F401 265 from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 266 267 if device in device_op_overrides_dict.keys(): 268 return device_op_overrides_dict[device] 269 270 271@functools.lru_cache(None) 272def boolean_ops(): 273 return ( 274 "isinf", 275 "isnan", 276 "logical_not", 277 "signbit", 278 "le", 279 "lt", 280 "ge", 281 "gt", 282 "eq", 283 "ne", 284 ) 285 286 287DTYPE_TO_COMPUTATION_DTYPE = { 288 torch.bfloat16: torch.float, 289 torch.float16: torch.float, 290 **{ 291 dtype: dtype 292 for dtype in [ 293 torch.bool, 294 torch.float32, 295 torch.float64, 296 torch.int8, 297 torch.int16, 298 torch.int32, 299 torch.int64, 300 torch.uint8, 301 torch.uint16, 302 torch.uint32, 303 torch.uint64, 304 ] 305 }, 306} 307 308 309def deduce_output_dtype_by_name( 310 op_name: str, 311 *args, 312 **kwargs, 313) -> Optional[torch.dtype]: 314 """ 315 Given op name and a list of input dtypes, deduce the output dtype 316 """ 317 if op_name in boolean_ops(): 318 return torch.bool 319 elif op_name in ( 320 "to_dtype", 321 "index_expr", 322 ): 323 return kwargs["dtype"] if "dtype" in kwargs else args[-1] 324 elif op_name in ( 325 "rand", 326 "randn", 327 ): 328 return torch.float 329 elif op_name in ( 330 "get_index", 331 "randint64", 332 "load_seed", 333 ): 334 return torch.int64 335 elif op_name == "reduction": 336 return kwargs["dtype"] if "dtype" in kwargs else args[1] 337 elif op_name == "constant": 338 dtype = kwargs["dtype"] if "dtype" in kwargs else args[-1] 339 return DTYPE_TO_COMPUTATION_DTYPE[dtype] # type: ignore[index] 340 elif op_name in ( 341 "load", 342 "store", 343 "store_reduction", 344 ): 345 buf_name = args[1] 346 return V.graph.get_dtype(buf_name) # type: ignore[arg-type] 347 elif op_name == "to_dtype_bitcast": 348 return kwargs["dtype"] if "dtype" in kwargs else args[-2] 349 return None 350 351 352class DataTypePropagation: 353 def __init__(self, body) -> None: 354 self.body = body 355 self.graphs: Dict[Union[Callable[..., Any], str], Any] = { 356 "root": body.root_block.graph 357 } 358 for k, v in body.subblocks.items(): 359 self.graphs[k] = v.graph 360 361 def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): 362 inputs = node.all_input_nodes 363 input_nodes = [ 364 n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" 365 ] 366 if len(input_nodes) == 0: 367 return None 368 369 all_input_nodes_propagated = all( 370 OptimizationContext.key in n.meta 371 and n.meta[OptimizationContext.key].dtype is not None 372 for n in input_nodes 373 ) 374 if not all_input_nodes_propagated: 375 return None 376 377 return functools.reduce( 378 torch.promote_types, 379 [n.meta[OptimizationContext.key].dtype for n in input_nodes], 380 ) 381 382 def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node): 383 sub_graph = self.graphs[node.target] 384 dtype = self.propagate_graph(sub_graph) 385 assert dtype 386 return dtype 387 388 def deduce_node_dtype(self, node: torch.fx.Node): 389 if node.op == "placeholder": 390 return None 391 392 if node.target == "output" and len(node.args) != 1: 393 # we can infer output node if it only have 1 arg 394 return None 395 396 if node.target == operator.getitem: 397 return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type] 398 399 assert isinstance(node.target, str) 400 401 if node.target.startswith("masked_subblock"): 402 return self.deduce_node_dtype_by_subgraph(node) 403 404 if ( 405 output_dtype := deduce_output_dtype_by_name( 406 node.target, 407 *node.args, 408 **node.kwargs, 409 ) 410 ) is not None: 411 return output_dtype 412 413 return self.deduce_node_dtype_by_inputs(node) 414 415 def propagate_graph(self, graph: torch.fx.Graph): 416 assert graph.nodes 417 graph_dtype = None 418 # For masked_subblock, we use output's dtype to represent 419 # the dtype of this subgraph. For other cases, graph_dtype 420 # might be None 421 for node in graph.nodes: 422 if OptimizationContext.key in node.meta: 423 opt_ctx = node.meta[OptimizationContext.key] 424 else: 425 opt_ctx = OptimizationContext() 426 427 opt_ctx.dtype = self.deduce_node_dtype(node) 428 node.meta[OptimizationContext.key] = opt_ctx 429 if node.target == "output": 430 graph_dtype = opt_ctx.dtype 431 return graph_dtype 432 433 def propagate(self): 434 self.propagate_graph(self.graphs["root"]) 435 436 @classmethod 437 def propagate_loopbody(cls, body): 438 return cls(body).propagate() 439 440 @classmethod 441 def propagate_scheduler_node(cls, node): 442 from ..loop_body import LoopBody 443 from ..scheduler import SchedulerNode 444 445 assert isinstance(node, SchedulerNode) 446 assert isinstance(node._body, LoopBody) 447 DataTypePropagation.propagate_loopbody(node._body) 448 449 450# This printer contains rules that are supposed to be generic for both C/C++ and 451# Python 452class ExprPrinter(Printer): 453 @staticmethod 454 def paren(string): 455 def all_in_parens(string): 456 if string[0] != "(" or len(string) < 2: 457 return False 458 count = 1 459 for i, char in enumerate(string[1:]): 460 if char == "(": 461 count += 1 462 elif char == ")": 463 count -= 1 464 if count == 0 and i != len(string) - 2: 465 return False 466 assert count == 0 467 return True 468 469 if ( 470 isinstance(string, CSEVariable) 471 or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE) 472 or re.match(r"^\([^)]*\)$", string, re.IGNORECASE) 473 or string == "" 474 ): 475 return string 476 # don't put extra parens for strings that are already wrapped in parens 477 if all_in_parens(string): 478 return string 479 return f"({string})" 480 481 def _print_Relational(self, expr): 482 return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) 483 484 def _print_Mul(self, expr): 485 return "*".join(map(self.paren, map(self._print, expr.args))) 486 487 def _print_Add(self, expr): 488 return " + ".join(map(self.paren, map(self._print, expr.args))) 489 490 # NB: this is OK to put here, because Mod is only defined for positive 491 # numbers, and so across C/Python its behavior is consistent 492 def _print_Mod(self, expr): 493 return " % ".join(map(self.paren, map(self._print, expr.args))) 494 495 def _print_FloatTrueDiv(self, expr): 496 lhs, rhs = expr.args 497 return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" 498 499 def _print_CleanDiv(self, expr): 500 return self._print_FloorDiv(expr) 501 502 def _print_Identity(self, expr): 503 return self._print(expr.args[0]) 504 505 def _print_GreaterThan(self, expr): 506 # GreaterThan: >= 507 # StrictlyGreaterThan: > 508 # Go figure... 509 return " >= ".join(map(self.paren, map(self._print, expr.args))) 510 511 # NB: The C implementation is injected into codegen at 512 # torch/_inductor/codegen/wrapper.py 513 def _print_align(self, expr): 514 assert len(expr.args) == 1 515 return f"align({self._print(expr.args[0])})" 516 517 # This must be implemented because sympy will collect x * x into Pow(x, 2), without 518 # any explicit intervention. We print it just like x * x, notably, we 519 # never generate sympy.Pow with floats. 520 # 521 # NB: this pow by natural, you should never have used builtin sympy.pow 522 # for FloatPow, and a symbolic exponent should be PowByNatural. These 523 # means exp is guaranteed to be integer. 524 def _print_Pow(self, expr): 525 base, exp = expr.args 526 base = self._print(base) 527 assert exp == int(exp), exp 528 exp = int(exp) 529 assert exp >= 0 530 if exp > 0: 531 return "*".join([self.paren(base)] * exp) 532 else: # exp == 0 533 return "1" 534 535 # Explicit NotImplemented functions are to prevent default sympy printing 536 # behavior, which will just barf out ToFloat(...) to your IR. The error 537 # message is better here because it tells you which printer class it needs 538 # to go in. 539 540 def _print_ToFloat(self, expr): 541 raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") 542 543 def _print_Infinity(self, expr): 544 raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") 545 546 def _print_NegativeInfinity(self, expr): 547 raise NotImplementedError( 548 f"_print_NegativeInfinity not implemented for {type(self)}" 549 ) 550 551 def _print_FloorDiv(self, expr): 552 raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") 553 554 def _print_PythonMod(self, expr): 555 raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") 556 557 def _print_IntTrueDiv(self, expr): 558 raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") 559 560 def _print_PowByNatural(self, expr): 561 raise NotImplementedError( 562 f"_print_PowByNatural not implemented for {type(self)}" 563 ) 564 565 def _print_FloatPow(self, expr): 566 raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") 567 568 def _print_TruncToInt(self, expr): 569 raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") 570 571 def _print_RoundToInt(self, expr): 572 raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") 573 574 def _print_RoundDecimal(self, expr): 575 raise NotImplementedError( 576 f"_print_RoundDecimal not implemented for {type(self)}" 577 ) 578 579 # NB: Some float operations are INTENTIONALLY not implemented for 580 # printers. You can implement them as a quick unblock, but it is better 581 # to ask yourself why we haven't done this computation in the Tensor 582 # universe instead 583 584 def _print_TruncToFloat(self, expr): 585 raise NotImplementedError( 586 f"_print_TruncToFloat not implemented for {type(self)}" 587 ) 588 589 def doprint(self, expr, *, simplify: bool = True): 590 # TODO: why are people passing strings to the printer here :think: 591 if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): 592 expr = V.graph.sizevars.simplify(expr) 593 return super().doprint(expr) 594 595 596class PythonPrinter(ExprPrinter): 597 def _print_ToFloat(self, expr): 598 assert len(expr.args) == 1 599 return f"float({self._print(expr.args[0])})" 600 601 def _print_ModularIndexing(self, expr): 602 x, div, mod = expr.args 603 x = self.paren(self.doprint(x)) 604 div = self.paren(self.doprint(div)) 605 mod = self.paren(self.doprint(mod)) 606 if div != "1": 607 x = f"({x} // {div})" 608 return f"{x} % {mod}" 609 610 def _print_Infinity(self, expr): 611 return "math.inf" 612 613 def _print_NegativeInfinity(self, expr): 614 return "-math.inf" 615 616 # WARNING: this is dangerous for Triton, which has C-style modulus 617 def _print_PythonMod(self, expr): 618 return " % ".join(map(self.paren, map(self._print, expr.args))) 619 620 # WARNING: this is dangerous for Triton, which has C-style modulus 621 def _print_FloorDiv(self, expr): 622 x, div = expr.args 623 x = self.paren(self.doprint(x)) 624 div = self.paren(self.doprint(div)) 625 return f"({x} // {div})" 626 627 # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python 628 # does a special algorithm 629 def _print_IntTrueDiv(self, expr): 630 lhs, rhs = expr.args 631 return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" 632 633 def _helper_sqrt(self, expr): 634 return f"math.sqrt({self._print(expr)})" 635 636 def _print_OpaqueUnaryFn_sqrt(self, expr): 637 return self._helper_sqrt(expr.args[0]) 638 639 def _print_FloatPow(self, expr): 640 base, exp = expr.args 641 return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" 642 643 # TODO: Not sure this works with Triton, even when base/exp are integral 644 def _print_PowByNatural(self, expr): 645 base, exp = expr.args 646 return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" 647 648 def _print_floor(self, expr): 649 assert len(expr.args) == 1 650 return f"math.floor({self._print(expr.args[0])})" 651 652 def _print_FloorToInt(self, expr): 653 assert len(expr.args) == 1 654 return f"math.floor({self._print(expr.args[0])})" 655 656 def _print_TruncToInt(self, expr): 657 assert len(expr.args) == 1 658 # This also could have been int(), they'll do the same thing for float 659 return f"math.trunc({self._print(expr.args[0])})" 660 661 def _print_ceiling(self, expr): 662 assert len(expr.args) == 1 663 return f"math.ceil({self._print(expr.args[0])})" 664 665 def _print_CeilToInt(self, expr): 666 assert len(expr.args) == 1 667 return f"math.ceil({self._print(expr.args[0])})" 668 669 def _print_Abs(self, expr): 670 assert len(expr.args) == 1 671 return f"abs({self._print(expr.args[0])})" 672 673 # NB: It's expected that we've made explicit any promotion in the sympy 674 # expression, so it doesn't matter that Python max/min doesn't perform 675 # promotion 676 def _print_Max(self, expr): 677 assert len(expr.args) >= 2 678 return f"max({', '.join(map(self._print, expr.args))})" 679 680 def _print_Min(self, expr): 681 assert len(expr.args) >= 2 682 return f"min({', '.join(map(self._print, expr.args))})" 683 684 def _print_OpaqueUnaryFn_cos(self, expr): 685 assert len(expr.args) == 1 686 return f"math.cos({self._print(expr.args[0])})" 687 688 def _print_OpaqueUnaryFn_cosh(self, expr): 689 assert len(expr.args) == 1 690 return f"math.cosh({self._print(expr.args[0])})" 691 692 def _print_OpaqueUnaryFn_acos(self, expr): 693 assert len(expr.args) == 1 694 return f"math.acos({self._print(expr.args[0])})" 695 696 def _print_OpaqueUnaryFn_sin(self, expr): 697 assert len(expr.args) == 1 698 return f"math.sin({self._print(expr.args[0])})" 699 700 def _print_OpaqueUnaryFn_sinh(self, expr): 701 assert len(expr.args) == 1 702 return f"math.sinh({self._print(expr.args[0])})" 703 704 def _print_OpaqueUnaryFn_asin(self, expr): 705 assert len(expr.args) == 1 706 return f"math.asin({self._print(expr.args[0])})" 707 708 def _print_OpaqueUnaryFn_tan(self, expr): 709 assert len(expr.args) == 1 710 return f"math.tan({self._print(expr.args[0])})" 711 712 def _print_OpaqueUnaryFn_tanh(self, expr): 713 assert len(expr.args) == 1 714 return f"math.tanh({self._print(expr.args[0])})" 715 716 def _print_OpaqueUnaryFn_atan(self, expr): 717 assert len(expr.args) == 1 718 return f"math.atan({self._print(expr.args[0])})" 719 720 def _print_RoundToInt(self, expr): 721 assert len(expr.args) == 1 722 return f"round({self._print(expr.args[0])})" 723 724 def _print_RoundDecimal(self, expr): 725 assert len(expr.args) == 2 726 number, ndigits = expr.args 727 assert isinstance(ndigits, sympy.Integer) 728 return f"round({self._print(number)}, {ndigits})" 729 730 731class OpOverrides: 732 def __init__(self, parent): 733 super().__init__() 734 self._parent = parent 735 736 def __getattr__(self, item): 737 return getattr(self._parent, item) 738 739 @staticmethod 740 def identity(value): 741 # used to trigger cse 742 return value 743 744 @staticmethod 745 def constant(value, dtype): 746 return repr(value) 747 748 @staticmethod 749 def reciprocal(x): 750 return ops.truediv(ops.constant(1, torch.int32), x) 751 752 @staticmethod 753 def square(x): 754 return ops.mul(x, x) 755 756 @staticmethod 757 def erfc(x): 758 return ops.sub(ops.constant(1, torch.float32), ops.erf(x)) 759 760 @staticmethod 761 def erfcx(x): 762 return ops.mul(ops.exp(ops.square(x)), ops.erfc(x)) 763 764 @staticmethod 765 def expm1(x): 766 return ops.sub(ops.exp(x), ops.constant(1, torch.float32)) 767 768 @staticmethod 769 def log10(x): 770 return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32)) 771 772 @staticmethod 773 def log2(x): 774 return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32)) 775 776 @staticmethod 777 def exp2(x): 778 return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32))) 779 780 @staticmethod 781 def log1p(x): 782 return ops.log(ops.add(x, ops.constant(1, torch.int32))) 783 784 @staticmethod 785 def sigmoid(x): 786 one = ops.constant(1, torch.int32) 787 return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x)))) 788 789 @staticmethod 790 def libdevice_sigmoid(x): 791 one = ops.constant(1, torch.int32) 792 return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x)))) 793 794 @staticmethod 795 def relu(x): 796 return ops.maximum(x, ops.constant(0, torch.int32)) 797 798 @staticmethod 799 def libdevice_abs(x): 800 return ops.abs(x) 801 802 @staticmethod 803 def libdevice_sqrt(x): 804 return ops.sqrt(x) 805 806 @staticmethod 807 def libdevice_cos(x): 808 return ops.cos(x) 809 810 @staticmethod 811 def libdevice_sin(x): 812 return ops.sin(x) 813 814 @staticmethod 815 def libdevice_log(x): 816 return ops.log(x) 817 818 @staticmethod 819 def libdevice_exp(x): 820 return ops.exp(x) 821 822 @staticmethod 823 def bitwise_not(x): 824 return f"~{ExprPrinter.paren(x)}" 825 826 @staticmethod 827 def logical_not(a): 828 return f"{ExprPrinter.paren(a)} == 0" 829 830 @staticmethod 831 def bitwise_and(x, y): 832 return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" 833 834 @staticmethod 835 def bitwise_or(x, y): 836 return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" 837 838 @staticmethod 839 def bitwise_xor(x, y): 840 return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" 841 842 @staticmethod 843 def bitwise_left_shift(x, y): 844 return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" 845 846 @staticmethod 847 def bitwise_right_shift(x, y): 848 return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" 849 850 @staticmethod 851 def remainder(a, b): 852 r = ops.mod(a, b) 853 cond = ops.and_( 854 ops.ne(r, ops.constant(0, torch.int32)), 855 ops.ne(ops.signbit(r), ops.signbit(b)), 856 ) 857 return ops.where(cond, ops.add(r, b), r) 858 859 @staticmethod 860 def trunc_to_int(a, dtype): 861 return ops.to_dtype(ops.trunc(a), dtype) 862 863 @staticmethod 864 def floor_to_int(a, dtype): 865 return ops.to_dtype(ops.floor(a), dtype) 866 867 @staticmethod 868 def ceil_to_int(a, dtype): 869 return ops.to_dtype(ops.ceil(a), dtype) 870 871 @staticmethod 872 def round_to_int(a, dtype): 873 return ops.to_dtype(ops.round(a), dtype) 874 875 @staticmethod 876 def int_truediv(a, b): 877 # TODO: this is wrong 878 # TODO: an easy bandaid is to generate runtime asserts that it's 879 # <= 2**53, which is when this equation is correct 880 return ops.truediv(a, b) 881 882 @staticmethod 883 def load_seed(name, offset): 884 return ops.load(name, sympy.Integer(offset)) 885 886 @classmethod 887 def _initialize_pointwise_overrides(cls, target): 888 assert target in {"triton", "cpp", "cppvec"}, target 889 890 for funcname, data in pointwise_overrides_data.items(): 891 impl = getattr(data, target) 892 if impl is None: 893 continue 894 setattr(cls, funcname, staticmethod(impl)) 895 896 897@dataclasses.dataclass 898class OverridesData: 899 name: str 900 cpp: Callable[..., str] 901 # None when not impl in libdevice/triton 902 triton: Optional[Callable[..., str]] = None 903 # None when not impl in aten/.../vec 904 cppvec: Optional[Callable[..., str]] = None 905 type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ( 906 ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 907 ) 908 909 910# NB: if you add a new special function, don't forget to update 911# torch._inductor.ops_handler too 912pointwise_overrides_data: Dict[str, OverridesData] = dict( 913 airy_ai=OverridesData( 914 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 915 cpp=lambda x: f"airy_ai_forward({x})", 916 name="special_airy_ai", 917 ), 918 bessel_j0=OverridesData( 919 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 920 cpp=lambda x: f"bessel_j0_forward({x})", 921 triton=lambda x: f"libdevice.j0({x})", 922 name="special_bessel_j0", 923 ), 924 bessel_j1=OverridesData( 925 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 926 cpp=lambda x: f"bessel_j1_forward({x})", 927 triton=lambda x: f"libdevice.j1({x})", 928 name="special_bessel_j1", 929 ), 930 bessel_y0=OverridesData( 931 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 932 cpp=lambda x: f"bessel_y0_forward({x})", 933 triton=lambda x: f"libdevice.y0({x})", 934 name="special_bessel_y0", 935 ), 936 bessel_y1=OverridesData( 937 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 938 cpp=lambda x: f"bessel_y1_forward({x})", 939 triton=lambda x: f"libdevice.y1({x})", 940 name="special_bessel_y1", 941 ), 942 digamma=OverridesData( 943 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 944 cpp=lambda x: f"calc_digamma({x})", 945 cppvec=lambda x: f"{x}.digamma()", 946 name="digamma", 947 ), 948 # no cpp nor triton implementation for entr, it is defined as decomposition 949 # erf, erfc 950 erfcx=OverridesData( 951 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 952 cpp=lambda x: f"calc_erfcx({x})", 953 triton=lambda x: f"libdevice.erfcx({x})", 954 name="special_erfcx", 955 ), 956 fma=OverridesData( 957 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 958 cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})", 959 cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})", 960 triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})", 961 name="fma", 962 ), 963 # erfinv, exp2, expit, gammaln 964 igamma=OverridesData( 965 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 966 cpp=lambda x, y: f"calc_igamma({x}, {y})", 967 name="igamma", 968 ), 969 igammac=OverridesData( 970 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 971 cpp=lambda x, y: f"calc_igammac({x}, {y})", 972 name="igammac", 973 ), 974 gammainc=OverridesData( 975 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 976 cpp=lambda x, y: f"calc_igamma({x}, {y})", 977 name="special_gammainc", 978 ), 979 gammaincc=OverridesData( 980 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 981 cpp=lambda x, y: f"calc_igammac({x}, {y})", 982 name="special_gammaincc", 983 ), 984 i0=OverridesData( 985 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 986 cpp=lambda x: f"calc_i0({x})", 987 triton=lambda x: f"libdevice.cyl_bessel_i0({x})", 988 cppvec=lambda x: f"{x}.i0()", 989 name="i0", 990 ), 991 i0e=OverridesData( 992 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 993 cpp=lambda x: f"calc_i0e({x})", 994 cppvec=lambda x: f"{x}.i0e()", 995 name="special_i0e", 996 ), 997 i1=OverridesData( 998 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 999 cpp=lambda x: f"calc_i1({x})", 1000 triton=lambda x: f"libdevice.cyl_bessel_i1({x})", 1001 name="special_i1", 1002 ), 1003 i1e=OverridesData( 1004 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1005 cpp=lambda x: f"calc_i1e({x})", 1006 name="special_i1e", 1007 ), 1008 log_ndtr=OverridesData( 1009 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1010 cpp=lambda x: f"calc_log_ndtr({x})", 1011 name="special_log_ndtr", 1012 ), 1013 # logit 1014 modified_bessel_i0=OverridesData( 1015 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1016 cpp=lambda x: f"modified_bessel_i0_forward({x})", 1017 triton=lambda x: f"libdevice.cyl_bessel_i0({x})", 1018 name="special_modified_bessel_i0", 1019 ), 1020 modified_bessel_i1=OverridesData( 1021 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1022 cpp=lambda x: f"modified_bessel_i1_forward({x})", 1023 triton=lambda x: f"libdevice.cyl_bessel_i1({x})", 1024 name="special_modified_bessel_i1", 1025 ), 1026 modified_bessel_k0=OverridesData( 1027 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1028 cpp=lambda x: f"modified_bessel_k0_forward({x})", 1029 name="special_modified_bessel_k0", 1030 ), 1031 modified_bessel_k1=OverridesData( 1032 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1033 cpp=lambda x: f"modified_bessel_k1_forward({x})", 1034 name="special_modified_bessel_k1", 1035 ), 1036 # multigamma 1037 ndtr=OverridesData( 1038 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1039 cpp=lambda x: f"calc_ndtr({x})", 1040 name="special_ndtr", 1041 ), 1042 ndtri=OverridesData( 1043 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1044 cpp=lambda x: f"calc_ndtri({x})", 1045 name="special_ndtri", 1046 ), 1047 polygamma=OverridesData( 1048 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1049 cpp=lambda x, y: f"calc_polygamma({y}, {x})", 1050 name="polygamma", 1051 ), 1052 # psi - alias to digamma 1053 # round 1054 scaled_modified_bessel_k0=OverridesData( 1055 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1056 cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})", 1057 name="special_scaled_modified_bessel_k0", 1058 ), 1059 scaled_modified_bessel_k1=OverridesData( 1060 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1061 cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})", 1062 name="special_scaled_modified_bessel_k1", 1063 ), 1064 # sinc 1065 spherical_bessel_j0=OverridesData( 1066 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1067 cpp=lambda x: f"spherical_bessel_j0_forward({x})", 1068 name="special_spherical_bessel_j0", 1069 ), 1070 zeta=OverridesData( 1071 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1072 cpp=lambda x, y: f"zeta({x}, {y})", 1073 name="special_zeta", 1074 ), 1075 chebyshev_polynomial_t=OverridesData( 1076 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1077 cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})", 1078 name="special_chebyshev_polynomial_t", 1079 ), 1080 chebyshev_polynomial_u=OverridesData( 1081 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1082 cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})", 1083 name="special_chebyshev_polynomial_u", 1084 ), 1085 chebyshev_polynomial_v=OverridesData( 1086 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1087 cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})", 1088 name="special_chebyshev_polynomial_v", 1089 ), 1090 chebyshev_polynomial_w=OverridesData( 1091 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1092 cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})", 1093 name="special_chebyshev_polynomial_w", 1094 ), 1095 legendre_polynomial_p=OverridesData( 1096 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1097 cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})", 1098 name="special_legendre_polynomial_p", 1099 ), 1100 shifted_chebyshev_polynomial_t=OverridesData( 1101 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1102 cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})", 1103 name="special_shifted_chebyshev_polynomial_t", 1104 ), 1105 shifted_chebyshev_polynomial_u=OverridesData( 1106 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1107 cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})", 1108 name="special_shifted_chebyshev_polynomial_u", 1109 ), 1110 shifted_chebyshev_polynomial_v=OverridesData( 1111 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1112 cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})", 1113 name="special_shifted_chebyshev_polynomial_v", 1114 ), 1115 shifted_chebyshev_polynomial_w=OverridesData( 1116 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1117 cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})", 1118 name="special_shifted_chebyshev_polynomial_w", 1119 ), 1120 hermite_polynomial_h=OverridesData( 1121 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1122 cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})", 1123 name="special_hermite_polynomial_h", 1124 ), 1125 hermite_polynomial_he=OverridesData( 1126 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1127 cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})", 1128 name="special_hermite_polynomial_he", 1129 ), 1130 laguerre_polynomial_l=OverridesData( 1131 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1132 cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})", 1133 name="special_laguerre_polynomial_l", 1134 ), 1135) 1136 1137 1138# Use mypy to check protocol implemented correctly 1139def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]: 1140 return h 1141 1142 1143class DeferredLine(DeferredLineBase): 1144 """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" 1145 1146 def __init__(self, name, line): 1147 super().__init__(line) 1148 self.name = name 1149 assert not isinstance(line, DeferredLineBase) 1150 1151 def __call__(self): 1152 if all( 1153 self.name not in x 1154 for x in ( 1155 V.graph.removed_buffers, 1156 V.kernel.removed_buffers, 1157 V.graph.inplaced_to_remove, 1158 V.kernel.inplaced_to_remove, 1159 ) 1160 ): 1161 return self.line 1162 return None 1163 1164 def _new_line(self, line): 1165 return DeferredLine(self.name, line) 1166 1167 1168class BracesBuffer(IndentedBuffer): 1169 def indent(self, offset=1): 1170 @contextlib.contextmanager 1171 def ctx(): 1172 for _ in range(offset): 1173 self.writeline("{") 1174 self._indent += 1 1175 for _ in range(-offset): 1176 self._indent -= 1 1177 self.writeline("}") 1178 yield 1179 for _ in range(-offset): 1180 self.writeline("{") 1181 self._indent += 1 1182 for _ in range(offset): 1183 self._indent -= 1 1184 self.writeline("}") 1185 1186 return ctx() 1187 1188 1189class InplacedBuffer(NamedTuple): 1190 inner_name: str 1191 other_names: List[str] 1192 1193 1194class KernelArgs: 1195 @staticmethod 1196 def _lookup(prefix, odict, name): 1197 assert isinstance(name, (str, sympy.Symbol)) 1198 if name not in odict: 1199 odict[name] = f"{prefix}{len(odict)}" 1200 return odict[name] 1201 1202 def __init__(self, sizevars=None): 1203 self.input_buffers = {} 1204 self.output_buffers = {} 1205 self.inplace_buffers = {} 1206 self.sizevars = sizevars or {} 1207 self.workspace_arg = None 1208 1209 def __repr__(self): 1210 return "KernelArgs({})".format( 1211 ", ".join( 1212 map( 1213 repr, 1214 [ 1215 self.input_buffers, 1216 self.output_buffers, 1217 self.inplace_buffers, 1218 self.sizevars, 1219 ], 1220 ) 1221 ) 1222 ) 1223 1224 def _buffer_is_marked_removed(self, name): 1225 return isinstance(name, str) and name.startswith("REMOVED") 1226 1227 def input(self, name): 1228 if V.graph.scheduler: 1229 name = V.graph.scheduler.mutation_real_name.get(name, name) 1230 assert name not in V.graph.removed_buffers, name 1231 if name in self.output_buffers: 1232 return self.output_buffers[name] 1233 if name in self.inplace_buffers: 1234 return self.inplace_buffers[name].inner_name 1235 if name.startswith("seed"): 1236 return self._lookup("seed", self.input_buffers, name) 1237 return self._lookup("in_ptr", self.input_buffers, name) 1238 1239 def output(self, name): 1240 if V.graph.scheduler: 1241 name = V.graph.scheduler.mutation_real_name.get(name, name) 1242 assert name not in V.graph.removed_buffers, name 1243 if name in self.inplace_buffers: 1244 return self.inplace_buffers[name].inner_name 1245 return self._lookup("out_ptr", self.output_buffers, name) 1246 1247 def make_inplace(self, input_name, output_name): 1248 assert output_name not in self.inplace_buffers 1249 if input_name in self.inplace_buffers: 1250 buf = self.inplace_buffers[input_name] 1251 buf.other_names.append(output_name) 1252 self.inplace_buffers[output_name] = buf 1253 else: 1254 buf = InplacedBuffer( 1255 f"in_out_ptr{len(unique(self.inplace_buffers.values()))}", 1256 [input_name, output_name], 1257 ) 1258 self.inplace_buffers[input_name] = buf 1259 self.inplace_buffers[output_name] = buf 1260 1261 def workspace(self, nbytes: sympy.Expr, zero_fill: bool): 1262 if self.workspace_arg is None: 1263 self.workspace_arg = WorkspaceArg(nbytes, zero_fill) 1264 return "ws_ptr", 0 1265 1266 offset = self.workspace_arg.nbytes 1267 zero_fill = zero_fill or self.workspace_arg.zero_fill 1268 self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill) 1269 return "ws_ptr", offset 1270 1271 def seed_offset(self, name, value): 1272 if value in self.sizevars: 1273 return self.sizevars[value] 1274 if name in self.sizevars.values(): 1275 name = ( 1276 f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" 1277 ) 1278 self.sizevars[value] = name 1279 return name 1280 1281 def size(self, name): 1282 if str(name) == "seed": 1283 self.sizevars["seed"] = "seed" 1284 return "seed" 1285 return self._lookup("ks", self.sizevars, name) 1286 1287 def call_names(self): 1288 return chain( 1289 self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() 1290 ) 1291 1292 def wrap_ptr_arg(self, buf, dtype): 1293 return buf 1294 1295 def wrap_size_arg(self, size): 1296 return str(size) 1297 1298 def cpp_argdefs(self): 1299 from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE 1300 1301 call_args = [] 1302 arg_defs = [] 1303 arg_types = [] 1304 for inplaced in unique(self.inplace_buffers.values()): 1305 if self._buffer_is_marked_removed(inplaced): 1306 continue 1307 outer = inplaced.other_names[-1] 1308 inner = inplaced.inner_name 1309 dtype = V.graph.get_dtype(outer) 1310 cpp_dtype = DTYPE_TO_CPP[dtype] 1311 arg_defs.append(f"{cpp_dtype}* {inner}") 1312 call_args.append(self.wrap_ptr_arg(outer, dtype)) 1313 arg_types.append(f"{cpp_dtype}*") 1314 for outer, inner in self.input_buffers.items(): 1315 if outer in self.inplace_buffers: 1316 continue 1317 dtype = V.graph.get_dtype(outer) 1318 cpp_dtype = DTYPE_TO_CPP[dtype] 1319 arg_defs.append(f"const {cpp_dtype}* {inner}") 1320 call_args.append(self.wrap_ptr_arg(outer, dtype)) 1321 arg_types.append(f"const {cpp_dtype}*") 1322 for outer, inner in self.output_buffers.items(): 1323 if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): 1324 continue 1325 dtype = V.graph.get_dtype(outer) 1326 cpp_dtype = DTYPE_TO_CPP[dtype] 1327 arg_defs.append(f"{cpp_dtype}* {inner}") 1328 call_args.append(self.wrap_ptr_arg(outer, dtype)) 1329 arg_types.append(f"{cpp_dtype}*") 1330 for outer, inner in self.sizevars.items(): 1331 arg_defs.append(f"const {INDEX_TYPE} {inner}") 1332 call_args.append(self.wrap_size_arg(outer)) 1333 arg_types.append(f"const {INDEX_TYPE}") 1334 if V.graph.wrapper_code: 1335 V.graph.wrapper_code.ensure_size_computed(outer) 1336 assert self.workspace_arg is None, "Workspace not supported on CPU " 1337 return arg_defs, call_args, arg_types 1338 1339 def python_argdefs(self): 1340 arg_defs: List[str] = [] 1341 call_args: List[str] = [] 1342 arg_types: List[torch.dtype] = [] 1343 precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = [] 1344 for inplaced in unique(self.inplace_buffers.values()): 1345 if self._buffer_is_marked_removed(inplaced): 1346 continue 1347 arg_defs.append(inplaced.inner_name) 1348 call_args.append(inplaced.other_names[-1]) 1349 arg_types.append(V.graph.get_dtype(inplaced.other_names[-1])) 1350 precompile_args.append( 1351 TensorArg( 1352 name=inplaced.inner_name, 1353 buffer=inplaced.other_names[-1], 1354 dtype=V.graph.get_dtype(inplaced.other_names[-1]), 1355 ) 1356 ) 1357 for outer, inner in chain( 1358 self.input_buffers.items(), self.output_buffers.items() 1359 ): 1360 if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): 1361 continue 1362 arg_defs.append(inner) 1363 call_args.append(outer) 1364 arg_types.append(V.graph.get_dtype(outer)) 1365 precompile_args.append( 1366 TensorArg( 1367 name=inner, 1368 buffer=outer, 1369 dtype=V.graph.get_dtype(outer), 1370 ) 1371 ) 1372 for outer, inner in self.sizevars.items(): 1373 arg_defs.append(inner) 1374 call_args.append(outer) 1375 arg_types.append(type(outer)) # type: ignore[arg-type] 1376 precompile_args.append(SizeArg(inner, outer)) 1377 if V.graph.wrapper_code: 1378 V.graph.wrapper_code.ensure_size_computed(outer) 1379 if self.workspace_arg is not None: 1380 arg_defs.append("ws_ptr") 1381 call_args.append("workspace") 1382 precompile_args.append(self.workspace_arg) 1383 return arg_defs, call_args, precompile_args, arg_types 1384 1385 def aliases(self): 1386 for inplaced in unique(self.inplace_buffers.values()): 1387 if self._buffer_is_marked_removed(inplaced): 1388 continue 1389 for other in inplaced.other_names: 1390 if ( 1391 other in V.graph.inplaced_to_remove 1392 or other in V.kernel.inplaced_to_remove 1393 ): 1394 continue 1395 if other in self.input_buffers: 1396 yield self.input_buffers[other], inplaced.inner_name 1397 if other in self.output_buffers: 1398 yield self.output_buffers[other], inplaced.inner_name 1399 1400 def is_removed(self, name): 1401 def _is_removed(name, buffers): 1402 return name not in buffers or self._buffer_is_marked_removed(buffers[name]) 1403 1404 return _is_removed(name, self.output_buffers) and _is_removed( 1405 name, self.inplace_buffers 1406 ) 1407 1408 # Includes inplace buffers, excludes removed buffers. Essentially, 1409 # after you do a call into this kernel, which buffers actually contain 1410 # updated data? Modeled off of python_argdefs. 1411 def live_output_buffers(self): 1412 live_outs = OrderedSet() # type: ignore[var-annotated] 1413 for inplaced in unique(self.inplace_buffers.values()): 1414 if self._buffer_is_marked_removed(inplaced): 1415 continue 1416 live_outs.add(inplaced.other_names[-1]) 1417 for outer, inner in self.output_buffers.items(): 1418 if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): 1419 continue 1420 live_outs.add(outer) 1421 return live_outs 1422 1423 1424class CSEVariable: 1425 """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. 1426 To do so, the backends can simply overload `Kernel.create_cse_var` 1427 The "CSEVariable.update_on_args" method gives you a hook for annotations 1428 See example of TritonCSEVariable in triton.py 1429 """ 1430 1431 def __init__(self, name, bounds: ValueRanges[Any]): 1432 assert isinstance(bounds, ValueRanges) 1433 self.name = name 1434 self.bounds = bounds 1435 self.use_count = 1 # track how many tims this expression is used 1436 1437 def __str__(self): 1438 return self.name 1439 1440 def __hash__(self) -> int: 1441 return hash(self.name) 1442 1443 def __eq__(self, other) -> bool: 1444 return type(other) == type(self) and other.name == self.name 1445 1446 def update_on_args(self, name, args, kwargs): 1447 pass 1448 1449 def __repr__(self): 1450 return f"{self.__class__.__name__}({self.name!r})" 1451 1452 1453class CppWrapperKernelArgs(KernelArgs): 1454 def wrap_ptr_arg(self, buf, dtype): 1455 from .cpp_utils import DTYPE_TO_CPP 1456 1457 if config.abi_compatible: 1458 # In the abi_compatible model, we just return the buf here. 1459 # We will form correct call args later in wrapper.generate_kernel_all. 1460 return buf 1461 else: 1462 return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" 1463 1464 def wrap_size_arg(self, size): 1465 return f"{size}" 1466 1467 1468class CSE: 1469 """Common subexpression elimination""" 1470 1471 def __init__( 1472 self, 1473 prefix="", 1474 suffix="", 1475 name_prefix="tmp", 1476 iter_buffers=None, 1477 store_cache=None, 1478 reduction_cache=None, 1479 varname_map=None, 1480 ): 1481 self.prefix = prefix 1482 self.suffix = suffix 1483 self.cache = {} 1484 self.name_prefix = name_prefix 1485 self.store_cache = store_cache or {} 1486 self.reduction_cache = reduction_cache or {} 1487 self.iter_buffer_ids = iter_buffers or itertools.count() 1488 self.invalidated_stores = OrderedSet() # type: ignore[var-annotated] 1489 self.varname_map = varname_map or {} 1490 1491 def invalidate(self, keep_vars: OrderedSet[str]): 1492 for name, tmp in list(self.store_cache.items()): 1493 if tmp not in keep_vars: 1494 del self.store_cache[name] 1495 self.invalidated_stores.add(name) 1496 self.cache = {k: v for k, v in self.cache.items() if v in keep_vars} 1497 1498 def clone(self): 1499 # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional 1500 return CSE( 1501 prefix=self.prefix, 1502 suffix=self.suffix, 1503 name_prefix=self.name_prefix, 1504 iter_buffers=self.iter_buffer_ids, 1505 store_cache=self.store_cache, 1506 varname_map=self.varname_map, 1507 ) 1508 1509 def generate( 1510 self, 1511 buffer: IndentedBuffer, 1512 expr: Union[str, CSEVariable, OpsValue, IndentedBuffer], 1513 *, 1514 bounds: ValueRanges[Any] = ValueRanges.unknown(), 1515 write=True, 1516 assignment=True, 1517 ) -> CSEVariable: 1518 if isinstance(expr, OpsValue): 1519 expr = expr.value 1520 1521 assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr) 1522 assert write or assignment 1523 if isinstance(expr, CSEVariable): 1524 # If the expressions were always created with all the information, we could 1525 # assert expr.bounds == bounds, but sometimes the expression is created 1526 # with the loose ValueRanges.unknown(), so we need to tighten the bounds 1527 expr.bounds = expr.bounds.tighten(bounds) 1528 expr.use_count += 1 1529 return expr 1530 cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr 1531 var = self.cache.get(cache_key, None) 1532 if not var: 1533 var = self.newvar(bounds) 1534 self.cache[cache_key] = var 1535 if write: 1536 if V.kernel.current_node: 1537 V.kernel.current_node.codegen_originating_info( 1538 buffer, only_once=True 1539 ) 1540 if isinstance(expr, IndentedBuffer): 1541 if assignment: 1542 buffer.writeline(f"{self.prefix}{var} =") 1543 buffer.splice(expr) 1544 buffer.writeline(self.suffix) 1545 else: 1546 if assignment: 1547 line = f"{self.prefix}{var} = {expr}{self.suffix}" 1548 else: 1549 line = f"{expr}{self.suffix}" 1550 buffer.writeline(line) 1551 else: 1552 var.bounds = var.bounds.tighten(bounds) 1553 var.use_count += 1 1554 1555 return var 1556 1557 def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable: 1558 var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" 1559 var = V.kernel.create_cse_var(var_name, bounds) 1560 self.varname_map[var_name] = var 1561 return var 1562 1563 1564class CodeGen: 1565 def __init__(self) -> None: 1566 super().__init__() 1567 self.exit_stack = contextlib.ExitStack() 1568 1569 def __enter__(self): 1570 self.exit_stack.__enter__() 1571 return self 1572 1573 def __exit__(self, exc_type, exc_val, exc_tb): 1574 self.exit_stack.__exit__(exc_type, exc_val, exc_tb) 1575 1576 1577class ScopedDict: 1578 def __init__(self, original_dict): 1579 self.original_dict = original_dict 1580 self.new_items = {} 1581 1582 def __getitem__(self, key): 1583 if key in self.new_items: 1584 return self.new_items[key] 1585 return self.original_dict[key] 1586 1587 def __setitem__(self, key, value): 1588 self.new_items[key] = value 1589 1590 def __contains__(self, key): 1591 return key in self.new_items or key in self.original_dict 1592 1593 def get(self, key, default=None): 1594 if key in self.new_items: 1595 return self.new_items[key] 1596 return self.original_dict.get(key, default) 1597 1598 1599class Kernel(CodeGen): 1600 newvar_prefix = "" 1601 suffix = "" 1602 overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None 1603 # TODO: these look dead, but with all the getattr it's hard to tell... 1604 load_format: None = None 1605 store_format: None = None 1606 1607 def __init__(self, args=None, increase_kernel_count=True): 1608 super().__init__() 1609 if increase_kernel_count: 1610 metrics.generated_kernel_count += 1 1611 self.args = args or KernelArgs() 1612 self.loads = IndentedBuffer() 1613 self.compute = IndentedBuffer() 1614 self.stores = IndentedBuffer() 1615 1616 self.num_load = 0 1617 self.num_reduction = 0 1618 1619 self.cse: CSE = CSE(self.newvar_prefix, self.suffix) 1620 self.must_keep_buffers = OrderedSet() # type: ignore[var-annotated] 1621 self.store_buffer_names = OrderedSet() # type: ignore[var-annotated] 1622 self._load_mask = None 1623 self._load_other = None 1624 # OrderedSet in set_current_node 1625 self.current_node = None 1626 self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None 1627 1628 self.removed_buffers = OrderedSet() # type: ignore[var-annotated] 1629 self.inplaced_to_remove = OrderedSet() # type: ignore[var-annotated] 1630 1631 # key: the buffer to write 1632 # value: the buffer to read and whose memory can be reused for 1633 # the buffer specified by key 1634 self.inplace_update_buffers = {} 1635 # Set minimum number of elements processed per thread. 1636 self.min_elem_per_thread = 1 1637 self.kernel_name = None 1638 1639 @contextlib.contextmanager 1640 def set_current_node(self, node): 1641 prior = self.current_node 1642 self.current_node = node 1643 self.node_to_bounds = node._body.bounds().get_bounds() 1644 try: 1645 yield 1646 finally: 1647 self.current_node = prior 1648 1649 @contextlib.contextmanager 1650 def swap_buffers(self, lb, cb=None, sb=None): 1651 def scope_cse(cse): 1652 new_cse = cse.clone() 1653 new_cse.cache = ScopedDict(cse.cache) 1654 new_cse.reduction_cache = ScopedDict(cse.reduction_cache) 1655 new_cse.store_cache = ScopedDict(cse.store_cache) 1656 return new_cse 1657 1658 if cb is None: 1659 cb = lb 1660 loads = self.loads 1661 compute = self.compute 1662 stores = self.stores 1663 cse = self.cse 1664 self.loads = lb 1665 self.compute = cb 1666 self.stores = sb 1667 self.cse = scope_cse(cse) 1668 try: 1669 yield 1670 finally: 1671 self.loads = loads 1672 self.compute = compute 1673 self.stores = stores 1674 self.cse = cse 1675 1676 def load(self, name: str, index: sympy.Expr) -> CSEVariable: 1677 raise NotImplementedError 1678 1679 def indirect_load(self, name: str, index: sympy.Expr): 1680 """A load the depends on an index we have read""" 1681 prior = self.loads 1682 try: 1683 # put the load in the compute section as it might have deps 1684 self.loads = self.compute 1685 return self.load(name, index) 1686 finally: 1687 self.loads = prior 1688 1689 def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): 1690 raise NotImplementedError 1691 1692 def store( 1693 self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None 1694 ) -> None: 1695 raise NotImplementedError 1696 1697 def reduction( 1698 self, 1699 dtype: torch.dtype, 1700 src_dtype: torch.dtype, 1701 reduction_type: ReductionType, 1702 value: Union[CSEVariable, Tuple[CSEVariable, ...]], 1703 ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: 1704 raise NotImplementedError 1705 1706 def scan( 1707 self, 1708 dtypes: Tuple[torch.dtype, ...], 1709 combine_fn: Callable[ 1710 [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...] 1711 ], 1712 values: Tuple[CSEVariable, ...], 1713 ) -> Tuple[CSEVariable, ...]: 1714 raise NotImplementedError 1715 1716 def sort( 1717 self, 1718 dtypes: Tuple[torch.dtype, ...], 1719 values: Tuple[CSEVariable, ...], 1720 stable: bool, 1721 descending: bool, 1722 ) -> Tuple[CSEVariable, ...]: 1723 raise NotImplementedError 1724 1725 def var_ranges(self): 1726 raise NotImplementedError 1727 1728 def bucketize( 1729 self, 1730 values: CSEVariable, 1731 offsets_name: str, 1732 offsets_size: sympy.Expr, 1733 indexing_dtype: torch.dtype, 1734 right: bool, 1735 ) -> CSEVariable: 1736 """ 1737 See [Note: Inductor bucketize op] 1738 """ 1739 raise NotImplementedError 1740 1741 @property 1742 def assert_function(self) -> str: 1743 raise NotImplementedError 1744 1745 def indirect_assert( 1746 self, 1747 var: Union[CSEVariable, str], 1748 lower: Optional[str], 1749 upper: Optional[str], 1750 mask: Optional[Union[CSEVariable, str]] = None, 1751 ) -> str: 1752 if isinstance(var, CSEVariable): 1753 var = str(var) 1754 assert isinstance(var, str) 1755 assert lower is None or isinstance(lower, str) 1756 assert upper is None or isinstance(upper, str) 1757 if lower and upper: 1758 # The conditions need to be in parens because of Python's operator precedence. 1759 # It'd be less error-prone to use and/or/not, which is suported by triton 1760 cond = f"({lower} <= {var}) & ({var} < {upper})" 1761 cond_print = f"{lower} <= {var} < {upper}" 1762 elif lower: 1763 cond = f"{lower} <= {var}" 1764 cond_print = cond 1765 else: 1766 assert upper 1767 cond = f"{var} < {upper}" 1768 cond_print = cond 1769 1770 if mask: 1771 cond = f"({cond}) | ~({mask})" 1772 1773 return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' 1774 1775 def check_bounds( 1776 self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool 1777 ): 1778 raise NotImplementedError 1779 1780 def index_to_str(self, index: sympy.Expr) -> str: 1781 raise NotImplementedError 1782 1783 def __enter__(self): 1784 # TODO: hoist this to top level 1785 class CSEProxy: 1786 self.name = "CSEProxy" 1787 vr_analysis = ValueRangeAnalysis() 1788 1789 @staticmethod 1790 def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] 1791 def inner(*args, **kwargs): 1792 bounds = CSEProxy._bound_variable(name, *args, **kwargs) 1793 1794 value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] 1795 1796 def do_cse(v): 1797 csevar = V.kernel.cse.generate( 1798 V.kernel.compute, v, bounds=bounds 1799 ) 1800 csevar.update_on_args(name, args, kwargs) 1801 return csevar 1802 1803 return pytree.tree_map(do_cse, value) 1804 1805 return inner 1806 1807 @staticmethod 1808 def _bound_variable(name, *args, **kwargs): 1809 """ 1810 If the variable comes from an FX node, we forward the bound we have already computed 1811 Else, if the variable when codegen'ing another op, we try to compute its bounds 1812 """ 1813 from ..select_algorithm import TritonTemplateKernel 1814 1815 if isinstance(V.kernel, TritonTemplateKernel): 1816 return ValueRanges.unknown() 1817 1818 fx_node = V.interpreter.current_node 1819 if fx_node.target == name and self.node_to_bounds is not None: 1820 assert isinstance(self.node_to_bounds, dict) 1821 return self.node_to_bounds.get(fx_node, ValueRanges.unknown()) 1822 elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): 1823 # These create lots of inner strings. We would need to compute the bounds at the ops 1824 # We will also likely not get much from computing VRs on these nodes 1825 if any( 1826 s in fx_node.target 1827 for s in ("set_indirect", "reduction", "scan") 1828 ): 1829 return ValueRanges.unknown() 1830 1831 # We assume that the inputs come from `ops.` and are not strings. If you want to generate 1832 # intermediary strings, wrap them in CSE variables with properly initialised bounds. 1833 1834 # If there is no FX bound but we know how to compute one we do so 1835 assert not kwargs 1836 1837 def arg_to_bound(x): 1838 if isinstance(x, CSEVariable): 1839 return x.bounds 1840 elif isinstance(x, sympy.Expr): 1841 return bound_sympy(x) 1842 else: 1843 return x 1844 1845 arg_bounds = list(map(arg_to_bound, args)) 1846 return getattr(CSEProxy.vr_analysis, name)(*arg_bounds) 1847 else: 1848 return ValueRanges.unknown() 1849 1850 @staticmethod 1851 def indirect_indexing( 1852 var: CSEVariable, 1853 size: Union[sympy.Expr, int], 1854 check: bool = True, 1855 wrap_neg=True, 1856 ): 1857 if isinstance(size, int): 1858 size = sympy.Integer(size) 1859 assert isinstance(size, sympy.Expr), size 1860 # Skip CSE since this doesn't return an expression 1861 1862 if var.bounds.lower < 0: # type: ignore[operator] 1863 if wrap_neg: 1864 stm = ops.add(var, ops.index_expr(size, torch.long)) 1865 # Mixed negative and non-negative 1866 if var.bounds.upper >= 0: # type: ignore[operator] 1867 lt = ops.lt(var, 0) 1868 stm = ops.where(lt, stm, var) 1869 else: 1870 stm = var 1871 1872 # Propagate bounds as we know how to compute them properly 1873 new_bounds = ValueRanges.unknown() 1874 if var.bounds != ValueRanges.unknown() and isinstance( 1875 size, sympy.Number 1876 ): 1877 # Take the negative part of the bound and add size to it 1878 # Then take union of that and the positive part 1879 # This is a tighter bound than that of a generic ops.where, as we have info on the cond 1880 neg_bounds = var.bounds & ValueRanges(-int_oo, -1) 1881 new_bounds = ValueRanges( 1882 neg_bounds.lower + size, neg_bounds.upper + size 1883 ) 1884 # We don't have a good way of representing the empty range 1885 if var.bounds.upper >= 0: # type: ignore[operator] 1886 pos = var.bounds & ValueRanges(0, int_oo) 1887 new_bounds = new_bounds | pos 1888 1889 var = self.cse.generate(self.compute, stm, bounds=new_bounds) 1890 1891 sympy_var = parent_handler.indirect_indexing(var, size, check) 1892 if generate_assert(check): 1893 assert_lower = not (var.bounds.lower >= 0) 1894 # value ranges cannot x < s when x and s are symbols 1895 assert_upper = not isinstance(size, sympy.Number) or not ( 1896 var.bounds.upper < size 1897 ) 1898 self.check_bounds(sympy_var, size, assert_lower, assert_upper) 1899 return sympy_var 1900 1901 @staticmethod 1902 def check_bounds( 1903 expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool 1904 ): 1905 return self.check_bounds(expr, size, lower, upper) 1906 1907 @staticmethod 1908 def load(name: str, index: sympy.Expr) -> CSEVariable: 1909 if name in self.cse.invalidated_stores: 1910 # A load from an invalidated store requires us to 1911 # keep the actual buffer around 1912 V.kernel.must_keep_buffers.add(name) 1913 if free_symbol_is_type(index, SymT.TMP): 1914 return self.indirect_load(name, index) 1915 store_cache = self.cse.store_cache 1916 if name in store_cache: 1917 return store_cache[name] 1918 out = self.load(name, index) 1919 # count load that is not in the store_cache, and also not in the 1920 # cse cache. 1921 if out.use_count == 1: 1922 self.num_load += 1 1923 return out 1924 1925 @staticmethod 1926 def _update_store_cache(name: str, value: CSEVariable): 1927 self.cse.store_cache[name] = value 1928 if self.current_node and name in V.graph.name_to_buffer: 1929 buf = self.current_node.get_output(name) 1930 for other_name in buf.get_mutations(): 1931 self.cse.store_cache[other_name] = value 1932 1933 @staticmethod 1934 def store( 1935 name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None 1936 ) -> None: 1937 self.store_buffer_names.add(name) 1938 if mode is None: 1939 CSEProxy._update_store_cache(name, value) 1940 if name not in V.graph.removed_buffers: 1941 return self.store(name, index, value, mode=mode) 1942 else: 1943 return None # type: ignore[return-value] 1944 1945 @staticmethod 1946 def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): 1947 self.store_buffer_names.add(name) 1948 CSEProxy._update_store_cache(name, value) 1949 1950 if name not in V.graph.removed_buffers: 1951 return self.store_reduction(name, index, value) 1952 1953 @staticmethod 1954 def reduction( 1955 dtype: torch.dtype, 1956 src_dtype: torch.dtype, 1957 reduction_type: ReductionType, 1958 value: Union[CSEVariable, Tuple[CSEVariable, ...]], 1959 ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: 1960 self.num_reduction += 1 1961 return self.reduction(dtype, src_dtype, reduction_type, value) 1962 1963 @staticmethod 1964 def scan( 1965 dtypes: Tuple[torch.dtype, ...], 1966 combine_fn: Callable[ 1967 [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], 1968 Tuple[CSEVariable, ...], 1969 ], 1970 values: Tuple[CSEVariable, ...], 1971 ) -> Tuple[CSEVariable, ...]: 1972 return self.scan(dtypes, combine_fn, values) 1973 1974 @staticmethod 1975 def sort( 1976 dtypes: Tuple[torch.dtype, ...], 1977 values: Tuple[CSEVariable, ...], 1978 stable: bool, 1979 descending: bool, 1980 ) -> Tuple[CSEVariable, ...]: 1981 return self.sort(dtypes, values, stable, descending) 1982 1983 @staticmethod 1984 def bucketize( 1985 values: CSEVariable, 1986 offsets_name: str, 1987 offsets_size: sympy.Expr, 1988 indexing_dtype: torch.dtype, 1989 right: bool, 1990 ) -> CSEVariable: 1991 """ 1992 [Note: Inductor bucketize op] 1993 1994 Given values (tensor) and offsets_name (reference to the name of a 1D 1995 tensor), calculate the bucket that each value belongs to. 1996 1997 e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True 1998 return = [ 0, 1, 1, 1, 1, 3, 3, 4]. 1999 2000 When right == False, bucket i refers to range (offsets[i], offsets[i+1]]. 2001 When right == True, bucket i refers to range [offsets[i], offsets[i+1]). 2002 2003 Offsets must be non-decreasing or the result is undefined. 2004 """ 2005 return self.bucketize( 2006 values, offsets_name, offsets_size, indexing_dtype, right 2007 ) 2008 2009 # Use mypy to check protocol implemented correctly 2010 def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: 2011 return h 2012 2013 super().__enter__() 2014 assert self.overrides 2015 parent_handler = self.overrides(V.get_ops_handler()) 2016 self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) 2017 self.exit_stack.enter_context(V.set_kernel_handler(self)) 2018 return self 2019 2020 def __exit__(self, exc_type, exc_val, exc_tb): 2021 """ 2022 Note that V.graph.scheduler can be None when codegening triton template 2023 kernels. 2024 """ 2025 if V.graph.scheduler: 2026 V.graph.scheduler.remove_kernel_local_buffers() 2027 super().__exit__(exc_type, exc_val, exc_tb) 2028 2029 def rename_indexing(self, index) -> sympy.Expr: 2030 # adds the necessary kernel args for index expressions 2031 # and renames variables in index expressions to kernel arg names 2032 if isinstance(index, (list, tuple)): 2033 return [self.rename_indexing(x) for x in index] # type: ignore[return-value] 2034 index = V.graph.sizevars.simplify(index) 2035 sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) 2036 replacements = { 2037 x: self.args.size(x) 2038 for x in sorted_symbols 2039 if symbol_is_type( 2040 x, 2041 ( 2042 SymT.UNBACKED_INT, 2043 SymT.SIZE, 2044 SymT.PRECOMPUTED_SIZE, 2045 ), 2046 ) 2047 } 2048 return sympy_subs(index, replacements) 2049 2050 def create_cse_var(self, *args, **kwargs): 2051 return CSEVariable(*args, **kwargs) 2052 2053 2054@dataclasses.dataclass 2055class OptimizationContext: 2056 key: ClassVar[str] = "opt_ctx" 2057 2058 dtype: Optional[torch.dtype] = None 2059 ops_name: str = "" 2060 2061 2062@functools.lru_cache(None) 2063def jinja2_env(): 2064 try: 2065 import jinja2 2066 2067 return jinja2.Environment( 2068 undefined=jinja2.StrictUndefined, 2069 ) 2070 except ImportError: 2071 return None 2072 2073 2074class KernelTemplate: 2075 """ 2076 Base class for defining kernel templates. 2077 2078 Children classes: TritonTemplate, CUDATemplate 2079 """ 2080 2081 @staticmethod 2082 def indent_except_first(source: str, num_indents: int, indents_spacing=4): 2083 lines = source.splitlines(True) 2084 if len(lines) > 1: 2085 lines[1:] = [ 2086 (" " * indents_spacing * num_indents) + line for line in lines[1:] 2087 ] 2088 return "".join(lines) 2089 2090 @staticmethod 2091 def _template_from_string(source): 2092 env = jinja2_env() 2093 if env is not None: 2094 env.filters["indent_except_first"] = KernelTemplate.indent_except_first 2095 from jinja2 import TemplateSyntaxError 2096 2097 class DetailedTemplateSyntaxError(TemplateSyntaxError): 2098 def __init__(self, original_error): 2099 super().__init__( 2100 original_error.message, 2101 original_error.lineno, 2102 original_error.name, 2103 original_error.filename, 2104 ) 2105 self.original_error = original_error 2106 2107 def __str__(self): 2108 error_info = f"Error in template at line {self.lineno}\n" 2109 error_info += f"Error message: {self.message}\n" 2110 if hasattr(self.original_error, "source"): 2111 lines = self.original_error.source.split("\n") 2112 error_info += "Context:\n" 2113 start = max(0, self.lineno - 2) 2114 end = min(len(lines), self.lineno + 2) 2115 for i in range(start, end): 2116 if i == self.lineno - 1: 2117 error_info += f"{i+1}: --> {lines[i]}\n" 2118 if hasattr(self.original_error, "column"): 2119 error_info += ( 2120 " " 2121 + " " * (self.original_error.column - 1) 2122 + "^\n" 2123 ) 2124 else: 2125 error_info += f"{i+1}: {lines[i]}\n" 2126 return error_info 2127 2128 try: 2129 return env.from_string(source) 2130 except TemplateSyntaxError as e: 2131 raise DetailedTemplateSyntaxError(e) from e 2132 2133 return None 2134 2135 @staticmethod 2136 def _fake_get_dtype(fake_out): 2137 _get_dtype_real = V.graph.get_dtype 2138 2139 def get_dtype(name): 2140 if name == fake_out.get_name(): 2141 return fake_out.get_dtype() 2142 return _get_dtype_real(name) 2143 2144 return get_dtype 2145 2146 def __init__(self, name: str): 2147 self.name = name 2148 2149 def maybe_append_choice(self, choices, **kwargs): 2150 """ 2151 Maybe generates a new ChoiceCaller and appends it into existing choices. 2152 2153 choices: A list of ChoiceCallers. 2154 kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. 2155 """ 2156 2157 try: 2158 choices.append(self.generate(**kwargs)) 2159 except NotImplementedError as e: 2160 pass 2161 2162 def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller": 2163 """ 2164 Generates a ChoiceCaller instance from the given arguments. 2165 """ 2166 2167 raise NotImplementedError 2168