1# mypy: allow-untyped-defs 2import itertools 3from typing import ( 4 Any, 5 Callable, 6 Dict, 7 Generic, 8 List, 9 Literal, 10 NamedTuple, 11 Optional, 12 Tuple, 13 TypeVar, 14 Union, 15) 16from typing_extensions import Protocol 17from unittest.mock import patch 18 19import sympy 20 21import torch 22import torch.utils._pytree as pytree 23 24from ..utils._ordered_set import OrderedSet 25from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str 26 27 28T = TypeVar("T") 29StoreMode = Optional[Literal["atomic_add"]] 30ReductionType = Literal[ 31 "argmax", 32 "argmin", 33 "welford_reduce", 34 "welford_combine", 35 "any", 36 "max", 37 "min", 38 "prod", 39 "sum", 40 "xor_sum", 41] 42 43 44def _arg_str(a) -> str: 45 if isinstance(a, sympy.Expr): 46 return sympy_str(a) 47 return str(a) 48 49 50# NB: This is not done as a parent class, because our ops handlers 51# implementations make heavy use of __getattr__ magic, and pre-existing 52# stubs for methods would interfere with this mechanism. 53# 54# TODO: A superclass that does desugaring for operations like 55# reciprocal/square might be useful. 56class OpsHandler(Protocol[T]): 57 """ 58 Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``, 59 as well as the contract for op handlers. The type T signifies the domain 60 of the abstract analysis AKA what all of the functions return / take as arguments 61 anywhere compute occurs. 62 63 While these operators are typically dtype polymorphic (e.g., you can use mul 64 on both integers and floats), they do NOT do promotion and usually return the 65 same dtype as the input. You are expected to have handled type promotion 66 during ATen decompositions. Most operators correspond exactly to pointwise 67 operations as defined by torch, so when in doubt about semantics, check the 68 corresponding torch documentation. These are all scalar operations (so they 69 are defined to operate on a single element at a time.) 70 71 For convenience, many operators take a src_dtype which indicates what the dtype 72 of the input argument is. Although in principle this can be derived by an 73 analysis, providing this for ops where it is useful helps avoid having to repeatedly 74 recompute dtype in code generation. 75 76 Note that this often describes a class of static methods, for stateless 77 ops handlers. 78 79 Handlers are often defined using ``__getattr__`` metaprogramming, which means 80 that you cannot declare that a type implements a protocol by inheriting from 81 it (as the type stubs count as attribute declarations and impede the getattr 82 magic method from being called). Instead, define a function that casts an 83 argument of your type to the protocol, which is sufficient to induce mypy to 84 test that the protocol is implemented correctly. Search for ``_typecheck_`` 85 in this file to see some examples. If you see an obscure error where a 86 class doesn't implement a Protocol, but mypy doesn't say why, check to see 87 that ``__getattr__`` is typed correctly (typically, it is not possible to 88 type ``__getattr__`` without typing it as ``Callable[..., Any]``) 89 """ 90 91 def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T: 92 """Produces a scalar constant of type dtype.""" 93 ... 94 95 def load_seed(self, name: str, offset: T): 96 """Computes inductor_prims.lookup_seed.""" 97 ... 98 99 def rand(self, seed: T, offset: T) -> T: 100 """Computes inductor_prims.random with mode="rand". offset has dtype int32.""" 101 ... 102 103 def randn(self, seed: T, offset: T) -> T: 104 """Computes inductor_prims.random with mode="randn". offset has dtype int32.""" 105 ... 106 107 def randint64(self, seed: T, offset: T, low: T, high: T) -> T: 108 """Computes inductor_prims.randint. offset has dtype int32.""" 109 ... 110 111 def masked(self, mask: T, body: Callable[[], T], other: T) -> T: 112 """ 113 Computes body, but only perform loads/stores if the boolean mask 114 evaluates to true. For example, you would use this if you needed to 115 perform an indirect load that may not be valid on some elements; 116 without masking, invalid accesses can cause IMAs. When mask is true, 117 the result is the result of body; otherwise it is other. Here, `other` 118 needs to be a constant. 119 120 Contrast this with ops.where, which can multiplex between two values 121 that have been unconditionally computed. 122 """ 123 ... 124 125 def where(self, condition: T, input: T, other: T) -> T: 126 """ 127 Computes torch.where: when condition is true, return input; otherwise return other. 128 """ 129 ... 130 131 def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T: 132 """ 133 Converts a sympy expression into a scalar of type dtype. expr is typically 134 an indexing expression, thus the name; however, it can also be used in 135 non-indexing situations. 136 """ 137 ... 138 139 def to_dtype( 140 self, 141 x: T, 142 dtype: torch.dtype, 143 src_dtype: Optional[torch.dtype] = None, 144 use_compute_types=True, 145 ) -> T: 146 """ 147 Convert x to dtype. src_dtype can be optionally set to specify what the original 148 dtype of x was, which can improve code generation (used by torch to(dtype=dtype)). 149 """ 150 ... 151 152 def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: 153 """ 154 Convert x to dtype with truncation semantics (similar to how the int 155 constructor works in Python). In Inductor codegen, this just decays 156 to trunc and then to_dtype, but this composite operation helps 157 roundtrips for Sympy evaluation. 158 159 dtype is taken as an explicit parameter because the desired output 160 dtype is typically the index dtype, which may vary between int32 and 161 int64 depending on if we've shown that all the indexing operations can 162 be done in int32. 163 """ 164 ... 165 166 def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: 167 """ 168 Convert x to dtype with ceiling semantics. See also trunc_to_int. 169 """ 170 ... 171 172 def floor_to_int(self, x: T, dtype: torch.dtype) -> T: 173 """ 174 Convert x to dtype with ceiling semantics. See also trunc_to_int. 175 """ 176 ... 177 178 def round_to_int(self, x: T, dtype: torch.dtype) -> T: 179 """ 180 Convert x to dtype with round-to-even semantics. See also trunc_to_int. 181 """ 182 ... 183 184 def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: 185 """ 186 Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) 187 src_dtype must be the original type of x. 188 """ 189 ... 190 191 def identity(self, x: T) -> T: 192 """ 193 Returns x as is. This is used to trigger CSE. 194 """ 195 ... 196 197 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 198 # These operations are only available in a "kernel" context. Check 199 # torch._inductor.codegen.common.CSEProxy for their typical implementation 200 # in op handler (routing to their respective implementations in the kernel 201 # handler) 202 # 203 # Importantly, inside a kernel, indexing and mask variables are available 204 # in scope, which are typically used by sympy.Expr indexing. 205 206 def indirect_indexing( 207 self, x: T, size: sympy.Expr, check: bool = True, wrap_neg=True 208 ) -> sympy.Expr: 209 """ 210 Convert an integral x into a sympy.Expr that can be subsequently used in 211 indexing computation. 'size' represents an upper bound on the what valid 212 indexes can be; when 'check' is True, we check that the x is in bounds. 213 214 NB: This is typically mandatory to implement for any analysis, because you 215 MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol). 216 """ 217 ... 218 219 def load(self, name: str, index: sympy.Expr) -> T: 220 """ 221 Load from the memory location 'name', offset by some indexing expression 'index'. 222 """ 223 ... 224 225 def store( 226 self, 227 name: str, 228 index: sympy.Expr, 229 value: T, 230 mode: StoreMode = None, 231 ) -> None: 232 """ 233 Store 'value' to the memory location 'name' offset by 'expr'. If 234 specified, 'mode' can require the store to be an atomic addition. 235 """ 236 ... 237 238 # TODO: Better explain how the "collective" semantics of these ops; 239 # remember that the input value is a scalar, you can't reduce on it in the 240 # traditional sense! 241 def reduction( 242 self, 243 dtype: torch.dtype, 244 src_dtype: torch.dtype, 245 reduction_type: ReductionType, 246 value: T, 247 ) -> Union[T, Tuple[T, ...]]: 248 """ 249 Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype', 250 using 'dtype' as the accumulation dtype for the reduction. The result 251 is an intermediate computation which should be stored to the final 252 location using 'ops.store_reduction'. 253 254 Valid reduction types are . For Welford reduction types, this 255 function returns multiple outputs; consult reduction_num_outputs to 256 determine the amount in metaprogramming applications. 257 """ 258 ... 259 260 # TODO: in practice, this seems to actually return None, but not returning 261 # a T makes common __getattr__ idioms not type correctly. Figure out if 262 # this should be returning something. 263 def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T: 264 """ 265 Store the fully accumulated result of 'reduction' to the memory 266 location 'name' offset by 'expr'. 267 """ 268 ... 269 270 def scan( 271 self, 272 dtypes: Tuple[torch.dtype, ...], 273 combine_fn: Callable[[Tuple[T, ...], Tuple[T, ...]], Tuple[T, ...]], 274 values: Tuple[T, ...], 275 ) -> Tuple[T, ...]: 276 """ 277 Perform an associative scan on 'value'. 278 """ 279 # TODO: Improve the description with some pseudocode 280 ... 281 282 def sort( 283 self, 284 dtypes: Tuple[torch.dtype, ...], 285 values: Tuple[T, ...], 286 stable: bool, 287 descending: bool, 288 ) -> Tuple[T, ...]: 289 """ 290 Sort values along the reduction dimension. 291 """ 292 ... 293 294 def bucketize( 295 self, 296 values: T, 297 offsets_name: str, 298 offsets_size: sympy.Expr, 299 indexing_dtype: torch.dtype, 300 right: bool, 301 ) -> T: 302 # See [Note: Inductor bucketize op] 303 ... 304 305 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 306 # The following ops have semantics that correspond exactly to the torch 307 # operation with the same corresponding name. 308 309 def abs(self, x0: T) -> T: 310 ... 311 312 def exp(self, x0: T) -> T: 313 ... 314 315 def exp2(self, x0: T) -> T: 316 ... 317 318 def expm1(self, x0: T) -> T: 319 ... 320 321 def sqrt(self, x0: T) -> T: 322 ... 323 324 def relu(self, x0: T) -> T: 325 ... 326 327 def minimum(self, x0: T, x1: T) -> T: 328 ... 329 330 def maximum(self, x0: T, x1: T) -> T: 331 ... 332 333 def cos(self, x0: T) -> T: 334 ... 335 336 def sin(self, x0: T) -> T: 337 ... 338 339 def lgamma(self, x0: T) -> T: 340 ... 341 342 def erf(self, x0: T) -> T: 343 ... 344 345 def cosh(self, x0: T) -> T: 346 ... 347 348 def sinh(self, x0: T) -> T: 349 ... 350 351 def acos(self, x0: T) -> T: 352 ... 353 354 def acosh(self, x0: T) -> T: 355 ... 356 357 def asin(self, x0: T) -> T: 358 ... 359 360 def asinh(self, x0: T) -> T: 361 ... 362 363 def atan2(self, x0: T, x1: T) -> T: 364 ... 365 366 def atan(self, x0: T) -> T: 367 ... 368 369 def atanh(self, x0: T) -> T: 370 ... 371 372 def copysign(self, x0: T, x1: T) -> T: 373 ... 374 375 def erfc(self, x0: T) -> T: 376 ... 377 378 def erfinv(self, x0: T) -> T: 379 ... 380 381 def frexp(self, x0: T): 382 ... 383 384 def hypot(self, x0: T, x1: T) -> T: 385 ... 386 387 def log10(self, x0: T) -> T: 388 ... 389 390 def log2(self, x0: T) -> T: 391 ... 392 393 def nextafter(self, x0: T, x1: T) -> T: 394 ... 395 396 def logical_and(self, x0: T, x1: T) -> T: 397 ... 398 399 def logical_not(self, x0: T) -> T: 400 ... 401 402 def logical_or(self, x0: T, x1: T) -> T: 403 ... 404 405 def logical_xor(self, x0: T, x1: T) -> T: 406 ... 407 408 def bitwise_and(self, x0: T, x1: T) -> T: 409 ... 410 411 def bitwise_not(self, x0: T) -> T: 412 ... 413 414 def bitwise_or(self, x0: T, x1: T) -> T: 415 ... 416 417 def bitwise_xor(self, x0: T, x1: T) -> T: 418 ... 419 420 def bitwise_left_shift(self, x0: T, x1: T) -> T: 421 ... 422 423 def bitwise_right_shift(self, x0: T, x1: T) -> T: 424 ... 425 426 def rsqrt(self, x0: T) -> T: 427 ... 428 429 def log1p(self, x0: T) -> T: 430 ... 431 432 def tan(self, x0: T) -> T: 433 ... 434 435 def tanh(self, x0: T) -> T: 436 ... 437 438 def sigmoid(self, x0: T) -> T: 439 ... 440 441 def signbit(self, x0: T) -> T: 442 ... 443 444 def fmod(self, x0: T, x1: T) -> T: 445 ... 446 447 def log(self, x0: T) -> T: 448 ... 449 450 def isinf(self, x0: T) -> T: 451 ... 452 453 def isnan(self, x0: T) -> T: 454 ... 455 456 # NB: this returns a float, like the torch operation 457 # This rounds half to even to break ties 458 def round(self, x0: T) -> T: 459 ... 460 461 # NB: this returns a float, like the torch operation 462 def floor(self, x0: T) -> T: 463 ... 464 465 def sign(self, x0: T) -> T: 466 ... 467 468 # NB: this returns a float, like the torch operation 469 def trunc(self, x0: T) -> T: 470 ... 471 472 # NB: this returns a float, like the torch operation 473 def ceil(self, x0: T) -> T: 474 ... 475 476 def neg(self, x0: T) -> T: 477 ... 478 479 def reciprocal(self, x0: T) -> T: 480 ... 481 482 def eq(self, x0: T, x1: T) -> T: 483 ... 484 485 def ne(self, x0: T, x1: T) -> T: 486 ... 487 488 def lt(self, x0: T, x1: T) -> T: 489 ... 490 491 def gt(self, x0: T, x1: T) -> T: 492 ... 493 494 def le(self, x0: T, x1: T) -> T: 495 ... 496 497 def ge(self, x0: T, x1: T) -> T: 498 ... 499 500 def add(self, x0: T, x1: T) -> T: 501 ... 502 503 def sub(self, x0: T, x1: T) -> T: 504 ... 505 506 def mul(self, x0: T, x1: T) -> T: 507 ... 508 509 # NB: this returns a float, like the torch operation 510 def pow(self, x0: T, x1: T) -> T: 511 ... 512 513 def and_(self, x0: T, x1: T) -> T: 514 ... 515 516 def or_(self, x0: T, x1: T) -> T: 517 ... 518 519 def xor(self, x0: T, x1: T) -> T: 520 ... 521 522 # These are metaprogrammed by MockHandler._init_cls 523 def lshift(self, x0: T, x1: T) -> T: 524 ... 525 526 def rshift(self, x0: T, x1: T) -> T: 527 ... 528 529 def getitem(self, x0: T, x1: T) -> T: 530 # TODO: this is probably just illegal lol 531 ... 532 533 def matmul(self, x0: T, x1: T) -> T: 534 # TODO: this is probably just illegal lol 535 ... 536 537 def invert(self, x0: T) -> T: 538 ... 539 540 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 541 # These are "special" operators. These only exist if the target 542 # language actually supports the operator. Keep this in sync with 543 # pointwise_overrides_data. 544 545 def airy_ai(self, x: T) -> T: 546 ... 547 548 def bessel_j0(self, x: T) -> T: 549 ... 550 551 def bessel_j1(self, x: T) -> T: 552 ... 553 554 def bessel_y0(self, x: T) -> T: 555 ... 556 557 def bessel_y1(self, x: T) -> T: 558 ... 559 560 def digamma(self, x: T) -> T: 561 ... 562 563 def erfcx(self, x: T) -> T: 564 ... 565 566 def fma(self, x: T, y: T, z: T) -> T: 567 ... 568 569 def igamma(self, x: T, y: T) -> T: 570 ... 571 572 def igammac(self, x: T, y: T) -> T: 573 ... 574 575 def gammainc(self, x: T, y: T) -> T: 576 ... 577 578 def gammaincc(self, x: T, y: T) -> T: 579 ... 580 581 def i0(self, x: T) -> T: 582 ... 583 584 def i0e(self, x: T) -> T: 585 ... 586 587 def i1(self, x: T) -> T: 588 ... 589 590 def i1e(self, x: T) -> T: 591 ... 592 593 def log_ndtr(self, x: T) -> T: 594 ... 595 596 def modified_bessel_i0(self, x: T) -> T: 597 ... 598 599 def modified_bessel_i1(self, x: T) -> T: 600 ... 601 602 def modified_bessel_k0(self, x: T) -> T: 603 ... 604 605 def modified_bessel_k1(self, x: T) -> T: 606 ... 607 608 def ndtr(self, x: T) -> T: 609 ... 610 611 def ndtri(self, x: T) -> T: 612 ... 613 614 def polygamma(self, x: T, y: T) -> T: 615 ... 616 617 def scaled_modified_bessel_k0(self, x: T) -> T: 618 ... 619 620 def scaled_modified_bessel_k1(self, x: T) -> T: 621 ... 622 623 def spherical_bessel_j0(self, x: T) -> T: 624 ... 625 626 def zeta(self, x: T, y: T) -> T: 627 ... 628 629 def chebyshev_polynomial_t(self, x: T, y: T) -> T: 630 ... 631 632 def chebyshev_polynomial_u(self, x: T, y: T) -> T: 633 ... 634 635 def chebyshev_polynomial_v(self, x: T, y: T) -> T: 636 ... 637 638 def chebyshev_polynomial_w(self, x: T, y: T) -> T: 639 ... 640 641 def legendre_polynomial_p(self, x: T, y: T) -> T: 642 ... 643 644 def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T: 645 ... 646 647 def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T: 648 ... 649 650 def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T: 651 ... 652 653 def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T: 654 ... 655 656 def hermite_polynomial_h(self, x: T, y: T) -> T: 657 ... 658 659 def hermite_polynomial_he(self, x: T, y: T) -> T: 660 ... 661 662 def laguerre_polynomial_l(self, x: T, y: T) -> T: 663 ... 664 665 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 666 # These operators are a bit special, because they are conventionally 667 # natively supported in both Python and C, but the semantics differ so 668 # care must be taken 669 670 def truncdiv(self, x0: T, x1: T) -> T: 671 """C-style trunc division between integers only. Computes the true 672 division of two numbers and rounds the result to zero. 673 """ 674 ... 675 676 def floordiv(self, x0: T, x1: T) -> T: 677 """Python-style floor division between integers only. Computes the 678 true division of two numbers and floors the result. If you want 679 floor division for floats, do regular truediv and floor the result. 680 """ 681 ... 682 683 def truediv(self, x0: T, x1: T) -> T: 684 """True division between floats. Integer inputs are NOT valid. To 685 do Python-style (int, int) -> float division, use int_truediv""" 686 ... 687 688 def int_truediv(self, x0: T, x1: T) -> T: 689 """True division between integers. This is NOT the same as promoting 690 to float and doing integer division, there is a bespoke algorithm for 691 doing the division in higher precision than the above. 692 """ 693 ... 694 695 def div(self, x0: T, x1: T) -> T: 696 """TODO: to be removed. This renders as / no matter what the backend is 697 which is incoherent.""" 698 ... 699 700 def mod(self, x0: T, x1: T) -> T: 701 """C-style modulus, take sign from LHS (x0).""" 702 ... 703 704 def remainder(self, x0: T, x1: T) -> T: 705 """Python-style modulus, take sign from RHS (x1).""" 706 ... 707 708 def round_decimal(self, x0: T, x1: T) -> T: 709 """Python-style round with decimal argument""" 710 ... 711 712 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 713 # In CUDA, optimized implementations of other mathematical operations are 714 # offered separately via libdevice for double precision computation (in 715 # Triton, these go to tl.math rather than tl). We lower to these 716 # operators when doing FP64 on CUDA. Note that some operators 717 # unconditional go to tl.math. 718 # 719 # TODO(ezyang): Is this really the best way to do this? What if we have 720 # abs internally route to tl.math automatically when given a double 721 # precision input? One reason is that when doing codegen, we often don't 722 # know what the dtype of the inputs are! (In principle we do know, but 723 # for many analyses it's not conveniently available.) 724 725 def libdevice_abs(self, x0: T) -> T: 726 ... 727 728 def libdevice_exp(self, x0: T) -> T: 729 ... 730 731 def libdevice_sqrt(self, x0: T) -> T: 732 ... 733 734 def libdevice_cos(self, x0: T) -> T: 735 ... 736 737 def libdevice_sin(self, x0: T) -> T: 738 ... 739 740 def libdevice_sigmoid(self, x0: T) -> T: 741 ... 742 743 def libdevice_log(self, x0: T) -> T: 744 ... 745 746 747class NoopHandler: 748 def __getattr__(self, name): 749 if name == "name": 750 return "NoopHandler" 751 752 def inner(*args, **kwargs): 753 return None 754 755 return inner 756 757 @staticmethod 758 def masked(mask, body, other) -> None: 759 return None 760 761 @staticmethod 762 def frexp(x) -> Tuple[None, None]: 763 return (None, None) 764 765 @staticmethod 766 def scan(dtypes, combine_fn, values) -> Tuple[None, ...]: 767 return (None,) * len(values) 768 769 @staticmethod 770 def sort(dtypes, values, stable, descending) -> Tuple[None, ...]: 771 return (None,) * len(values) 772 773 @staticmethod 774 def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: 775 return sympy.Integer(0) 776 777 778# Use mypy to check protocol implemented correctly 779def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]: 780 return h 781 782 783class MockHandler: 784 def __getattr__(self, name): 785 if name == "name": 786 return "MockHandler" 787 788 def inner(*args, **kwargs): 789 fargs = [_arg_str(a) for a in args] 790 fargs.extend(f"{k}={v}" for k, v in kwargs.items()) 791 return f"ops.{name}({', '.join(fargs)})" 792 793 return inner 794 795 @staticmethod 796 def masked(mask, body, other) -> str: 797 return f"ops.masked({mask}, {body()}, {other})" 798 799 @staticmethod 800 def frexp(x): 801 return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]") 802 803 @staticmethod 804 def scan(dtypes, combine_fn, values): 805 return tuple( 806 f"ops.scan({dtypes}, {combine_fn}, {values})[{i}]" 807 for i in range(len(values)) 808 ) 809 810 @staticmethod 811 def sort(dtypes, values, stable, descending): 812 return tuple( 813 f"ops.sort({dtypes}, {values}, stable={stable}, descending={descending})[{i}]" 814 for i in range(len(values)) 815 ) 816 817 @staticmethod 818 def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: 819 return sympy_index_symbol(str(index_var)) 820 821 @classmethod 822 def _init_cls(cls): 823 def make_handler(format_string): 824 @staticmethod # type: ignore[misc] 825 def inner(*args): 826 return format_string.format(*args) 827 828 return inner 829 830 for name, format_string in { 831 "add": "{} + {}", 832 "sub": "{} - {}", 833 "mul": "{} * {}", 834 "floordiv": "{} // {}", 835 "truediv": "{} / {}", 836 "mod": "{} % {}", # careful, depending on target semantics varies 837 "pow": "{} ** {}", 838 "lshift": "{} << {}", 839 "rshift": "{} >> {}", 840 "and_": "{} & {}", 841 "or_": "{} | {}", 842 "xor": "{} ^ {}", 843 "eq": "{} == {}", 844 "ne": "{} != {}", 845 "lt": "{} < {}", 846 "gt": "{} > {}", 847 "le": "{} <= {}", 848 "ge": "{} >= {}", 849 "neg": "-{}", 850 }.items(): 851 setattr(cls, name, make_handler(format_string)) 852 853 854MockHandler._init_cls() 855 856 857# Use mypy to check protocol implemented correctly 858def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]: 859 return h 860 861 862class KernelFormatterHandler: 863 def __init__(self, parent_handler): 864 self.parent_handler = parent_handler 865 self.output = IndentedBuffer(1) 866 self.var_counter = itertools.count() 867 868 @staticmethod 869 def ir_to_string(ir_fn, index, rindex=None) -> str: 870 from .ir import FlexibleLayout 871 from .virtualized import V 872 873 args = [index, rindex] if rindex is not None else [index] 874 names = ["index", "rindex"] if rindex is not None else ["index"] 875 formatter = KernelFormatterHandler(MockHandler()) 876 877 with formatter.output.indent(-1): 878 formatter.output.writeline(f"def inner_fn({', '.join(names)}):") 879 for name, arg in zip(names, args): 880 if arg: 881 lhs = ", ".join( 882 [ 883 str("_" if isinstance(v, (int, sympy.Integer)) else v) 884 for v in arg 885 ] 886 ) 887 formatter.output.writeline(f"{lhs} = {name}") 888 889 with V.set_ops_handler(formatter), patch.object( 890 FlexibleLayout, "allow_indexing", True 891 ): 892 result = ir_fn(*args) 893 return formatter.getvalue(result) 894 895 def __getattr__(self, name) -> Callable[..., Any]: 896 def inner(*args, **kwargs): 897 line = getattr(self.parent_handler, name)(*args, **kwargs) 898 if name == "indirect_indexing": 899 return line 900 901 def write(line): 902 # replace line with a new variable name 903 varname = f"tmp{next(self.var_counter)}" 904 self.output.writeline(f"{varname} = {line}") 905 return varname 906 907 return pytree.tree_map(write, line) 908 909 return inner 910 911 def reduction( 912 self, 913 dtype: torch.dtype, 914 src_dtype: torch.dtype, 915 reduction_type: ReductionType, 916 value: Union[str, Tuple[str, ...]], 917 ) -> Union[str, Tuple[str, ...]]: 918 line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value) 919 num_values = reduction_num_outputs(reduction_type) 920 varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)] 921 self.output.writeline(f"{','.join(varnames)} = {line}") 922 return tuple(varnames) if num_values > 1 else varnames[0] 923 924 def getvalue(self, result): 925 self.output.writeline(f"return {result}") 926 return self.output.getvalue() 927 928 929# Use mypy to check protocol implemented correctly 930def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: 931 return h 932 933 934class WrapperHandler(Generic[T]): 935 def __init__(self, inner: OpsHandler[T]): 936 self._inner = inner 937 938 def __getattr__(self, item): 939 return getattr(self._inner, item) 940 941 942# Use mypy to check protocol implemented correctly 943def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]: 944 return h 945 946 947class AddParenHandler(WrapperHandler[T]): 948 def __getattr__(self, name): 949 def inner(*args, **kwargs): 950 val = getattr(self._inner, name)(*args, **kwargs) 951 return f"({val})" 952 953 return inner 954 955 956# Use mypy to check protocol implemented correctly 957def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]: 958 return h 959 960 961class OpCountResult(NamedTuple): 962 num_ops: int 963 used_ops: OrderedSet[str] 964 read_buffers: List[str] 965 nontrivial_read_count: int 966 967 968class OpCounterCSE: 969 """Shim to count how many ops are used""" 970 971 def __init__(self, inner): 972 super().__init__() 973 self.parent_handler = inner 974 self.op_count = 0 975 self.var_names = {} 976 self._used_ops: OrderedSet[str] = OrderedSet() 977 self._read_names: List[str] = [] 978 self._nontrivial_read_count = 0 979 980 def __getattr__(self, name): 981 def inner(*args, **kwargs): 982 return pytree.tree_map( 983 self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) 984 ) 985 986 self._used_ops.add(name) 987 return inner 988 989 def _update_count(self, val): 990 varname = self.var_names.get(val) 991 if not varname: 992 varname = f"tmp{self.op_count}" 993 self.op_count += 1 994 self.var_names[val] = varname 995 return varname 996 997 def indirect_indexing(self, *args, **kwargs): 998 self._used_ops.add("indirect_indexing") 999 return self.parent_handler.indirect_indexing(*args, **kwargs) 1000 1001 def load(self, name: str, index: sympy.Expr) -> str: 1002 val = self.parent_handler.load(name, index) 1003 if val not in self.var_names: 1004 self._used_ops.add("load") 1005 self._read_names.append(name) 1006 if not isinstance(index, (sympy.Integer, int)): 1007 self._nontrivial_read_count += 1 1008 return self._update_count(val) 1009 1010 def load_seed(self, name: str, offset: T): 1011 val = self.parent_handler.load_seed(name, offset) 1012 if val not in self.var_names: 1013 self._used_ops.add("load_seed") 1014 self._read_names.append(name) 1015 return self._update_count(val) 1016 1017 def bucketize( 1018 self, 1019 values, 1020 offsets_name: str, 1021 offsets_size: sympy.Expr, 1022 indexing_dtype: torch.dtype, 1023 right: bool, 1024 ): 1025 val = self.parent_handler.bucketize( 1026 values, offsets_name, offsets_size, indexing_dtype, right 1027 ) 1028 if val not in self.var_names: 1029 self._used_ops.add("bucketize") 1030 self._read_names.append(offsets_name) 1031 return self._update_count(val) 1032 1033 def getvalue(self): 1034 return OpCountResult( 1035 self.op_count, self._used_ops, self._read_names, self._nontrivial_read_count 1036 ) 1037 1038 1039def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]: 1040 return h 1041 1042 1043class ExtractConstantsHandler(NoopHandler): 1044 def __init__(self, device): 1045 self.device = device 1046 1047 def constant(self, value: Any, dtype: torch.dtype) -> "torch._inductor.ir.Constant": 1048 from torch._inductor import ir 1049 1050 return ir.Constant(value=value, dtype=dtype, device=self.device) 1051 1052 1053def _typecheck_ExtractConstantsHandler(h: ExtractConstantsHandler) -> OpsHandler[Any]: 1054 return h 1055 1056 1057class SimpleCSEHandler(WrapperHandler[T]): 1058 """Wraps the underlying handler with a CSE pass 1059 1060 NOTE: Compared to codegen level CSE this is simplified as it 1061 doesn't support stores which require load cache invalidation. 1062 """ 1063 1064 def __init__(self, inner: OpsHandler[T]): 1065 super().__init__(inner) 1066 self.cse_cache: Dict[str, Union[T, Tuple[T, ...]]] = {} 1067 self.mock = MockHandler() 1068 1069 def indirect_indexing(self, *args, **kwargs) -> sympy.Expr: 1070 return super().indirect_indexing(*args, **kwargs) # type: ignore[misc] 1071 1072 def store(self, *args, **kwargs) -> T: 1073 raise NotImplementedError("store not implemented") 1074 1075 def store_reduction(self, *args, **kwargs) -> T: 1076 raise NotImplementedError("store not implemented") 1077 1078 def __getattr__(self, name) -> Callable[..., Any]: 1079 def inner(*args, **kwargs): 1080 key = getattr(self.mock, name)(*args, **kwargs) 1081 val = self.cse_cache.get(key) 1082 if val is not None: 1083 return val 1084 1085 val = getattr(self._inner, name)(*args, **kwargs) 1086 self.cse_cache[key] = val 1087 return val 1088 1089 return inner 1090 1091 1092def _typecheck_SimpleCSEHandler(h: SimpleCSEHandler[Any]) -> OpsHandler[Any]: 1093 return h 1094