xref: /aosp_15_r20/external/pytorch/torch/_inductor/ops_handler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3from typing import (
4    Any,
5    Callable,
6    Dict,
7    Generic,
8    List,
9    Literal,
10    NamedTuple,
11    Optional,
12    Tuple,
13    TypeVar,
14    Union,
15)
16from typing_extensions import Protocol
17from unittest.mock import patch
18
19import sympy
20
21import torch
22import torch.utils._pytree as pytree
23
24from ..utils._ordered_set import OrderedSet
25from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str
26
27
28T = TypeVar("T")
29StoreMode = Optional[Literal["atomic_add"]]
30ReductionType = Literal[
31    "argmax",
32    "argmin",
33    "welford_reduce",
34    "welford_combine",
35    "any",
36    "max",
37    "min",
38    "prod",
39    "sum",
40    "xor_sum",
41]
42
43
44def _arg_str(a) -> str:
45    if isinstance(a, sympy.Expr):
46        return sympy_str(a)
47    return str(a)
48
49
50# NB: This is not done as a parent class, because our ops handlers
51# implementations make heavy use of __getattr__ magic, and pre-existing
52# stubs for methods would interfere with this mechanism.
53#
54# TODO: A superclass that does desugaring for operations like
55# reciprocal/square might be useful.
56class OpsHandler(Protocol[T]):
57    """
58    Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
59    as well as the contract for op handlers.  The type T signifies the domain
60    of the abstract analysis AKA what all of the functions return / take as arguments
61    anywhere compute occurs.
62
63    While these operators are typically dtype polymorphic (e.g., you can use mul
64    on both integers and floats), they do NOT do promotion and usually return the
65    same dtype as the input.  You are expected to have handled type promotion
66    during ATen decompositions.  Most operators correspond exactly to pointwise
67    operations as defined by torch, so when in doubt about semantics, check the
68    corresponding torch documentation.  These are all scalar operations (so they
69    are defined to operate on a single element at a time.)
70
71    For convenience, many operators take a src_dtype which indicates what the dtype
72    of the input argument is.  Although in principle this can be derived by an
73    analysis, providing this for ops where it is useful helps avoid having to repeatedly
74    recompute dtype in code generation.
75
76    Note that this often describes a class of static methods, for stateless
77    ops handlers.
78
79    Handlers are often defined using ``__getattr__`` metaprogramming, which means
80    that you cannot declare that a type implements a protocol by inheriting from
81    it (as the type stubs count as attribute declarations and impede the getattr
82    magic method from being called).  Instead, define a function that casts an
83    argument of your type to the protocol, which is sufficient to induce mypy to
84    test that the protocol is implemented correctly.  Search for ``_typecheck_``
85    in this file to see some examples.  If you see an obscure error where a
86    class doesn't implement a Protocol, but mypy doesn't say why, check to see
87    that ``__getattr__`` is typed correctly (typically, it is not possible to
88    type ``__getattr__`` without typing it as ``Callable[..., Any]``)
89    """
90
91    def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T:
92        """Produces a scalar constant of type dtype."""
93        ...
94
95    def load_seed(self, name: str, offset: T):
96        """Computes inductor_prims.lookup_seed."""
97        ...
98
99    def rand(self, seed: T, offset: T) -> T:
100        """Computes inductor_prims.random with mode="rand".  offset has dtype int32."""
101        ...
102
103    def randn(self, seed: T, offset: T) -> T:
104        """Computes inductor_prims.random with mode="randn".  offset has dtype int32."""
105        ...
106
107    def randint64(self, seed: T, offset: T, low: T, high: T) -> T:
108        """Computes inductor_prims.randint.  offset has dtype int32."""
109        ...
110
111    def masked(self, mask: T, body: Callable[[], T], other: T) -> T:
112        """
113        Computes body, but only perform loads/stores if the boolean mask
114        evaluates to true.  For example, you would use this if you needed to
115        perform an indirect load that may not be valid on some elements;
116        without masking, invalid accesses can cause IMAs.  When mask is true,
117        the result is the result of body; otherwise it is other. Here, `other`
118        needs to be a constant.
119
120        Contrast this with ops.where, which can multiplex between two values
121        that have been unconditionally computed.
122        """
123        ...
124
125    def where(self, condition: T, input: T, other: T) -> T:
126        """
127        Computes torch.where: when condition is true, return input; otherwise return other.
128        """
129        ...
130
131    def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T:
132        """
133        Converts a sympy expression into a scalar of type dtype.  expr is typically
134        an indexing expression, thus the name; however, it can also be used in
135        non-indexing situations.
136        """
137        ...
138
139    def to_dtype(
140        self,
141        x: T,
142        dtype: torch.dtype,
143        src_dtype: Optional[torch.dtype] = None,
144        use_compute_types=True,
145    ) -> T:
146        """
147        Convert x to dtype.  src_dtype can be optionally set to specify what the original
148        dtype of x was, which can improve code generation (used by torch to(dtype=dtype)).
149        """
150        ...
151
152    def trunc_to_int(self, x: T, dtype: torch.dtype) -> T:
153        """
154        Convert x to dtype with truncation semantics (similar to how the int
155        constructor works in Python).  In Inductor codegen, this just decays
156        to trunc and then to_dtype, but this composite operation helps
157        roundtrips for Sympy evaluation.
158
159        dtype is taken as an explicit parameter because the desired output
160        dtype is typically the index dtype, which may vary between int32 and
161        int64 depending on if we've shown that all the indexing operations can
162        be done in int32.
163        """
164        ...
165
166    def ceil_to_int(self, x: T, dtype: torch.dtype) -> T:
167        """
168        Convert x to dtype with ceiling semantics.  See also trunc_to_int.
169        """
170        ...
171
172    def floor_to_int(self, x: T, dtype: torch.dtype) -> T:
173        """
174        Convert x to dtype with ceiling semantics.  See also trunc_to_int.
175        """
176        ...
177
178    def round_to_int(self, x: T, dtype: torch.dtype) -> T:
179        """
180        Convert x to dtype with round-to-even semantics.  See also trunc_to_int.
181        """
182        ...
183
184    def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T:
185        """
186        Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.)
187        src_dtype must be the original type of x.
188        """
189        ...
190
191    def identity(self, x: T) -> T:
192        """
193        Returns x as is.  This is used to trigger CSE.
194        """
195        ...
196
197    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
198    # These operations are only available in a "kernel" context.  Check
199    # torch._inductor.codegen.common.CSEProxy for their typical implementation
200    # in op handler (routing to their respective implementations in the kernel
201    # handler)
202    #
203    # Importantly, inside a kernel, indexing and mask variables are available
204    # in scope, which are typically used by sympy.Expr indexing.
205
206    def indirect_indexing(
207        self, x: T, size: sympy.Expr, check: bool = True, wrap_neg=True
208    ) -> sympy.Expr:
209        """
210        Convert an integral x into a sympy.Expr that can be subsequently used in
211        indexing computation.  'size' represents an upper bound on the what valid
212        indexes can be; when 'check' is True, we check that the x is in bounds.
213
214        NB: This is typically mandatory to implement for any analysis, because you
215        MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol).
216        """
217        ...
218
219    def load(self, name: str, index: sympy.Expr) -> T:
220        """
221        Load from the memory location 'name', offset by some indexing expression 'index'.
222        """
223        ...
224
225    def store(
226        self,
227        name: str,
228        index: sympy.Expr,
229        value: T,
230        mode: StoreMode = None,
231    ) -> None:
232        """
233        Store 'value' to the memory location 'name' offset by 'expr'.  If
234        specified, 'mode' can require the store to be an atomic addition.
235        """
236        ...
237
238    # TODO: Better explain how the "collective" semantics of these ops;
239    # remember that the input value is a scalar, you can't reduce on it in the
240    # traditional sense!
241    def reduction(
242        self,
243        dtype: torch.dtype,
244        src_dtype: torch.dtype,
245        reduction_type: ReductionType,
246        value: T,
247    ) -> Union[T, Tuple[T, ...]]:
248        """
249        Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype',
250        using 'dtype' as the accumulation dtype for the reduction.  The result
251        is an intermediate computation which should be stored to the final
252        location using 'ops.store_reduction'.
253
254        Valid reduction types are .  For Welford reduction types, this
255        function returns multiple outputs; consult reduction_num_outputs to
256        determine the amount in metaprogramming applications.
257        """
258        ...
259
260    # TODO: in practice, this seems to actually return None, but not returning
261    # a T makes common __getattr__ idioms not type correctly.  Figure out if
262    # this should be returning something.
263    def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T:
264        """
265        Store the fully accumulated result of 'reduction' to the memory
266        location 'name' offset by 'expr'.
267        """
268        ...
269
270    def scan(
271        self,
272        dtypes: Tuple[torch.dtype, ...],
273        combine_fn: Callable[[Tuple[T, ...], Tuple[T, ...]], Tuple[T, ...]],
274        values: Tuple[T, ...],
275    ) -> Tuple[T, ...]:
276        """
277        Perform an associative scan on 'value'.
278        """
279        # TODO: Improve the description with some pseudocode
280        ...
281
282    def sort(
283        self,
284        dtypes: Tuple[torch.dtype, ...],
285        values: Tuple[T, ...],
286        stable: bool,
287        descending: bool,
288    ) -> Tuple[T, ...]:
289        """
290        Sort values along the reduction dimension.
291        """
292        ...
293
294    def bucketize(
295        self,
296        values: T,
297        offsets_name: str,
298        offsets_size: sympy.Expr,
299        indexing_dtype: torch.dtype,
300        right: bool,
301    ) -> T:
302        # See [Note: Inductor bucketize op]
303        ...
304
305    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
306    # The following ops have semantics that correspond exactly to the torch
307    # operation with the same corresponding name.
308
309    def abs(self, x0: T) -> T:
310        ...
311
312    def exp(self, x0: T) -> T:
313        ...
314
315    def exp2(self, x0: T) -> T:
316        ...
317
318    def expm1(self, x0: T) -> T:
319        ...
320
321    def sqrt(self, x0: T) -> T:
322        ...
323
324    def relu(self, x0: T) -> T:
325        ...
326
327    def minimum(self, x0: T, x1: T) -> T:
328        ...
329
330    def maximum(self, x0: T, x1: T) -> T:
331        ...
332
333    def cos(self, x0: T) -> T:
334        ...
335
336    def sin(self, x0: T) -> T:
337        ...
338
339    def lgamma(self, x0: T) -> T:
340        ...
341
342    def erf(self, x0: T) -> T:
343        ...
344
345    def cosh(self, x0: T) -> T:
346        ...
347
348    def sinh(self, x0: T) -> T:
349        ...
350
351    def acos(self, x0: T) -> T:
352        ...
353
354    def acosh(self, x0: T) -> T:
355        ...
356
357    def asin(self, x0: T) -> T:
358        ...
359
360    def asinh(self, x0: T) -> T:
361        ...
362
363    def atan2(self, x0: T, x1: T) -> T:
364        ...
365
366    def atan(self, x0: T) -> T:
367        ...
368
369    def atanh(self, x0: T) -> T:
370        ...
371
372    def copysign(self, x0: T, x1: T) -> T:
373        ...
374
375    def erfc(self, x0: T) -> T:
376        ...
377
378    def erfinv(self, x0: T) -> T:
379        ...
380
381    def frexp(self, x0: T):
382        ...
383
384    def hypot(self, x0: T, x1: T) -> T:
385        ...
386
387    def log10(self, x0: T) -> T:
388        ...
389
390    def log2(self, x0: T) -> T:
391        ...
392
393    def nextafter(self, x0: T, x1: T) -> T:
394        ...
395
396    def logical_and(self, x0: T, x1: T) -> T:
397        ...
398
399    def logical_not(self, x0: T) -> T:
400        ...
401
402    def logical_or(self, x0: T, x1: T) -> T:
403        ...
404
405    def logical_xor(self, x0: T, x1: T) -> T:
406        ...
407
408    def bitwise_and(self, x0: T, x1: T) -> T:
409        ...
410
411    def bitwise_not(self, x0: T) -> T:
412        ...
413
414    def bitwise_or(self, x0: T, x1: T) -> T:
415        ...
416
417    def bitwise_xor(self, x0: T, x1: T) -> T:
418        ...
419
420    def bitwise_left_shift(self, x0: T, x1: T) -> T:
421        ...
422
423    def bitwise_right_shift(self, x0: T, x1: T) -> T:
424        ...
425
426    def rsqrt(self, x0: T) -> T:
427        ...
428
429    def log1p(self, x0: T) -> T:
430        ...
431
432    def tan(self, x0: T) -> T:
433        ...
434
435    def tanh(self, x0: T) -> T:
436        ...
437
438    def sigmoid(self, x0: T) -> T:
439        ...
440
441    def signbit(self, x0: T) -> T:
442        ...
443
444    def fmod(self, x0: T, x1: T) -> T:
445        ...
446
447    def log(self, x0: T) -> T:
448        ...
449
450    def isinf(self, x0: T) -> T:
451        ...
452
453    def isnan(self, x0: T) -> T:
454        ...
455
456    # NB: this returns a float, like the torch operation
457    # This rounds half to even to break ties
458    def round(self, x0: T) -> T:
459        ...
460
461    # NB: this returns a float, like the torch operation
462    def floor(self, x0: T) -> T:
463        ...
464
465    def sign(self, x0: T) -> T:
466        ...
467
468    # NB: this returns a float, like the torch operation
469    def trunc(self, x0: T) -> T:
470        ...
471
472    # NB: this returns a float, like the torch operation
473    def ceil(self, x0: T) -> T:
474        ...
475
476    def neg(self, x0: T) -> T:
477        ...
478
479    def reciprocal(self, x0: T) -> T:
480        ...
481
482    def eq(self, x0: T, x1: T) -> T:
483        ...
484
485    def ne(self, x0: T, x1: T) -> T:
486        ...
487
488    def lt(self, x0: T, x1: T) -> T:
489        ...
490
491    def gt(self, x0: T, x1: T) -> T:
492        ...
493
494    def le(self, x0: T, x1: T) -> T:
495        ...
496
497    def ge(self, x0: T, x1: T) -> T:
498        ...
499
500    def add(self, x0: T, x1: T) -> T:
501        ...
502
503    def sub(self, x0: T, x1: T) -> T:
504        ...
505
506    def mul(self, x0: T, x1: T) -> T:
507        ...
508
509    # NB: this returns a float, like the torch operation
510    def pow(self, x0: T, x1: T) -> T:
511        ...
512
513    def and_(self, x0: T, x1: T) -> T:
514        ...
515
516    def or_(self, x0: T, x1: T) -> T:
517        ...
518
519    def xor(self, x0: T, x1: T) -> T:
520        ...
521
522    # These are metaprogrammed by MockHandler._init_cls
523    def lshift(self, x0: T, x1: T) -> T:
524        ...
525
526    def rshift(self, x0: T, x1: T) -> T:
527        ...
528
529    def getitem(self, x0: T, x1: T) -> T:
530        # TODO: this is probably just illegal lol
531        ...
532
533    def matmul(self, x0: T, x1: T) -> T:
534        # TODO: this is probably just illegal lol
535        ...
536
537    def invert(self, x0: T) -> T:
538        ...
539
540    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
541    # These are "special" operators.  These only exist if the target
542    # language actually supports the operator.  Keep this in sync with
543    # pointwise_overrides_data.
544
545    def airy_ai(self, x: T) -> T:
546        ...
547
548    def bessel_j0(self, x: T) -> T:
549        ...
550
551    def bessel_j1(self, x: T) -> T:
552        ...
553
554    def bessel_y0(self, x: T) -> T:
555        ...
556
557    def bessel_y1(self, x: T) -> T:
558        ...
559
560    def digamma(self, x: T) -> T:
561        ...
562
563    def erfcx(self, x: T) -> T:
564        ...
565
566    def fma(self, x: T, y: T, z: T) -> T:
567        ...
568
569    def igamma(self, x: T, y: T) -> T:
570        ...
571
572    def igammac(self, x: T, y: T) -> T:
573        ...
574
575    def gammainc(self, x: T, y: T) -> T:
576        ...
577
578    def gammaincc(self, x: T, y: T) -> T:
579        ...
580
581    def i0(self, x: T) -> T:
582        ...
583
584    def i0e(self, x: T) -> T:
585        ...
586
587    def i1(self, x: T) -> T:
588        ...
589
590    def i1e(self, x: T) -> T:
591        ...
592
593    def log_ndtr(self, x: T) -> T:
594        ...
595
596    def modified_bessel_i0(self, x: T) -> T:
597        ...
598
599    def modified_bessel_i1(self, x: T) -> T:
600        ...
601
602    def modified_bessel_k0(self, x: T) -> T:
603        ...
604
605    def modified_bessel_k1(self, x: T) -> T:
606        ...
607
608    def ndtr(self, x: T) -> T:
609        ...
610
611    def ndtri(self, x: T) -> T:
612        ...
613
614    def polygamma(self, x: T, y: T) -> T:
615        ...
616
617    def scaled_modified_bessel_k0(self, x: T) -> T:
618        ...
619
620    def scaled_modified_bessel_k1(self, x: T) -> T:
621        ...
622
623    def spherical_bessel_j0(self, x: T) -> T:
624        ...
625
626    def zeta(self, x: T, y: T) -> T:
627        ...
628
629    def chebyshev_polynomial_t(self, x: T, y: T) -> T:
630        ...
631
632    def chebyshev_polynomial_u(self, x: T, y: T) -> T:
633        ...
634
635    def chebyshev_polynomial_v(self, x: T, y: T) -> T:
636        ...
637
638    def chebyshev_polynomial_w(self, x: T, y: T) -> T:
639        ...
640
641    def legendre_polynomial_p(self, x: T, y: T) -> T:
642        ...
643
644    def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T:
645        ...
646
647    def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T:
648        ...
649
650    def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T:
651        ...
652
653    def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T:
654        ...
655
656    def hermite_polynomial_h(self, x: T, y: T) -> T:
657        ...
658
659    def hermite_polynomial_he(self, x: T, y: T) -> T:
660        ...
661
662    def laguerre_polynomial_l(self, x: T, y: T) -> T:
663        ...
664
665    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
666    # These operators are a bit special, because they are conventionally
667    # natively supported in both Python and C, but the semantics differ so
668    # care must be taken
669
670    def truncdiv(self, x0: T, x1: T) -> T:
671        """C-style trunc division between integers only.  Computes the true
672        division of two numbers and rounds the result to zero.
673        """
674        ...
675
676    def floordiv(self, x0: T, x1: T) -> T:
677        """Python-style floor division between integers only.  Computes the
678        true division of two numbers and floors the result.  If you want
679        floor division for floats, do regular truediv and floor the result.
680        """
681        ...
682
683    def truediv(self, x0: T, x1: T) -> T:
684        """True division between floats.  Integer inputs are NOT valid.  To
685        do Python-style (int, int) -> float division, use int_truediv"""
686        ...
687
688    def int_truediv(self, x0: T, x1: T) -> T:
689        """True division between integers.  This is NOT the same as promoting
690        to float and doing integer division, there is a bespoke algorithm for
691        doing the division in higher precision than the above.
692        """
693        ...
694
695    def div(self, x0: T, x1: T) -> T:
696        """TODO: to be removed.  This renders as / no matter what the backend is
697        which is incoherent."""
698        ...
699
700    def mod(self, x0: T, x1: T) -> T:
701        """C-style modulus, take sign from LHS (x0)."""
702        ...
703
704    def remainder(self, x0: T, x1: T) -> T:
705        """Python-style modulus, take sign from RHS (x1)."""
706        ...
707
708    def round_decimal(self, x0: T, x1: T) -> T:
709        """Python-style round with decimal argument"""
710        ...
711
712    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
713    # In CUDA, optimized implementations of other mathematical operations are
714    # offered separately via libdevice for double precision computation (in
715    # Triton, these go to tl.math rather than tl).  We lower to these
716    # operators when doing FP64 on CUDA.  Note that some operators
717    # unconditional go to tl.math.
718    #
719    # TODO(ezyang): Is this really the best way to do this?  What if we have
720    # abs internally route to tl.math automatically when given a double
721    # precision input?  One reason is that when doing codegen, we often don't
722    # know what the dtype of the inputs are!  (In principle we do know, but
723    # for many analyses it's not conveniently available.)
724
725    def libdevice_abs(self, x0: T) -> T:
726        ...
727
728    def libdevice_exp(self, x0: T) -> T:
729        ...
730
731    def libdevice_sqrt(self, x0: T) -> T:
732        ...
733
734    def libdevice_cos(self, x0: T) -> T:
735        ...
736
737    def libdevice_sin(self, x0: T) -> T:
738        ...
739
740    def libdevice_sigmoid(self, x0: T) -> T:
741        ...
742
743    def libdevice_log(self, x0: T) -> T:
744        ...
745
746
747class NoopHandler:
748    def __getattr__(self, name):
749        if name == "name":
750            return "NoopHandler"
751
752        def inner(*args, **kwargs):
753            return None
754
755        return inner
756
757    @staticmethod
758    def masked(mask, body, other) -> None:
759        return None
760
761    @staticmethod
762    def frexp(x) -> Tuple[None, None]:
763        return (None, None)
764
765    @staticmethod
766    def scan(dtypes, combine_fn, values) -> Tuple[None, ...]:
767        return (None,) * len(values)
768
769    @staticmethod
770    def sort(dtypes, values, stable, descending) -> Tuple[None, ...]:
771        return (None,) * len(values)
772
773    @staticmethod
774    def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol:
775        return sympy.Integer(0)
776
777
778# Use mypy to check protocol implemented correctly
779def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]:
780    return h
781
782
783class MockHandler:
784    def __getattr__(self, name):
785        if name == "name":
786            return "MockHandler"
787
788        def inner(*args, **kwargs):
789            fargs = [_arg_str(a) for a in args]
790            fargs.extend(f"{k}={v}" for k, v in kwargs.items())
791            return f"ops.{name}({', '.join(fargs)})"
792
793        return inner
794
795    @staticmethod
796    def masked(mask, body, other) -> str:
797        return f"ops.masked({mask}, {body()}, {other})"
798
799    @staticmethod
800    def frexp(x):
801        return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]")
802
803    @staticmethod
804    def scan(dtypes, combine_fn, values):
805        return tuple(
806            f"ops.scan({dtypes}, {combine_fn}, {values})[{i}]"
807            for i in range(len(values))
808        )
809
810    @staticmethod
811    def sort(dtypes, values, stable, descending):
812        return tuple(
813            f"ops.sort({dtypes}, {values}, stable={stable}, descending={descending})[{i}]"
814            for i in range(len(values))
815        )
816
817    @staticmethod
818    def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol:
819        return sympy_index_symbol(str(index_var))
820
821    @classmethod
822    def _init_cls(cls):
823        def make_handler(format_string):
824            @staticmethod  # type: ignore[misc]
825            def inner(*args):
826                return format_string.format(*args)
827
828            return inner
829
830        for name, format_string in {
831            "add": "{} + {}",
832            "sub": "{} - {}",
833            "mul": "{} * {}",
834            "floordiv": "{} // {}",
835            "truediv": "{} / {}",
836            "mod": "{} % {}",  # careful, depending on target semantics varies
837            "pow": "{} ** {}",
838            "lshift": "{} << {}",
839            "rshift": "{} >> {}",
840            "and_": "{} & {}",
841            "or_": "{} | {}",
842            "xor": "{} ^ {}",
843            "eq": "{} == {}",
844            "ne": "{} != {}",
845            "lt": "{} < {}",
846            "gt": "{} > {}",
847            "le": "{} <= {}",
848            "ge": "{} >= {}",
849            "neg": "-{}",
850        }.items():
851            setattr(cls, name, make_handler(format_string))
852
853
854MockHandler._init_cls()
855
856
857# Use mypy to check protocol implemented correctly
858def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]:
859    return h
860
861
862class KernelFormatterHandler:
863    def __init__(self, parent_handler):
864        self.parent_handler = parent_handler
865        self.output = IndentedBuffer(1)
866        self.var_counter = itertools.count()
867
868    @staticmethod
869    def ir_to_string(ir_fn, index, rindex=None) -> str:
870        from .ir import FlexibleLayout
871        from .virtualized import V
872
873        args = [index, rindex] if rindex is not None else [index]
874        names = ["index", "rindex"] if rindex is not None else ["index"]
875        formatter = KernelFormatterHandler(MockHandler())
876
877        with formatter.output.indent(-1):
878            formatter.output.writeline(f"def inner_fn({', '.join(names)}):")
879        for name, arg in zip(names, args):
880            if arg:
881                lhs = ", ".join(
882                    [
883                        str("_" if isinstance(v, (int, sympy.Integer)) else v)
884                        for v in arg
885                    ]
886                )
887                formatter.output.writeline(f"{lhs} = {name}")
888
889        with V.set_ops_handler(formatter), patch.object(
890            FlexibleLayout, "allow_indexing", True
891        ):
892            result = ir_fn(*args)
893            return formatter.getvalue(result)
894
895    def __getattr__(self, name) -> Callable[..., Any]:
896        def inner(*args, **kwargs):
897            line = getattr(self.parent_handler, name)(*args, **kwargs)
898            if name == "indirect_indexing":
899                return line
900
901            def write(line):
902                # replace line with a new variable name
903                varname = f"tmp{next(self.var_counter)}"
904                self.output.writeline(f"{varname} = {line}")
905                return varname
906
907            return pytree.tree_map(write, line)
908
909        return inner
910
911    def reduction(
912        self,
913        dtype: torch.dtype,
914        src_dtype: torch.dtype,
915        reduction_type: ReductionType,
916        value: Union[str, Tuple[str, ...]],
917    ) -> Union[str, Tuple[str, ...]]:
918        line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value)
919        num_values = reduction_num_outputs(reduction_type)
920        varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)]
921        self.output.writeline(f"{','.join(varnames)} = {line}")
922        return tuple(varnames) if num_values > 1 else varnames[0]
923
924    def getvalue(self, result):
925        self.output.writeline(f"return {result}")
926        return self.output.getvalue()
927
928
929# Use mypy to check protocol implemented correctly
930def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]:
931    return h
932
933
934class WrapperHandler(Generic[T]):
935    def __init__(self, inner: OpsHandler[T]):
936        self._inner = inner
937
938    def __getattr__(self, item):
939        return getattr(self._inner, item)
940
941
942# Use mypy to check protocol implemented correctly
943def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]:
944    return h
945
946
947class AddParenHandler(WrapperHandler[T]):
948    def __getattr__(self, name):
949        def inner(*args, **kwargs):
950            val = getattr(self._inner, name)(*args, **kwargs)
951            return f"({val})"
952
953        return inner
954
955
956# Use mypy to check protocol implemented correctly
957def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]:
958    return h
959
960
961class OpCountResult(NamedTuple):
962    num_ops: int
963    used_ops: OrderedSet[str]
964    read_buffers: List[str]
965    nontrivial_read_count: int
966
967
968class OpCounterCSE:
969    """Shim to count how many ops are used"""
970
971    def __init__(self, inner):
972        super().__init__()
973        self.parent_handler = inner
974        self.op_count = 0
975        self.var_names = {}
976        self._used_ops: OrderedSet[str] = OrderedSet()
977        self._read_names: List[str] = []
978        self._nontrivial_read_count = 0
979
980    def __getattr__(self, name):
981        def inner(*args, **kwargs):
982            return pytree.tree_map(
983                self._update_count, getattr(self.parent_handler, name)(*args, **kwargs)
984            )
985
986        self._used_ops.add(name)
987        return inner
988
989    def _update_count(self, val):
990        varname = self.var_names.get(val)
991        if not varname:
992            varname = f"tmp{self.op_count}"
993            self.op_count += 1
994            self.var_names[val] = varname
995        return varname
996
997    def indirect_indexing(self, *args, **kwargs):
998        self._used_ops.add("indirect_indexing")
999        return self.parent_handler.indirect_indexing(*args, **kwargs)
1000
1001    def load(self, name: str, index: sympy.Expr) -> str:
1002        val = self.parent_handler.load(name, index)
1003        if val not in self.var_names:
1004            self._used_ops.add("load")
1005            self._read_names.append(name)
1006            if not isinstance(index, (sympy.Integer, int)):
1007                self._nontrivial_read_count += 1
1008        return self._update_count(val)
1009
1010    def load_seed(self, name: str, offset: T):
1011        val = self.parent_handler.load_seed(name, offset)
1012        if val not in self.var_names:
1013            self._used_ops.add("load_seed")
1014            self._read_names.append(name)
1015        return self._update_count(val)
1016
1017    def bucketize(
1018        self,
1019        values,
1020        offsets_name: str,
1021        offsets_size: sympy.Expr,
1022        indexing_dtype: torch.dtype,
1023        right: bool,
1024    ):
1025        val = self.parent_handler.bucketize(
1026            values, offsets_name, offsets_size, indexing_dtype, right
1027        )
1028        if val not in self.var_names:
1029            self._used_ops.add("bucketize")
1030            self._read_names.append(offsets_name)
1031        return self._update_count(val)
1032
1033    def getvalue(self):
1034        return OpCountResult(
1035            self.op_count, self._used_ops, self._read_names, self._nontrivial_read_count
1036        )
1037
1038
1039def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]:
1040    return h
1041
1042
1043class ExtractConstantsHandler(NoopHandler):
1044    def __init__(self, device):
1045        self.device = device
1046
1047    def constant(self, value: Any, dtype: torch.dtype) -> "torch._inductor.ir.Constant":
1048        from torch._inductor import ir
1049
1050        return ir.Constant(value=value, dtype=dtype, device=self.device)
1051
1052
1053def _typecheck_ExtractConstantsHandler(h: ExtractConstantsHandler) -> OpsHandler[Any]:
1054    return h
1055
1056
1057class SimpleCSEHandler(WrapperHandler[T]):
1058    """Wraps the underlying handler with a CSE pass
1059
1060    NOTE: Compared to codegen level CSE this is simplified as it
1061    doesn't support stores which require load cache invalidation.
1062    """
1063
1064    def __init__(self, inner: OpsHandler[T]):
1065        super().__init__(inner)
1066        self.cse_cache: Dict[str, Union[T, Tuple[T, ...]]] = {}
1067        self.mock = MockHandler()
1068
1069    def indirect_indexing(self, *args, **kwargs) -> sympy.Expr:
1070        return super().indirect_indexing(*args, **kwargs)  # type: ignore[misc]
1071
1072    def store(self, *args, **kwargs) -> T:
1073        raise NotImplementedError("store not implemented")
1074
1075    def store_reduction(self, *args, **kwargs) -> T:
1076        raise NotImplementedError("store not implemented")
1077
1078    def __getattr__(self, name) -> Callable[..., Any]:
1079        def inner(*args, **kwargs):
1080            key = getattr(self.mock, name)(*args, **kwargs)
1081            val = self.cse_cache.get(key)
1082            if val is not None:
1083                return val
1084
1085            val = getattr(self._inner, name)(*args, **kwargs)
1086            self.cse_cache[key] = val
1087            return val
1088
1089        return inner
1090
1091
1092def _typecheck_SimpleCSEHandler(h: SimpleCSEHandler[Any]) -> OpsHandler[Any]:
1093    return h
1094