xref: /aosp_15_r20/external/pytorch/torch/_inductor/ir.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import contextlib
4import dataclasses
5import functools
6import itertools
7import logging
8import re
9import textwrap
10import traceback
11from contextlib import nullcontext
12from functools import partial
13from typing import (
14    Any,
15    Callable,
16    ClassVar,
17    Dict,
18    Iterable,
19    List,
20    Optional,
21    Sequence,
22    Set,
23    Tuple,
24    TYPE_CHECKING,
25    Union,
26)
27from unittest.mock import patch
28
29import sympy
30from sympy import Expr, Integer
31
32import torch._export.serde.schema as export_schema
33
34import torch._logging
35
36import torch.fx
37import torch.utils._pytree as pytree
38from torch._dynamo.device_interface import get_interface_for_device
39from torch._dynamo.utils import identity
40from torch._export.serde.serialize import GraphModuleSerializer
41from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
42from torch._inductor import metrics
43from torch._prims_common import (
44    compute_required_storage_length,
45    is_boolean_dtype,
46    is_float_dtype,
47    make_channels_last_strides_for,
48    StrideType,
49)
50from torch._subclasses.fake_tensor import get_schema_info
51from torch.fx.experimental.symbolic_shapes import (
52    CallMethodKey,
53    compute_unbacked_bindings,
54    DivideByKey,
55    free_unbacked_symbols,
56    rebind_unbacked,
57    resolve_unbacked_bindings,
58    SymTypes,
59)
60from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing
61from torch.utils._sympy.symbol import SymT
62
63from . import config, dependencies
64from .codegen.common import index_prevent_reordering
65from .dependencies import (
66    extract_free_unbacked_symbols,
67    extract_input_node_reduction_ranges,
68    extract_read_writes,
69    var_builder,
70)
71from .ops_handler import OpCounterCSE
72from .runtime.hints import ReductionHint
73from .runtime.runtime_utils import do_bench
74from .utils import (
75    argsort,
76    cache_on_self,
77    ceildiv,
78    convert_shape_to_inductor,
79    convert_shape_to_symint,
80    developer_warning,
81    get_kernel_metadata,
82    is_dynamic,
83    is_gpu,
84    pad_listlike,
85    sympy_dot,
86    sympy_index_symbol,
87    sympy_index_symbol_with_prefix,
88    sympy_product,
89    sympy_subs,
90)
91from .virtualized import ops, V
92
93if TYPE_CHECKING:
94    from .graph import GraphLowering
95
96log = logging.getLogger(__name__)
97indent = functools.partial(textwrap.indent, prefix="  ")
98aten = torch.ops.aten
99
100""" [Note: Inductor IR]
101
102Inductor's IR is produced by executing 'lowering' code (see lowering.py).  Each
103lowering is registered to a particular aten operator, and expects inputs that
104correspond to the aten schema.  However, in place of torch Tensor inputs, lowerings
105expect Inductor TensorBox inputs.
106
107TensorBox IR represents torch tensors.  Tensors are sometimes single objects owning
108storage, and sometimes views of another Tensor's storage.  Mutating tensor operations
109(such as add_()) affect the underlying storage and any associated views.  Other operations
110(such as .t_()) update metadata about the current view but don't modify the underlying storage.
111
112To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
113
114TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
115output from an operation.  But just as torch.Tensors take different forms, TensorBox IR can
116reference View IR or directly reference StorageBox IRs.
117
118Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
119may take an existing TensorBox and point it to a new underlying View IR.
120
121Tensors that directly own storage are represented as a chain of:
122TensorBox -> StorageBox -> Buffer
123where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
124
125If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
126(leaving the old buffer unmodified and functionalizing the operation).
127
128Tensors backed by views add one more indirection to the IR.
129TensorBox -> View -> StorageBox -> Buffer
130In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
131"""
132
133
134def validate_ir(node_or_nodes):
135    def _check_tensorbox(nodes):
136        # Could expand this to check deeper properties
137        # (e.g. TensorBox points to View or StorageBox)
138        if nodes is None:
139            pass
140        elif isinstance(nodes, (list, tuple)):
141            for node in nodes:
142                _check_tensorbox(node)
143        elif isinstance(nodes, dict):
144            for node in nodes.values():
145                _check_tensorbox(node)
146        else:
147            assert isinstance(
148                nodes,
149                (
150                    torch._inductor.ir.ExpandView,
151                    DynamicScalar,
152                    AssertScalar,
153                    TensorBox,
154                    sympy.logic.boolalg.Boolean,
155                    Expr,
156                    EffectfulKernel,
157                ),
158            ), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
159
160    # Be picky about the accepted data structure (don't use pytree here)
161    _check_tensorbox(node_or_nodes)
162
163
164def ops_wrapper(name):
165    assert isinstance(name, str)
166
167    def fn(*args, **kwargs):
168        return getattr(ops, name)(*args, **kwargs)
169
170    return fn
171
172
173def inverse_reorder(order):
174    inv_order = dict(zip(order, range(len(order))))
175
176    def reindex(index):
177        assert len(index) == len(inv_order)
178        return [index[inv_order[i]] for i in range(len(index))]
179
180    return reindex
181
182
183def same_reorder(order):
184    def reindex(index):
185        assert len(index) == len(order)
186        return [index[order[i]] for i in range(len(index))]
187
188    return reindex
189
190
191def fuse_reindexing(reindex1, reindex2):
192    def reindex(index):
193        return reindex1(reindex2(index))
194
195    return reindex
196
197
198NHWC_STRIDE_ORDER = [3, 0, 2, 1]
199NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
200
201
202def stride_order2fill_order(order):
203    """
204    Convert stride order to fill order
205    For channel last format,
206
207    stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
208    """
209    lookup = {pos: idx for idx, pos in enumerate(order)}
210    fill_order = [lookup[i] for i in range(len(order))]
211    return fill_order
212
213
214def get_stride_order(seq: Sequence[int]) -> List[int]:
215    """
216    Convert strides to stride order
217    """
218    sorted_idx: List[int] = argsort(seq)
219    out = [0 for _ in range(len(seq))]
220    for i, elem in enumerate(sorted_idx):
221        out[elem] = i
222    return out
223
224
225def ir_node_to_tensor(x, guard_shape=True):
226    if x is None:
227        return None
228
229    shape_fn: Callable[[Expr], Union[int, Expr]]
230    if not guard_shape:
231        shape_fn = V.graph.sizevars.size_hint
232    else:
233        shape_fn = identity
234    size = [shape_fn(s) for s in x.get_size()]
235    stride: StrideType
236    if is_storage_and_layout(x):
237        stride = [shape_fn(s) for s in x.get_layout().stride]  # type: ignore[misc]
238    else:
239        stride = FlexibleLayout.contiguous_strides(size)  # type: ignore[arg-type]
240    dtype = x.get_dtype()
241    device = x.get_device()
242    size = convert_shape_to_symint(size)
243    stride = convert_shape_to_symint(stride)
244    t = torch.empty_strided(
245        size=size, stride=stride, dtype=dtype, device=device
246    ).zero_()
247    return t
248
249
250def may_convert_to_optional(value):
251    if isinstance(value, list) and not value:
252        # [None] makes sure the cpp wrapper codegen will generate something like
253        # {c10::nullopt} instead of {}
254        return [None]
255    return value
256
257
258def get_device_type(x):
259    if getattr(x, "get_device", None):
260        return get_device_type(x.get_device())
261    if isinstance(x, torch.device):
262        return x.type
263    return None
264
265
266def is_triton(x):
267    return is_gpu(get_device_type(x))
268
269
270def is_cpu(x):
271    return get_device_type(x) == "cpu"
272
273
274class IRNode:
275    _current_origins: ClassVar[Set[Any]] = set()
276
277    @staticmethod
278    @contextlib.contextmanager
279    def current_origins(origins: Set[torch.fx.Node]):
280        old = IRNode._current_origins
281        IRNode._current_origins = old | origins
282        try:
283            yield
284        finally:
285            IRNode._current_origins = old
286
287    def __post_init__(self):
288        self.origins = set(self._current_origins)
289        self.traceback = traceback.format_stack() if config.debug_ir_traceback else None
290
291    def get_traceback(self):
292        return self.traceback
293
294    def common_repr(self):
295        origins = f"origins={getattr(self, 'origins', '')}"
296        if len(origins) > 64:
297            # this can get *very* long
298            origins = f"{origins[:61]}..."
299        return [origins]
300
301    def str_helper(self, lines):
302        lines = lines + self.common_repr()
303        lines = indent(",\n".join(map(str, lines)))
304        return f"{type(self).__name__}(\n{lines}\n)"
305
306    def is_user_of(self, name):
307        return name in self.get_read_names()
308
309    @cache_on_self
310    def get_read_names(self):
311        return {dep.name for dep in self.get_reads()}
312
313    def get_dtype(self):
314        return self.dtype
315
316    def get_layout(self):
317        raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!")
318
319    def get_size(self):
320        raise NotImplementedError(f"get_size() is not implemented by {type(self)}!")
321
322    def get_numel(self):
323        return sympy_product(self.get_size())
324
325    def is_zero_elements(self):
326        return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))  # type: ignore[arg-type]
327
328    def realize(self):
329        """
330        If the IRNode refers to data which has not been materialized (e.g.,
331        it is a Pointwise/Reduction that could potentially have more
332        compute fused into it), realize the IRNode into physical memory,
333        ending the possibility of fusing into it, but allowing, e.g., multiple
334        users to access the data without having to recompute.
335
336        Check StorageBox.realize for a particularly notable implementation.
337
338        TODO(ezyang): I think, in principle, every IRNode should have an
339        implementation of this, and most of the time no-op is OK, but you
340        really do have to audit each IRNode for this, so for now, raise
341        an error if it's not implemented.  Note that some code in graph.py
342        will catch this thrown error and suppress it with a warning.
343        """
344        raise NotImplementedError(f"realize NYI on {type(self)}")
345
346    def codegen_reference(self, writer=None):
347        raise NotImplementedError(f"codegen_reference NYI on {type(self)}")
348
349    # The abstract method declarations below serve to convince mypy that all IRNode instances have these functions
350    # defined, while having no effect at runtime. We cannot create stub implementations here because other parts of
351    # the code dynamically check for defined attributes.
352    get_device: Callable[[], torch.device]
353    dtype: torch.dtype
354    get_name: Callable[[], str]
355    get_reads: Callable[[], Any]
356    get_stride: Callable[[], Any]
357    get_storage_numel: Callable[[], Any]
358    has_exceeded_max_reads: Callable[[], bool]
359    make_loader: Callable[[], Callable[[Any], Any]]
360    make_indexer: Callable[[], Callable[[Any], Any]]
361    mark_reuse: Callable[[int], None]
362    realize_hint: Callable[[], None]
363    get_unbacked_symbol_uses: Callable[[], Set[sympy.Symbol]]
364
365
366@dataclasses.dataclass
367class Loops(IRNode):
368    device: torch.device
369    dtype: torch.dtype
370    inner_fn: Callable[..., Any]
371    ranges: List[Expr]
372
373    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
374        return set().union(
375            *(free_unbacked_symbols(e) for e in self.ranges),
376            self.inner_fn_free_unbacked_symbols(),
377        )
378
379    def __str__(self, names=("ranges",)):
380        return self.str_helper(
381            [
382                f"'{self.device.type}'",
383                str(self.dtype),
384                self.inner_fn_str(),
385            ]
386            + [f"{name}={getattr(self, name)}" for name in names]
387            + [f"origin_node={self.origin_node!r}"]
388        )
389
390    def __post_init__(self):
391        super().__post_init__()
392        self.origin_node = None
393
394    __repr__ = __str__
395
396    def get_device(self):
397        return self.device
398
399    def get_origin_node(self):
400        return self.origin_node
401
402    def get_size(self):
403        return self.ranges
404
405    def get_pointwise_size(self):
406        return self.ranges
407
408    def is_extern(self):
409        return False
410
411    @classmethod
412    def create(cls, *args, **kwargs):
413        origin_node = kwargs.pop("origin_node", None)
414        tb = kwargs.pop("traceback", None)
415        r = cls(*args, **kwargs)
416        r.origin_node = origin_node
417        r.traceback = (
418            tb or traceback.format_stack() if config.debug_ir_traceback else None
419        )
420        return TensorBox.create(r)
421
422    @staticmethod
423    def _index(ranges, prefix=SymT.INDEX):
424        return [
425            sympy.Integer(0) if s == 1 else sympy_index_symbol_with_prefix(prefix, n)
426            for n, s in enumerate(ranges)
427        ]
428
429    @cache_on_self
430    def inner_fn_opcount(self):
431        opcounter = OpCounterCSE(V.MockHandler())
432
433        with V.set_ops_handler(opcounter), patch.object(
434            FlexibleLayout, "allow_indexing", True
435        ):
436            self.inner_fn(*self.inner_fn_args())
437            return opcounter.op_count
438
439    def inner_fn_args(self):
440        return (self._index(self.ranges),)
441
442    def inner_fn_str(self):
443        return V.KernelFormatterHandler.ir_to_string(
444            self.inner_fn, *self.inner_fn_args()
445        )
446
447    def has_large_inner_fn(self):
448        return self.inner_fn_opcount() > config.realize_opcount_threshold
449
450    def inner_fn_free_unbacked_symbols(self):
451        index = self._index(self.ranges)
452        return extract_free_unbacked_symbols(self.inner_fn, index)
453
454    def get_reads(self):
455        with patch.object(FlexibleLayout, "allow_indexing", True):
456            if self.get_reduction_type():
457                return extract_read_writes(
458                    self.make_loader(),
459                    self.get_size(),
460                    self.get_reduction_size(),
461                ).reads
462            else:
463                return extract_read_writes(
464                    self.make_loader(),
465                    self.get_size(),
466                ).reads
467
468    def get_reduction_size(self):
469        raise NotImplementedError(
470            f"get_reduction_size() is not implemented by {type(self)}!"
471        )
472
473    def get_reduction_type(self):
474        raise NotImplementedError(
475            f"get_reduction_type() is not implemented by {type(self)}!"
476        )
477
478    def constant_to_device(self, device):
479        raise NotImplementedError(
480            f"constant_to_device() is not implemented by {type(self)}!"
481        )
482
483
484def nop_loader_fn(idx, *, dtype):
485    if dtype.is_floating_point:
486        return ops.constant(float("nan"), dtype)
487    else:
488        return ops.constant(0, dtype)
489
490
491class Pointwise(Loops):
492    def make_loader(self):
493        # Make zero-element loops into a no-op
494        if self.is_zero_elements():
495            return partial(nop_loader_fn, dtype=self.dtype)
496
497        return self.inner_fn
498
499    def get_reduction_size(self):
500        return []
501
502    def get_reduction_type(self):
503        return None
504
505    def store_output(self, output_name, indexer, vars):
506        loader = self.make_loader()
507        return ops.store(output_name, indexer(vars), loader(vars))
508
509    def constant_to_device(self, device):
510        """Move this to a given device. Requires that all reads are to constants."""
511        loader = self.make_loader()
512        loader = patch.object(ConstantBuffer, "override_device", device)(loader)
513        return Pointwise(device, self.dtype, loader, self.ranges)
514
515
516@dataclasses.dataclass
517class Scatter(Pointwise):
518    output_indexer: Callable[[List[Expr]], Expr]
519    scatter_mode: Optional[str] = None
520
521    def constant_to_device(self, device):
522        """Move this to a given device. Requires that all reads are to constants."""
523        loader = self.make_loader()
524        loader = patch.object(ConstantBuffer, "override_device", device)(loader)
525        return Scatter(
526            device,
527            self.dtype,
528            loader,
529            self.ranges,
530            self.output_indexer,
531            self.scatter_mode,
532        )
533
534    def store_output(self, output_name, indexer, vars):
535        loader = self.make_loader()
536        return ops.store(
537            output_name,
538            indexer(self.output_indexer(vars)),
539            loader(vars),
540            mode=self.scatter_mode,
541        )
542
543
544REDUCTION_COMBINE_FN = {
545    "any": ops_wrapper("logical_or"),
546    "max": ops_wrapper("maximum"),
547    "min": ops_wrapper("minimum"),
548    "prod": ops_wrapper("mul"),
549    "sum": ops_wrapper("add"),
550    "xor_sum": ops_wrapper("bitwise_xor"),
551}
552
553
554def get_reduction_combine_fn(reduction_type, dtype, arg_break_ties_left=True):
555    if reduction_type in REDUCTION_COMBINE_FN:
556        combine_fn = REDUCTION_COMBINE_FN[reduction_type]
557    elif reduction_type in {"argmax", "argmin"}:
558
559        def combine_fn(a, b):
560            a_value, a_index = a
561            b_value, b_index = b
562
563            if reduction_type == "argmin":
564                mask = ops.lt(a_value, b_value)
565            else:
566                mask = ops.gt(a_value, b_value)
567
568            equal = ops.eq(a_value, b_value)
569            if is_float_dtype(dtype):
570                a_isnan = ops.ne(a_value, a_value)
571                b_isnan = ops.ne(b_value, b_value)
572                mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan))
573                equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan))
574
575            tie = (
576                ops.lt(a_index, b_index)
577                if arg_break_ties_left
578                else ops.gt(a_index, b_index)
579            )
580            mask = ops.logical_or(mask, ops.logical_and(equal, tie))
581            return (
582                ops.where(mask, a_value, b_value),
583                ops.where(mask, a_index, b_index),
584            )
585
586    elif reduction_type == "welford_combine":
587
588        def combine_fn(a, b):
589            a_mean, a_m2, a_weight = a
590            b_mean, b_m2, b_weight = b
591
592            delta = b_mean - a_mean
593            new_weight = a_weight + b_weight
594            w2_over_w = b_weight / new_weight
595            return (
596                a_mean + delta * w2_over_w,
597                a_m2 + b_m2 + delta * delta * a_weight * w2_over_w,
598                new_weight,
599            )
600
601    else:
602        raise NotImplementedError(f"unknown reduction_type={reduction_type}")
603
604    return combine_fn
605
606
607@dataclasses.dataclass
608class Reduction(Loops):
609    reduction_ranges: List[Expr]
610    reduction_type: str
611    # self.dtype represents the dst dtype
612    src_dtype: torch.dtype
613    reduction_hint: ReductionHint
614
615    def __str__(self):
616        return Loops.__str__(  # type: ignore[call-arg]
617            self, names=("ranges", "reduction_ranges", "reduction_type")
618        )
619
620    def __repr__(self):
621        return self.__str__()
622
623    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
624        return super().get_unbacked_symbol_uses() | set().union(
625            *(free_unbacked_symbols(e) for e in self.reduction_ranges)
626        )
627
628    def get_reduction_size(self):
629        return self.reduction_ranges
630
631    def get_reduction_type(self):
632        return self.reduction_type
633
634    def store_reduction(self, output_name, indexer, vars, reduction_vars):
635        value = ops.reduction(
636            self.dtype,
637            self.src_dtype,
638            self.reduction_type,
639            self.inner_fn(vars, reduction_vars),
640        )
641        return ops.store_reduction(output_name, indexer(vars), value)
642
643    def index_length(self):
644        return len(self.ranges) + len(self.reduction_ranges)
645
646    def inner_fn_args(self):
647        index = self._index(self.ranges)
648        rindex = self._index(self.reduction_ranges, SymT.RINDEX)
649        return (index, rindex)
650
651    def inner_fn_free_unbacked_symbols(self):
652        index = self._index(self.ranges)
653        rindex = self._index(self.reduction_ranges, SymT.RINDEX)
654        return extract_free_unbacked_symbols(self.inner_fn, index, rindex)
655
656    def constant_to_device(self, device):
657        """Move this to a given device. Requires that all reads are to constants."""
658        loader = self.make_loader()
659        loader = patch.object(ConstantBuffer, "override_device", device)(loader)
660        return Reduction(
661            device,
662            self.dtype,
663            loader,
664            self.ranges,
665            self.reduction_ranges,
666            self.reduction_type,
667            self.src_dtype,
668            ReductionHint.DEFAULT,
669        )
670
671    @staticmethod
672    def num_splits(
673        device,
674        dst_dtype,
675        src_dtype,
676        inner_fn,
677        ranges,
678        reduction_ranges,
679        reduction_type,
680        reduction_numel,
681        input_node: Optional[IRNode] = None,
682    ):
683        def _is_static(x):
684            return isinstance(x, (int, sympy.Integer))
685
686        reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel)
687        numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges))
688
689        should_split = (
690            is_gpu(get_device_type(device))
691            and reduction_type
692            not in {
693                "argmax",
694                "argmin",
695            }
696            and config.split_reductions
697            # We don't support unbacked symints
698            and _is_static(reduction_numel_hint)
699            and _is_static(numel_hint)
700        )
701        if not should_split:
702            return ReductionHint.DEFAULT, 1
703
704        device_interface = get_interface_for_device(get_device_type(device))
705        device_properties = device_interface.Worker.get_device_properties(device)
706        if get_device_type(device) == "xpu":
707            num_sm = device_properties.gpu_subslice_count
708        else:
709            # default is cuda behavior
710            num_sm = device_properties.multi_processor_count
711
712        min_elements_per_thread = 32
713        max_elements_per_thread = 512
714        threads_per_sm = 2048
715        min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm
716        max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm
717
718        def inner_reduction_splits(reduction_numel_hint, numel_hint):
719            # do heuristics that's close to eager mode for split inner reduction
720            # we leak reduction autotune configs here, and will need to refactor to avoid this later
721            num_warps = 8
722            num_threads = 32 * num_warps
723            if numel_hint >= 2 * num_sm:  # don't split if there are enough outputs
724                return 1
725            if reduction_numel_hint <= 8192:
726                return 1
727            if reduction_numel_hint * numel_hint <= min_elements_per_device:
728                split_size = min_elements_per_thread
729            elif reduction_numel_hint * numel_hint < max_elements_per_device:
730                target_blocks = num_sm * threads_per_sm // (2 * num_threads)
731                blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint
732                tmp_split_size = (
733                    reduction_numel_hint + num_threads * blocks_per_output - 1
734                ) // (num_threads * blocks_per_output)
735                divisors = sympy.divisors(reduction_numel_hint)
736                closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
737                if abs(closest - tmp_split_size) < 30:
738                    # prefer even splits, but never smalle than min_elements_per_thread
739                    split_size = max(closest, min_elements_per_thread)
740                else:
741                    split_size = tmp_split_size
742            else:
743                divisors = sympy.divisors(reduction_numel_hint)
744                closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
745                if abs(closest - max_elements_per_thread) < 50:
746                    # prefer even splits
747                    split_size = closest
748                else:
749                    split_size = max_elements_per_thread
750            return (reduction_numel_hint + split_size * num_threads - 1) // (
751                split_size * num_threads
752            )
753
754        def outer_reduction_splits(reduction_numel_hint, numel_hint):
755            # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128
756            # extend to even smaller number of outputs
757            num_warps = 8
758            num_threads = num_warps * 32
759            rvals_per_thread = 4  # comes from heuristics, refactor to not leak here
760            xvals_per_block = 128
761            xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block
762            if reduction_numel_hint * numel_hint < min_elements_per_device:
763                split_size = min_elements_per_thread
764            elif reduction_numel_hint * numel_hint < max_elements_per_device:
765                target_blocks = num_sm * threads_per_sm // (num_threads)
766                target_blocks = (target_blocks + xblocks - 1) // xblocks
767                tmp_split_size = (
768                    reduction_numel_hint + rvals_per_thread * target_blocks - 1
769                ) // (rvals_per_thread * target_blocks)
770                divisors = sympy.divisors(reduction_numel_hint)
771                closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
772                if abs(tmp_split_size - closest) < 20:
773                    split_size = max(closest, min_elements_per_thread)
774                else:
775                    split_size = tmp_split_size
776            else:
777                divisors = sympy.divisors(reduction_numel_hint)
778                closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
779                if abs(closest - max_elements_per_thread) < 50:
780                    # prefer even splits
781                    split_size = closest
782                else:
783                    split_size = max_elements_per_thread
784
785            return (reduction_numel_hint + rvals_per_thread * split_size - 1) // (
786                rvals_per_thread * split_size
787            )
788
789        # easy cases
790        if numel_hint == 1:
791            split = inner_reduction_splits(reduction_numel_hint, numel_hint)
792            if split == 1:
793                # No need to split.
794                return ReductionHint.INNER, split
795            if input_node is not None and isinstance(input_node, TensorBox):
796                new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
797                    input_node
798                )
799                if new_ranges is not None and new_reduction_ranges is not None:
800                    extracted_numel_hint = V.graph.sizevars.symbolic_hint(
801                        sympy_product(new_ranges + new_reduction_ranges)
802                    )
803                    if reduction_numel_hint == extracted_numel_hint:
804                        log.debug(
805                            "Use previous IRNode's range and reduction_ranges instead of split. "
806                            "current ranges: %s, current reduction ranges: %s, current split: %d, "
807                            "new ranges: %s, new reduction ranges: %s",
808                            ranges,
809                            reduction_ranges,
810                            split,
811                            new_ranges,
812                            new_reduction_ranges,
813                        )
814                        # If the input_node or its dependent nodes are also Reduction nodes,
815                        # use reduction_sizes of this node or its dependent nodes directly.
816                        return ReductionHint.INNER, -1
817            return ReductionHint.INNER, split
818        if (
819            reduction_numel_hint <= min_elements_per_thread
820            or numel_hint >= num_sm * 2 * 32
821        ):
822            return ReductionHint.DEFAULT, 1
823
824        r = Reduction(
825            device,
826            dst_dtype,
827            inner_fn,
828            ranges,
829            reduction_ranges,
830            reduction_type,
831            src_dtype,
832            ReductionHint.DEFAULT,
833        )
834
835        def get_read_indices(r):
836            cb = ComputedBuffer(
837                name=None,
838                layout=FlexibleLayout(
839                    device=r.get_device(),
840                    dtype=r.get_dtype(),
841                    size=r.get_size(),
842                ),
843                data=r,
844            )
845            read_writes = cb.get_read_writes()
846            # try finding the full size producer
847            # TODO this will fail for something like ((1, N) * (N, 1)).sum()
848            # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
849            range_vars = [
850                r
851                for r in read_writes.range_vars
852                if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number)
853            ]
854            indices = []
855            changed = False
856            for md in sorted(read_writes.reads, key=lambda x: x.name):
857                if all(r in md.index.free_symbols for r in range_vars):
858                    indices.append(md.index)
859                    if md.name in V.graph.name_to_buffer:
860                        buf = V.graph.name_to_buffer[md.name]
861                        original_stride = buf.layout.stride
862                        buf.decide_layout()
863                        if buf.layout.stride != original_stride:
864                            changed = True
865            return indices, changed
866
867        indices, changed = get_read_indices(r)
868        if changed:
869            indices, _ = get_read_indices(r)
870
871        if len(indices) == 0:
872            # TODO determine splits when all inputs are broadcast
873            return ReductionHint.DEFAULT, 1
874
875        (_, reduction_vars), ranges = dependencies.index_vars_squeeze(
876            r.get_size(), r.get_reduction_size()
877        )
878        num_outer = 0
879        num_inner = 0
880        for i in indices:
881            i = V.graph.sizevars.simplify_with_ranges(i, ranges)
882            strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys())
883            outer = all(s > 1 for s in strides)
884            if outer:
885                num_outer += 1
886            else:
887                num_inner += 1
888        if num_inner > num_outer:
889            return ReductionHint.INNER, inner_reduction_splits(
890                reduction_numel_hint, numel_hint
891            )
892        else:
893            return ReductionHint.OUTER, outer_reduction_splits(
894                reduction_numel_hint, numel_hint
895            )
896
897    @staticmethod
898    def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type, src_dtype):
899        """Convert inner_fn from a reduction to an pointwise"""
900        reduction_ranges = [
901            V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges
902        ]
903
904        combine_fn = get_reduction_combine_fn(reduction_type, src_dtype)
905
906        def fn(index):
907            return functools.reduce(
908                combine_fn,
909                (
910                    value_fn(index, rindex)
911                    for rindex in itertools.product(
912                        *[range(x) for x in reduction_ranges]
913                    )
914                ),
915            )
916
917        if reduction_type in ("argmin", "argmax"):
918            flatten_index = FixedLayout(
919                None,  # type: ignore[arg-type]
920                None,  # type: ignore[arg-type]
921                reduction_ranges,
922                FlexibleLayout.contiguous_strides(reduction_ranges),
923            ).make_indexer()
924
925            def value_fn(index, rindex):
926                rindex = [sympy.expand(i) for i in rindex]
927                return (
928                    inner_fn(index, rindex),
929                    ops.index_expr(flatten_index(rindex), torch.int64),
930                )
931
932            return lambda index: fn(index)[1]
933        else:
934            value_fn = inner_fn
935            return fn
936
937    @classmethod
938    def create(  # type: ignore[override]
939        cls,
940        device: torch.device,
941        dst_dtype: torch.dtype,
942        src_dtype: torch.dtype,
943        inner_fn: Callable[..., Any],
944        ranges: List[Expr],
945        reduction_ranges: List[Expr],
946        reduction_type: str,
947        reduction_hint: ReductionHint = ReductionHint.DEFAULT,
948        input_node: Optional[IRNode] = None,
949    ):
950        reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
951
952        if reduction_numel == 0:
953            # N.B. This is a hack to generate the literal of the given type
954            # Ideally, we should be fixing `def constant` in triton.py
955            # but it breaks due to hardcoded dtypes in other places
956            def py_cnst(val):
957                return (
958                    bool(val)
959                    if dst_dtype == torch.bool
960                    else float(val)
961                    if dst_dtype.is_floating_point
962                    else int(val)
963                )
964
965            rtypes_to_inits = {
966                "sum": py_cnst(0),
967                "xor_sum": py_cnst(0),
968                "prod": py_cnst(1),
969                "any": py_cnst(0),
970                # "all" is desugared to `!any(!val)`
971            }
972
973            assert (
974                reduction_type in rtypes_to_inits.keys()
975            ), f"{reduction_type} not supported for zero-dimension tensors!"
976
977            def const_fn(index):
978                return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
979
980            return Pointwise.create(
981                device=device,
982                dtype=src_dtype,
983                inner_fn=const_fn,
984                ranges=list(ranges),
985            )
986
987        if reduction_numel == 1:
988            # this reduction is actually a pointwise op
989            if reduction_type in ("argmin", "argmax"):
990
991                def fn(index):
992                    return ops.constant(0, dst_dtype)
993
994            else:
995
996                def fn(index):
997                    reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
998                    return inner_fn(index, reduction_index)
999
1000            return Pointwise.create(device, dst_dtype, fn, ranges)
1001
1002        if (
1003            isinstance(reduction_numel, sympy.Integer)
1004            and V.graph.sizevars.size_hint(reduction_numel)
1005            < config.unroll_reductions_threshold
1006            and sympy_product(ranges) != 1
1007        ):
1008            return Pointwise.create(
1009                device,
1010                dst_dtype,
1011                cls._unroll_reduction_fn(
1012                    inner_fn, reduction_ranges, reduction_type, src_dtype
1013                ),
1014                ranges,
1015            )
1016
1017        # triton doesn't support reduce to single element well, so break it up
1018        hint, split = cls.num_splits(
1019            device,
1020            dst_dtype,
1021            src_dtype,
1022            inner_fn,
1023            ranges,
1024            reduction_ranges,
1025            reduction_type,
1026            reduction_numel,
1027            input_node,
1028        )
1029        # intermediate reduction in split can contain complex indexing,
1030        # and num_splits will fail to correctly set the hint
1031        # reuse the passed hint if available
1032        if reduction_hint == ReductionHint.DEFAULT:
1033            reduction_hint = hint
1034        if split == -1:
1035            assert input_node is not None
1036            new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
1037                input_node  # type: ignore[arg-type]
1038            )
1039            assert new_ranges is not None
1040            assert new_reduction_ranges is not None
1041            return cls.create_multilayer_existing_ranges(
1042                device,
1043                dst_dtype,
1044                src_dtype,
1045                inner_fn,
1046                ranges,
1047                reduction_ranges,
1048                new_ranges,
1049                new_reduction_ranges,
1050                reduction_type,
1051                reduction_hint,
1052            )
1053        elif split > 1:
1054            # triton doesn't support reduce to single element well, so break it up
1055            return cls.create_multilayer(
1056                device,
1057                dst_dtype,
1058                src_dtype,
1059                inner_fn,
1060                ranges,
1061                reduction_ranges,
1062                reduction_type,
1063                split,
1064                reduction_hint,
1065            )
1066
1067        return TensorBox.create(
1068            Reduction(
1069                device,
1070                dst_dtype,
1071                inner_fn,
1072                ranges,
1073                reduction_ranges,
1074                reduction_type,
1075                src_dtype,
1076                reduction_hint,
1077            )
1078        )
1079
1080    @staticmethod
1081    def default_accumulator(reduction_type, dtype):
1082        if reduction_type in {"max", "argmax"}:
1083            if is_float_dtype(dtype):
1084                return float("-inf")
1085            elif is_boolean_dtype(dtype):
1086                return 0
1087            else:
1088                return torch.iinfo(dtype).min
1089        if reduction_type in {"min", "argmin"}:
1090            if is_float_dtype(dtype):
1091                return float("inf")
1092            elif is_boolean_dtype(dtype):
1093                return 1
1094            else:
1095                return torch.iinfo(dtype).max
1096
1097        return {
1098            "sum": 0,
1099            "prod": 1,
1100            "xor_sum": 0,
1101            "any": 0,
1102            "welford_reduce": (0, 0, 0),
1103            "welford_combine": (0, 0, 0),
1104        }[reduction_type]
1105
1106    @staticmethod
1107    def default_value(reduction_type, dtype):
1108        if reduction_type == "welford_reduce":
1109            return 0
1110        return Reduction.default_accumulator(reduction_type, dtype)
1111
1112    @staticmethod
1113    def _multilayer_second_step_hint(
1114        split: int, numel_hint: int, reduction_hint: ReductionHint
1115    ) -> ReductionHint:
1116        if split == -1:
1117            return reduction_hint
1118        if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
1119            return ReductionHint.OUTER_TINY
1120        if (
1121            split <= 1024
1122            and numel_hint <= 256
1123            and reduction_hint == ReductionHint.OUTER
1124        ):
1125            return ReductionHint.OUTER_TINY
1126
1127        return reduction_hint
1128
1129    @classmethod
1130    def _multilayer_wrap_loader(
1131        cls,
1132        loader,
1133        reduction_ranges,
1134        reduction_numel,
1135        split,
1136        block_size,
1137        default,
1138    ):
1139        reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
1140        need_mask = not V.graph.sizevars.is_expr_static_and_true(
1141            sympy.Eq(reduction_numel % split, 0)  # type: ignore[arg-type]
1142        )
1143
1144        def wrapper_fn(index, reduction_index):
1145            (reduction_index,) = reduction_index
1146            *new_index, reduction_block = index
1147            indices = block_size * reduction_block + reduction_index
1148
1149            def body():
1150                return loader(new_index, reindex([indices]))
1151
1152            if need_mask:
1153                mask = ops.lt(
1154                    ops.index_expr(indices, torch.int32),
1155                    ops.index_expr(reduction_numel, torch.int32),
1156                )
1157                return ops.masked(mask, body, default)
1158            else:
1159                return body()
1160
1161        return wrapper_fn
1162
1163    @classmethod
1164    def _multilayer_wrap_loader_existing_ranges(
1165        cls,
1166        loader,
1167        original_ranges,
1168        original_reduction_ranges,
1169        new_ranges,
1170        new_reduction_ranges,
1171        default,
1172    ):
1173        assert all(
1174            r == 1 for r in original_ranges
1175        ), f"Only enabled for numel_hint == 1, found {original_ranges=}"
1176        reindex = View.dynamic_reshape_indexer(
1177            original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
1178        )
1179
1180        def wrapper_fn(merged_index, new_reduction_index):
1181            original_idx = merged_index[: len(original_ranges)]
1182            new_index = merged_index[len(original_ranges) :]
1183            return loader(
1184                original_idx,
1185                reindex(tuple(new_index) + tuple(new_reduction_index)),
1186            )
1187
1188        return wrapper_fn
1189
1190    @classmethod
1191    def create_multilayer_helper(
1192        cls,
1193        device: torch.device,
1194        dst_dtype: torch.dtype,
1195        src_dtype: torch.dtype,
1196        wrapper_fn: Callable[..., Any],
1197        original_ranges: List[Expr],
1198        original_reduction_ranges: List[Expr],
1199        new_ranges: List[Expr],
1200        new_reduction_ranges: List[Expr],
1201        reduction_type: str,
1202        split: int,
1203        reduction_hint: ReductionHint,
1204    ):
1205        """
1206        Break a large reduction up into multiple smaller reductions
1207        recursively
1208        """
1209        # triton will automatically compute reductions in fp32 if reducing over fp16/bf16
1210        # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
1211        # in fp32 and not reduce precision by breaking up the kernel into multiple layers
1212        intermediate_dtype = (
1213            dst_dtype
1214            if dst_dtype not in (torch.float16, torch.bfloat16)
1215            else torch.float
1216        )
1217        intermediate = Reduction.create(
1218            device,
1219            intermediate_dtype,
1220            src_dtype,
1221            wrapper_fn,
1222            new_ranges,
1223            new_reduction_ranges,
1224            reduction_type,
1225            reduction_hint,
1226        )
1227        intermediate.realize()
1228        intermediate_loader = intermediate.make_loader()
1229
1230        def intermediate_fn(index, reduction_index):
1231            return intermediate_loader([*index, *reduction_index])
1232
1233        numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges))
1234        reduction_hint = cls._multilayer_second_step_hint(
1235            split, numel_hint, reduction_hint
1236        )
1237
1238        assert original_ranges == new_ranges[: len(original_ranges)]
1239        return TensorBox.create(
1240            Reduction(
1241                device,
1242                dst_dtype,
1243                intermediate_fn,
1244                original_ranges,
1245                new_ranges[len(original_ranges) :],
1246                reduction_type,
1247                src_dtype,
1248                reduction_hint,
1249            )
1250        )
1251
1252    @classmethod
1253    def create_multilayer(
1254        cls,
1255        device: torch.device,
1256        dst_dtype: torch.dtype,
1257        src_dtype: torch.dtype,
1258        inner_fn: Callable[..., Any],
1259        ranges: List[Expr],
1260        reduction_ranges: List[Expr],
1261        reduction_type: str,
1262        split: int,
1263        reduction_hint: ReductionHint,
1264    ):
1265        """
1266        Break a large reduction up into multiple smaller reductions
1267        recursively
1268        """
1269        # TODO(jansel): realize the reduction so we can do dynamic indexing
1270        reduction_numel = sympy_product(reduction_ranges)
1271        block_size = FloorDiv(reduction_numel + (split - 1), split)
1272        default = cls.default_value(reduction_type, dst_dtype)
1273        wrapper_fn = cls._multilayer_wrap_loader(
1274            inner_fn, reduction_ranges, reduction_numel, split, block_size, default
1275        )
1276
1277        return cls.create_multilayer_helper(
1278            device,
1279            dst_dtype,
1280            src_dtype,
1281            wrapper_fn,
1282            ranges,
1283            reduction_ranges,
1284            [*ranges, split],  # type: ignore[list-item]
1285            [block_size],
1286            reduction_type,
1287            split,
1288            reduction_hint,
1289        )
1290
1291    @classmethod
1292    def create_multilayer_existing_ranges(
1293        cls,
1294        device: torch.device,
1295        dst_dtype: torch.dtype,
1296        src_dtype: torch.dtype,
1297        inner_fn: Callable[..., Any],
1298        original_ranges: List[Expr],
1299        original_reduction_ranges: List[Expr],
1300        new_ranges: List[Expr],
1301        new_reduction_ranges: List[Expr],
1302        reduction_type: str,
1303        reduction_hint: ReductionHint,
1304    ):
1305        """
1306        Break a large reduction up into multiple smaller reductions
1307        recursively
1308        """
1309        default = cls.default_value(reduction_type, dst_dtype)
1310        wrapper_fn = cls._multilayer_wrap_loader_existing_ranges(
1311            inner_fn,
1312            original_ranges,
1313            original_reduction_ranges,
1314            new_ranges,
1315            new_reduction_ranges,
1316            default,
1317        )
1318        return cls.create_multilayer_helper(
1319            device,
1320            dst_dtype,
1321            src_dtype,
1322            wrapper_fn,
1323            original_ranges,
1324            original_reduction_ranges,
1325            [*original_ranges, *new_ranges],
1326            new_reduction_ranges,
1327            reduction_type,
1328            -1,
1329            reduction_hint,
1330        )
1331
1332
1333def num_reduction_outputs(reduction_type):
1334    return 3 if "welford" in reduction_type else 1
1335
1336
1337class WelfordReduction(Reduction):
1338    output_index: int
1339
1340    def __init__(
1341        self,
1342        device,
1343        dtype,
1344        inner_fns,
1345        ranges,
1346        reduction_ranges,
1347        reduction_type,
1348        reduction_hint,
1349        output_index,
1350    ):
1351        if len(inner_fns) == 1:
1352            loader = inner_fns[0]
1353        else:
1354
1355            def loader(idx, reduction_idx):
1356                return tuple(fn(idx, reduction_idx) for fn in inner_fns)
1357
1358        super().__init__(
1359            device,
1360            dtype,
1361            loader,
1362            ranges,
1363            reduction_ranges,
1364            reduction_type,
1365            dtype,
1366            reduction_hint,
1367        )
1368        self.output_index = output_index
1369
1370    def store_reduction(self, output_name, indexer, vars, reduction_vars):
1371        values = ops.reduction(
1372            self.dtype,
1373            self.src_dtype,
1374            self.reduction_type,
1375            self.inner_fn(vars, reduction_vars),
1376        )
1377        value = values[self.output_index]
1378        return ops.store_reduction(output_name, indexer(vars), value)
1379
1380    @classmethod
1381    def create(  # type: ignore[override]
1382        cls,
1383        device: torch.device,
1384        dtype: torch.dtype,
1385        inner_fns: Sequence[Callable[..., Any]],
1386        ranges: List[Expr],
1387        reduction_ranges: List[Expr],
1388        reduction_type: str,
1389        reduction_hint: ReductionHint = ReductionHint.DEFAULT,
1390    ):
1391        assert reduction_type in {"welford_reduce", "welford_combine"}
1392
1393        reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
1394
1395        def const(val):
1396            def inner_fn(idx):
1397                return ops.constant(
1398                    val,
1399                    dtype,
1400                )
1401
1402            return Pointwise.create(
1403                device=device,
1404                dtype=dtype,
1405                inner_fn=inner_fn,
1406                ranges=list(ranges),
1407            )
1408
1409        if reduction_numel == 0:
1410            mean = const(0)
1411            m2 = const(0)
1412            weight = const(0)
1413            return mean, m2, weight
1414
1415        if reduction_numel == 1:
1416
1417            def copy(loader):
1418                def inner_fn(idx):
1419                    reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
1420                    return loader(idx, reduction_index)
1421
1422                return Pointwise.create(
1423                    device=device,
1424                    dtype=dtype,
1425                    inner_fn=inner_fn,
1426                    ranges=list(ranges),
1427                )
1428
1429            if reduction_type == "welford_reduce":
1430                return copy(inner_fns[0]), const(0), const(1)
1431            else:
1432                return tuple(copy(fn) for fn in inner_fns)
1433
1434        # TODO: Unrolled reduction
1435        # if (
1436        #     isinstance(reduction_numel, sympy.Integer)
1437        #     and V.graph.sizevars.size_hint(reduction_numel)
1438        #     < config.unroll_reductions_threshold
1439        #     and sympy_product(ranges) != 1
1440        # ):
1441        #     return Pointwise.create(
1442        #         device,
1443        #         dst_dtype,
1444        #         cls._unroll_reduction_fn(
1445        #             inner_fn, reduction_ranges, reduction_type, src_dtype
1446        #         ),
1447        #         ranges,
1448        #     )
1449
1450        # triton doesn't support reduce to single element well, so break it up
1451        hint, split = Reduction.num_splits(
1452            device,
1453            dtype,
1454            dtype,
1455            inner_fns[0],
1456            ranges,
1457            reduction_ranges,
1458            reduction_type=reduction_type,
1459            reduction_numel=reduction_numel,
1460        )
1461        # intermediate reduction in split can contain complex indexing,
1462        # and num_splits will fail to correctly set the hint
1463        # reuse the passed hint if available
1464        if reduction_hint == ReductionHint.DEFAULT:
1465            reduction_hint = hint
1466        if split > 1:
1467            # triton doesn't support reduce to single element well, so break it up
1468            return cls.create_multilayer(
1469                device,
1470                dtype,
1471                inner_fns,
1472                ranges,
1473                reduction_ranges,
1474                reduction_type,
1475                split,
1476                reduction_hint,
1477            )
1478
1479        results = [
1480            TensorBox.create(
1481                WelfordReduction(
1482                    device,
1483                    dtype,
1484                    inner_fns,
1485                    ranges,
1486                    reduction_ranges,
1487                    reduction_type,
1488                    reduction_hint,
1489                    output_idx,
1490                )
1491            )
1492            for output_idx in range(3)
1493        ]
1494        for t in results:
1495            t.realize()
1496        return results
1497
1498    @staticmethod
1499    def default_value(reduction_type, dtype):
1500        return (0, 0, 0)
1501
1502    @classmethod
1503    def create_multilayer(  # type: ignore[override]
1504        cls,
1505        device: torch.device,
1506        dtype: torch.dtype,
1507        inner_fns: Sequence[Callable[..., Any]],
1508        ranges: List[Expr],
1509        reduction_ranges: List[Expr],
1510        reduction_type: str,
1511        split: int,
1512        reduction_hint: ReductionHint,
1513    ):
1514        """
1515        Break a large reduction up into multiple smaller reductions
1516        recursively
1517        """
1518        reduction_numel = sympy_product(reduction_ranges)
1519        need_mask = not V.graph.sizevars.is_expr_static_and_true(
1520            sympy.Eq(reduction_numel % split, 0)  # type: ignore[arg-type]
1521        )
1522
1523        if need_mask and reduction_type != "welford_combine":
1524            # If we need mask, then "welford_reduce" doesn't work because
1525            # masked inputs shouldn't count towards the welford weight
1526
1527            def constant(idx, reduction_idx, value):
1528                return ops.constant(value, dtype)
1529
1530            return cls.create_multilayer(
1531                device=device,
1532                dtype=dtype,
1533                inner_fns=(
1534                    inner_fns[0],
1535                    partial(constant, value=0),
1536                    partial(constant, value=1),
1537                ),
1538                ranges=ranges,
1539                reduction_ranges=reduction_ranges,
1540                reduction_type="welford_combine",
1541                split=split,
1542                reduction_hint=reduction_hint,
1543            )
1544
1545        block_size = FloorDiv(reduction_numel + (split - 1), split)
1546        intermediates = WelfordReduction.create(
1547            device,
1548            dtype,
1549            tuple(
1550                cls._multilayer_wrap_loader(
1551                    loader,
1552                    reduction_ranges,
1553                    reduction_numel,
1554                    split,
1555                    block_size,
1556                    default=0,
1557                )
1558                for loader in inner_fns
1559            ),
1560            [*ranges, split],  # type: ignore[list-item]
1561            [block_size],
1562            reduction_type,
1563            reduction_hint,
1564        )
1565        for i in intermediates:
1566            i.realize()
1567
1568        i_loaders = [i.make_loader() for i in intermediates]
1569
1570        def intermediate_loader_fn(index, reduction_index, loader):
1571            return loader([*index, *reduction_index])
1572
1573        numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
1574        reduction_hint = cls._multilayer_second_step_hint(
1575            split, numel_hint, reduction_hint
1576        )
1577        return WelfordReduction.create(
1578            device,
1579            dtype,
1580            tuple(
1581                partial(intermediate_loader_fn, loader=i.make_loader())
1582                for i in intermediates
1583            ),
1584            ranges,
1585            [split],  # type: ignore[list-item]
1586            # welford_reduce turns one input into three outputs, which are combined with welford_combine
1587            "welford_combine",
1588            reduction_hint,
1589        )
1590
1591
1592@dataclasses.dataclass
1593class Scan(Loops):
1594    scan_ranges: List[Expr]
1595    size: List[Expr]
1596    combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]]
1597    reindex: Callable[[List[Expr], List[Expr]], List[Expr]]
1598    reduction_hint: ReductionHint
1599    output_index: int
1600    # output_index indexes the following tuples
1601    dtypes: Tuple[torch.dtype, ...]
1602    inner_fns: Tuple[Callable[..., Any], ...]
1603
1604    # HACK we mimick reduction
1605
1606    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
1607        # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
1608        # need to explicitly represent the closure so we can pull out unbacked
1609        # symbols here
1610        return (
1611            super().get_unbacked_symbol_uses()
1612            | set().union(*(free_unbacked_symbols(e) for e in self.scan_ranges))
1613            | set().union(*(free_unbacked_symbols(e) for e in self.size))
1614        )
1615
1616    def __post_init__(self):
1617        assert len(self.ranges) + len(self.scan_ranges) == len(self.size)
1618        super().__post_init__()
1619
1620    def store_reduction(self, output_name, indexer, vars, scan_vars):
1621        idx = self.reindex(vars, scan_vars)
1622        values = [inner_fn(idx) for inner_fn in self.inner_fns]
1623        result = ops.scan(self.dtypes, self.combine_fn, values)
1624        return ops.store(output_name, indexer(idx), result[self.output_index])
1625
1626    def get_reduction_type(self):
1627        # return self.scan_op
1628        return "custom"
1629
1630    def get_reduction_size(self):
1631        return self.scan_ranges
1632
1633    def get_size(self):
1634        return self.size
1635
1636    def get_pointwise_size(self):
1637        return self.ranges
1638
1639    def index_length(self):
1640        return len(self.ranges) + len(self.scan_ranges)
1641
1642    def inner_fn_args(self):
1643        index = self._index(self.ranges)
1644        rindex = self._index(self.scan_ranges, SymT.RINDEX)
1645        idx = self.reindex(index, rindex)
1646        return (idx,)
1647
1648    def inner_fn_free_unbacked_symbols(self):
1649        index = self._index(self.ranges)
1650        rindex = self._index(self.scan_ranges, SymT.RINDEX)
1651        idx = self.reindex(index, rindex)
1652        return extract_free_unbacked_symbols(self.inner_fn, idx)
1653
1654    @classmethod
1655    def create(
1656        cls,
1657        device: torch.device,
1658        dtypes: Tuple[torch.dtype, ...],
1659        inner_fns: Tuple[Callable[[List[Expr]], Any], ...],
1660        size: List[Expr],
1661        axis: int,
1662        combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]],
1663        reduction_hint: ReductionHint = ReductionHint.DEFAULT,
1664        **kwargs,
1665    ) -> List[Optional["TensorBox"]]:
1666        pointwise_ranges = [*size[:axis], *size[axis + 1 :]]
1667        scan_ranges = [size[axis]]
1668
1669        if not is_gpu(device.type):
1670            # TODO: CPU support
1671            return [None] * len(dtypes)
1672
1673        if torch.version.hip is not None and len(dtypes) > 1:
1674            # TODO: Remove this when ROCm triton adds support for multiple inputs
1675            return [None] * len(dtypes)
1676
1677        sizevars = V.graph.sizevars
1678        scan_numel = sizevars.simplify(sympy_product(scan_ranges))
1679
1680        assert len(dtypes) == len(inner_fns)
1681
1682        # Scan with a single element is just a copy
1683        if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)):  # type: ignore[arg-type]
1684            return [
1685                Pointwise.create(
1686                    device=device,
1687                    dtype=dtypes[output_index],
1688                    inner_fn=inner_fns[output_index],
1689                    ranges=size,
1690                )
1691                for output_index in range(len(dtypes))
1692            ]
1693
1694        reduction_hint, num_splits = cls.num_splits(
1695            device=device,
1696            dtype=dtypes[0],
1697            inner_fn=inner_fns[0],
1698            axis=axis,
1699            pointwise_ranges=pointwise_ranges,
1700            scan_ranges=scan_ranges,
1701            combine_fn=combine_fn,
1702            scan_numel=scan_numel,
1703        )
1704        scan_type = Scan if num_splits <= 1 else SplitScan
1705
1706        if num_splits > 1 and torch.version.hip is not None:
1707            # Fallback for split-scan on ROCm
1708            return [None] * len(dtypes)
1709
1710        if num_splits > 1 and len(dtypes) > 1:
1711            # Fallback for split-scans for multiple inputs
1712            return [None] * len(dtypes)
1713
1714        def reindex(index, scan_index):
1715            assert len(scan_index) == len(scan_ranges)
1716            assert len(index) == len(pointwise_ranges)
1717            return [*index[:axis], *scan_index, *index[axis:]]
1718
1719        results = [
1720            TensorBox.create(
1721                scan_type(
1722                    device=device,
1723                    dtype=dtypes[output_index],
1724                    dtypes=dtypes,
1725                    inner_fn=inner_fns[output_index],
1726                    inner_fns=inner_fns,
1727                    size=size,
1728                    ranges=pointwise_ranges,
1729                    scan_ranges=scan_ranges,
1730                    combine_fn=combine_fn,
1731                    reindex=reindex,
1732                    reduction_hint=reduction_hint,
1733                    output_index=output_index,
1734                    **kwargs,
1735                )
1736            )
1737            for output_index in range(len(dtypes))
1738        ]
1739
1740        for result in results:
1741            result.realize()
1742
1743        return results
1744
1745    @classmethod
1746    def num_splits(
1747        cls,
1748        device: torch.device,
1749        dtype: torch.dtype,
1750        inner_fn: Callable[[List[Expr]], Any],
1751        axis: int,
1752        pointwise_ranges: List[Expr],
1753        scan_ranges: List[Expr],
1754        combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]],
1755        scan_numel: Expr,
1756    ):
1757        # TODO: custom splitting heuristic for scan
1758        def wrapper_fn(idx, reduction_idx):
1759            return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]])
1760
1761        return Reduction.num_splits(
1762            device=device,
1763            dst_dtype=dtype,
1764            src_dtype=dtype,
1765            inner_fn=wrapper_fn,
1766            ranges=pointwise_ranges,
1767            reduction_ranges=scan_ranges,
1768            reduction_type="sum",
1769            reduction_numel=scan_numel,
1770        )
1771
1772
1773# This signifies a scan op that should go through TritonSplitScanKernel codegen on CUDA.
1774@dataclasses.dataclass
1775class SplitScan(Scan):
1776    pass
1777
1778
1779def is_storage_and_layout(x):
1780    try:
1781        as_storage_and_layout(x, freeze=False)
1782        return True
1783    except NotImplementedError:
1784        return False
1785
1786
1787def is_contiguous_storage_and_layout(x):
1788    try:
1789        buffer, layout = as_storage_and_layout(x, freeze=False)
1790        # pad the stride here so we will NOT claim an tensor as contiguous
1791        # if a padding is gonna happen.
1792        if layout.should_pad_strides():
1793            layout.pad_strides()
1794        return layout.is_contiguous()
1795    except NotImplementedError:
1796        return False
1797
1798
1799def as_storage_and_layout(
1800    x, freeze=True, want_contiguous=False, stride_order=None, allow_padding=False
1801):
1802    """
1803    Try to simplify x into a StorageBox and a Layout.
1804
1805    allow_padding only affect how we apply stride_order. When allow_padding
1806    is True, we have the freedom to add padding when applying the stride_order.
1807    """
1808    if isinstance(x, TensorBox):
1809        return as_storage_and_layout(
1810            x.data,
1811            freeze=freeze,
1812            want_contiguous=want_contiguous,
1813            stride_order=stride_order,
1814            allow_padding=allow_padding,
1815        )
1816    if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
1817        if freeze:
1818            if want_contiguous:
1819                x.data.freeze_layout()
1820                assert x.data.layout.is_contiguous()
1821            elif stride_order is not None:
1822                x.data.freeze_layout_with_stride_order(
1823                    stride_order, allow_padding=allow_padding
1824                )
1825            else:
1826                x.data.decide_layout()
1827        return x, x.data.layout
1828    if isinstance(x, ReinterpretView):
1829        # making the base of x contiguous or stride_ordered will not necessarily make
1830        # the ReinterpretView either, so don't pass along those arguments
1831        buffer, _ = as_storage_and_layout(
1832            x.data,
1833            freeze=freeze,
1834        )
1835        return buffer, x.layout
1836    raise NotImplementedError
1837
1838
1839as_contiguous_storage_and_layout = functools.partial(
1840    as_storage_and_layout, want_contiguous=True
1841)
1842
1843
1844def is_stride_order_storage_and_layout(x, stride_order):
1845    try:
1846        buffer, layout = as_storage_and_layout(x, freeze=False)
1847        return layout.is_stride_ordered(stride_order)
1848    except NotImplementedError:
1849        return False
1850
1851
1852@dataclasses.dataclass
1853class BaseView(IRNode):
1854    data: IRNode
1855
1856    def get_unbacked_symbol_uses(self):
1857        return self.data.get_unbacked_symbol_uses()
1858
1859    def make_reindexer(self):
1860        raise NotImplementedError(f"make_reindexer NYI on {self}")
1861
1862    def make_indexer(self):
1863        inner = self.data.make_indexer()
1864        reindex = self.make_reindexer()
1865
1866        def indexer(idx):
1867            return inner(reindex(idx))
1868
1869        return indexer
1870
1871    def make_loader(self):
1872        inner = self.data.make_loader()
1873        reindex = self.make_reindexer()
1874
1875        def loader(idx):
1876            return inner(reindex(idx))
1877
1878        return loader
1879
1880    @property
1881    def dtype(self):
1882        return self.data.dtype
1883
1884    def get_layout(self):
1885        return self.data.get_layout()
1886
1887    def get_device(self):
1888        return self.data.get_device()
1889
1890    def get_origin_node(self):
1891        return None
1892
1893    def get_name(self):
1894        return self.data.get_name()
1895
1896    def get_pointwise_size(self):
1897        return self.get_size()
1898
1899    def mark_reuse(self, users):
1900        return self.data.mark_reuse(users)
1901
1902    def has_exceeded_max_reads(self):
1903        return self.data.has_exceeded_max_reads()
1904
1905    def realize(self):
1906        return self.data.realize()
1907
1908    def realize_hint(self):
1909        return self.data.realize_hint()
1910
1911    def get_storage_numel(self):
1912        return self.data.get_storage_numel()
1913
1914    def is_extern(self):
1915        return self.data.is_extern()  # type: ignore[attr-defined]
1916
1917    def is_module_buffer(self):
1918        return self.data.is_module_buffer()  # type: ignore[attr-defined]
1919
1920    def get_reads(self):
1921        with patch.object(FlexibleLayout, "allow_indexing", True):
1922            return extract_read_writes(
1923                self.make_loader(),
1924                self.get_size(),
1925            ).reads
1926
1927    def unwrap_view(self):
1928        x: IRNode = self
1929        while isinstance(x, BaseView):
1930            x = x.data
1931        return x
1932
1933    def constant_to_device(self, device):
1934        """Move this to a given device. Requires that all reads are to constants."""
1935        loader = self.make_loader()
1936        loader = patch.object(ConstantBuffer, "override_device", device)(loader)
1937        return Pointwise(device, self.get_dtype(), loader, self.get_size())
1938
1939
1940@dataclasses.dataclass
1941class ExpandView(BaseView):
1942    size: List[Expr]
1943
1944    @staticmethod
1945    def _normalize_size(x, new_size):
1946        """Replace `-1` with correct sizes"""
1947        sizevars = V.graph.sizevars
1948        new_size = list(map(sympy.expand, new_size))
1949        old_size = x.get_size()
1950        old_size = [None] * (len(new_size) - len(old_size)) + list(old_size)
1951        assert len(new_size) == len(old_size)
1952        for i in range(len(new_size)):
1953            if new_size[i] == -1:
1954                assert old_size[i] is not None
1955                new_size[i] = old_size[i]
1956            elif old_size[i] is None or old_size[i] == 1:
1957                pass
1958            else:
1959                # Sanity check: Expect broadcast compatibility
1960                #
1961                # NB: new_size[i] == old_size[i] is expected to already be
1962                # guarded because the meta formula was expected to have taught
1963                # us this equality.
1964                assert (
1965                    sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0
1966                ), "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
1967        return new_size
1968
1969    @classmethod
1970    def create(cls, x, new_size):
1971        new_size = cls._normalize_size(x, new_size)
1972
1973        if is_storage_and_layout(x):
1974            storage, old_layout = as_storage_and_layout(x)
1975            skip = len(new_size) - len(old_layout.size)
1976            assert skip >= 0
1977            new_stride = [sympy.Integer(0)] * skip
1978            for stride, size in zip(old_layout.stride, old_layout.size):
1979                new_stride.append(stride if size != 1 else sympy.Integer(0))
1980            new_layout = FixedLayout(
1981                old_layout.device,
1982                old_layout.dtype,
1983                list(new_size),
1984                new_stride,
1985                old_layout.offset,
1986            )
1987            return ReinterpretView(storage, new_layout)
1988
1989        return ExpandView(x, new_size)
1990
1991    def get_size(self):
1992        return self.size
1993
1994    def make_reindexer(self):
1995        target = self.get_size()
1996        actual = self.data.get_size()
1997        skip = len(target) - len(actual)
1998
1999        def reindex(index):
2000            index = list(index[skip:])
2001            assert len(index) == len(actual)
2002            for i in range(len(actual)):
2003                if actual[i] == 1:
2004                    # zero out broadcast dimension
2005                    index[i] = sympy.Integer(0)
2006            return index
2007
2008        return reindex
2009
2010
2011@dataclasses.dataclass
2012class PermuteView(BaseView):
2013    dims: List[Expr]
2014
2015    @classmethod
2016    def create(cls, x, dims):
2017        dims = cls._map_neg_dims(dims)
2018        assert set(dims) == set(range(len(dims)))
2019
2020        if is_storage_and_layout(x):
2021            storage, old_layout = as_storage_and_layout(x)
2022            new_layout = FixedLayout(
2023                old_layout.device,
2024                old_layout.dtype,
2025                [old_layout.size[i] for i in dims],
2026                [old_layout.stride[i] for i in dims],
2027                old_layout.offset,
2028            )
2029            return ReinterpretView(storage, new_layout)
2030
2031        return PermuteView(x, dims)
2032
2033    @classmethod
2034    def _map_neg_dims(cls, dims):
2035        return [dim if dim >= 0 else len(dims) + dim for dim in dims]
2036
2037    def get_size(self):
2038        assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims)))
2039        size = self.data.get_size()
2040        return [size[i] for i in self.dims]
2041
2042    def make_reindexer(self):
2043        inv = {j: i for i, j in enumerate(self.dims)}
2044        inv = [inv[i] for i in range(len(self.dims))]  # type: ignore[index]
2045        assert set(inv) == set(range(len(self.dims)))
2046
2047        def reindex(index):
2048            return [index[i] for i in inv]
2049
2050        return reindex
2051
2052
2053class SqueezeView(BaseView):
2054    @classmethod
2055    def create(cls, x, *, dim=None):
2056        if is_storage_and_layout(x):
2057            storage, old_layout = as_storage_and_layout(x)
2058            new_size = []
2059            new_stride = []
2060            if dim is not None:
2061                assert isinstance(dim, int), "expected integer dim argument"
2062                assert 0 <= dim and dim < len(old_layout.size)
2063
2064            for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)):
2065                if dim is None:
2066                    if size != 1:
2067                        new_size.append(size)
2068                        new_stride.append(stride)
2069                else:
2070                    if i != dim:
2071                        new_size.append(size)
2072                        new_stride.append(stride)
2073                    else:
2074                        assert size == 1, "expected squeezed size to be 1"
2075
2076            new_layout = FixedLayout(
2077                old_layout.device,
2078                old_layout.dtype,
2079                new_size,
2080                new_stride,
2081                old_layout.offset,
2082            )
2083            return ReinterpretView(storage, new_layout)
2084
2085        if dim is None:
2086            # redirect to a generic view
2087            return View.create(x, [s for s in x.get_size() if s != 1])
2088        else:
2089            assert x.get_size()[dim] == 1
2090            return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim])
2091
2092    @staticmethod
2093    def squeezer(size: Tuple[sympy.Expr, ...]):
2094        new_size = [s for s in size if s != 1]
2095        not_one = [i for i, s in enumerate(size) if s != 1]
2096        length = len(size)
2097
2098        def reindex(index: List[sympy.Expr]) -> Tuple[sympy.Expr, ...]:
2099            assert len(index) == len(not_one), f"{index} {not_one}"
2100            new_index = [sympy.Integer(0)] * length
2101            for idx, s in zip(not_one, index):
2102                new_index[idx] = s
2103            return tuple(new_index)
2104
2105        return new_size, reindex
2106
2107    def __init__(self, data):
2108        raise AssertionError("use SqueezeView.create()")
2109
2110
2111@dataclasses.dataclass
2112class GenericView(BaseView):
2113    size: List[Expr]
2114    reindex: Callable[..., Any]
2115
2116    def make_reindexer(self):
2117        return self.reindex
2118
2119    def reindex_str(self):
2120        index_old = [
2121            sympy_index_symbol_with_prefix(SymT.INDEX, n) for n in range(len(self.size))
2122        ]
2123        index_new = list(self.reindex(index_old))
2124        return f"lambda {', '.join(map(str, index_old))}: {index_new}"
2125
2126    def __str__(self):
2127        return self.str_helper(
2128            [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
2129        )
2130
2131    __repr__ = __str__
2132
2133    @classmethod
2134    def create(cls, x, new_size, reindex):
2135        return cls(x, list(new_size), reindex)
2136
2137    def get_size(self):
2138        return self.size
2139
2140
2141@dataclasses.dataclass
2142class View(GenericView):
2143    @staticmethod
2144    def handle_negative_index(idx, size):
2145        idx = sympy.expand(idx)
2146        size = sympy.expand(size)
2147        evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr
2148        if evaluate_expr(sympy.Lt(idx, 0)):
2149            idx = idx + size
2150        return idx
2151
2152    @classmethod
2153    def create(cls, x, new_size):
2154        assert isinstance(new_size, (tuple, list))
2155        old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
2156
2157        # Skip pointless views
2158        if V.graph.sizevars.statically_known_list_equals(old_size, new_size):
2159            return x
2160
2161        unbacked_symbols_in_sizes = False
2162        if (
2163            len(free_unbacked_symbols(old_size)) > 0
2164            or len(free_unbacked_symbols(new_size)) > 0
2165        ):
2166            unbacked_symbols_in_sizes = True
2167
2168        if 0 in new_size:
2169
2170            def fake_reindex(index):
2171                return tuple([0] * len(old_size))
2172
2173            return cls(x, list(new_size), fake_reindex)
2174        # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
2175        elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes:
2176            if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)):
2177                # realize x; otherwise, the dynamic_reshape_indexer below will fail
2178                # due to the size_hint's inability to process unbacked SymInts
2179                x = ExternKernel.realize_input(x)
2180
2181            storage, old_layout = as_contiguous_storage_and_layout(x)
2182            new_layout = FixedLayout(
2183                old_layout.device,
2184                old_layout.dtype,
2185                new_size,
2186                FlexibleLayout.contiguous_strides(new_size),
2187                old_layout.offset,
2188            )
2189            return ReinterpretView(storage, new_layout)
2190
2191        reindex = cls.dynamic_reshape_indexer(old_size, new_size)
2192        return cls(x, list(new_size), reindex)
2193
2194    @staticmethod
2195    def resolve_negative_size(old_size, new_size):
2196        new_size = [V.graph.sizevars.simplify(x) for x in new_size]
2197        old_size = [V.graph.sizevars.simplify(x) for x in old_size]
2198
2199        new_size = list(new_size)
2200        for i in range(len(new_size)):
2201            if new_size[i] == -1:
2202                new_size[i] = sympy.Integer(1)
2203                new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
2204                break
2205
2206        V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size))
2207        return old_size, new_size
2208
2209    @classmethod
2210    def dynamic_reshape_indexer(cls, old_size, new_size):
2211        try:
2212            reindex = cls._dynamic_reshape_indexer(old_size, new_size)
2213        except (AssertionError, IndexError):
2214            # optimistic algorithm failed, lets do a fallback
2215            flat = [sympy_product(old_size)]
2216            reindex1 = cls._dynamic_reshape_indexer(old_size, flat)
2217            reindex2 = cls._dynamic_reshape_indexer(flat, new_size)
2218            reindex = fuse_reindexing(reindex1, reindex2)
2219        return reindex
2220
2221    @staticmethod
2222    def _dynamic_reshape_indexer(old_size, new_size):
2223        """
2224        Perform a reshape entirely by modifying indexing math
2225        """
2226        size_hint = V.graph.sizevars.size_hint
2227        # TODO: These symbols may not escape, if they don't assert so and
2228        # treat them as temporary
2229        vars = [
2230            sympy_index_symbol_with_prefix(SymT.VIEW, i) for i in range(len(new_size))
2231        ]
2232
2233        stack_new = list(zip(vars, new_size))
2234        stack_old = list(old_size)
2235
2236        view_expr = []
2237        while stack_new and stack_old:
2238            size_old = stack_old.pop()
2239            var, size_new = stack_new.pop()
2240            if size_old == 1:
2241                view_expr.append(sympy.Integer(0))
2242                stack_new.append((var, size_new))  # re-add
2243            elif size_new == 1:
2244                stack_old.append(size_old)  # re-add
2245            elif size_hint(size_new) == size_hint(size_old):
2246                view_expr.append(var)
2247                V.graph.sizevars.guard_equals(size_new, size_old)
2248            elif size_hint(size_new) < size_hint(size_old):
2249                while size_hint(size_new) < size_hint(size_old):
2250                    var2, size_new2 = stack_new.pop()
2251                    var = var2 * size_new + var
2252                    size_new = size_new * size_new2
2253                view_expr.append(var)
2254                V.graph.sizevars.guard_equals(size_new, size_old)
2255            elif size_hint(size_new) > size_hint(size_old):
2256                divisor = sympy.Integer(1)
2257                modulus = size_old
2258                view_expr.append(ModularIndexing(var, divisor, modulus))
2259                divisor = divisor * modulus
2260                while size_hint(size_new) > size_hint(size_old):
2261                    modulus = stack_old.pop()
2262                    view_expr.append(ModularIndexing(var, divisor, modulus))
2263                    divisor = divisor * modulus
2264                    size_old = size_old * modulus
2265                V.graph.sizevars.guard_equals(size_new, size_old)
2266            else:
2267                raise AssertionError
2268
2269        while stack_old:
2270            size_old = stack_old.pop()
2271            V.graph.sizevars.guard_equals(size_old, 1)  # type: ignore[arg-type]
2272            view_expr.append(sympy.Integer(0))
2273
2274        while stack_new:
2275            var, size_new = stack_new.pop()
2276            V.graph.sizevars.guard_equals(size_new, 1)  # type: ignore[arg-type]
2277
2278        view_expr.reverse()
2279        assert len(view_expr) == len(old_size)
2280
2281        def reindex(index):
2282            assert len(index) == len(vars), (len(index), len(vars))
2283            replacements = dict(zip(vars, index))
2284            return tuple(sympy_subs(x, replacements) for x in view_expr)  # type: ignore[arg-type]
2285
2286        return reindex
2287
2288
2289@dataclasses.dataclass
2290class ReinterpretView(BaseView):
2291    """Pretend our storage has a different layout"""
2292
2293    layout: "Layout"
2294
2295    def __post_init__(self):
2296        super().__post_init__()
2297        if isinstance(self.data, BaseView):
2298            self.data = self.data.unwrap_view()
2299
2300    def __str__(self):
2301        return self.str_helper(
2302            [
2303                self.data,
2304                self.layout,
2305            ]
2306        )
2307
2308    __repr__ = __str__
2309
2310    def get_name(self):
2311        return self.data.get_name()
2312
2313    def get_device(self):
2314        return self.layout.device
2315
2316    def get_origin_node(self):
2317        return None
2318
2319    @property
2320    def dtype(self):
2321        return self.layout.dtype
2322
2323    def get_size(self):
2324        return list(self.layout.size)
2325
2326    def get_stride(self):
2327        return list(self.layout.stride)
2328
2329    def make_loader(self):
2330        def loader(index):
2331            indexer = self.layout.make_indexer()
2332            return ops.load(self.get_name(), indexer(index))
2333
2334        return loader
2335
2336    def make_indexer(self):
2337        return self.layout.make_indexer()
2338
2339    def get_layout(self):
2340        return self.layout
2341
2342    def freeze_layout(self):
2343        pass
2344
2345    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
2346        return (
2347            free_unbacked_symbols(self.layout.size)
2348            | free_unbacked_symbols(self.layout.stride)
2349            | free_unbacked_symbols(self.layout.offset)
2350        )
2351
2352    def codegen_reference(self, writer=None):
2353        # reinterpret_tensor is similar to as_strided except:
2354        # - offset is added to the existing offset (rather than replacing it)
2355        # - view tracking is disabled similar to unsafe_view
2356        return V.graph.wrapper_code.codegen_reinterpret_view(
2357            self.data,
2358            self.layout.size,
2359            self.layout.stride,
2360            self.layout.offset,
2361            writer,
2362        )
2363
2364
2365class SliceView(View):
2366    @classmethod
2367    def normalize_start_end(cls, x, dim, start, end):
2368        """
2369        Normalize start and end such that both are in the range
2370        [0, x.get_size()[dim]] and start <= end.
2371        """
2372        sizevars = V.graph.sizevars
2373        dim_size = x.get_size()[dim]
2374
2375        if any(free_unbacked_symbols(x) for x in (start, end, dim_size)):
2376
2377            def clamp(x, lower, upper):
2378                return sympy.Min(sympy.Max(x, lower), upper)
2379
2380        else:
2381
2382            def clamp(x, lower, upper):
2383                return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper)
2384
2385        def clamp_wrap(val, lower, upper, default):
2386            if val is None:
2387                return default
2388            val = cls.handle_negative_index(val, dim_size)
2389            return clamp(val, lower, upper)
2390
2391        start = clamp_wrap(start, 0, dim_size, 0)
2392        end = clamp_wrap(end, start, dim_size, dim_size)
2393        return start, end
2394
2395    @classmethod
2396    def create(cls, x, dim, start, end, step=1, clamp=True):
2397        step = sympy.expand(step)
2398        assert step > 0
2399        try:
2400            if start == 0 and end >= 2**63 - 1 and step == 1:
2401                return x
2402        except TypeError:
2403            pass
2404
2405        sizevars = V.graph.sizevars
2406        new_size = list(x.get_size())
2407
2408        # NB: Ordinarily we default to clamping.
2409        # We only don't clamp for split_with_sizes. For split_with_sizes, sizes should be already valid
2410        # failing in this situation is ok, since invalid sizes could trigger silent errors.
2411        if clamp:
2412            start, end = cls.normalize_start_end(x, dim, start, end)
2413
2414        new_size[dim] = FloorDiv(end - start + (step - 1), step)
2415
2416        if is_storage_and_layout(x):
2417            # Fast path
2418            storage, old_layout = as_storage_and_layout(x)
2419            new_stride = list(old_layout.stride)
2420            new_stride[dim] = new_stride[dim] * step
2421            new_layout = FixedLayout(
2422                old_layout.device,
2423                old_layout.dtype,
2424                new_size,
2425                new_stride,
2426                old_layout.offset + old_layout.stride[dim] * start,
2427            )
2428            return ReinterpretView(storage, new_layout)
2429
2430        def reindex(index):
2431            assert len(index) == len(new_size), f"wrong ndim {index} {new_size}"
2432            index = list(index)
2433            index[dim] = index[dim] * step + start
2434            return index
2435
2436        # redirect to a generic view
2437        return SliceView(x, size=new_size, reindex=reindex)
2438
2439
2440class BaseConstant(IRNode):
2441    dtype: torch.dtype
2442    device: torch.device
2443
2444    def get_size(self):
2445        return ()
2446
2447    def get_device(self):
2448        return self.device
2449
2450    def get_origin_node(self):
2451        return None
2452
2453    def mark_reuse(self, users):
2454        pass
2455
2456    def has_exceeded_max_reads(self):
2457        return False
2458
2459    def get_reads(self):
2460        return ()
2461
2462    def is_extern(self):
2463        return False
2464
2465
2466@dataclasses.dataclass
2467class Constant(BaseConstant):
2468    value: Any
2469    dtype: torch.dtype
2470    device: torch.device
2471
2472    def make_loader(self):
2473        def loader(index):
2474            return ops.constant(self.value, self.dtype)
2475
2476        return loader
2477
2478    def realize(self):
2479        pass
2480
2481    def constant_to_device(self, device):
2482        return Constant(self.value, self.dtype, device)
2483
2484
2485@dataclasses.dataclass
2486class IndexingConstant(BaseConstant):
2487    index: Any
2488    dtype: torch.dtype
2489    device: torch.device
2490
2491    def make_loader(self):
2492        def loader(index):
2493            return ops.index_expr(self.index, self.dtype)
2494
2495        return loader
2496
2497    def constant_to_device(self, device):
2498        return IndexingConstant(self.index, self.dtype, device)
2499
2500
2501def is_contiguous_strides_for_shape(stride, shape):
2502    return all(
2503        size == 1 or left == right
2504        for left, right, size in zip(
2505            stride, FlexibleLayout.contiguous_strides(shape), shape
2506        )
2507    )
2508
2509
2510def get_align_for_dtype(dtype):
2511    """
2512    CUDA max memory transaction size is 128 bytes for a warp.
2513    We pick `128 // dtype.itemsize` as alighment so GPU can do coalesced
2514    memory access.
2515    """
2516    return 128 // dtype.itemsize
2517
2518
2519@dataclasses.dataclass
2520class Layout(IRNode):
2521    def __init__(
2522        self,
2523        device: torch.device,
2524        dtype: torch.dtype,
2525        size: List[Expr],
2526        stride: Optional[Sequence[Union[Expr, int]]],
2527        offset: Expr = Integer(0),
2528    ):
2529        assert stride is None or len(size) == len(
2530            stride
2531        ), f"size={size}, stride={stride}"
2532        self.device = device
2533        self.dtype = dtype
2534        assert all(isinstance(s, (Expr, int)) for s in size)
2535        self.size = size
2536        self._stride = stride
2537        self.offset = offset
2538
2539    @property
2540    def stride(self):
2541        return self._stride
2542
2543    def __str__(self):
2544        offset = ""
2545        if self.offset != 0:
2546            offset = f", offset={self.offset}"
2547        return (
2548            f"{type(self).__name__}('{self.device.type}', {self.dtype}, "
2549            f"size={self.size}, stride={self.stride}{offset})"
2550        )
2551
2552    __repr__ = __str__
2553
2554    def is_contiguous(self):
2555        return is_contiguous_strides_for_shape(self.stride, self.size)
2556
2557    @staticmethod
2558    def is_channels_last_contiguous(shape, strides):
2559        ndim = len(shape)
2560        if ndim not in [4, 5] or shape[1] == 1:
2561            return False
2562        for left, right, size in zip(
2563            strides, make_channels_last_strides_for(shape), shape  # type: ignore[arg-type]
2564        ):
2565            if size != 1 and left != right:
2566                return False
2567        return True
2568
2569    def is_transposed(self):
2570        for left, right, size in zip(
2571            self.stride,
2572            reversed(FlexibleLayout.contiguous_strides(self.size)),
2573            self.size,
2574        ):
2575            if size != 1 and left != right:
2576                return False
2577        return True
2578
2579    def is_stride_ordered(self, order):
2580        assert len(self.stride) == len(order)
2581
2582        # ignore dimensions of size 1, they dont affect layout
2583        non_1_indices = [
2584            i
2585            for i, dim in enumerate(self.size)
2586            if V.graph.sizevars.size_hint(dim, fallback=2) != 1
2587        ]
2588
2589        stride = [self.stride[i] for i in non_1_indices]
2590        order = [order[i] for i in non_1_indices]
2591
2592        def sorted_indices(arr):
2593            sorted_arr = sorted(arr)
2594            return [sorted_arr.index(element) for element in arr]
2595
2596        # since we may have removed dimensions, need to re-sort & re-index order
2597        order = sorted_indices(order)
2598
2599        # reorder the stride given order
2600        stride_ordered = [-1] * len(order)
2601        for i in range(len(order)):
2602            stride_ordered[order[i]] = V.graph.sizevars.size_hint(stride[i])
2603        # check if it is in ascending order
2604        for i in range(len(order) - 1):
2605            if stride_ordered[i] > stride_ordered[i + 1]:
2606                return False
2607        return True
2608
2609    def is_channels_last_stride_ordered(self):
2610        # create channels_last order(NCHW, NCDHW, the C is the first order).
2611        order = [0] + list(reversed(range(1, len(self.stride) - 1)))
2612        order = [len(order)] + order
2613        return self.is_stride_ordered(order)
2614
2615    @staticmethod
2616    def _pad_strides(in_strides, size, dtype):
2617        """
2618        The padding does not change stride order but makes sure all strides larger
2619        than the threshold are multiple of align.
2620        """
2621        align = get_align_for_dtype(dtype)
2622        if len(in_strides) == 0:
2623            return in_strides
2624
2625        if not config.pad_channels_last and Layout.is_channels_last_contiguous(
2626            size, in_strides
2627        ):
2628            return in_strides
2629
2630        current_fx_node = V.get_current_node()
2631        if hasattr(current_fx_node, "meta") and current_fx_node.meta.get(
2632            "dislike_padding", False
2633        ):
2634            return in_strides
2635
2636        # get_stride_order does not work with dynamic shape. Also we can not
2637        # statically decide if a padding is needed or how much padding we should
2638        # do for dynamic shape.
2639        #
2640        # Skip padding the strides for dynamic shape for now.
2641        if not all(
2642            isinstance(s, (int, sympy.Integer))
2643            for s in itertools.chain(in_strides, size)
2644        ):
2645            return in_strides
2646
2647        stride_order = get_stride_order(in_strides)
2648        fill_order = stride_order2fill_order(stride_order)
2649
2650        new_strides = [0 for _ in range(len(in_strides))]
2651        # since we pad when the layout is flexible, we can decide the
2652        # smallest stride to be 1.
2653        new_strides[fill_order[0]] = 1
2654
2655        # Don't align a too small stride since that causes too much memory increase.
2656        # Pad too small stride may also cause perf loss. We may result in many tiny data blocks
2657        # with gaps in between. That causes less coalesced GPU memory access!
2658        #
2659        # Initially we pick 320 as the threshold since for alignement=16,
2660        # that results in at most 5% memory cost.
2661        #
2662        # But later on we raise the threshold to 1024 to avoid interfere with persistent reduction.
2663        # Let's say an inner reduction has a row size 513. Inductor will generate
2664        # persistent reduction code.
2665        # If we do padding, the strides are not contiguous any more. Inductor
2666        # uses a much smaller threshold for persistent reduction in this case and
2667        # generates potentially worse non-persistent reduction code.
2668        #
2669        # This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x.
2670        # (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms)
2671        align_stride_threshold = 1024
2672        padded = False
2673        for rank, idx in enumerate(fill_order[1:], start=1):
2674            prev_idx = fill_order[rank - 1]
2675            stride = new_strides[prev_idx] * size[prev_idx]
2676
2677            if stride > align_stride_threshold and stride % align != 0:
2678                stride = ceildiv(stride, align) * align
2679                padded = True
2680            new_strides[idx] = stride
2681
2682        if not padded:
2683            # Consider a tensor with shape [256, 1, 5, 5]
2684            # Avoid strides like [25, 5, 5, 1] being padded to equivalent strides
2685            # [25, 25, 5, 1].
2686            return in_strides
2687
2688        metrics.num_comprehensive_padding += 1
2689        return new_strides
2690
2691    def pad_strides(self):
2692        assert isinstance(self, FlexibleLayout)
2693        assert self._stride is not None
2694        self._stride = self._pad_strides(self._stride, self.size, self.dtype)
2695
2696    def should_pad_strides(self):
2697        return config.comprehensive_padding and isinstance(self, FlexibleLayout)
2698
2699    def as_fixed(self):
2700        if isinstance(self, FixedLayout):
2701            return self
2702
2703        if self.should_pad_strides():
2704            self.pad_strides()
2705        return FixedLayout(
2706            self.device,
2707            self.dtype,
2708            self.size,
2709            self.stride,
2710            self.offset,
2711        )
2712
2713    def make_indexer(self):
2714        assert (
2715            FlexibleLayout.allow_indexing
2716        ), f"convert {type(self).__name__} to FixedLayout first"
2717        return self.as_fixed().make_indexer()
2718
2719    def __eq__(self, other) -> bool:
2720        return (
2721            self.device == other.device
2722            and self.dtype == other.dtype
2723            and self.size == other.size
2724            and self.stride == other.stride
2725            and self.offset == other.offset
2726        )
2727
2728    def storage_size(self) -> sympy.Expr:
2729        return compute_required_storage_length(self.size, self.stride, self.offset)  # type: ignore[arg-type, return-value]
2730
2731
2732class FixedLayout(Layout):
2733    """A Tensor layout we cannot change"""
2734
2735    def __init__(
2736        self,
2737        device: torch.device,
2738        dtype: torch.dtype,
2739        size: Union[List[Expr], List[int]],
2740        stride: Optional[Sequence[Union[Expr, int]]] = None,
2741        offset: Union[Expr, int] = Integer(0),
2742    ):
2743        if stride is None:
2744            stride = FlexibleLayout.contiguous_strides(size)
2745        super().__init__(
2746            device,
2747            dtype,
2748            size,  # type: ignore[arg-type]
2749            stride,
2750            offset,  # type: ignore[arg-type]
2751        )
2752
2753    def make_indexer(self):
2754        """A closure containing math to read a given element"""
2755
2756        def indexer(index):
2757            assert len(index) == len(self.stride)
2758            assert len(index) == len(self.size)
2759            result = self.offset
2760            for idx, stride, sz in zip(index, self.stride, self.size):
2761                if sz != 1:
2762                    result = result + idx * stride
2763            return result
2764
2765        return indexer
2766
2767
2768class FlexibleLayout(Layout):
2769    """A Tensor layout we are allowed to change"""
2770
2771    allow_indexing = False
2772
2773    # WARNING!  This doesn't handle zero size tensors correctly
2774    @staticmethod
2775    def contiguous_strides(sizes):
2776        if len(sizes) == 0:
2777            return []
2778        reversed_strides = [sympy.Integer(1)]
2779        for size in reversed(sizes[1:]):
2780            reversed_strides.append(size * reversed_strides[-1])
2781        return list(reversed(reversed_strides))
2782
2783    @staticmethod
2784    def fill_ordered(sizes, order):
2785        """
2786        Create a stride based on the order the dimensions should be filled in.
2787
2788        In this format, channels last would be:
2789            [1, 3, 2, 0]
2790        """
2791        assert set(range(len(sizes))) == set(order)
2792        next_stride = sympy.Integer(1)
2793        strides = [None] * len(order)
2794
2795        for i in order:
2796            strides[i] = next_stride
2797            next_stride = next_stride * sizes[i]
2798        return strides
2799
2800    @staticmethod
2801    def stride_ordered(sizes, order):
2802        """
2803        Create a stride based on the sorted order of a permuted range.
2804
2805        In this format, channels last would be:
2806            [3, 0, 2, 1]
2807        """
2808        assert set(range(len(sizes))) == set(order)
2809        fill_order = stride_order2fill_order(order)
2810        return FlexibleLayout.fill_ordered(sizes, fill_order)
2811
2812    @staticmethod
2813    def stride_ordered_for_memory_format(sizes, memory_format):
2814        """
2815        Create a stride based on a memory format.
2816
2817        Memory format is translasted into a stride order,
2818        so channels_last is the same as:
2819            FlexibleLayout.stride_ordered(sizes, [3, 0, 2, 1])
2820
2821        This interface does not support memory_format `torch.preserve_format`
2822        which should be used to deduce a format from another source
2823        """
2824        if memory_format == torch.channels_last:
2825            return FlexibleLayout.stride_ordered(sizes, NHWC_STRIDE_ORDER)
2826        elif memory_format == torch.channels_last_3d:
2827            return FlexibleLayout.stride_ordered(sizes, NHWDC_STRIDE_ORDER)
2828        elif memory_format == torch.contiguous_format:
2829            return FlexibleLayout.contiguous_strides(sizes)
2830        else:
2831            log.debug(
2832                "stride_ordered_for_memory_format, unsuppored memory_format: %s",
2833                memory_format,
2834            )
2835            raise NotImplementedError
2836
2837    @staticmethod
2838    def same_ordered(sizes, stride):
2839        """
2840        Create a stride that has the same stride order as given stride
2841
2842        For example, if given stride is [1000, 1, 100, 10],
2843        the fill order should be [1, 3, 2, 0]
2844        """
2845        assert len(sizes) == len(stride)
2846        stride = [V.graph.sizevars.size_hint(x) for x in stride]
2847        fill_order = sorted(range(len(stride)), key=stride.__getitem__)
2848        return FlexibleLayout.fill_ordered(sizes, fill_order)
2849
2850    def as_stride_order(self, order, allow_padding=False):
2851        new_stride = self.stride_ordered(self.size, order)
2852        if self.should_pad_strides() and allow_padding:
2853            new_stride = self._pad_strides(new_stride, self.size, self.dtype)
2854
2855        return FixedLayout(
2856            self.device,
2857            self.dtype,
2858            self.size,
2859            new_stride,
2860            self.offset,
2861        )
2862
2863    def as_fill_order(self, order):
2864        new_stride = self.fill_ordered(self.size, order)
2865        if self.should_pad_strides():
2866            new_stride = self._pad_strides(new_stride, self.size, self.dtype)
2867        return FixedLayout(
2868            self.device,
2869            self.dtype,
2870            self.size,
2871            new_stride,
2872            self.offset,
2873        )
2874
2875    def as_same_order(self, stride):
2876        new_stride = self.same_ordered(self.size, stride)
2877        if self.should_pad_strides():
2878            new_stride = self._pad_strides(new_stride, self.size, self.dtype)
2879        return FixedLayout(
2880            self.device,
2881            self.dtype,
2882            self.size,
2883            new_stride,
2884            self.offset,
2885        )
2886
2887    def __init__(self, device, dtype, size, stride_order=None):
2888        if stride_order:
2889            strides = FlexibleLayout.fill_ordered(size, stride_order)
2890        else:
2891            strides = FlexibleLayout.contiguous_strides(size)
2892        super().__init__(device, dtype, size, strides)
2893
2894
2895class NonOwningLayout(Layout):
2896    """Is a view into the storage of another tensor"""
2897
2898    def __init__(self, view: Union[BaseView, "TensorBox"]):
2899        layout = view.get_layout()
2900        super().__init__(
2901            layout.device,
2902            layout.dtype,
2903            layout.size,
2904            layout.stride,
2905        )
2906        self.view = view
2907
2908    def make_indexer(self):
2909        return self.as_fixed().make_indexer()
2910
2911    def maybe_guard_aligned(self):
2912        offset = self.view.get_layout().offset
2913        if offset == 0:
2914            return True
2915        from .compile_fx import ALIGNMENT
2916
2917        return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)  # type: ignore[arg-type]
2918
2919
2920class NoneLayout(IRNode):
2921    # This is janky, I figured out what fields to populate by just running
2922    # the model I was interested in and adding properties/methods as needed.
2923    # This doesn't inherit from Layout because Layout assumes you have stuff
2924    # like sizes, but I don't really have anything here.
2925    #
2926    # If you have an ir.Node with NoneLayout, you probably need to setup
2927    # dependencies manually in scheduler
2928
2929    def __init__(self, device):
2930        self.device = device
2931        self.size = [0]
2932        self.stride = [0]
2933
2934    def storage_size(self):
2935        return 0
2936
2937    def as_fixed(self):
2938        return self
2939
2940
2941class MutationLayoutSHOULDREMOVE(Layout):
2942    def __init__(self, target: IRNode):
2943        super().__init__(
2944            target.get_device(),
2945            target.get_dtype(),
2946            target.get_size(),
2947            None,
2948        )
2949        self.target = target
2950        name = self.get_buffer().get_name()
2951        V.graph.mark_buffer_mutated(name)
2952
2953    @Layout.stride.getter  # type: ignore[attr-defined]
2954    def stride(self):
2955        return self.real_layout().stride
2956
2957    def storage_size(self) -> sympy.Expr:
2958        return self.real_layout().storage_size()
2959
2960    def get_buffer(self) -> "Buffer":
2961        def unwrap_views(target):
2962            if isinstance(target, MutationLayoutSHOULDREMOVE):
2963                return unwrap_views(target.target)
2964            if isinstance(target, BaseView):
2965                return unwrap_views(target.unwrap_view())
2966            if isinstance(target, MutableBox):
2967                return unwrap_views(target.data)
2968            return target
2969
2970        result = unwrap_views(self.target)
2971        assert isinstance(
2972            result, Buffer
2973        ), "MutationLayoutSHOULDREMOVE must refer to a buffer"
2974        return result
2975
2976    def real_layout(self):
2977        return self.get_buffer().layout
2978
2979    @classmethod
2980    def realize_into(cls, src, dst, unsafe_alias=False):
2981        dst.realize()
2982        # NOTE: We must realize users of `dst` before we realize `src`, since
2983        # realization order determines scheduling order. Otherwise, src's
2984        # mutation would be scheduled before the existing users of dst!
2985        V.graph.mark_buffer_mutated(dst.get_name())
2986
2987        if isinstance(src, TensorBox):
2988            src = src.data
2989
2990        # We copy the contents of src into dst. In most cases this should
2991        # be fused into a single kernel by the scheduler.
2992        # NOTE: We cannot change src's layout to mutate dst directly as this
2993        # would alias src to dst, which is not correct as further mutations to
2994        # dst would effect users of src. However if there are no more users of
2995        # dst, we can alias src to dst.
2996        src.realize_hint()
2997
2998        if not unsafe_alias:
2999            src = Pointwise.create(
3000                device=src.get_device(),
3001                dtype=src.get_dtype(),
3002                inner_fn=src.make_loader(),
3003                ranges=[
3004                    V.graph.sizevars.guard_equals(a, b)
3005                    for a, b in zip(src.get_size(), dst.get_size())
3006                ],
3007            ).data
3008
3009        src.realize()
3010        assert isinstance(src.data.layout, FlexibleLayout)
3011        src.data.layout = MutationLayoutSHOULDREMOVE(dst)
3012        return src.data
3013
3014    def as_fixed(self):
3015        return self
3016
3017    def make_indexer(self):
3018        return self.target.make_indexer()
3019
3020
3021@dataclasses.dataclass
3022class Buffer(IRNode):
3023    # Name is sometimes None; e.g., ForceInPlace, where there isn't
3024    # a meaningful name
3025    name: Optional[str]
3026    layout: Layout
3027
3028    # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly,
3029    # MultiOutput does NOT define this!
3030
3031    def __post_init__(self):
3032        super().__post_init__()
3033        self.origin_node = None
3034
3035    def make_indexer(self):
3036        return self.layout.make_indexer()
3037
3038    def get_name(self) -> str:
3039        assert self.name, self
3040        return self.name
3041
3042    def get_device(self):
3043        return self.layout.device
3044
3045    def get_origin_node(self):
3046        return self.origin_node
3047
3048    @property
3049    def dtype(self):
3050        return getattr(self.layout, "dtype", None)
3051
3052    def get_size(self):
3053        return list(self.layout.size)
3054
3055    def get_stride(self):
3056        return list(self.layout.stride)
3057
3058    def get_offset(self):
3059        return self.layout.offset
3060
3061    def get_layout(self):
3062        return self.layout
3063
3064    def get_storage_numel(self):
3065        return self.get_numel()
3066
3067    def is_extern(self):
3068        return False
3069
3070    def freeze_layout(self):
3071        if not isinstance(self.layout, (MultiOutputLayout, NonOwningLayout)):
3072            self.layout = self.layout.as_fixed()
3073
3074    def freeze_layout_with_stride_order(self, order, allow_padding=False):
3075        assert isinstance(self.layout, FlexibleLayout)
3076        self.layout = self.layout.as_stride_order(order, allow_padding=allow_padding)
3077
3078    def freeze_layout_with_fill_order(self, order):
3079        assert isinstance(self.layout, FlexibleLayout)
3080        self.layout = self.layout.as_fill_order(order)
3081
3082    def freeze_layout_with_same_order(self, stride):
3083        assert isinstance(self.layout, FlexibleLayout)
3084        self.layout = self.layout.as_same_order(stride)
3085
3086    def is_zero_elements(self):
3087        return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))  # type: ignore[arg-type]
3088
3089    def make_loader(self):
3090        # Loading from a zero-element buffer is a no-op
3091        if self.is_zero_elements():
3092            return partial(nop_loader_fn, dtype=self.get_dtype())
3093
3094        def loader(index):
3095            indexer = self.layout.make_indexer()
3096            return ops.load(self.name, indexer(index))
3097
3098        return loader
3099
3100    def is_no_op(self):
3101        return False
3102
3103    def codegen_reference(self, writer=None):
3104        return self.get_name()
3105
3106    def decide_layout(self):
3107        pass
3108
3109    def get_inputs_that_alias_output(self):
3110        if isinstance(self.layout, NonOwningLayout):
3111            return [self.layout.view.get_name()]
3112        return ()
3113
3114    def get_mutation_names(self):
3115        if isinstance(self.layout, MutationLayoutSHOULDREMOVE):
3116            return [self.layout.target.get_name()]
3117        return ()
3118
3119    def get_read_writes(self):
3120        with patch.object(FlexibleLayout, "allow_indexing", True):
3121            return extract_read_writes(
3122                self.make_loader(),
3123                self.get_size(),
3124            )
3125
3126    def get_reads(self):
3127        return self.get_read_writes().reads
3128
3129    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
3130        return set()
3131
3132    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
3133        """
3134        Returns the unbacked symbols which are required to be in scope in
3135        order to successfully perform codegen for this buffer.  For example,
3136        a buffer that corresponds to an extern kernel call that takes i0 as
3137        an argument would return {i0} here.  This is used to generate necessary
3138        dependencies that ensure we actually bind i0 in codegen before you
3139        try to use it.
3140
3141        Note that this is NOT transitive; in particular, if this buffer takes
3142        in as input another buffer with dynamic shape (e.g., (i0,)), we will
3143        not report it here, because you will already have a dependency
3144        on that buffer, which will eventually have a dependency on i0 if
3145        necessary.
3146        """
3147        return set()
3148
3149    def realize(self):
3150        pass
3151
3152    def get_workspace_size(self):
3153        """
3154        Gets extra global memory size needed by this buffer.
3155        Some algorithms (e.g. group gemm) may require extra global memory in the generated code.
3156        """
3157        return 0
3158
3159    def should_allocate(self):
3160        # Returns False by default.
3161        return False
3162
3163
3164class InputBuffer(Buffer):
3165    pass
3166
3167
3168class ConstantBuffer(InputBuffer):
3169    override_device: Optional[torch.device] = None
3170
3171    def make_loader(self):
3172        def loader(index):
3173            indexer = self.layout.make_indexer()
3174            return ops.load(
3175                V.graph.constant_name(self.get_name(), self.override_device),
3176                indexer(index),
3177            )
3178
3179        return loader
3180
3181    def constant_to_device(self, device):
3182        return ConstantBuffer(
3183            V.graph.constant_name(self.get_name(), device), self.layout
3184        )
3185
3186
3187class NoneAsConstantBuffer(IRNode):
3188    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
3189        return set()
3190
3191    def codegen_reference(self, writer=None):
3192        return V.graph.wrapper_code.none_str
3193
3194
3195class ShapeAsConstantBuffer(IRNode):
3196    def __init__(self, shape):
3197        super().__init__()
3198        self.shape = shape
3199
3200    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
3201        return free_unbacked_symbols(self.shape)
3202
3203    def codegen_reference(self, writer=None):
3204        return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.shape))
3205
3206
3207@dataclasses.dataclass
3208class ComputedBuffer(Buffer):
3209    data: Loops
3210
3211    def get_computed_buffer_name(self):
3212        """
3213        Returns self.name if it exists, otherwise returns the name of the data node if that exists.
3214        If neither exist, returns None.
3215        """
3216        if self.name is not None:
3217            return self.name
3218        if hasattr(self.data, "name"):
3219            return self.data.name
3220        return None
3221
3222    @cache_on_self
3223    def num_reads(self):
3224        return len(self.get_read_writes().reads)
3225
3226    def get_read_writes(self):
3227        with patch.object(FlexibleLayout, "allow_indexing", True):
3228            if self.data.get_reduction_type():
3229                return extract_read_writes(
3230                    self.get_store_function(),
3231                    self.data.get_pointwise_size(),
3232                    self.data.get_reduction_size(),
3233                )
3234            else:
3235                return extract_read_writes(
3236                    self.get_store_function(),
3237                    self.data.get_size(),
3238                )
3239
3240    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
3241        # Ordinarily, we'd like to just peek at the arguments list,
3242        # but ComputedBuffers have no argument list.
3243        #
3244        # Morally, this logic needs to be synchronized with the
3245        # KernelArgs.size calls, which are responsible for making symbols make
3246        # there way as kernel arguments (and it is precisely passing in one of
3247        # those symbols that establishes a dependency).  However, we haven't
3248        # started codegen yet so we can't directly reuse that logic.
3249        #
3250        # For now, I'm just yoloing with the size of the buffer.  Not sure if
3251        # it is enough.
3252        #
3253        # One thing you might wonder is if this is enough for a ComputedBuffer
3254        # denoting a reduction over i0.  Empirically, it is enough, but for an
3255        # unusual reason: we only need accurate dependencies for item() call,
3256        # but it's impossible to end up with a reduction over i0 from an
3257        # item() call without a regular non-reduction buffer first.
3258        return (
3259            free_unbacked_symbols(self.get_size())
3260            | free_unbacked_symbols(self.get_stride())
3261            | free_unbacked_symbols(self.get_offset())
3262            | self.data.get_unbacked_symbol_uses()
3263        )
3264
3265    def make_loader(self):
3266        # Inline constants and index_expressions
3267        if (
3268            hasattr(self.data, "make_loader")
3269            and self.name not in V.graph.mutated_buffers
3270            and self.num_reads() == 0
3271        ):
3272            # can be inlined
3273            return self.data.make_loader()
3274        return super().make_loader()
3275
3276    def get_store_function(self):
3277        indexer = self.layout.as_fixed().make_indexer()
3278        if isinstance(self.data, (Reduction, Scan)):
3279            return partial(self.data.store_reduction, self.name, indexer)
3280        else:
3281            assert isinstance(self.data, Pointwise)
3282            return partial(self.data.store_output, self.name, indexer)
3283
3284    def get_fill_order(self):
3285        """
3286        If our layout is still flexible, try to determine the stride order based on stride orders of reads.
3287
3288        TODO(jansel): A better algorithm here would look at downstream consumers of this
3289                      value and try to do global graph-level layout optimization.
3290                      This is also something just begging to be autotuned.
3291        """
3292        if isinstance(self.layout, FlexibleLayout):
3293            (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(
3294                self.data.get_pointwise_size(), self.data.get_reduction_size()
3295            )
3296            reads = self.get_read_writes().reads
3297            reads_bufs = [
3298                V.graph.name_to_buffer[r.name]
3299                if r.name in V.graph.name_to_buffer.keys()
3300                else None
3301                for r in reads
3302            ]
3303            # only consider reads to buffer of same size
3304            # ignore StarDeps because they don't contribute stride information
3305            assert all(
3306                isinstance(r, (dependencies.StarDep, dependencies.MemoryDep))
3307                for r in reads
3308            )
3309            reads = [
3310                sympy_subs(
3311                    r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0}
3312                )
3313                for r in reads
3314                if isinstance(r, dependencies.MemoryDep)
3315            ]
3316
3317            if reads:
3318                if isinstance(self.data, Scan):
3319                    indices = self.data.reindex(index_vars, reduction_vars)
3320                else:
3321                    indices = index_vars
3322                stride_lengths = [
3323                    V.graph.sizevars.stride_hints(expr, indices) for expr in reads  # type: ignore[arg-type]
3324                ]
3325                from .scheduler import pick_loop_order
3326
3327                return pick_loop_order(stride_lengths, self.get_size())
3328
3329        return None
3330
3331    def decide_layout(self):
3332        if isinstance(self.layout, FlexibleLayout):
3333            order = self.get_fill_order()
3334            if order:
3335                self.freeze_layout_with_fill_order(order)
3336            else:
3337                self.freeze_layout()
3338
3339    @cache_on_self
3340    def get_default_sizes_body(self):
3341        args, var_ranges = dependencies.index_vars_squeeze(
3342            self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q"
3343        )
3344        with patch.object(ConstantBuffer, "override_device", self.get_device()):
3345            body = LoopBody(
3346                self.get_store_function(),
3347                (args if self.get_reduction_type() else args[:1]),
3348                var_ranges,
3349            )
3350        index_vars = []
3351        reduce_vars: List[Any] = []
3352        index_size = []
3353        reduce_size = []
3354        for v, s in var_ranges.items():
3355            if v in args[0]:
3356                assert not reduce_vars
3357                index_vars.append(v)
3358                index_size.append(s)
3359            else:
3360                assert v in args[1]
3361                reduce_vars.append(v)
3362                reduce_size.append(s)
3363        return (index_size, reduce_size), body, (index_vars, reduce_vars)
3364
3365    def simplify_and_reorder(
3366        self,
3367        extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None,
3368    ):
3369        """
3370        This is a main place where we do loop transformations in a
3371        backend-agnostic way.
3372
3373        Here we:
3374            1) Remove any 1 dimensions
3375            2) Fuse contiguous dimensions together
3376            3) Reorder dimensions based on stride orders
3377
3378        Optional argument extra_indexing_constraints can be used to append additional
3379        indexing expressions to existing ones derived from buffer's body. This can be useful
3380        to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...)
3381        on CPU by preventing indexing simplifications and obtaining index/reduce ranges for
3382        the scheduler node compatible with other nodes.
3383        """
3384        (
3385            (index_size, reduce_size),
3386            body,
3387            (index_vars, reduce_vars),
3388        ) = self.get_default_sizes_body()
3389
3390        index_formulas = [*body.indexing_exprs.values()]
3391        if extra_indexing_constraints is not None:
3392            assert (
3393                isinstance(extra_indexing_constraints, tuple)
3394                and len(extra_indexing_constraints) == 2
3395            )
3396            extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints
3397            assert isinstance(extra_indexing_ranges, dict)
3398            assert isinstance(extra_indexing_expr, list)
3399            assert all(isinstance(f, Expr) for f in extra_indexing_expr)
3400
3401            expected_var_ranges = body.var_ranges
3402            assert expected_var_ranges == extra_indexing_ranges, (
3403                expected_var_ranges,
3404                extra_indexing_ranges,
3405            )
3406            # remove already existing expressions
3407            extra_indexing_expr = [
3408                e for e in extra_indexing_expr if e not in index_formulas
3409            ]
3410            index_formulas += extra_indexing_expr
3411
3412        reads_bufs = [
3413            V.graph.name_to_buffer[reads_name]
3414            if reads_name in V.graph.name_to_buffer.keys()
3415            else None
3416            for reads_name in body.reads_name2expr.keys()
3417        ]
3418        memory_addrs = [
3419            *body.reads_name2expr.values(),
3420            *body.writes_name2expr.values(),
3421        ]
3422
3423        def simplify_and_reorder(x_vars, support_vars, sizes):
3424            sizes, reindex0, reindex1 = self._apply_loop_reordering(
3425                x_vars, support_vars, sizes, memory_addrs
3426            )
3427            # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1]
3428            x_vars = reindex0(x_vars)
3429            sizes, reindex2, prune = V.graph.sizevars._simplify_loops(
3430                x_vars,
3431                sizes,
3432                index_prevent_reordering(index_formulas, x_vars, sizes),
3433            )
3434            x_vars = prune(x_vars)
3435            # sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas)
3436            # x_vars = prune(x_vars)
3437            # sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs)
3438            reindex = fuse_reindexing(reindex1, reindex2)
3439            return sizes, reindex, reindex1
3440
3441        support_vars = index_vars + reduce_vars
3442        iter_ranges, iter_reindex, _ = simplify_and_reorder(
3443            index_vars,
3444            support_vars,
3445            index_size,
3446        )
3447        reduce_ranges, reduce_reindex, _ = simplify_and_reorder(
3448            reduce_vars, support_vars, reduce_size
3449        )
3450
3451        # retrace the loop body with simplification and reordering applied
3452        (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
3453            iter_ranges, reduce_ranges, prefix="z"
3454        )
3455        body = LoopBody(
3456            body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges
3457        )
3458        return (iter_ranges, reduce_ranges), body
3459
3460    @staticmethod
3461    def _apply_loop_reordering(
3462        index_vars,
3463        support_vars,
3464        sizes,
3465        memory_addrs,
3466        priority_idx=None,
3467    ):
3468        """
3469        Shuffle the order of loops around to hopefully improve performance.
3470        """
3471        from .scheduler import pick_loop_order
3472
3473        if priority_idx is None:
3474            priority_idx = []
3475
3476        try:
3477            strides = [
3478                V.graph.sizevars.stride_hints(expr, index_vars, support_vars)
3479                for expr in memory_addrs
3480            ]
3481            assert len(strides) == len(memory_addrs) and len(strides[0]) == len(
3482                index_vars
3483            )
3484            order = list(reversed(pick_loop_order(strides, sizes, priority_idx)))
3485        except Exception:
3486            if config.debug:
3487                log.warning(
3488                    "Did not simplify complex index:\n%s\n%s",
3489                    dict(zip(index_vars, sizes)),
3490                    memory_addrs,
3491                )
3492            order = list(range(len(sizes)))
3493        sizes = [sizes[i] for i in order]
3494        return sizes, same_reorder(order), inverse_reorder(order)
3495
3496    def get_reduction_size(self):
3497        return self.data.get_reduction_size()
3498
3499    def get_reduction_type(self):
3500        return self.data.get_reduction_type()
3501
3502    def is_no_op(self):
3503        return self.data.is_zero_elements()
3504
3505    def should_allocate(self):
3506        return True
3507
3508    def constant_to_device(self, device):
3509        """Move this to a given device. Requires that all reads are to constants."""
3510        return self.data.constant_to_device(device)
3511
3512
3513class TemplateBuffer(Buffer):
3514    """
3515    Represents a Triton (in the future other type) of template operator
3516    that we can fuse an epilogue onto.
3517    """
3518
3519    def __init__(self, layout, inputs, make_kernel_render):
3520        super().__init__(name=None, layout=layout)
3521        self.inputs = InputsKernel.unwrap_storage(inputs)
3522        self.make_kernel_render = make_kernel_render
3523        self.name = V.graph.register_buffer(self)
3524
3525    def get_read_writes(self):
3526        return self.normalized_read_writes()
3527
3528    def normalized_read_writes(self):
3529        name = self.get_name()
3530        indexer = self.layout.make_indexer()
3531
3532        def dummy(index, rindex):
3533            assert len(rindex) == 0
3534            return ops.store(name, indexer(index), "fake")
3535
3536        deps = dependencies.extract_read_writes(
3537            dummy, self.get_size(), (), normalize=True
3538        )
3539        deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
3540        return deps
3541
3542    def get_reduction_size(self):
3543        return 1
3544
3545    def get_reduction_type(self):
3546        return None
3547
3548    def is_no_op(self):
3549        return False
3550
3551    def should_allocate(self):
3552        return True
3553
3554    def simplify_and_reorder(
3555        self,
3556        extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None,
3557    ):
3558        return (
3559            (
3560                self.get_size(),
3561                (),
3562            ),
3563            None,
3564        )
3565
3566
3567class TritonTemplateBuffer(TemplateBuffer):
3568    def __init__(
3569        self,
3570        layout,
3571        inputs,
3572        make_kernel_render,
3573        debug_extra=None,
3574        mutated_inputs: Optional[Iterable[IRNode]] = None,
3575    ):
3576        """
3577        NOTE:[TritonTemplates with multiple outputs]
3578        We want the ability for TritonTemplates to output multiple tensors. Triton
3579        kernels have no notion of outputs and this is done by creating tensors that
3580        are then mutated by the kernel. Currenlty our STORE_OUTPUT codegen doesn't
3581        support creating multinode outputs for triton templates.
3582        We work around this by creating an extra input buffer during the lowering
3583        and we mark them as mutated inputs.
3584        """
3585        super().__init__(layout, inputs, make_kernel_render)
3586        self.debug_extra = debug_extra
3587        self.mutated_inputs = mutated_inputs
3588        if mutated_inputs is not None:
3589            # Ensure that the mutated inputs are only allowed for certain nodes
3590            allowed_set = {
3591                torch.ops.higher_order.flex_attention,
3592                torch.ops.higher_order.flex_attention_backward,
3593            }
3594            current_node = V.graph.current_node.target
3595            assert (
3596                current_node in allowed_set
3597            ), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
3598            mark_node_as_mutating(self, *mutated_inputs)
3599
3600    def __str__(self):
3601        out = f"TritonTemplateBuffer(layout={self.layout}, {self.debug_extra})"
3602        return out
3603
3604
3605PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]]
3606
3607
3608class ChoiceCaller:
3609    """
3610    Represents a possible choice used in autotune_process.py.
3611    During autotuning, self.benchmark() is first called to get benchmark result,
3612    and if this choice is selected, self.output_node() is called to get the output_node.
3613
3614    Children classes: TritonTemplateCaller, CUDATemplateCaller.
3615    """
3616
3617    def __init__(self, name, input_nodes, layout):
3618        super().__init__()
3619        self.name = name
3620        self.layout = layout
3621        self.input_nodes = input_nodes
3622
3623    def benchmark(self, *args, out) -> float:
3624        algo = self.to_callable()
3625        return do_bench(algo, args, {"out": out})
3626
3627    def call_name(self) -> str:
3628        raise NotImplementedError
3629
3630    def to_callable(self):
3631        raise NotImplementedError
3632
3633    def hash_key(self) -> str:
3634        raise NotImplementedError
3635
3636    def output_node(self) -> "TensorBox":
3637        raise NotImplementedError
3638
3639    def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
3640        """Information returned here is logged to the autotune log file when that is enabled."""
3641        return {}
3642
3643
3644class TritonTemplateCallerBase(ChoiceCaller):
3645    def get_make_kernel_render(self) -> Any:
3646        raise NotImplementedError
3647
3648
3649class MultiTemplateBuffer(TritonTemplateBuffer):
3650    """
3651    Represents a Buffer with multiple backing implementation choices.
3652
3653    Choices can be TritonTemplates or ExternKernels. During scheduling if there is a potential
3654    epilogue we will benchmark each of the choices with the epilogue to determine an implementation.
3655    Otherwise, the fastest base choice will be chosen.
3656    """
3657
3658    def __init__(
3659        self,
3660        layout: Layout,
3661        inputs: List[IRNode],
3662        choice_timings: Callable[[], Dict[ChoiceCaller, float]],
3663    ):
3664        super().__init__(layout=layout, inputs=inputs, make_kernel_render=None)
3665        self._choice_timings_fn = choice_timings
3666        self._choice_timings: Optional[Dict[ChoiceCaller, float]] = None
3667        self.original_inputs = inputs
3668
3669    @property
3670    def choice_timings(self) -> Dict[ChoiceCaller, float]:
3671        if self._choice_timings is None:
3672            self._choice_timings = self._choice_timings_fn()
3673        return self._choice_timings
3674
3675    @contextlib.contextmanager
3676    def swap_as_triton_caller(self, caller: TritonTemplateCallerBase):
3677        assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller)
3678        assert self.layout == caller.layout
3679
3680        render = self.make_kernel_render
3681        self.make_kernel_render = caller.get_make_kernel_render()
3682        try:
3683            yield
3684        finally:
3685            self.make_kernel_render = render
3686
3687    def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase):
3688        assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller)
3689        assert self.layout.size == caller.layout.size
3690        assert self.layout.stride == caller.layout.stride
3691        self.make_kernel_render = caller.get_make_kernel_render()
3692
3693    def get_min_choice(self) -> Tuple[ChoiceCaller, float]:
3694        min_choice = min(self.choice_timings, key=self.choice_timings.get)  # type: ignore[arg-type]
3695        return (min_choice, self.choice_timings[min_choice])
3696
3697
3698class CUDATemplateBuffer(TemplateBuffer):
3699    def __init__(
3700        self,
3701        layout,
3702        inputs,
3703        make_kernel_render,
3704        workspace_size: int,
3705        template: "CUDATemplate",  # type: ignore[name-defined]  # noqa: F821
3706    ):
3707        super().__init__(layout, inputs, make_kernel_render)
3708        # Global memory (in bytes) needed for this template.
3709        self.workspace_size = workspace_size
3710        self.template = template
3711
3712    def get_workspace_size(self):
3713        return self.workspace_size if self.workspace_size is not None else 0
3714
3715
3716class CppTemplateBuffer(TemplateBuffer):
3717    def __init__(self, layout, inputs, make_kernel_render, template, choice):
3718        super().__init__(layout, inputs, make_kernel_render)
3719        self.template = template
3720        self.choice = choice
3721
3722
3723@dataclasses.dataclass
3724class InputsKernel(Buffer):
3725    inputs: List[Buffer]
3726
3727    def get_read_writes_input(self, x):
3728        return dependencies.StarDep(x.get_name())
3729
3730    def get_read_writes(self):
3731        star_dep = []
3732        for input in self.inputs:
3733            if isinstance(input, list):
3734                star_dep.extend([self.get_read_writes_input(x) for x in input])
3735            else:
3736                star_dep.append(self.get_read_writes_input(input))
3737
3738        return dependencies.ReadWrites(
3739            set(star_dep),
3740            {dependencies.StarDep(self.get_name())},
3741            set(),
3742            [],
3743            None,
3744            op_counts=collections.Counter(),
3745        )
3746
3747    @classmethod
3748    def unwrap_storage_for_input(cls, x):
3749        if isinstance(x, TensorBox):
3750            x = x.data
3751        if isinstance(x, StorageBox):
3752            x = x.data
3753        if isinstance(x, BaseView) and not isinstance(x, ReinterpretView):
3754            x = ExternKernel.realize_input(x)
3755        if isinstance(x, TensorBox):
3756            # when converting to ReinterpretView fails in the
3757            # realize_input call above, the result will be wrapped
3758            # into TensorBox / StorageBox pair as a result of the
3759            # cls.copy_input call; so we should unwrap recursively
3760            return cls.unwrap_storage_for_input(x)
3761        if isinstance(x, TorchBindObject):
3762            return x
3763        assert isinstance(x, (Buffer, ReinterpretView)), x
3764        return x
3765
3766    @staticmethod
3767    def unwrap_storage(inputs):
3768        inputs_new = []
3769        for x in inputs:
3770            if isinstance(x, list):
3771                x = [InputsKernel.unwrap_storage_for_input(i) for i in x]
3772            else:
3773                x = InputsKernel.unwrap_storage_for_input(x)
3774            inputs_new.append(x)
3775        return inputs_new
3776
3777    def is_extern(self):
3778        return True
3779
3780
3781class NopKernel(InputsKernel):
3782    def is_no_op(self):
3783        return True
3784
3785
3786class ConcatKernel(NopKernel):
3787    """
3788    There isn't actually a real kernel for concat, we just change the
3789    storage for the upstream data.
3790    """
3791
3792    @classmethod
3793    def create(cls, inputs, dim):
3794        device = inputs[0].get_device()
3795        dtype = inputs[0].get_dtype()
3796        new_size = list(inputs[0].get_size())
3797        offsets_start = [0]
3798        offsets_end = [new_size[dim]]
3799        assert 0 <= dim < len(new_size)
3800        for i in range(1, len(inputs)):
3801            input_size = inputs[i].get_size()
3802            offsets_start.append(new_size[dim])
3803            assert len(input_size) == len(new_size)
3804            assert inputs[i].get_dtype() == dtype
3805            assert inputs[i].get_device() == device
3806            for j in range(len(new_size)):
3807                if j == dim:
3808                    new_size[j] = new_size[j] + input_size[j]
3809                else:
3810                    new_size[j] = V.graph.sizevars.guard_equals(
3811                        new_size[j], input_size[j]
3812                    )
3813            offsets_end.append(new_size[dim])
3814
3815        output_stride = FlexibleLayout.contiguous_strides(new_size)
3816        # If any of the inputs is in CL format, use CL format for the output
3817        for i in range(len(inputs)):
3818            x = inputs[i]
3819            if is_storage_and_layout(x):
3820                layout = x.get_layout()
3821                if isinstance(
3822                    layout, FixedLayout
3823                ) and Layout.is_channels_last_contiguous(layout.size, layout.stride):
3824                    # use CL stride for the output
3825                    output_stride = make_channels_last_strides_for(new_size)
3826                    break
3827        any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs)
3828        fx_node_args = V.graph.current_node.args[0]
3829        assert isinstance(fx_node_args, list)
3830        # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output
3831        if any_input_is_storage_and_layout is False and any(
3832            "val" in arg.meta
3833            and (
3834                arg.meta["val"].is_contiguous(memory_format=torch.channels_last)
3835                or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d)
3836            )
3837            for arg in fx_node_args
3838        ):
3839            output_stride = make_channels_last_strides_for(new_size)
3840
3841        concat_kernel = ConcatKernel(
3842            name=None,
3843            layout=FixedLayout(
3844                device=device,
3845                dtype=dtype,
3846                size=new_size,
3847                stride=output_stride,
3848            ),
3849            inputs=[],
3850        )
3851        kernel = StorageBox(concat_kernel)
3852        buffer_names = []
3853        for i in range(len(inputs)):
3854            input_buffer = cls.realize_into(
3855                inputs[i],
3856                SliceView.create(
3857                    kernel, dim, offsets_start[i], offsets_end[i], clamp=False
3858                ),
3859            )
3860            concat_kernel.inputs.append(input_buffer)
3861
3862            if isinstance(inputs[i].data, BaseView):
3863                input_unwrapped = inputs[i].data.unwrap_view()
3864            else:
3865                input_unwrapped = inputs[i].data
3866
3867            if (
3868                input_unwrapped.is_input_buffer()
3869                and is_gpu(inputs[i].get_device().type)
3870                and not is_dynamic(input_buffer)
3871            ):
3872                buffer_names.append(input_buffer.get_name())
3873
3874        if len(buffer_names) > 1:
3875            V.graph.register_list(buffer_names)
3876
3877        concat_kernel.name = V.graph.register_buffer(concat_kernel)
3878        concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs)
3879
3880        return kernel
3881
3882    @classmethod
3883    def can_realize_into_without_copy(cls, src):
3884        if isinstance(src, TensorBox):
3885            # unwrap a TensorBox
3886            return cls.can_realize_into_without_copy(src.data)
3887
3888        return isinstance(src.data.layout, FlexibleLayout) and not isinstance(
3889            src.data, ExternKernelAlloc
3890        )
3891
3892    @classmethod
3893    def realize_into(cls, src, dst):
3894        # Attempt to turn this into a ReinterpretView rather than assert.
3895        # This has concessions around layout, as as_storage_and_layout
3896        # can cause us to go from flexible to fixed layout.
3897        if not isinstance(dst, ReinterpretView):
3898            if is_storage_and_layout(dst):
3899                storage, layout = as_storage_and_layout(dst)
3900                dst = ReinterpretView(storage, layout)
3901        assert isinstance(dst, ReinterpretView), dst
3902        if isinstance(src, TensorBox):
3903            # unwrap a TensorBox
3904            return cls.realize_into(src.data, dst)
3905        if isinstance(src, StorageBox):
3906            src.realize()
3907            # ExternKernelAlloc has specific requirements for output layout, should create a copy
3908            assert hasattr(src.data, "layout")
3909            if cls.can_realize_into_without_copy(src):
3910                src.data.layout = NonOwningLayout(dst)
3911                return src.data
3912        # introduce a copy
3913        pw = Pointwise.create(
3914            device=src.get_device(),
3915            dtype=src.get_dtype(),
3916            inner_fn=src.make_loader(),
3917            ranges=[
3918                V.graph.sizevars.guard_equals(a, b)
3919                for a, b in zip(src.get_size(), dst.get_size())
3920            ],
3921        )
3922        return cls.realize_into(pw, dst)
3923
3924    def should_allocate(self):
3925        return True
3926
3927
3928def get_aten_cpp_kernel_name(kernel):
3929    # Calling with the default kernel name can lead to ambiguous behavior like the following example.
3930    # repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
3931    # repeat_interleave(const at::Tensor & self, int64_t repeats,
3932    #       c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
3933    if not isinstance(kernel, torch._ops.OpOverload) or kernel.namespace != "aten":
3934        return None
3935    opname = (
3936        kernel.__name__.split(".")[0]
3937        if kernel._overloadname == "default"
3938        else kernel.__name__.replace(".", "_")
3939    )
3940    return f"at::_ops::{opname}::call"
3941
3942
3943@dataclasses.dataclass
3944class ExternKernel(InputsKernel):
3945    constant_args: Tuple[Any, ...] = ()
3946    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
3947    output_view: Optional[ReinterpretView] = None
3948    python_kernel_name: Optional[str] = None
3949    cpp_kernel_name: Optional[str] = None
3950    # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel
3951    # We shouldn't need to do this since the information can be retrieved from op_overload._schema.
3952    ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field(
3953        default_factory=list
3954    )
3955    op_overload: Optional[
3956        Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]
3957    ] = None
3958    arg_properties: Optional[List[Dict[str, Any]]] = None
3959    kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None
3960    unbacked_bindings: Dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field(
3961        default_factory=dict
3962    )
3963
3964    def __init__(
3965        self,
3966        name,
3967        layout,
3968        inputs,
3969        constant_args=(),
3970        kwargs=None,
3971        output_view=None,
3972        python_kernel_name=None,
3973        cpp_kernel_name=None,
3974        ordered_kwargs_for_cpp_kernel=(),
3975        op_overload=None,
3976    ):
3977        super().__init__(
3978            name,
3979            layout,
3980            inputs,
3981        )
3982        self.constant_args = constant_args
3983        self.kwargs = kwargs if kwargs else {}
3984        self.output_view = output_view
3985        self.python_kernel_name = python_kernel_name
3986        # If cpp_kernel_name is None, we will try to construct it from op_overload
3987        self.cpp_kernel_name = cpp_kernel_name or get_aten_cpp_kernel_name(op_overload)
3988        self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
3989        self.op_overload = op_overload
3990        self.collect_arg_kwarg_properties()
3991        self.unbacked_bindings = {}
3992        self.fx_node = V.graph.current_node
3993
3994    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
3995        return set()
3996
3997    def collect_arg_kwarg_properties(self):
3998        # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional
3999        # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen
4000        self.arg_properties = (
4001            [
4002                {
4003                    "name": x.name,
4004                    "type": x.real_type,
4005                    "default_value": x.default_value,
4006                }
4007                for x in self.op_overload._schema.arguments
4008                if not x.kwarg_only
4009            ]
4010            if isinstance(self.op_overload, torch._ops.OpOverload)
4011            else [{} for i in range(len(self.inputs))]
4012        )
4013        self.allarg_properties = (
4014            {
4015                x.name: {"type": x.real_type, "default_value": x.default_value}
4016                for x in self.op_overload._schema.arguments
4017            }
4018            if isinstance(self.op_overload, torch._ops.OpOverload)
4019            else {}
4020        )
4021        # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes
4022        # ordered_kwargs_for_cpp_kernel is explicilty passed in.
4023        if (
4024            isinstance(self.op_overload, torch._ops.OpOverload)
4025            and not self.ordered_kwargs_for_cpp_kernel
4026        ):
4027            self.ordered_kwargs_for_cpp_kernel = [
4028                x.name for x in self.op_overload._schema.arguments if x.kwarg_only
4029            ]
4030
4031    def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False):
4032        # Previously, we want to maintain forward-compatibility by skipping
4033        # default args in the serialized artifacts in fbcode. However,
4034        # some of our shim interfaces require default values being set.
4035        # Discussed with Sherlock offline and we decided to allow serializing
4036        # default args into the C++ wrapper code for now. We will refine this
4037        # part if we see real FC requirement. More details related to FC
4038        # can be found at:
4039        # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing
4040        assert isinstance(args, (list, tuple))
4041        if isinstance(args, tuple):
4042            args = list(args)
4043        assert self.arg_properties, "ExternKernel.arg_properties should not be empty"
4044
4045        n_args = len(args)
4046        n_pos_args = len(self.arg_properties)
4047        # For cpp wrapper, if some positional args are not provided, we need to check
4048        # if they're in the kwargs or use their default value
4049        if n_args < n_pos_args:
4050            log.debug(
4051                "%s has %d unprovided positional arguments. "
4052                "Will check if they are in the keyword arguments or will use default values.",
4053                self.op_overload,
4054                n_pos_args - n_args,
4055            )
4056            for i in range(n_args, n_pos_args):
4057                arg_name = self.arg_properties[i]["name"]
4058                args.append(
4059                    kwargs[arg_name]
4060                    if arg_name in kwargs
4061                    else self.arg_properties[i]["default_value"]
4062                )
4063        return args
4064
4065    def decide_layout(self):
4066        if isinstance(self.layout, FlexibleLayout):
4067            self.apply_constraint()
4068            self.freeze_layout()
4069
4070    def codegen_comment(self, wrapper):
4071        origin_str, detailed_origin_str = get_kernel_metadata(self, wrapper)
4072        if origin_str:
4073            wrapper.writeline(origin_str)
4074
4075    def codegen(self, wrapper):
4076        raise NotImplementedError
4077
4078    def get_kernel_name(self):
4079        return (
4080            (
4081                V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name)  # type: ignore[attr-defined]
4082                if config.abi_compatible
4083                else self.cpp_kernel_name
4084            )
4085            if V.graph.cpp_wrapper
4086            else self.python_kernel_name
4087        )
4088
4089    @staticmethod
4090    def copy_input(x):
4091        pw = Pointwise.create(
4092            device=x.get_device(),
4093            dtype=x.get_dtype(),
4094            inner_fn=x.make_loader(),
4095            ranges=x.get_size(),
4096            origin_node=x.get_origin_node(),
4097            traceback=x.get_traceback(),
4098        )
4099        pw.realize()
4100        return pw
4101
4102    @classmethod
4103    def process_kernel(
4104        cls, kernel, *args, **kwargs
4105    ) -> Tuple[
4106        Any,
4107        List[Any],
4108        List[Any],
4109        Callable[[Any, Any], Any],
4110        Optional[Dict[sympy.Symbol, pytree.KeyPath]],
4111    ]:
4112        binded_args = {"args": args, "kwargs": kwargs}
4113
4114        args_flat, args_spec = pytree.tree_flatten(binded_args)
4115
4116        is_arg_tensor = []
4117        tensor_args = []
4118        non_tensor_args: List[Any] = []
4119        for arg in args_flat:
4120            is_arg_tensor.append(isinstance(arg, IRNode))
4121            if is_arg_tensor[-1]:
4122                tensor_args.append(arg)
4123            else:
4124                if isinstance(arg, sympy.Expr):
4125                    arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
4126                non_tensor_args.append(arg)
4127
4128        def unflatten_args(new_tensor_args, new_non_tensor_args):
4129            result = []
4130            it_tensors = iter(new_tensor_args)
4131            it_non_tensors = iter(new_non_tensor_args)
4132            for is_tensor in is_arg_tensor:
4133                if is_tensor:
4134                    result.append(next(it_tensors))
4135                else:
4136                    result.append(next(it_non_tensors))
4137            r = pytree.tree_unflatten(result, args_spec)
4138            return r.get("args", []), r.get("kwargs", {})
4139
4140        tensor_args = [cls.realize_input(x) for x in tensor_args]
4141
4142        # freeze layout otherwise our output stride calculation might
4143        # become incorrect
4144        for x in tensor_args:
4145            if is_storage_and_layout(x):
4146                as_storage_and_layout(x, freeze=True)
4147
4148        # Rerun fake tensor propagation, because Inductor may have changed the
4149        # strides of inputs and we need to determine accurately what the
4150        # output stride will be.
4151        example_args: List[Union[torch.Tensor, torch._C.ScriptObject]] = []
4152
4153        # We need to retain the constant values of fake tensors that we originally
4154        # propagated the graph with, because for some operators running without a
4155        # constant would trigger an error / DataDependentException
4156        for x in tensor_args:
4157            if x.get_name() in V.graph.constants:
4158                example_args.append(V.graph.constants[x.get_name()])
4159            elif x.get_name() in V.graph.torchbind_constants:
4160                example_args.append(V.graph.torchbind_constants[x.get_name()])
4161            else:
4162                example_args.append(ir_node_to_tensor(x, guard_shape=True))
4163
4164        new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
4165        example_output = kernel(*new_args, **new_kwargs)
4166
4167        unbacked_bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]] = None
4168        if shape_env := V.fake_mode.shape_env:
4169            rebind_unbacked(shape_env, V.current_node, example_output)
4170            unbacked_bindings = compute_unbacked_bindings(
4171                shape_env, example_output, V.current_node.meta.get("val")
4172            )
4173
4174        example_out_li = (
4175            [example_output]
4176            if not isinstance(example_output, (list, tuple))
4177            else example_output
4178        )
4179        for t in example_out_li:
4180            if isinstance(t, torch.Tensor) and t.is_sparse:
4181                msg = "sparsity not handled. Please file issue for sparse inference weights."
4182                if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
4183                    msg = f"{msg} Found from : \n {stack_trace}"
4184                V.graph.disable_cudagraphs_reason = msg
4185
4186        return (
4187            example_output,
4188            tensor_args,
4189            non_tensor_args,
4190            unflatten_args,
4191            unbacked_bindings,
4192        )
4193
4194    @classmethod
4195    def convert_to_reinterpret_view(cls, x):
4196        """
4197        In order to pass this to an extern kernel we need a
4198        ReinterpretView not a View.  This allows us to avoid some
4199        unneeded copies.
4200        """
4201        assert isinstance(x, BaseView)
4202        if isinstance(x, ReinterpretView):
4203            return x
4204
4205        # NOTE: Don't use extract_read_writes here as it fails when
4206        # make_loader() inlines the computation
4207        x_unwrap_view = x.unwrap_view()
4208        x_unwrap_view_fx_node = V.graph.get_buffer(
4209            x_unwrap_view.get_name()
4210        ).get_origin_node()
4211        # Prefer channels last format according to how the format is set from eager.
4212        if (
4213            x_unwrap_view_fx_node is not None
4214            and "val" in x_unwrap_view_fx_node.meta
4215            and isinstance(x_unwrap_view.layout, FlexibleLayout)
4216            and (
4217                x_unwrap_view_fx_node.meta["val"].is_contiguous(
4218                    memory_format=torch.channels_last
4219                )
4220                or x_unwrap_view_fx_node.meta["val"].is_contiguous(
4221                    memory_format=torch.channels_last_3d
4222                )
4223            )
4224        ):
4225            x_unwrap_view.freeze_layout_with_same_order(
4226                make_channels_last_strides_for(x_unwrap_view.get_size())
4227            )
4228        else:
4229            x_unwrap_view.freeze_layout()
4230
4231        index_args, var_ranges = dependencies.index_vars_squeeze(
4232            x.get_size(), prefix="r"
4233        )
4234        range_vars = index_args[0]
4235        index = x.make_indexer()(range_vars)
4236
4237        index = V.graph.sizevars.simplify_with_ranges(index, var_ranges)
4238        strides = V.graph.sizevars.stride_vars(index, range_vars)
4239        offset = V.graph.sizevars.offset_var(index, range_vars)
4240        expected = sympy_dot(range_vars, strides) + offset
4241
4242        if index != expected:
4243            log.debug(
4244                "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s",
4245                strides,
4246                offset,
4247                index,
4248            )
4249            raise NotImplementedError
4250
4251        return ReinterpretView(
4252            data=x.data,
4253            layout=FixedLayout(
4254                device=x.get_device(),
4255                dtype=x.get_dtype(),
4256                size=x.get_size(),
4257                stride=strides,
4258                offset=offset,
4259            ),
4260        )
4261
4262    @classmethod
4263    def realize_input(cls, x):
4264        if x is None:
4265            return NoneAsConstantBuffer()
4266        if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)):
4267            return ShapeAsConstantBuffer(x)
4268        if isinstance(x, Constant):
4269            return V.graph.add_tensor_constant(
4270                torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
4271            )
4272        if isinstance(x, ConstantBuffer):
4273            return x
4274        if isinstance(x, TensorBox):
4275            return cls.realize_input(x.data)
4276        if isinstance(x, ReinterpretView):
4277            return ReinterpretView(cls.realize_input(x.data), x.get_layout())
4278        if isinstance(x, BaseView):
4279            x.realize()
4280            if is_storage_and_layout(x.unwrap_view()):
4281                try:
4282                    return cls.convert_to_reinterpret_view(x)
4283                except NotImplementedError:
4284                    pass
4285        if isinstance(x, StorageBox):
4286            # TODO(jansel): impose layout preference on realized buffer
4287            x.realize()
4288            return x
4289        if isinstance(x, TorchBindObject):
4290            return x
4291        return cls.copy_input(x)
4292
4293    @classmethod
4294    def require_stride1(cls, x):
4295        if is_storage_and_layout(x):
4296            if len(x.get_stride()) == 0:
4297                return x
4298            for stride in x.get_stride():
4299                if stride == 1:
4300                    return x
4301        return cls.copy_input(x)
4302
4303    @classmethod
4304    def require_stride_order(cls, x, order, allow_padding=False):
4305        if x.get_numel() == 0:  # Layout doesn't matter
4306            return x
4307
4308        # require x to have the layout as strided_ordered as order
4309        if is_storage_and_layout(x):
4310            while isinstance(x.get_layout(), NonOwningLayout):
4311                x = x.get_layout().view
4312            if isinstance(x.get_layout(), FlexibleLayout):
4313                # If the the FlexibleLayout already has the size and stride in the required order,
4314                # freeze it to a FixedLayout by using its current size and stride.
4315                # The behavior of using its current size and stride or the given order can be different
4316                # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1:
4317                # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last),
4318                # the current size and stride already satisfies this order.
4319                # However by freezing it to the required order, the layout will be changed to:
4320                # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary.
4321
4322                # fix flexiblelayout to be FixedLayout with stride_order
4323                as_storage_and_layout(
4324                    x,
4325                    freeze=True,
4326                    want_contiguous=False,
4327                    stride_order=get_stride_order(
4328                        V.graph.sizevars.size_hints(x.get_layout().stride)
4329                    )
4330                    if is_stride_order_storage_and_layout(x, order)
4331                    else order,
4332                    allow_padding=allow_padding,
4333                )
4334                return x
4335            elif isinstance(
4336                x.get_layout(), FixedLayout
4337            ) and x.get_layout().is_stride_ordered(order):
4338                return x
4339            elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE):
4340                if isinstance(x.get_layout().real_layout(), FlexibleLayout):
4341                    raise AssertionError(
4342                        "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout"
4343                    )
4344                elif isinstance(
4345                    x.get_layout().real_layout(), FixedLayout
4346                ) and x.get_layout().real_layout().is_stride_ordered(order):
4347                    return x
4348
4349        # TODO - Storage to InputBuffer
4350        if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order):
4351            return x
4352        if (
4353            isinstance(x, TensorBox)
4354            and isinstance(x.data, BaseView)
4355            and not isinstance(x.data, ReinterpretView)
4356            and is_storage_and_layout(x.unwrap_view())
4357            and not isinstance(x.unwrap_view().data, ExternKernelAlloc)
4358        ):
4359            try:
4360                x.data = cls.convert_to_reinterpret_view(x.data)
4361                return cls.require_stride_order(x, order, allow_padding=allow_padding)
4362            except NotImplementedError:
4363                pass
4364        x = cls.copy_input(x)
4365        as_storage_and_layout(
4366            x,
4367            freeze=True,
4368            want_contiguous=False,
4369            stride_order=order,
4370            allow_padding=allow_padding,
4371        )
4372        assert is_stride_order_storage_and_layout(x, order)
4373        return x
4374
4375    @classmethod
4376    def require_channels_last(cls, x):
4377        return cls.require_stride_order(x, NHWC_STRIDE_ORDER)
4378
4379    @classmethod
4380    def require_channels_last_3d(cls, x):
4381        return cls.require_stride_order(x, NHWDC_STRIDE_ORDER)
4382
4383    @classmethod
4384    def require_contiguous(cls, x):
4385        return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
4386
4387    def apply_constraint(self):
4388        pass
4389
4390    def codegen_const_args(self):
4391        if V.graph.cpp_wrapper:
4392            result = []
4393            for i, x in enumerate(self.constant_args):
4394                idx = len(self.inputs) + i
4395                type_ = (
4396                    self.arg_properties[i].get("type")
4397                    if self.arg_properties and idx < len(self.arg_properties)
4398                    else None
4399                )
4400                result.append(
4401                    V.graph.wrapper_code.val_to_arg_str(x, type_)  # type: ignore[arg-type]
4402                )
4403            return result
4404        else:
4405            return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args)
4406
4407    def codegen_args(self):
4408        args = []
4409        for i, x in enumerate(self.inputs):
4410            if isinstance(x, list):
4411                names = [i.codegen_reference() for i in x]
4412                codegen_reference = f'[{", ".join(names)}]'
4413                args.append(codegen_reference)
4414            else:
4415                if V.graph.cpp_wrapper:
4416                    assert self.arg_properties and i < len(
4417                        self.arg_properties
4418                    ), "Invalid access to ExternKernel.arg_properties"
4419                    type_ = self.arg_properties[i].get("type")
4420                    args.append(
4421                        V.graph.wrapper_code.val_to_arg_str(  # type: ignore[arg-type]
4422                            x, type_
4423                        )
4424                    )
4425                else:
4426                    args.append(x.codegen_reference())
4427        args.extend(self.codegen_const_args())
4428        return args
4429
4430    def get_kwargs_value(self, arg_name):
4431        if arg_name in self.kwargs:
4432            return self.kwargs.get(arg_name)
4433        if self.allarg_properties and self.allarg_properties.get(arg_name):
4434            return self.allarg_properties.get(arg_name).get("default_value")  # type: ignore[union-attr]
4435        else:
4436            raise AssertionError(f"{arg_name} not in self.allarg_properties")
4437
4438    def codegen_kwargs(self, skip_out=False):
4439        if V.graph.cpp_wrapper:
4440            kwargs = []
4441            for arg_name in self.ordered_kwargs_for_cpp_kernel:
4442                if skip_out and arg_name == "out":
4443                    # ExternKernelOut has its own logic for inserting the out parameter
4444                    continue
4445
4446                v = self.get_kwargs_value(arg_name)
4447                if isinstance(v, sympy.Expr):
4448                    kwargs.append(v)
4449                else:
4450                    type_ = (
4451                        self.allarg_properties.get(arg_name).get("type")  # type: ignore[union-attr]
4452                        if self.allarg_properties and arg_name in self.allarg_properties
4453                        else None
4454                    )
4455                    kwargs.append(
4456                        V.graph.wrapper_code.val_to_arg_str(  # type: ignore[arg-type]
4457                            v, type_
4458                        )
4459                    )
4460        else:
4461            kwargs = [
4462                f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}"  # type: ignore[misc]
4463                for k, v in self.kwargs.items()
4464            ]
4465        return kwargs
4466
4467    def codegen_size_asserts(self, wrapper):
4468        if config.size_asserts and not V.graph.cpp_wrapper:
4469            # comparing strides for 0 size tensor is tricky. Ignore them for now.
4470            if sympy_product(self.get_size()) == 0:
4471                return
4472            size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
4473            stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
4474            wrapper.writeline(
4475                f"assert_size_stride({self.get_name()}, {size}, {stride})"
4476            )
4477
4478    def get_group_stride(self):
4479        """
4480        get output sizes and strides, for template_codegen
4481        """
4482        _size = self.get_size()
4483        _stride = self.get_stride()
4484        # iter_ranges = _size of output tensor, reduce_range = [] because no reduction
4485        return [_size, []], _stride
4486
4487    def canonicalize(self):
4488        """
4489        Manually get canonicalization of the output index
4490        """
4491        # manually generate index formula for conv
4492        sizevars = V.graph.sizevars
4493        sizes = self.get_size()
4494        strides = self.get_stride()
4495        strides = [sizevars.size_hint(x) for x in strides]
4496        # TODO: I can't tell if the symbols here are temporary
4497        index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))]
4498        # reorder index vars according to stride
4499        index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
4500        lookup = {pos: idx for idx, pos in enumerate(index_order)}
4501        order = [lookup[i] for i in range(len(lookup))]
4502        index_vars = [index_vars[i] for i in order]
4503        indexer = self.make_indexer()
4504        index = indexer(index_vars)
4505
4506        new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
4507            index_vars, sizes, [index]
4508        )
4509
4510        # assign new variables each dimension to deal with numbering mismatches
4511        # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
4512        _, add_var = var_builder("c")
4513        replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
4514
4515        index = sympy_subs(sympy.expand(index), replacement)  # type: ignore[arg-type]
4516        return index, tuple(new_sizes)
4517
4518    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
4519        # NB: It's not necessary to check regular inputs as we automatically
4520        # have dependencies on them
4521        r = set()
4522        for arg in self.constant_args:
4523            r |= maybe_free_unbacked_symbols(arg)
4524        for arg in self.kwargs.values():
4525            r |= maybe_free_unbacked_symbols(arg)
4526        return r
4527
4528    def __str__(self):
4529        kernel_name = getattr(self, "python_kernel_name", None)
4530        lines = [
4531            f"python_kernel_name={kernel_name!r}",
4532        ]
4533        lines += [
4534            f"{field.name}={getattr(self, field.name)}"
4535            for field in dataclasses.fields(self)
4536        ]
4537        lines.append(f"origin_node={self.origin_node!r}")
4538        return self.str_helper(lines)
4539
4540    __repr__ = __str__
4541
4542
4543@dataclasses.dataclass
4544class ExternKernelOut(ExternKernel):
4545    def codegen(self, wrapper):
4546        self.codegen_comment(wrapper)
4547        args = [*self.codegen_args(), *self.codegen_kwargs(skip_out=True)]
4548        wrapper.generate_extern_kernel_out(
4549            self.get_kernel_name(),
4550            self.codegen_reference(),
4551            self.output_view.codegen_reference() if self.output_view else None,
4552            args,
4553        )
4554
4555    def __init__(
4556        self,
4557        layout,
4558        inputs,
4559        constant_args=(),
4560        kwargs=None,
4561        output_view=None,
4562        python_kernel_name=None,
4563        cpp_kernel_name=None,
4564        ordered_kwargs_for_cpp_kernel=(),
4565        op_overload=None,
4566    ):
4567        super().__init__(
4568            None,
4569            layout,
4570            self.unwrap_storage(inputs),
4571            constant_args,
4572            kwargs or {},
4573            None,
4574            python_kernel_name,
4575            cpp_kernel_name,
4576            ordered_kwargs_for_cpp_kernel,
4577            op_overload,
4578        )
4579        self.name = V.graph.register_buffer(self)
4580
4581    def should_allocate(self):
4582        return True
4583
4584
4585class RandomSeeds(ExternKernelOut):
4586    def __init__(self, count: int, device: torch.device):
4587        limits = torch.iinfo(torch.int64)
4588        super().__init__(
4589            layout=FixedLayout(
4590                device=device,
4591                dtype=torch.int64,
4592                size=[count],
4593            ),
4594            inputs=[],
4595            constant_args=[limits.min, limits.max, [count]],
4596            python_kernel_name="aten.randint.low_out",
4597            # FIXME: Ideally we should only use at::_ops::randint_low_out::call here,
4598            # but the signature is different from is at::randint_out. Again,
4599            # we can simplify the code when only keeping an ABI-compatible version.
4600            cpp_kernel_name="at::_ops::randint_low_out::call"
4601            if config.abi_compatible
4602            else "at::randint_out",
4603            op_overload=aten.randint.low_out,
4604        )
4605
4606
4607class ExternKernelAlloc(ExternKernel):
4608    def codegen(self, wrapper):
4609        self.codegen_comment(wrapper)
4610        args = [*self.codegen_args(), *self.codegen_kwargs()]
4611        V.graph.wrapper_code.generate_extern_kernel_alloc(self, args)
4612        if isinstance(self.layout, Layout):
4613            self.codegen_size_asserts(wrapper)
4614
4615    def __init__(
4616        self,
4617        layout,
4618        inputs,
4619        constant_args=(),
4620        kwargs=None,
4621        python_kernel_name=None,
4622        cpp_kernel_name=None,
4623        ordered_kwargs_for_cpp_kernel=(),
4624        op_overload=None,
4625    ):
4626        super().__init__(
4627            None,
4628            layout,
4629            self.unwrap_storage(inputs),
4630            constant_args,
4631            kwargs or {},
4632            None,
4633            python_kernel_name,
4634            cpp_kernel_name,
4635            ordered_kwargs_for_cpp_kernel,
4636            op_overload,
4637        )
4638        self.name = V.graph.register_buffer(self)
4639
4640    def should_allocate(self):
4641        return False
4642
4643    def apply_constraint(self):
4644        raise NotImplementedError
4645
4646
4647class UserDefinedTritonKernel(ExternKernel):
4648    def get_kernel_and_configs(self):
4649        from triton.runtime.autotuner import Autotuner
4650
4651        from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
4652
4653        kernel = kernel_side_table.get_kernel(self.kernel_idx)
4654        configs = []
4655        if isinstance(kernel, Autotuner):
4656            configs = kernel.configs
4657            kernel = kernel.fn
4658        return kernel, configs
4659
4660    def codegen(self, wrapper):
4661        kernel, configs = self.get_kernel_and_configs()
4662
4663        # Definition of kernel
4664        new_name, triton_meta = wrapper.define_user_defined_triton_kernel(
4665            kernel, configs, self.kwargs
4666        )
4667
4668        args = self.codegen_kwargs()
4669        arg_types = []
4670        if V.graph.cpp_wrapper:
4671            # in C++ wrapper, we don't pass constexpr args, as they don't
4672            # get added as parameters to the PTX code compiled from the
4673            # user-defined Triton kernel (only non-constexpr args do)
4674            args = [arg for i, arg in enumerate(args) if i not in kernel.constexprs]
4675            # cpp wrapper needs arg type info for codegen
4676            for arg_name in self.ordered_kwargs_for_cpp_kernel:
4677                val = self.get_kwargs_value(arg_name)
4678                arg_types.append(
4679                    val.get_dtype() if hasattr(val, "get_dtype") else type(val)
4680                )
4681            arg_types = [
4682                t for i, t in enumerate(arg_types) if i not in kernel.constexprs
4683            ]
4684
4685        # Call to kernel
4686        self.codegen_comment(wrapper)
4687        wrapper.generate_user_defined_triton_kernel(
4688            new_name, self.grid, configs, args, triton_meta, arg_types
4689        )
4690
4691    def should_allocate(self):
4692        return False
4693
4694    def has_side_effects(self):
4695        # UserDefinedTritonKernel does not return anything, but rather
4696        # modifies input in place, do not let it get DCEd
4697        return True
4698
4699    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
4700        # add unbacked symbols used in the grid to the ones used
4701        # in the kwargs (the latter is generated by ExternKernel)
4702        return super().get_unbacked_symbol_uses() | free_unbacked_symbols(self.grid)
4703
4704    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4705        return set()
4706
4707    def get_mutation_names(self):
4708        # NB: Inductor only allows a node to mutate 0 or 1 buffers.
4709        # To get around that, we create MutationOutputs which marks their
4710        # assigned input as mutable, thus, adhering to Inductor's constraint.
4711        return []
4712
4713    def __init__(self, *, kernel_idx, grid, kernel_args):
4714        inputs = []
4715        kwargs = dict()
4716        constant_args = []
4717        for k, v in kernel_args.items():
4718            if isinstance(v, TensorBox):
4719                t = InputsKernel.unwrap_storage_for_input(self.realize_input(v))
4720                inputs.append(t)
4721                kwargs[k] = t
4722            else:
4723                constant_args.append(v)
4724                kwargs[k] = v
4725
4726        assert len(inputs) != 0
4727        device = inputs[0].get_device()
4728
4729        super().__init__(
4730            None,
4731            NoneLayout(device),  # type: ignore[arg-type]
4732            inputs,
4733            tuple(constant_args),
4734            kwargs,
4735        )
4736        self.name = V.graph.register_buffer(self)
4737        self.kernel_idx = kernel_idx
4738        self.grid = grid
4739
4740        kernel, configs = self.get_kernel_and_configs()
4741        # If we are autotuning, not all arguments will be passed
4742        self.ordered_kwargs_for_cpp_kernel = [
4743            arg for arg in kernel.arg_names if arg in kernel_args
4744        ]
4745
4746        from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
4747
4748        autotuned_kwargs = configs[0].kwargs if len(configs) > 0 else {}
4749        self.mutable_args = [
4750            kernel_args[key]
4751            for key in identify_mutated_tensors(
4752                kernel, {**kernel_args, **autotuned_kwargs}
4753            )
4754        ]
4755        mark_node_as_mutating(self, *self.mutable_args)
4756
4757    def get_inputs_that_alias_output(self):
4758        return [i.get_name() for i in self.mutable_args]
4759
4760
4761def mark_node_as_mutating(cur_buffer, *mutated_nodes: IRNode):
4762    """
4763    Allows ops in mutated_nodes to be marked as being mutated as well as
4764    indicates to the scheduler that these ops depend on cur_buffer.
4765
4766    NB: Use this instead of directly constructing MutationOutput
4767    """
4768    for node in mutated_nodes:
4769        assert isinstance(
4770            node, IRNode
4771        ), f"{node} node is type {type(node)} and is not an IRNode"
4772        V.graph.mark_buffer_mutated(node.get_name())
4773        MutationOutput(node.get_layout(), node, cur_buffer)
4774
4775
4776class MutationOutput(ExternKernel):
4777    def get_mutation_names(self):
4778        return [self.inputs[0].get_name()]
4779
4780    def __init__(self, layout, mutated_node, node_doing_mutating):
4781        # NB: Do not directly construct this - use `mark_node_as_mutating`
4782        super().__init__(None, layout, [mutated_node, node_doing_mutating], ())
4783        self.node_doing_mutating = node_doing_mutating
4784        self.name = V.graph.register_buffer(self)
4785
4786    def should_allocate(self):
4787        return False
4788
4789    def is_no_op(self):
4790        return True
4791
4792    def has_side_effects(self):
4793        return True
4794
4795    def get_inputs_that_alias_output(self):
4796        return [self.inputs[0].get_name()]
4797
4798
4799class InplaceBernoulliFallback(ExternKernel):
4800    """
4801    This needs to be a custom class to handle mutation properly
4802    """
4803
4804    def codegen(self, wrapper):
4805        (x,) = (t.codegen_reference() for t in self.inputs)
4806
4807        if V.graph.cpp_wrapper and config.abi_compatible:
4808            # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here,
4809            # which needs to be explicitly generated for cpp wrapper
4810            wrapper.writeline(
4811                f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}"
4812            )
4813        else:
4814            wrapper.writeline(
4815                f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
4816            )
4817
4818    def should_allocate(self):
4819        return False
4820
4821    def get_mutation_names(self):
4822        return [self.inputs[0].get_name()]
4823
4824    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4825        return set()
4826
4827    def __init__(self, op_overload, x, *constant_args):
4828        super().__init__(
4829            None,
4830            NoneLayout(x.get_device()),  # type: ignore[arg-type]
4831            self.unwrap_storage([x]),
4832            constant_args,
4833            op_overload=op_overload,
4834        )
4835        self.name = V.graph.register_buffer(self)
4836        self.python_kernel_name = "aten.bernoulli_"
4837        if not config.abi_compatible:
4838            # TODO: this should be simplified once we switch to ABI-compatible only
4839            self.cpp_kernel_name = "at::native::bernoulli_"
4840        mark_node_as_mutating(self, x)
4841
4842
4843# Used to deal with torch.complex types
4844class InplaceCopyFallback(ExternKernel):
4845    """
4846    This needs to be a custom class to handle mutation properly
4847    """
4848
4849    def codegen(self, wrapper):
4850        (dst, src, non_blocking) = self.codegen_args()
4851        wrapper.writeline(
4852            f"{self.get_kernel_name()}({dst}, {src}, {non_blocking}){wrapper.ending}"
4853        )
4854
4855    def should_allocate(self):
4856        return False
4857
4858    def get_mutation_names(self):
4859        return [self.inputs[0].get_name()]
4860
4861    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4862        return set()
4863
4864    def __init__(
4865        self,
4866        layout,
4867        inputs,
4868        constant_args,
4869    ):
4870        super().__init__(
4871            None,
4872            layout,
4873            inputs,
4874            constant_args,
4875            python_kernel_name="aten.copy_",
4876            cpp_kernel_name=(
4877                "aoti_torch_copy_" if config.abi_compatible else "at::_ops::copy_::call"
4878            ),
4879        )
4880        self.name = V.graph.register_buffer(self)
4881
4882    @classmethod
4883    def create(cls, dst, src, non_blocking: bool = False):
4884        inputs = [cls.realize_input(t) for t in [dst, src]]
4885        constant_args = (non_blocking,)
4886        result = InplaceCopyFallback(
4887            NoneLayout(dst.get_device()),  # type: ignore[arg-type]
4888            inputs,
4889            constant_args,
4890        )
4891        mark_node_as_mutating(result, dst)
4892        return result
4893
4894
4895class MutatingFirstArgExternKernel(ExternKernel):
4896    """
4897    This needs to be a custom class to handle mutation properly
4898    """
4899
4900    def codegen(self, wrapper):
4901        argrefs = [
4902            *(t.codegen_reference() for t in self.inputs),
4903            *map(repr, self.constant_args),
4904        ]
4905        wrapper.writeline(
4906            f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}"
4907        )
4908
4909    def should_allocate(self):
4910        return False
4911
4912    def get_mutation_names(self):
4913        return [self.inputs[0].get_name()]
4914
4915    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4916        return set()
4917
4918    def has_side_effects(self):
4919        return True
4920
4921
4922class ResizeStorageBytes(MutatingFirstArgExternKernel):
4923    def __init__(self, variable, new_size):
4924        assert isinstance(new_size, int), "TODO: dynamic shapes"
4925        super().__init__(
4926            None,
4927            NoneLayout(variable.get_device()),  # type: ignore[arg-type]
4928            self.unwrap_storage([variable]),
4929            constant_args=(new_size,),
4930        )
4931        V.graph.mark_buffer_mutated(variable.get_name())
4932        self.name = V.graph.register_buffer(self)
4933        self.python_kernel_name = "inductor_ops.resize_storage_bytes_"
4934        self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_"
4935        V.graph.never_reuse_buffers.add(variable.data.get_name())
4936        mark_node_as_mutating(self, variable)
4937
4938
4939class SetSourceTensorKernel(ExternKernelAlloc):
4940    def __init__(self, self_tensor, storage_tensor):
4941        self_tensor.freeze_layout()
4942        super().__init__(
4943            self_tensor.get_layout(),
4944            [self_tensor, storage_tensor],
4945            python_kernel_name="torch.ops.aten.set_.source_Tensor",
4946        )
4947        V.graph.never_reuse_buffers.add(self_tensor.data.get_name())
4948        V.graph.never_reuse_buffers.add(storage_tensor.get_name())
4949        V.graph.never_reuse_buffers.add(self.get_name())
4950        mark_node_as_mutating(self, self_tensor, storage_tensor)
4951
4952    def get_inputs_that_alias_output(self):
4953        return [self.inputs[0].get_name(), self.inputs[1].get_name()]
4954
4955    def get_mutation_names(self):
4956        return [self.inputs[1].get_name()]
4957
4958    def has_side_effects(self):
4959        return True
4960
4961
4962class ScatterFallback(ExternKernel):
4963    """
4964    This needs to be a custom class to handle mutation properly.
4965    This class handles both aten.scatter_ and aten.scatter_reduce_.
4966    It also handle the case `src` being a scalar properly.
4967    """
4968
4969    def codegen(self, wrapper):
4970        reduce = self.kwargs["reduce"]
4971        if V.graph.cpp_wrapper:
4972            # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum
4973            get_operator_enum = {"add": "sum", "multiply": "prod"}
4974            if reduce in get_operator_enum:
4975                reduce = get_operator_enum[reduce]
4976
4977        if self.src_is_tensor:
4978            (x, index, src) = (t.codegen_reference() for t in self.inputs)
4979        else:
4980            (x, index) = (t.codegen_reference() for t in self.inputs)
4981            src = self.constant_args[1]
4982        wrapper.generate_scatter_fallback(
4983            x,
4984            [x, self.constant_args[0], index, src],
4985            self.cpp_kernel_name,
4986            self.python_kernel_name,
4987            self.src_is_tensor,
4988            reduce,
4989            self.codegen_kwargs(),
4990        )
4991
4992    def should_allocate(self):
4993        return False
4994
4995    def get_mutation_names(self):
4996        return [self.inputs[0].get_name()]
4997
4998    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4999        return set()
5000
5001    def __init__(
5002        self,
5003        op_overload,
5004        x,
5005        dim: int,
5006        index,
5007        src,
5008        *,
5009        reduce: Optional[str] = None,
5010        include_self: bool = True,
5011    ):
5012        self.src_is_tensor = isinstance(src, TensorBox)
5013
5014        constant_args: Tuple[Any, ...]
5015        if self.src_is_tensor:
5016            tensors = [self.realize_input(t) for t in [x, index, src]]
5017            constant_args = (dim,)
5018        else:
5019            tensors = [self.realize_input(t) for t in [x, index]]
5020            constant_args = (dim, src)
5021
5022        super().__init__(
5023            None,
5024            NoneLayout(x.get_device()),  # type: ignore[arg-type]
5025            self.unwrap_storage(tensors),
5026            constant_args,
5027            {"reduce": reduce, "include_self": include_self},
5028            python_kernel_name=str(op_overload),
5029            ordered_kwargs_for_cpp_kernel=["reduce", "include_self"],
5030            op_overload=op_overload,
5031        )
5032        self.cpp_kernel_name = get_aten_cpp_kernel_name(op_overload)
5033        self.name = V.graph.register_buffer(self)
5034        mark_node_as_mutating(self, x)
5035
5036
5037class IndexPutFallback(ExternKernel):
5038    """
5039    This needs to be a custom class to handle mutation and indices properly
5040    """
5041
5042    def codegen(self, wrapper):
5043        (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs)
5044        indices = []
5045        iter_valid_indices = iter(valid_indices)
5046        for i, _ in enumerate(self.indices):
5047            if self.indices[i] is not None:
5048                indices.append(next(iter_valid_indices))
5049            else:
5050                indices.append(V.graph.wrapper_code.none_str)
5051
5052        wrapper.generate_index_put_fallback(
5053            self.get_kernel_name(), x, indices, values, *self.codegen_const_args()
5054        )
5055
5056    def should_allocate(self):
5057        return False
5058
5059    def get_mutation_names(self):
5060        return [self.inputs[0].get_name()]
5061
5062    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
5063        return set()
5064
5065    def __init__(self, op_overload, x, indices, values, accumulate):
5066        self.indices = indices
5067        valid_indices = [i for i in indices if i is not None]
5068        tensors = [self.realize_input(x) for x in [x, values, *valid_indices]]
5069        cpp_kernel_name = (
5070            "aoti_torch_index_put_out" if config.abi_compatible else "at::index_put_out"
5071        )
5072        super().__init__(
5073            None,
5074            NoneLayout(x.get_device()),  # type: ignore[arg-type]
5075            self.unwrap_storage(tensors),
5076            (accumulate,),
5077            python_kernel_name="aten.index_put_",
5078            cpp_kernel_name=cpp_kernel_name,
5079            op_overload=op_overload,
5080        )
5081        self.name = V.graph.register_buffer(self)
5082        mark_node_as_mutating(self, x)
5083
5084
5085class DeviceCopy(ExternKernelOut):
5086    @classmethod
5087    def create(cls, x, device):
5088        if (
5089            not x.is_extern()
5090            and all(
5091                (r.name in V.graph.constants and isinstance(r, dependencies.MemoryDep))
5092                for r in x.get_reads()
5093            )
5094            and not config.aot_inductor.use_runtime_constant_folding
5095        ):
5096            return x.constant_to_device(device)
5097
5098        V.graph.add_device_info(device)
5099        V.graph.add_device_info(x.get_device())
5100
5101        developer_warning("DeviceCopy in input program")
5102        return DeviceCopy(
5103            FlexibleLayout(
5104                device=device,
5105                dtype=x.get_dtype(),
5106                size=x.get_size(),
5107            ),
5108            [cls.realize_input(x)],
5109        )
5110
5111    def codegen(self, wrapper):
5112        args = self.codegen_args()
5113        assert len(args) == 1
5114        if self.output_view:
5115            wrapper.codegen_device_copy(args[0], self.output_view.codegen_reference())
5116        else:
5117            wrapper.codegen_device_copy(args[0], self.codegen_reference())
5118
5119
5120class DynamicScalar(ExternKernel):
5121    """
5122    The result of a call to aten._local_scalar_dense.
5123    """
5124
5125    def get_reads(self):
5126        return ()
5127
5128    def should_allocate(self):
5129        return False
5130
5131    def __init__(self, sym, keypath, data):
5132        data.realize()
5133        super().__init__(None, NoneLayout(torch.device("cpu")), self.unwrap_storage([data]))  # type: ignore[arg-type]
5134        self.sym = sym
5135        self.keypath = keypath
5136
5137    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
5138        return {self.sym}
5139
5140    def codegen(self, wrapper):
5141        wrapper.codegen_dynamic_scalar(self)
5142
5143
5144class AssertScalar(ExternKernel):
5145    """
5146    The result of a call to aten._assert_scalar
5147    """
5148
5149    def get_reads(self):
5150        return ()
5151
5152    def should_allocate(self):
5153        return False
5154
5155    def __init__(self, scalar, msg):
5156        super().__init__(
5157            # Buffer(name, layotu)
5158            None,
5159            NoneLayout(torch.device("cpu")),  # type: ignore[arg-type]
5160            # InputsKernel(inputs)
5161            [],
5162        )  # type: ignore[arg-type]
5163        self.scalar = scalar
5164        self.msg = msg
5165
5166    def has_side_effects(self):
5167        return True
5168
5169    def get_unbacked_symbol_uses(self):
5170        return free_unbacked_symbols(self.scalar)
5171
5172    def codegen(self, wrapper):
5173        if V.graph.cpp_wrapper:
5174            pass
5175        else:
5176            # NB: It is EXTREMELY important not to simplify the scalar under
5177            # assertion here, because simplify is done with respect to
5178            # runtime asserts.  So if you have "u0 == 0" in the runtime
5179            # asserts, if you subsequently try to simplify(u0 == 0), you will
5180            # get True (because we've already runtime assert'ed that it's
5181            # true).  But we're code generating the actual runtime assert
5182            # here!!
5183            wrapper.writeline(
5184                f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar, simplify=False)}:"
5185            )
5186            wrapper.writeline(f"    raise RuntimeError({repr(self.msg)})")
5187            # No one should ever use this buffer, but for uniformity
5188            # define the variable and assign it None
5189            wrapper.writeline(f"{self.get_name()} = None")
5190
5191
5192@dataclasses.dataclass
5193class ExternKernelNode:
5194    name: str
5195    node: export_schema.Node
5196
5197
5198has_c_shim = {
5199    aten._embedding_bag.default,
5200    aten._fft_c2c.default,
5201    aten._scaled_dot_product_efficient_attention.default,
5202    aten._scaled_dot_product_flash_attention.default,
5203    aten._scaled_mm.default,
5204    aten.addmm.out,
5205    aten.bmm.out,
5206    aten.copy_.default,
5207    aten.mm.out,
5208    aten.repeat_interleave.Tensor,
5209    aten.nonzero.default,
5210    aten.view.dtype,
5211    aten.view_as_real.default,
5212}
5213
5214
5215class FallbackKernel(ExternKernelAlloc):
5216    def __init__(
5217        self,
5218        layout,
5219        kernel,
5220        tensor_args,
5221        nontensor_args,
5222        unflatten_args,
5223        kwargs=None,
5224        *,
5225        unbacked_bindings=None,
5226    ):
5227        if (
5228            kernel == aten.mul.Tensor
5229            and len(tensor_args) == 1
5230            and len(nontensor_args) == 1
5231        ):
5232            # When aten.mul.Tensor's second arg is constant, cpp wrapper expects
5233            # to call mul_Scalar. A more proper fix is to do it in decomposition.
5234            # See https://github.com/pytorch/pytorch/issues/123478
5235            kernel = aten.mul.Scalar
5236
5237        super().__init__(
5238            layout,
5239            tuple(tensor_args),
5240            tuple(nontensor_args),
5241            op_overload=kernel,
5242        )
5243
5244        # We need output buffers for generating kernel arguments in the
5245        # abi-compatible mode, where we retrieve outputs by pass each individual
5246        # output through the abi-compatible interface.
5247        self.outputs: Sequence[Any] = []
5248        self.use_runtime_dispatch = False
5249        self.unbacked_bindings = unbacked_bindings
5250
5251        assert isinstance(
5252            kernel,
5253            (
5254                torch._ops.OpOverload,
5255                torch._ops.HigherOrderOperator,
5256            ),
5257        ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported"
5258        self.op_overload = kernel
5259        self.unflatten_args = unflatten_args
5260        self.kwargs = {} if kwargs is None else kwargs
5261        V.graph.warn_fallback(self.python_kernel_name)
5262
5263        # args that are aliased
5264        self.alias_names: List[str] = []
5265        # args that are mutated AND returned from the op
5266        self.mutation_names: List[str] = []
5267
5268        if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
5269            # We assume here that HOPs with FallbackKernel are functional.
5270            # This may not always be true! HOPs must individually opt-in to
5271            # FallbackKernel, so please check this if you opt-in.
5272            return
5273
5274        if "_c10d_functional" in self.op_overload.name():
5275            # _c10d_functional kernels are lowered into _CollectiveKernel which
5276            # derives from FallbackKernel for the cpp codegen. The kernels
5277            # don't pass the can_auto_functionalize check, but their mutation
5278            # is handled properly by _CollectiveKernel.
5279            return
5280
5281        schema = self.op_overload._schema
5282
5283        # NOTE: [FallbackKernel supported operators]
5284        # We only support three types of operators:
5285        # - functional ops
5286        # - view ops
5287        # - inplace aten ops
5288        # - mutating ops that are auto-functionalizable. That is,
5289        # the operator may mutate any number of inputs, but its outputs
5290        # may not alias any of the inputs.
5291        #
5292        # The unsupported cases usually do not show up here (because
5293        # AOTAutograd functionalized them away); the only way for an in-place
5294        # op to show up here is if a lowering or pass introduced it.
5295        if torch._library.utils.mutates_and_returns_first_arg(self.op_overload):
5296            self.mutation_names.append(tensor_args[0].get_name())
5297            return
5298
5299        if schema.is_mutable and not can_auto_functionalize(kernel):
5300            raise NotImplementedError(
5301                f"NYI: Can't generate FallbackKernel for {kernel}"
5302            )
5303
5304        schema_args = schema.arguments
5305        args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
5306
5307        def handle_aliasing_and_mutation(info, arg):
5308            # Assertions to make sure we didn't mismatch args
5309            if isinstance(info.type, torch.ListType):
5310                assert isinstance(arg, (list, tuple))
5311            is_optional_tensor = isinstance(
5312                info.type, torch.OptionalType
5313            ) and isinstance(info.type.getElementType(), torch.TensorType)
5314            if is_optional_tensor or isinstance(info.type, torch.TensorType):
5315                # PyTorch also accepts None and scalar types for args marked as "Tensor".
5316                # We're not going to check all of them here.
5317                assert not isinstance(arg, (tuple, list))
5318
5319            if arg is None:
5320                return
5321            if info.alias_info is None:
5322                return
5323            # can_auto_functionalize already filters out mutable List[Tensor].
5324            # We can support this in the future, but this is very uncommon.
5325            assert isinstance(info.type, torch.TensorType) or is_optional_tensor
5326            self.alias_names.append(arg.get_name())
5327            if info.alias_info.is_write:
5328                mark_node_as_mutating(self, arg)
5329
5330        for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
5331            handle_aliasing_and_mutation(info, arg)
5332
5333    def codegen_unbacked_symbol_defs(self, wrapper):
5334        if not hasattr(self, "unbacked_bindings"):
5335            return
5336
5337        unbacked_bindings = resolve_unbacked_bindings(
5338            V.graph.sizevars.shape_env, self.unbacked_bindings
5339        )
5340
5341        if not unbacked_bindings:
5342            return
5343
5344        for s, keypath in unbacked_bindings.items():
5345
5346            def go(expr, keypath):
5347                if keypath == ():
5348                    return expr
5349
5350                if (
5351                    len(keypath) >= 2
5352                    and isinstance(keypath[0], CallMethodKey)
5353                    and isinstance(keypath[1], pytree.SequenceKey)
5354                ):
5355                    return go(
5356                        f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:]
5357                    )
5358                elif isinstance(keypath[0], CallMethodKey):
5359                    return go(f"{expr}.{keypath[0].name}()", keypath[1:])
5360                elif isinstance(keypath[0], pytree.SequenceKey):
5361                    return go(f"{expr}[{keypath[0].idx}]", keypath[1:])
5362                elif isinstance(keypath[0], DivideByKey):
5363                    # TODO: need to assert divisibility
5364                    # TODO: this is invalid C++ codegen
5365                    return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:])
5366                else:
5367                    raise AssertionError(f"unrecognized keypath {keypath}")
5368
5369            def go_outer():
5370                if V.graph.cpp_wrapper and config.abi_compatible:
5371                    # Special handling for the top level buffer access,
5372                    # because self.get_name() is actually never bound; the
5373                    # individual output arguments are bound by
5374                    # generate_c_shim_fallback_kernel
5375                    if len(self.outputs) == 1:
5376                        return go(self.outputs[0].get_name(), keypath)
5377                    else:
5378                        assert isinstance(keypath[0], pytree.SequenceKey)
5379                        return go(self.outputs[keypath[0].idx].get_name(), keypath[1:])
5380                else:
5381                    return go(self.get_name(), keypath)
5382
5383            wrapper.writeline(
5384                f"{wrapper.codegen_unbacked_symbol_decl(s)} = {go_outer()}{wrapper.ending}"
5385            )
5386
5387    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
5388        if unbacked_bindings := getattr(self, "unbacked_bindings", None):
5389            return resolve_unbacked_bindings(
5390                V.graph.sizevars.shape_env, unbacked_bindings
5391            ).keys()
5392        else:
5393            return set()
5394
5395    def set_cpp_kernel(self, kernel):
5396        from .codegen.wrapper import get_cpp_op_schema
5397
5398        assert (
5399            not kernel._schema.is_mutable
5400        ), f"mutable {kernel.__name__} is not supported with cpp_wrapper"
5401
5402        # These checks are here because ops that return aliasing tensors will
5403        # return type Tensor& instead of Tensor, but codegen will always write
5404        # type Tensor on the LHS.
5405        def is_not_write(arg):
5406            return arg.alias_info is None or not arg.alias_info.is_write
5407
5408        assert all(
5409            is_not_write(x) for x in kernel._schema.arguments
5410        ), f"{kernel.__name__} with alias_info arguments is not supported with cpp_wrapper"
5411        assert all(
5412            is_not_write(x) for x in kernel._schema.returns
5413        ), f"{kernel.__name__} with alias_info returns is not supported with cpp_wrapper"
5414
5415        self.cpp_kernel_name = kernel._schema.name
5416        self.cpp_kernel_overload_name = kernel._schema.overload_name
5417        self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}"  # type: ignore[union-attr]
5418
5419        self.cpp_op_schema = get_cpp_op_schema(kernel)
5420
5421    def codegen_args(self):
5422        @dataclasses.dataclass
5423        class Shim:
5424            ref: Any
5425
5426            def __repr__(self):
5427                return self.ref
5428
5429        tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
5430        args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
5431        if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload):
5432            args = self.fill_non_provided_args(args, kwargs)
5433            args = [
5434                V.graph.wrapper_code.val_to_arg_str(x, param.real_type)
5435                for param, x in zip(self.op_overload._schema.arguments, args)
5436            ]
5437        else:
5438            args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]
5439
5440        # let self.codegen_kwargs handle kwargs
5441        self.kwargs.update(kwargs)
5442        return args
5443
5444    @staticmethod
5445    def find_device(tensor_args, example_output):
5446        if tensor_args:
5447            devices = [arg.get_device() for arg in tensor_args if arg.get_device()]
5448            return devices[0]
5449        if isinstance(example_output, torch.Tensor):
5450            return example_output.device
5451        if isinstance(example_output, (list, tuple)):
5452            device_set = {FallbackKernel.find_device(None, x) for x in example_output}
5453            # Remove None
5454            devices = [device for device in device_set if device]
5455            if len(devices) == 1:
5456                return devices[0]
5457            for device in devices:
5458                if is_gpu(device.type):
5459                    return device
5460            return devices[0]
5461        return None
5462
5463    def has_side_effects(self):
5464        if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
5465            return False
5466        return get_schema_info(self.op_overload).is_mutable()
5467
5468    def get_inputs_that_alias_output(self):
5469        return self.alias_names
5470
5471    def get_mutation_names(self):
5472        assert len(self.mutation_names) <= 1
5473        return self.mutation_names
5474
5475    # ProxyExecutor Design Note
5476    # We export the ExternFallbackNodes (for custom ops) into a serialized file
5477    # and run it with a host side proxy executor to address the ABI problem
5478    # This is currently only implemented for fbcode. Eventually, we will also make this work for OSS.
5479    # Detailed design doc can be found at
5480    # https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing
5481    def export_extern_kernel_node(self):
5482        assert isinstance(self, FallbackKernel)
5483        args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
5484        args = self.fill_non_provided_args(args, kwargs)
5485        ordered_kwargs = [
5486            kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel
5487        ]
5488        if not V.graph.aot_mode:
5489            # No need to serialize in the cpp wrapper JIT mode
5490            return [*args, *ordered_kwargs]
5491
5492        serializer = GraphModuleSerializer(None, None)  # type: ignore[arg-type]
5493        named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs)  # type: ignore[arg-type]
5494
5495        # serialize_outputs
5496        def handle_single_output(return_type, output):
5497            if isinstance(return_type, torch.TensorType):
5498                # For single Tensor
5499                out = output
5500                if isinstance(output, (list, tuple)):
5501                    assert len(output) == 1
5502                    out = output[0]
5503                return export_schema.Argument.create(
5504                    as_tensor=export_schema.TensorArgument(name=out.get_name())
5505                )
5506            elif isinstance(return_type, torch.ListType) and isinstance(
5507                return_type.getElementType(), torch.TensorType
5508            ):
5509                # For single TensorList
5510                return export_schema.Argument.create(
5511                    as_tensors=[
5512                        export_schema.TensorArgument(name=out.get_name())
5513                        for out in output
5514                    ]
5515                )
5516            else:
5517                raise RuntimeError(f"Unsupported return type {type(return_type)}")
5518
5519        target = self.op_overload
5520        returns = target._schema.returns  # type: ignore[union-attr]
5521        if len(returns) == 1:
5522            return_type = returns[0].real_type
5523            output_arguments = [handle_single_output(return_type, self.outputs)]
5524        else:
5525            # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])"
5526            assert isinstance(self.outputs, tuple)
5527            assert len(returns) == len(self.outputs)
5528            output_arguments = [
5529                handle_single_output(return_schema.real_type, output)
5530                for return_schema, output in zip(returns, self.outputs)
5531            ]
5532
5533        node = ExternKernelNode(
5534            name=self.get_name(),
5535            node=export_schema.Node(
5536                target=self.op_overload.name(),  # type: ignore[union-attr]
5537                inputs=named_arguments,
5538                outputs=output_arguments,
5539                metadata={},
5540            ),
5541        )
5542
5543        V.graph.extern_kernel_nodes.append(node)
5544
5545        return [*args, *ordered_kwargs]
5546
5547    def codegen(self, wrapper):
5548        kernel = self.op_overload
5549        if kernel.namespace == "aten":  # type: ignore[union-attr]
5550            # Aten Fallback Ops
5551            assert isinstance(kernel, torch._ops.OpOverload)
5552            if V.graph.cpp_wrapper:
5553                if (
5554                    config.is_fbcode()
5555                    and kernel not in has_c_shim
5556                    # C shim v2 is torchgen-ed, which should cover all aten ops.
5557                    # If you do hit a missed op, please update gen_aoti_c_shim.py.
5558                    and config.c_shim_version == "1"
5559                ):
5560                    log.warning(
5561                        "%s is missing a c-shim implementation, using proxy executor as fallback",
5562                        kernel,
5563                    )
5564                    self.use_runtime_dispatch = True
5565                    self.set_cpp_kernel(kernel)
5566            else:
5567                self.python_kernel_name = str(kernel)
5568        elif kernel.namespace == "_quantized":  # type: ignore[union-attr]
5569            # Internal Quantized Fallback Ops
5570            assert isinstance(kernel, torch._ops.OpOverload)
5571            if V.graph.cpp_wrapper:
5572                self.set_cpp_kernel(kernel)
5573                if not config.abi_compatible:
5574                    self.use_runtime_dispatch = True
5575            else:
5576                self.python_kernel_name = str(kernel)
5577        elif isinstance(kernel, torch._ops.HigherOrderOperator):
5578            self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}"
5579        else:
5580            # For non-aten OpOverload, i.e. custom ops
5581            self.python_kernel_name = f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"  # type: ignore[union-attr]
5582            if V.graph.cpp_wrapper:
5583                self.use_runtime_dispatch = True
5584                self.set_cpp_kernel(kernel)
5585
5586        if self.use_runtime_dispatch:
5587            self.codegen_comment(wrapper)
5588
5589            exported_args = None
5590            args = None
5591            if config.abi_compatible:
5592                exported_args = self.export_extern_kernel_node()
5593            else:
5594                args = [*self.codegen_args(), *self.codegen_kwargs()]
5595
5596            wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5597                self.get_name(),
5598                self.python_kernel_name,
5599                self.cpp_kernel_name,
5600                args,
5601                self.cpp_op_schema,
5602                self.cpp_kernel_key,
5603                self.cpp_kernel_overload_name,
5604                self.op_overload,
5605                exported_args,
5606                self.outputs,
5607            )
5608        else:
5609            self.codegen_comment(wrapper)
5610            args = [*self.codegen_args(), *self.codegen_kwargs()]
5611            V.graph.wrapper_code.generate_fallback_kernel(self, args)
5612            if isinstance(self.layout, Layout):
5613                self.codegen_size_asserts(wrapper)
5614
5615        self.codegen_unbacked_symbol_defs(wrapper)
5616
5617    @staticmethod
5618    def tensor_to_layout(output: torch.Tensor):
5619        return FixedLayout(
5620            output.device,
5621            output.dtype,
5622            convert_shape_to_inductor(output.size()),
5623            convert_shape_to_inductor(output.stride()),
5624        )
5625
5626    @classmethod
5627    def create(cls, kernel, *args, **kwargs):
5628        fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
5629        context = (
5630            V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext()
5631        )
5632        with context:
5633            (
5634                example_output,
5635                tensor_args,
5636                non_tensor_args,
5637                unflatten_args,
5638                unbacked_bindings,
5639            ) = cls.process_kernel(kernel, *args, **kwargs)
5640
5641        device = cls.find_device(tensor_args, example_output)
5642        if example_output is None:
5643            packed = cls(
5644                NoneLayout(device),
5645                kernel,
5646                tensor_args,
5647                non_tensor_args,
5648                unflatten_args,
5649                unbacked_bindings=unbacked_bindings,
5650            )
5651
5652        else:
5653            assert device, "Not sure where to find device info"
5654            packed = cls(
5655                MultiOutputLayout(device),
5656                kernel,
5657                tensor_args,
5658                non_tensor_args,
5659                unflatten_args,
5660                unbacked_bindings=unbacked_bindings,
5661            )
5662
5663        def generate_output(output, indices):
5664            if isinstance(output, (list, tuple)):
5665                return type(output)(
5666                    generate_output(output[i], indices + [(type(output), i)])
5667                    for i in range(len(output))
5668                )
5669            elif isinstance(output, dict):
5670                return {
5671                    key: generate_output(val, indices + [(type(output), key)])
5672                    for key, val in output.items()
5673                }
5674            elif isinstance(output, torch.Tensor):
5675                return MultiOutput(
5676                    cls.tensor_to_layout(output),
5677                    packed,
5678                    indices,
5679                )
5680            elif isinstance(output, int):
5681                return output
5682            elif isinstance(output, torch.SymInt):
5683                return output.node.expr
5684            else:
5685                assert (
5686                    output is None
5687                ), f"FallbackKernel output type {type(output)} is not supported"
5688                return None
5689
5690        outputs = generate_output(example_output, [])
5691        if isinstance(outputs, (list, tuple, dict)):
5692            packed.outputs = outputs  # type: ignore[assignment]
5693        else:
5694            packed.outputs = [outputs]
5695        return outputs
5696
5697    def apply_constraint(self):
5698        return super().apply_constraint()
5699
5700
5701@dataclasses.dataclass
5702class ComplexView(FallbackKernel):
5703    """View a complex number as two dtyped numbers or vice versa"""
5704
5705    def should_allocate(self):
5706        return False
5707
5708    def get_inputs_that_alias_output(self):
5709        # Signal to codegen that our output buffer isn't safe to reuse
5710        return [self.inputs[0].get_name()]
5711
5712    def __init__(
5713        self,
5714        layout,
5715        kernel,
5716        tensor_args,
5717        nontensor_args,
5718        unflatten_args,
5719        *,
5720        unbacked_bindings=None,
5721    ):
5722        super().__init__(
5723            layout,
5724            kernel,
5725            tensor_args,
5726            nontensor_args,
5727            unflatten_args,
5728            unbacked_bindings=unbacked_bindings,
5729        )
5730
5731
5732@dataclasses.dataclass
5733class MultiOutputLayout(IRNode):
5734    device: torch.device
5735
5736
5737class MultiOutput(ExternKernel):
5738    # Given an input MultiOutputLayout buffer, indexes out an actual buffer
5739    # from that result.  This doesn't actually produce multiple outputs,
5740    # that's MultiOutputLayout!
5741    def codegen_list_tuple_access(self, basename, indices):
5742        if len(indices) > 0:
5743            itype, i = indices[0]
5744            if issubclass(itype, list):
5745                return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:])
5746            elif issubclass(itype, tuple):
5747                # cpp wrapper code needs to use std::get<> to access a tuple
5748                tuple_access = V.graph.wrapper_code.codegen_tuple_access(
5749                    basename, self.get_name(), str(i)
5750                )
5751                return self.codegen_list_tuple_access(tuple_access, indices[1:])
5752            elif issubclass(itype, dict):
5753                return self.codegen_list_tuple_access(f"{basename}['{i}']", indices[1:])
5754            else:
5755                raise AssertionError("non supported index type: ", itype)
5756        else:
5757            return basename
5758
5759    def codegen(self, wrapper):
5760        wrapper.codegen_multi_output(
5761            self.get_name(),
5762            self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
5763        )
5764
5765    def __init__(self, layout, input, indices: List[Tuple[Any, ...]]):
5766        super().__init__(None, layout, [input], ())
5767        self.name = V.graph.register_buffer(self)
5768        self.indices = indices
5769
5770    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
5771        return self.inputs[0].get_unbacked_symbol_uses()
5772
5773    def should_allocate(self):
5774        return False
5775
5776    def get_inputs_that_alias_output(self):
5777        return [
5778            inp.get_name()
5779            for inp in self.inputs
5780            if isinstance(inp, FallbackKernel)
5781            and len(inp.get_inputs_that_alias_output()) > 0
5782        ]
5783
5784
5785def _prepare_convolution_fusion_create(
5786    cls,
5787    x: "TensorBox",
5788    weight: "TensorBox",
5789    bias: "TensorBox",
5790    padding: List[int],
5791    stride: List[int],
5792    dilation: List[int],
5793    groups: int,
5794    transposed: bool = False,
5795    output_padding: Optional[List[int]] = None,
5796):
5797    """
5798    This function is a helper function to prepare inputs, layout and constant args
5799    for convolution post-op fusion's create function, including deciding the output
5800    layout (channels first or channels last), realizing inputs and make them etc. The
5801    function only supports the CPU device since conv post-op fusion kernel is only
5802    supported on CPU right now.
5803    """
5804
5805    # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size
5806    def _conv_input_size(
5807        output_size, weight_size, padding, output_padding, stride, dilation, groups
5808    ):
5809        assert len(output_size) == len(weight_size), "Expect input dim == weight dim"
5810        dim = len(output_size)
5811        assert dim > 2, "Expect input dim > 2"
5812
5813        BATCH_DIM = 0
5814        WEIGHT_INPUT_CHANNELS_DIM = 1
5815        input_size = []
5816        input_size.append(output_size[BATCH_DIM])
5817        input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups)
5818        for d in range(2, dim):
5819            kernel = (weight_size[d] - 1) * dilation[d - 2] + 1
5820            input_size_d = (
5821                (output_size[d] - 1) * stride[d - 2]
5822                - (padding[d - 2] * 2)
5823                + kernel
5824                + output_padding[d - 2]
5825            )
5826            input_size.append(input_size_d)
5827        return list(map(int, input_size))
5828
5829    # The size of prepacked_weight is the prepacked weight size of deconv:
5830    #   Groups > 1:  [g*o, i/g, ...]
5831    #   Groups == 1: [o, i, ...]
5832    # Returns original weight size in [i, o, ...]
5833    def _original_deconv_weight_size(
5834        prepacked_weight,
5835        groups,
5836    ):
5837        prepacked_weight_size = prepacked_weight.size()
5838        dim = len(prepacked_weight_size)
5839        assert dim > 2, "Expect weight dim > 2"
5840        if groups > 1:
5841            weight_size = []
5842            weight_size.append(prepacked_weight_size[1] * groups)
5843            weight_size.append(prepacked_weight_size[0] / groups)
5844            for d in range(2, dim):
5845                weight_size.append(prepacked_weight_size[d])
5846        else:
5847            weight_size = prepacked_weight.transpose(0, 1).size()
5848        return weight_size
5849
5850    x.realize()
5851    weight.realize()
5852    if bias is not None:
5853        bias.realize()
5854    with V.graph.fake_mode:
5855        # TODO <Leslie> cleaned up the fake_tensor trace as Linear implementation
5856        x_fake = ir_node_to_tensor(x, guard_shape=True)
5857        weight_fake = ir_node_to_tensor(weight, guard_shape=True)
5858        dims = len(x_fake.size()) - 2
5859        assert 0 < len(padding) <= dims
5860        assert 0 < len(dilation) <= dims
5861        assert 0 < len(stride) <= dims
5862        padding = pad_listlike(padding, dims)
5863        dilation = pad_listlike(dilation, dims)
5864        stride = pad_listlike(stride, dims)
5865        if output_padding is None:
5866            output_padding = pad_listlike([0], dims)
5867        else:
5868            assert 0 < len(output_padding) <= dims
5869            output_padding = pad_listlike(output_padding, dims)
5870        assert isinstance(groups, int)
5871        if transposed:
5872            # When transposed, the size of the prepacked oneDNN weight is different
5873            # from the PyTorch weight. We're not able to run aten conv with such
5874            # size. We infer the output size from the input params here:
5875            weight_size = _original_deconv_weight_size(weight_fake, groups)
5876            input_size = x_fake.size()
5877            output_size = _conv_input_size(
5878                input_size,
5879                weight_size,
5880                padding,
5881                output_padding,
5882                stride,
5883                dilation,
5884                groups,
5885            )
5886        else:
5887            bias_fake = (
5888                ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
5889            )
5890            output = torch.ops.aten.convolution(
5891                x_fake,
5892                weight_fake,
5893                bias_fake,
5894                stride,
5895                padding,
5896                dilation,
5897                transposed,
5898                output_padding,
5899                groups,
5900            )
5901            output_size = output.size()
5902
5903        req_stride_order = [0] + list(reversed(range(1, len(stride) + 1)))
5904        req_stride_order = [len(req_stride_order)] + req_stride_order
5905
5906    x = cls.require_stride_order(x, req_stride_order)
5907
5908    # We won't do weight prepack for Conv if dynamic_shapes.
5909    # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel.
5910    # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1),
5911    # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order
5912    # won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel,
5913    # this tensor is considered as channels first and the output will be in contiguous format.
5914    # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last.
5915    dynamic_shapes = not all(isinstance(i, int) for i in (output_size))
5916    if dynamic_shapes and is_contiguous_storage_and_layout(x):
5917        output_stride = FlexibleLayout.contiguous_strides(output_size)
5918    else:
5919        output_stride = make_channels_last_strides_for(output_size)
5920
5921    assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
5922    inputs = [x, weight]
5923
5924    kernel_layout = FixedLayout(
5925        x.get_device(),
5926        x.get_dtype(),
5927        convert_shape_to_inductor(output_size),
5928        convert_shape_to_inductor(output_stride),
5929    )
5930    constant_args = [padding, stride, dilation, groups]
5931    if transposed:
5932        constant_args.insert(1, output_padding)
5933
5934    if bias is not None:
5935        inputs.append(bias)
5936    else:
5937        constant_args.insert(0, bias)
5938    return inputs, constant_args, kernel_layout, req_stride_order
5939
5940
5941def _prepare_linear_fusion_create(
5942    cls,
5943    x: "TensorBox",
5944    weight: "TensorBox",
5945    bias: "TensorBox",
5946):
5947    """
5948    This function is a helper function to prepare inputs, layout and constant args
5949    for linear post-op fusion's create function. The function only supports the CPU device
5950    since linear post-op fusion kernel is only supported on CPU right now.
5951    """
5952    x.realize()
5953    weight.realize()
5954    if bias is not None:
5955        bias.realize()
5956
5957    *m, _ = x.get_size()
5958    # The weight has been transposed during the qlinear weight prepack process.
5959    # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/
5960    # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291
5961    _, oc = weight.get_size()
5962    output_size = list(m) + [oc]
5963    req_stride_order = list(reversed(range(len(x.get_size()))))
5964
5965    x = cls.require_stride_order(x, req_stride_order)
5966    assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
5967    inputs = [x, weight]
5968
5969    output_stride = FlexibleLayout.contiguous_strides(output_size)
5970    kernel_layout = FixedLayout(
5971        x.get_device(),
5972        x.get_dtype(),
5973        output_size,
5974        output_stride,
5975    )
5976    constant_args: List[Any] = []
5977
5978    if bias is not None:
5979        inputs.append(bias)
5980    else:
5981        constant_args.insert(0, bias)
5982    return inputs, constant_args, kernel_layout, req_stride_order
5983
5984
5985class ConvolutionUnary(ExternKernelAlloc):
5986    def __init__(
5987        self,
5988        layout,
5989        inputs,
5990        constant_args=(),
5991    ):
5992        super().__init__(
5993            layout,
5994            inputs,
5995            constant_args,
5996            None,
5997            python_kernel_name="torch.ops.mkldnn._convolution_pointwise",
5998            cpp_kernel_name="mkldnn::_convolution_pointwise",
5999        )
6000        self.cpp_kernel_key = "convolution_pointwise"
6001        self.cpp_op_schema = """
6002            at::Tensor(
6003                const at::Tensor& input_t,
6004                const at::Tensor& weight_t,
6005                const c10::optional<at::Tensor>& bias_opt,
6006                at::IntArrayRef padding,
6007                at::IntArrayRef stride,
6008                at::IntArrayRef dilation,
6009                int64_t groups,
6010                c10::string_view attr,
6011                torch::List<c10::optional<at::Scalar>> scalars,
6012                c10::optional<c10::string_view> algorithm)"""
6013
6014    def codegen(self, wrapper):
6015        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6016            self.get_name(),
6017            self.python_kernel_name,
6018            self.cpp_kernel_name,
6019            self.codegen_args(),
6020            self.cpp_op_schema,
6021            self.cpp_kernel_key,
6022        )
6023        if isinstance(self.layout, Layout):
6024            self.codegen_size_asserts(wrapper)
6025
6026    @classmethod
6027    def create(
6028        cls,
6029        x: "TensorBox",
6030        weight: "TensorBox",
6031        bias: "TensorBox",
6032        padding_: List[int],
6033        stride_: List[int],
6034        dilation_: List[int],
6035        groups: int,
6036        attr,
6037        scalars: Optional[List[Any]],
6038        algorithm,
6039    ):
6040        (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
6041            cls, x, weight, bias, padding_, stride_, dilation_, groups
6042        )
6043        constant_args = constant_args + [
6044            attr,
6045            may_convert_to_optional(scalars),
6046            algorithm,
6047        ]
6048        return ConvolutionUnary(
6049            layout=kernel_layout,
6050            inputs=inputs,
6051            constant_args=constant_args,
6052        )
6053
6054
6055class ConvolutionBinary(ExternKernelAlloc):
6056    def __init__(
6057        self,
6058        layout,
6059        inputs,
6060        constant_args=(),
6061        cpp_constant_args=(),
6062    ):
6063        super().__init__(
6064            layout,
6065            inputs,
6066            constant_args,
6067            None,
6068            python_kernel_name="torch.ops.mkldnn._convolution_pointwise.binary",
6069            cpp_kernel_name="mkldnn::_convolution_pointwise",
6070        )
6071        self.cpp_kernel_overload_name = "binary"
6072        self.cpp_kernel_key = "convolution_pointwise_binary"
6073        self.cpp_op_schema = """
6074            at::Tensor(
6075                const at::Tensor& input_t,
6076                const at::Tensor& other_t,
6077                const at::Tensor& weight_t,
6078                const c10::optional<at::Tensor>& bias_opt,
6079                at::IntArrayRef padding,
6080                at::IntArrayRef stride,
6081                at::IntArrayRef dilation,
6082                int64_t groups,
6083                c10::string_view binary_attr,
6084                c10::optional<at::Scalar> alpha,
6085                c10::optional<c10::string_view> unary_attr,
6086                torch::List<c10::optional<at::Scalar>> unary_scalars,
6087                c10::optional<c10::string_view> unary_algorithm)"""
6088        self.cpp_constant_args = cpp_constant_args
6089
6090    def codegen(self, wrapper):
6091        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6092            self.get_name(),
6093            self.python_kernel_name,
6094            self.cpp_kernel_name,
6095            self.codegen_args(),
6096            self.cpp_op_schema,
6097            self.cpp_kernel_key,
6098            self.cpp_kernel_overload_name,
6099        )
6100        if isinstance(self.layout, Layout):
6101            self.codegen_size_asserts(wrapper)
6102
6103    @classmethod
6104    def create(
6105        cls,
6106        x: "TensorBox",
6107        other: "TensorBox",
6108        weight: "TensorBox",
6109        bias: "TensorBox",
6110        padding_: List[int],
6111        stride_: List[int],
6112        dilation_: List[int],
6113        groups: int,
6114        binary_attr: str,
6115        binary_alpha: Optional[float],
6116        unary_attr: Optional[str],
6117        unary_scalars: Optional[List[Any]],
6118        unary_algorithm: Optional[str],
6119    ):
6120        (
6121            inputs,
6122            constant_args,
6123            kernel_layout,
6124            req_stride_order,
6125        ) = _prepare_convolution_fusion_create(
6126            cls, x, weight, bias, padding_, stride_, dilation_, groups
6127        )
6128        other = cls.require_stride_order(other, req_stride_order)
6129        inputs.insert(1, other)
6130        constant_args = constant_args + [
6131            binary_attr,
6132            binary_alpha,
6133            unary_attr,
6134            may_convert_to_optional(unary_scalars),
6135            unary_algorithm,
6136        ]
6137        return ConvolutionBinary(
6138            layout=kernel_layout,
6139            inputs=inputs,
6140            constant_args=constant_args,
6141        )
6142
6143
6144class ConvolutionBinaryInplace(ExternKernelAlloc):
6145    def __init__(
6146        self,
6147        kernel_layout,
6148        inputs,
6149        constant_args=(),
6150    ):
6151        # Due to constrain of op.call, other (Tensor&) should be at input[0]
6152        reordered_inputs = [inputs[1], inputs[0]] + inputs[2:]
6153
6154        super().__init__(
6155            kernel_layout,
6156            reordered_inputs,
6157            constant_args,
6158            None,
6159            python_kernel_name="torch.ops.mkldnn._convolution_pointwise_.binary",
6160            cpp_kernel_name="mkldnn::_convolution_pointwise_",
6161        )
6162        self.cpp_kernel_overload_name = "binary"
6163        self.cpp_kernel_key = "convolution_pointwise_binary_"
6164        # TODO: op.call: input[0] should be at::Tensor&
6165        self.cpp_op_schema = """
6166            at::Tensor&(
6167                at::Tensor& other_t,
6168                const at::Tensor& input_t,
6169                const at::Tensor& weight_t,
6170                const c10::optional<at::Tensor>& bias_opt,
6171                at::IntArrayRef padding,
6172                at::IntArrayRef stride,
6173                at::IntArrayRef dilation,
6174                int64_t groups,
6175                c10::string_view binary_attr,
6176                c10::optional<at::Scalar> alpha,
6177                c10::optional<c10::string_view> unary_attr,
6178                torch::List<c10::optional<at::Scalar>> unary_scalars,
6179                c10::optional<c10::string_view> unary_algorithm)"""
6180
6181    def codegen(self, wrapper):
6182        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6183            self.get_name(),
6184            self.python_kernel_name,
6185            self.cpp_kernel_name,
6186            self.codegen_args(),
6187            self.cpp_op_schema,
6188            self.cpp_kernel_key,
6189            self.cpp_kernel_overload_name,
6190        )
6191
6192    def get_mutation_names(self):
6193        return [self.inputs[0].get_name()]
6194
6195    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
6196        return set()
6197
6198    @classmethod
6199    def create(
6200        cls,
6201        x: "TensorBox",
6202        other: "TensorBox",
6203        weight: "TensorBox",
6204        bias: "TensorBox",
6205        padding_: List[int],
6206        stride_: List[int],
6207        dilation_: List[int],
6208        groups: int,
6209        binary_attr: str,
6210        binary_alpha: Optional[float],
6211        unary_attr: Optional[str],
6212        unary_scalars: Optional[List[Any]],
6213        unary_algorithm: Optional[str],
6214    ):
6215        (
6216            inputs,
6217            constant_args,
6218            _,
6219            req_stride_order,
6220        ) = _prepare_convolution_fusion_create(
6221            cls, x, weight, bias, padding_, stride_, dilation_, groups
6222        )
6223        other = cls.require_stride_order(other, req_stride_order)
6224        inputs.insert(1, other)
6225        constant_args = constant_args + [
6226            binary_attr,
6227            binary_alpha,
6228            unary_attr,
6229            may_convert_to_optional(unary_scalars),
6230            unary_algorithm,
6231        ]
6232        packed = ConvolutionBinaryInplace(
6233            kernel_layout=NoneLayout(inputs[1].get_device()),  # type: ignore[arg-type]
6234            inputs=inputs,
6235            constant_args=constant_args,
6236        )
6237        mark_node_as_mutating(packed, inputs[1])
6238        # This op mutates in place which means that the result is not the
6239        # target but rather the input that is being mutated
6240        # init reorders the inputs, so inputs[1] becomes packed.inputs[0]
6241        return packed.inputs[0]
6242
6243
6244class MKLPackedLinear(ExternKernelAlloc):
6245    def __init__(
6246        self,
6247        layout,
6248        inputs,
6249        constant_args=(),
6250    ):
6251        super().__init__(
6252            layout,
6253            inputs,
6254            constant_args,
6255            None,
6256            python_kernel_name="torch.ops.mkl._mkl_linear",
6257            cpp_kernel_name="mkl::_mkl_linear",
6258        )
6259        self.cpp_kernel_key = "mkl_linear"
6260        self.cpp_op_schema = """
6261            at::Tensor(
6262                const at::Tensor& self,
6263                const at::Tensor& mkl_weight_t,
6264                const at::Tensor& origin_weight_t,
6265                const c10::optional<at::Tensor>& bias_opt,
6266                const int64_t prepack_batch_size)"""
6267
6268    def codegen(self, wrapper):
6269        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6270            self.get_name(),
6271            self.python_kernel_name,
6272            self.cpp_kernel_name,
6273            self.codegen_args(),
6274            self.cpp_op_schema,
6275            self.cpp_kernel_key,
6276        )
6277
6278    @classmethod
6279    def create(cls, x, packed_w, orig_w, B, batch_size):
6280        x = cls.require_stride1(cls.realize_input(x))
6281        orig_w = cls.require_stride1(cls.realize_input(orig_w))
6282        *m, _ = x.get_size()
6283        oc, _ = orig_w.get_size()
6284        output_size = list(m) + [oc]
6285        output_stride = FlexibleLayout.contiguous_strides(output_size)
6286        inputs = [x, packed_w, orig_w]
6287        constant_args = [batch_size]
6288        if B is not None:
6289            inputs += [B]
6290        else:
6291            constant_args.insert(0, None)
6292
6293        return MKLPackedLinear(
6294            layout=FixedLayout(
6295                x.get_device(), x.get_dtype(), output_size, output_stride
6296            ),
6297            inputs=inputs,
6298            constant_args=constant_args,
6299        )
6300
6301
6302class LinearUnary(ExternKernelAlloc):
6303    def __init__(
6304        self,
6305        layout,
6306        inputs,
6307        constant_args=(),
6308    ):
6309        super().__init__(
6310            layout,
6311            inputs,
6312            constant_args,
6313            None,
6314            python_kernel_name="torch.ops.mkldnn._linear_pointwise",
6315            cpp_kernel_name="mkldnn::_linear_pointwise",
6316        )
6317        self.cpp_kernel_key = "linear_pointwise"
6318        self.cpp_op_schema = """
6319            at::Tensor(
6320                const at::Tensor& input_t,
6321                const at::Tensor& weight_t,
6322                const c10::optional<at::Tensor>& bias_opt,
6323                c10::string_view attr,
6324                torch::List<c10::optional<at::Scalar>> scalars,
6325                c10::optional<c10::string_view> algorithm)"""
6326
6327    def codegen(self, wrapper):
6328        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6329            self.get_name(),
6330            self.python_kernel_name,
6331            self.cpp_kernel_name,
6332            self.codegen_args(),
6333            self.cpp_op_schema,
6334            self.cpp_kernel_key,
6335        )
6336
6337    @classmethod
6338    def create(cls, x, w, b, attr, scalars, algorithm):
6339        x = cls.require_contiguous(cls.realize_input(x))
6340        w = cls.require_contiguous(cls.realize_input(w))
6341
6342        *m, ic = x.get_size()
6343        oc, ic = w.get_size()
6344        inputs = [x, w]
6345        constant_args = [attr, scalars if scalars else [-1], algorithm]
6346        if b is not None:
6347            b = cls.require_contiguous(cls.realize_input(b))
6348            inputs.append(b)
6349        else:
6350            constant_args.insert(0, None)
6351
6352        return LinearUnary(
6353            layout=FlexibleLayout(
6354                device=x.get_device(),
6355                dtype=x.get_dtype(),
6356                size=list(m) + [oc],
6357            ),
6358            inputs=inputs,
6359            constant_args=constant_args,
6360        )
6361
6362    def apply_constraint(self):
6363        pass
6364
6365
6366class LinearBinary(ExternKernelAlloc):
6367    kernel = "torch.ops.mkldnn._linear_pointwise.binary"
6368
6369    def __init__(
6370        self,
6371        layout,
6372        inputs,
6373        constant_args=(),
6374    ):
6375        super().__init__(
6376            layout,
6377            inputs,
6378            constant_args,
6379            None,
6380            python_kernel_name="torch.ops.mkldnn._linear_pointwise.binary",
6381            cpp_kernel_name="mkldnn::_linear_pointwise",
6382        )
6383        self.cpp_kernel_overload_name = "binary"
6384        self.cpp_kernel_key = "linear_pointwise_binary"
6385        self.cpp_op_schema = """
6386            at::Tensor(
6387                const at::Tensor& input_t,
6388                const at::Tensor& other_t,
6389                const at::Tensor& weight_t,
6390                const c10::optional<at::Tensor>& bias_opt,
6391                c10::string_view attr)
6392        """
6393
6394    def codegen(self, wrapper):
6395        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6396            self.get_name(),
6397            self.python_kernel_name,
6398            self.cpp_kernel_name,
6399            self.codegen_args(),
6400            self.cpp_op_schema,
6401            self.cpp_kernel_key,
6402            self.cpp_kernel_overload_name,
6403        )
6404
6405    @classmethod
6406    def create(cls, x, y, w, B, attr):
6407        x = cls.require_contiguous(cls.realize_input(x))
6408        y = cls.require_contiguous(cls.realize_input(y))
6409        w = cls.require_contiguous(cls.realize_input(w))
6410
6411        *m, ic = x.get_size()
6412        oc, ic = w.get_size()
6413
6414        inputs = [x, y, w]
6415        constant_args = [attr]
6416        if B is not None:
6417            B = cls.require_contiguous(cls.realize_input(B))
6418            inputs.append(B)
6419        else:
6420            constant_args.insert(0, B)
6421
6422        return LinearBinary(
6423            layout=FlexibleLayout(
6424                device=x.get_device(),
6425                dtype=x.get_dtype(),
6426                size=list(m) + [oc],
6427            ),
6428            inputs=inputs,
6429            constant_args=constant_args,
6430        )
6431
6432    def apply_constraint(self):
6433        pass
6434
6435
6436class ConvolutionTransposeUnary(ExternKernelAlloc):
6437    def __init__(
6438        self,
6439        layout,
6440        inputs,
6441        constant_args=(),
6442    ):
6443        super().__init__(
6444            layout,
6445            inputs,
6446            constant_args,
6447            None,
6448            python_kernel_name="torch.ops.mkldnn._convolution_transpose_pointwise",
6449            cpp_kernel_name="mkldnn::_convolution_transpose_pointwise",
6450        )
6451        self.cpp_kernel_key = "convolution_transpose_pointwise"
6452        self.cpp_op_schema = """
6453            at::Tensor(
6454                const at::Tensor& input_t,
6455                const at::Tensor& weight_t,
6456                const c10::optional<at::Tensor>& bias_opt,
6457                at::IntArrayRef padding,
6458                at::IntArrayRef output_padding,
6459                at::IntArrayRef stride,
6460                at::IntArrayRef dilation,
6461                int64_t groups,
6462                c10::string_view attr,
6463                torch::List<c10::optional<at::Scalar>> scalars,
6464                c10::optional<c10::string_view> algorithm)"""
6465
6466    def codegen(self, wrapper):
6467        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6468            self.get_name(),
6469            self.python_kernel_name,
6470            self.cpp_kernel_name,
6471            self.codegen_args(),
6472            self.cpp_op_schema,
6473            self.cpp_kernel_key,
6474        )
6475
6476    @classmethod
6477    def create(
6478        cls,
6479        x: "TensorBox",
6480        weight: "TensorBox",
6481        bias: "TensorBox",
6482        padding_: List[int],
6483        output_padding_: List[int],
6484        stride_: List[int],
6485        dilation_: List[int],
6486        groups_: int,
6487        attr,
6488        scalars: Optional[List[Any]],
6489        algorithm,
6490    ):
6491        transposed = True
6492        (
6493            inputs,
6494            constant_args,
6495            kernel_layout,
6496            _,
6497        ) = _prepare_convolution_fusion_create(
6498            cls,
6499            x,
6500            weight,
6501            bias,
6502            padding_,
6503            stride_,
6504            dilation_,
6505            groups_,
6506            transposed,
6507            output_padding_,
6508        )
6509        constant_args = constant_args + [
6510            attr,
6511            may_convert_to_optional(scalars),
6512            algorithm,
6513        ]
6514        return ConvolutionTransposeUnary(
6515            layout=kernel_layout,
6516            inputs=inputs,
6517            constant_args=constant_args,
6518        )
6519
6520
6521class MkldnnRnnLayer(ExternKernelAlloc):
6522    def __init__(
6523        self,
6524        layout,
6525        inputs,
6526        constant_args=(),
6527    ):
6528        super().__init__(
6529            layout,
6530            inputs,
6531            constant_args,
6532            None,
6533            python_kernel_name="aten.mkldnn_rnn_layer",
6534            cpp_kernel_name="at::mkldnn_rnn_layer",
6535        )
6536
6537    @classmethod
6538    def create(
6539        cls,
6540        x: "TensorBox",
6541        w0: "TensorBox",
6542        w1: "TensorBox",
6543        w2: "TensorBox",
6544        w3: "TensorBox",
6545        hx: "TensorBox",
6546        cx: "TensorBox",
6547        reverse: bool,
6548        batch_sizes: List[int],
6549        mode: int,
6550        hidden_size: int,
6551        num_layers: int,
6552        has_biases: bool,
6553        bidirectional: bool,
6554        batch_first: bool,
6555        train: bool,
6556    ):
6557        x = cls.require_stride1(cls.realize_input(x))
6558        # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer.
6559        # Make sure x is contiguous in batch_first case.
6560        x.freeze_layout()
6561        w0 = cls.require_stride1(cls.realize_input(w0))
6562        w1 = cls.require_stride1(cls.realize_input(w1))
6563        w2 = cls.require_stride1(cls.realize_input(w2))
6564        w3 = cls.require_stride1(cls.realize_input(w3))
6565        hx = cls.require_stride1(cls.realize_input(hx))
6566        hx.freeze_layout()
6567        cx = cls.require_stride1(cls.realize_input(cx))
6568        cx.freeze_layout()
6569
6570        input_size = x.get_size()
6571        assert len(input_size) == 3, "Expect lstm input to be 3D"
6572        # batch_first is handled in the lstm OP. When entering
6573        # rnn_layer here, we'll always have batch_first = False
6574        seq_length, mini_batch, input_size = input_size
6575        output_shape = [seq_length, mini_batch, hidden_size]
6576
6577        hy_shape = hx.get_size()
6578        cy_shape = cx.get_size()
6579
6580        res: List[IRNode] = []
6581
6582        inputs = [x, w0, w1, w2, w3, hx, cx]
6583        constant_args = [
6584            reverse,
6585            batch_sizes,
6586            mode,
6587            hidden_size,
6588            num_layers,
6589            has_biases,
6590            bidirectional,
6591            batch_first,
6592            train,
6593        ]
6594
6595        packed = MkldnnRnnLayer(
6596            MultiOutputLayout(x.get_device()),
6597            inputs=inputs,
6598            constant_args=constant_args,
6599        )
6600
6601        def get_strides_of_lstm_output(output_shape, batch_first):
6602            assert len(output_shape) == 3, "Expect output_shape to be 3D"
6603            return FlexibleLayout.contiguous_strides(output_shape)
6604
6605        output_sizes = [output_shape, hy_shape, cy_shape]
6606        output_strides = [
6607            get_strides_of_lstm_output(output_shape, batch_first),
6608            FlexibleLayout.contiguous_strides(hy_shape),
6609            FlexibleLayout.contiguous_strides(cy_shape),
6610        ]
6611        output_ir = [
6612            MultiOutput(
6613                FixedLayout(
6614                    x.get_device(),
6615                    x.get_dtype(),
6616                    output_size,
6617                    output_stride,
6618                ),
6619                packed,
6620                [(tuple, i)],
6621            )
6622            for i, (output_size, output_stride) in enumerate(
6623                zip(output_sizes, output_strides)
6624            )
6625        ]
6626
6627        return output_ir
6628
6629
6630class QConvPointWisePT2E(ExternKernelAlloc):
6631    def __init__(
6632        self,
6633        layout,
6634        inputs,
6635        constant_args=(),
6636    ):
6637        """
6638        if bias is not None
6639            - inputs = [x, w, b, weight_scale, weight_zp]
6640            - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp,
6641              fp32_output, unary_attr, unary_scalars, unary_algorithm]
6642        else
6643            - inputs = [x, w, weight_scale, weight_zp]
6644            - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp,
6645              fp32_output, unary_attr, unary_scalars, unary_algorithm]
6646        """
6647        self.has_bias = len(inputs) == 5
6648        super().__init__(
6649            layout,
6650            inputs,
6651            constant_args,
6652            None,
6653            python_kernel_name="torch.ops.onednn.qconv2d_pointwise",
6654            cpp_kernel_name="onednn::qconv2d_pointwise",
6655        )
6656        self.cpp_kernel_key = "qconv2d_pointwise"
6657        self.cpp_op_schema = """
6658            at::Tensor(
6659                at::Tensor act,
6660                double act_scale,
6661                int64_t act_zero_point,
6662                at::Tensor weight,
6663                at::Tensor weight_scales,
6664                at::Tensor weight_zero_points,
6665                c10::optional<at::Tensor> bias,
6666                torch::List<int64_t> stride,
6667                torch::List<int64_t> padding,
6668                torch::List<int64_t> dilation,
6669                int64_t groups,
6670                double output_scale,
6671                int64_t output_zero_point,
6672                c10::optional<c10::ScalarType> output_dtype,
6673                c10::string_view attr,
6674                torch::List<c10::optional<at::Scalar>> scalars,
6675                c10::optional<c10::string_view> algorithm)"""
6676
6677    def codegen(self, wrapper):
6678        # Parser the inputs and constant
6679        args = [x.codegen_reference() for x in self.inputs]
6680        const_args = []
6681        const_args.extend(self.codegen_const_args())
6682
6683        x = args[0]
6684        packed_weight = args[1]
6685        bias = args[2] if self.has_bias else const_args[0]
6686        w_scale, w_zp = args[-2], args[-1]
6687        (
6688            stride,
6689            padding,
6690            dilation,
6691            groups,
6692            x_scale,
6693            x_zp,
6694            o_inv_scale,
6695            o_zp,
6696            output_dtype,
6697            unary_attr,
6698            unary_scalars,
6699            unary_algorithm,
6700        ) = const_args[-12:]
6701
6702        codegen_args = (
6703            x,
6704            x_scale,
6705            x_zp,
6706            packed_weight,
6707            w_scale,
6708            w_zp,
6709            bias,
6710            stride,
6711            padding,
6712            dilation,
6713            groups,
6714            o_inv_scale,
6715            o_zp,
6716            output_dtype,
6717            unary_attr,
6718            unary_scalars,
6719            unary_algorithm,
6720        )
6721        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6722            self.get_name(),
6723            self.python_kernel_name,
6724            self.cpp_kernel_name,
6725            codegen_args,
6726            self.cpp_op_schema,
6727            self.cpp_kernel_key,
6728        )
6729        if isinstance(self.layout, Layout):
6730            self.codegen_size_asserts(wrapper)
6731
6732    @classmethod
6733    def create(
6734        cls,
6735        x: "TensorBox",
6736        x_scale: float,
6737        x_zp: int,
6738        weight: "TensorBox",  # packed_weight
6739        w_scale: "TensorBox",
6740        w_zp: "TensorBox",
6741        bias: "TensorBox",
6742        stride_: List[int],
6743        padding_: List[int],
6744        dilation_: List[int],
6745        groups: int,
6746        o_inv_scale: float,
6747        output_zero_point: int,
6748        output_dtype,
6749        unary_attr,
6750        unary_scalars,
6751        unary_algorithm,
6752    ):
6753        transposed = False
6754        output_padding = None
6755        (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
6756            cls,
6757            x,
6758            weight,
6759            bias,
6760            padding_,
6761            stride_,
6762            dilation_,
6763            groups,
6764            transposed,
6765            output_padding,
6766        )
6767        # swap padding and stride to align with functional conv arg order
6768        if bias is None:
6769            constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
6770        else:
6771            constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
6772
6773        w_scale.realize()
6774        w_zp.realize()
6775        inputs = inputs + [w_scale, w_zp]
6776        constant_args = constant_args + [
6777            x_scale,
6778            x_zp,
6779            o_inv_scale,
6780            output_zero_point,
6781            output_dtype,
6782            unary_attr,
6783            may_convert_to_optional(unary_scalars),
6784            unary_algorithm,
6785        ]
6786
6787        if output_dtype is not None:
6788            assert output_dtype in [torch.float32, torch.bfloat16]
6789            # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout
6790            # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8.
6791            kernel_layout.dtype = output_dtype
6792
6793        return QConvPointWisePT2E(
6794            layout=kernel_layout,
6795            inputs=inputs,
6796            constant_args=constant_args,
6797        )
6798
6799
6800class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
6801    def __init__(
6802        self,
6803        layout,
6804        inputs,
6805        constant_args=(),
6806    ):
6807        """
6808        Needs input/weight/output qparams
6809        if bias is not None
6810            - inputs = [x, w, b, accum, w_scale, w_zp]
6811            - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, o_zp,
6812            fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
6813        else
6814            - inputs = [x, w, accum, w_scale, w_zp]
6815            - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale,
6816            accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
6817        """
6818        self.has_bias = len(inputs) == 6
6819        self.idx_for_inplace_sum = 3 if self.has_bias else 2
6820        super().__init__(
6821            layout,
6822            inputs,
6823            constant_args,
6824            None,
6825            python_kernel_name="torch.ops.onednn.qconv2d_pointwise.binary",
6826            cpp_kernel_name="onednn::qconv2d_pointwise",
6827        )
6828        self.cpp_kernel_overload_name = "binary"
6829        self.cpp_kernel_key = "qconv2d_pointwise_binary"
6830        self.cpp_op_schema = """
6831            at::Tensor(
6832                at::Tensor act,
6833                double act_scale,
6834                int64_t act_zero_point,
6835                at::Tensor accum,
6836                double accum_scale,
6837                int64_t accum_zero_point,
6838                at::Tensor weight,
6839                at::Tensor weight_scales,
6840                at::Tensor weight_zero_points,
6841                c10::optional<at::Tensor> bias,
6842                torch::List<int64_t> stride,
6843                torch::List<int64_t> padding,
6844                torch::List<int64_t> dilation,
6845                int64_t groups,
6846                double output_scale,
6847                int64_t output_zero_point,
6848                c10::optional<c10::ScalarType> output_dtype,
6849                c10::string_view binary_attr,
6850                c10::optional<at::Scalar> alpha,
6851                c10::optional<c10::string_view> attr,
6852                torch::List<c10::optional<at::Scalar>> scalars,
6853                c10::optional<c10::string_view> algorithm)"""
6854
6855    def codegen(self, wrapper):
6856        # Parser the inputs and constant
6857        args = [x.codegen_reference() for x in self.inputs]
6858        const_args = []
6859        const_args.extend(self.codegen_const_args())
6860
6861        x = args[0]
6862        packed_weight = args[1]
6863        bias = args[2] if self.has_bias else const_args[0]
6864        accum, w_scale, w_zp = args[-3], args[-2], args[-1]
6865        (
6866            stride,
6867            padding,
6868            dilation,
6869            groups,
6870            x_scale,
6871            x_zp,
6872            accum_scale,
6873            accum_zp,
6874            o_inv_scale,
6875            o_zp,
6876            output_dtype,
6877            binary_attr,
6878            alpha,
6879            unary_attr,
6880            unary_scalars,
6881            unary_algorithm,
6882        ) = const_args[-16:]
6883        conv_args = (
6884            x,
6885            x_scale,
6886            x_zp,
6887            accum,
6888            accum_scale,
6889            accum_zp,
6890            packed_weight,
6891            w_scale,
6892            w_zp,
6893            bias,
6894            stride,
6895            padding,
6896            dilation,
6897            groups,
6898            o_inv_scale,
6899            o_zp,
6900            output_dtype,
6901            binary_attr,
6902            alpha,
6903            unary_attr,
6904            unary_scalars,
6905            unary_algorithm,
6906        )
6907        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6908            self.get_name(),
6909            self.python_kernel_name,
6910            self.cpp_kernel_name,
6911            conv_args,
6912            self.cpp_op_schema,
6913            self.cpp_kernel_key,
6914            self.cpp_kernel_overload_name,
6915        )
6916        if isinstance(self.layout, Layout):
6917            self.codegen_size_asserts(wrapper)
6918
6919    def get_mutation_names(self):
6920        return [self.inputs[self.idx_for_inplace_sum].get_name()]
6921
6922    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
6923        return set()
6924
6925    @classmethod
6926    def create(
6927        cls,
6928        x: "TensorBox",
6929        x_scale,
6930        x_zp,
6931        accum: "TensorBox",
6932        accum_scale,
6933        accum_zp,
6934        weight: "TensorBox",  # packed_weight
6935        w_scale,
6936        w_zp,
6937        bias: "TensorBox",
6938        stride_: List[int],
6939        padding_: List[int],
6940        dilation_: List[int],
6941        groups: int,
6942        o_inv_scale: "TensorBox",
6943        output_zero_point: "TensorBox",
6944        output_dtype,
6945        binary_attr,
6946        alpha,
6947        unary_attr,
6948        unary_scalars,
6949        unary_algorithm,
6950    ):
6951        transposed = False
6952        output_padding = None
6953        (
6954            inputs,
6955            constant_args,
6956            kernel_layout,
6957            req_stride_order,
6958        ) = _prepare_convolution_fusion_create(
6959            cls,
6960            x,
6961            weight,
6962            bias,
6963            padding_,
6964            stride_,
6965            dilation_,
6966            groups,
6967            transposed,
6968            output_padding,
6969        )
6970
6971        accum = cls.require_stride_order(accum, req_stride_order)
6972        inputs.append(accum)
6973
6974        # swap padding and stride to align with functional conv arg order
6975        if bias is None:
6976            constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
6977        else:
6978            constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
6979
6980        w_scale.realize()
6981        w_zp.realize()
6982        inputs = inputs + [w_scale, w_zp]
6983        constant_args = constant_args + [
6984            x_scale,
6985            x_zp,
6986            accum_scale,
6987            accum_zp,
6988            o_inv_scale,
6989            output_zero_point,
6990            output_dtype,
6991            binary_attr,
6992            alpha,
6993            unary_attr,
6994            may_convert_to_optional(unary_scalars),
6995            unary_algorithm,
6996        ]
6997
6998        assert (
6999            binary_attr == "sum"
7000        ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
7001
7002        packed = QConvPointWiseBinaryPT2E(
7003            layout=NoneLayout(accum.get_device()),
7004            inputs=inputs,
7005            constant_args=constant_args,
7006        )
7007        mark_node_as_mutating(packed, accum)
7008
7009        # Return accum since it has been inplace changed.
7010        return packed.inputs[packed.idx_for_inplace_sum]
7011
7012
7013class QLinearPointwisePT2E(ExternKernelAlloc):
7014    def __init__(
7015        self,
7016        layout,
7017        inputs,
7018        constant_args=(),
7019        has_bias=True,
7020        x_scale_zp_are_tensors=False,
7021    ):
7022        """
7023        if bias is not None
7024            - inputs = [x, w, b, weight_scale, weight_zp]
7025            - const_args is: [x_scale, x_zp, o_inv_scale, o_zp,
7026              fp32_output, unary_attr, unary_scalars, unary_algorithm]
7027        else
7028            - inputs = [x, w, weight_scale, weight_zp]
7029            - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp,
7030              fp32_output, unary_attr, unary_scalars, unary_algorithm]
7031        """
7032        self.has_bias = has_bias
7033        self.x_scale_zp_are_tensors = x_scale_zp_are_tensors
7034        super().__init__(
7035            layout,
7036            inputs,
7037            constant_args,
7038            None,
7039            python_kernel_name=(
7040                "torch.ops.onednn.qlinear_pointwise.tensor"
7041                if x_scale_zp_are_tensors
7042                else "torch.ops.onednn.qlinear_pointwise.default"
7043            ),
7044            cpp_kernel_name="onednn::qlinear_pointwise",
7045        )
7046        self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else ""
7047        self.cpp_kernel_key = "qlinear_pointwise"
7048        x_scale_type_str, x_zp_type_str = (
7049            ("at::Tensor", "at::Tensor")
7050            if x_scale_zp_are_tensors
7051            else ("double", "int64_t")
7052        )
7053        self.cpp_op_schema = f"""
7054            at::Tensor(
7055                at::Tensor act,
7056                {x_scale_type_str} act_scale,
7057                {x_zp_type_str} act_zero_point,
7058                at::Tensor weight,
7059                at::Tensor weight_scales,
7060                at::Tensor weight_zero_points,
7061                c10::optional<at::Tensor> bias,
7062                double output_scale,
7063                int64_t output_zero_point,
7064                c10::optional<c10::ScalarType> output_dtype,
7065                c10::string_view post_op_name,
7066                torch::List<c10::optional<at::Scalar>> post_op_args,
7067                c10::string_view post_op_algorithm)"""
7068
7069    def codegen(self, wrapper):
7070        # Parser the inputs and constant
7071        args = [x.codegen_reference() for x in self.inputs]
7072        const_args = []
7073        const_args.extend(self.codegen_const_args())
7074
7075        x = args[0]
7076        packed_weight = args[1]
7077        bias = args[2] if self.has_bias else const_args[0]
7078        w_scale, w_zp = args[-2], args[-1]
7079        if self.x_scale_zp_are_tensors:
7080            assert len(args) >= 4
7081            x_scale, x_zp = args[-4], args[-3]
7082            (
7083                o_inv_scale,
7084                o_zp,
7085                output_dtype,
7086                unary_attr,
7087                unary_scalars,
7088                unary_algorithm,
7089            ) = const_args[-6:]
7090        else:
7091            assert len(const_args) >= 8
7092            (
7093                x_scale,
7094                x_zp,
7095                o_inv_scale,
7096                o_zp,
7097                output_dtype,
7098                unary_attr,
7099                unary_scalars,
7100                unary_algorithm,
7101            ) = const_args[-8:]
7102
7103        codegen_args = (
7104            x,
7105            x_scale,
7106            x_zp,
7107            packed_weight,
7108            w_scale,
7109            w_zp,
7110            bias,
7111            o_inv_scale,
7112            o_zp,
7113            output_dtype,
7114            unary_attr,
7115            unary_scalars,
7116            unary_algorithm,
7117        )
7118        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
7119            self.get_name(),
7120            self.python_kernel_name,
7121            self.cpp_kernel_name,
7122            codegen_args,
7123            self.cpp_op_schema,
7124            self.cpp_kernel_key,
7125            self.cpp_kernel_overload_name,
7126        )
7127        if isinstance(self.layout, Layout):
7128            self.codegen_size_asserts(wrapper)
7129
7130    @classmethod
7131    def create(
7132        cls,
7133        x: "TensorBox",
7134        x_scale: float,
7135        x_zp: int,
7136        weight: "TensorBox",  # packed_weight
7137        w_scale: "TensorBox",
7138        w_zp: "TensorBox",
7139        bias: "TensorBox",
7140        o_inv_scale: float,
7141        output_zero_point: int,
7142        output_dtype,
7143        unary_attr,
7144        unary_scalars,
7145        unary_algorithm,
7146    ):
7147        (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create(
7148            cls,
7149            x,
7150            weight,
7151            bias,
7152        )
7153
7154        if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox):
7155            x_scale.realize()
7156            x_zp.realize()
7157            inputs = inputs + [x_scale, x_zp]
7158            x_scale_zp_are_tensors = True
7159        else:
7160            assert isinstance(x_scale, float) and isinstance(x_zp, int)
7161            constant_args = constant_args + [x_scale, x_zp]
7162            x_scale_zp_are_tensors = False
7163        w_scale.realize()
7164        w_zp.realize()
7165        inputs = inputs + [w_scale, w_zp]
7166        constant_args = constant_args + [
7167            o_inv_scale,
7168            output_zero_point,
7169            output_dtype,
7170            unary_attr,
7171            may_convert_to_optional(unary_scalars),
7172            unary_algorithm,
7173        ]
7174
7175        if output_dtype is not None:
7176            assert output_dtype in [torch.float32, torch.bfloat16]
7177            # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout
7178            # if we set fp32_output, the output buf should be dtype float32 instead of uint8.
7179            kernel_layout.dtype = output_dtype
7180
7181        return QLinearPointwisePT2E(
7182            layout=kernel_layout,
7183            inputs=inputs,
7184            constant_args=constant_args,
7185            has_bias=(bias is not None),
7186            x_scale_zp_are_tensors=x_scale_zp_are_tensors,
7187        )
7188
7189
7190class QLinearPointwiseBinaryPT2E(ExternKernelAlloc):
7191    def __init__(
7192        self,
7193        layout,
7194        inputs,
7195        constant_args=(),
7196        has_bias=True,
7197        x_scale_zp_are_tensors=False,
7198    ):
7199        """
7200        if bias is not None
7201            - inputs = [x, w, b, weight_scale, weight_zp, x2]
7202            - const_args is: [x_scale, x_zp, o_inv_scale, o_zp,
7203              fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
7204        else
7205            - inputs = [x, w, weight_scale, weight_zp, x2]
7206            - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp,
7207              fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
7208        """
7209        self.has_bias = has_bias
7210        self.x_scale_zp_are_tensors = x_scale_zp_are_tensors
7211        super().__init__(
7212            layout,
7213            inputs,
7214            constant_args,
7215            None,
7216            python_kernel_name=(
7217                "torch.ops.onednn.qlinear_pointwise.binary_tensor"
7218                if x_scale_zp_are_tensors
7219                else "torch.ops.onednn.qlinear_pointwise.binary"
7220            ),
7221            cpp_kernel_name="onednn::qlinear_pointwise",
7222        )
7223        self.cpp_kernel_overload_name = (
7224            "binary_tensor" if x_scale_zp_are_tensors else "binary"
7225        )
7226        self.cpp_kernel_key = "qlinear_pointwise_binary"
7227        x_scale_type_str, x_zp_type_str = (
7228            ("at::Tensor", "at::Tensor")
7229            if x_scale_zp_are_tensors
7230            else ("double", "int64_t")
7231        )
7232        self.cpp_op_schema = f"""
7233            at::Tensor(
7234                at::Tensor act,
7235                {x_scale_type_str} act_scale,
7236                {x_zp_type_str} act_zero_point,
7237                at::Tensor weight,
7238                at::Tensor weight_scales,
7239                at::Tensor weight_zero_points,
7240                c10::optional<at::Tensor> bias,
7241                double inv_output_scale,
7242                int64_t output_zero_point,
7243                c10::optional<c10::ScalarType> output_dtype,
7244                c10::optional<at::Tensor> other,
7245                double other_scale,
7246                int64_t other_zero_point,
7247                c10::string_view binary_post_op,
7248                double binary_alpha,
7249                c10::string_view unary_post_op,
7250                torch::List<c10::optional<at::Scalar>> unary_post_op_args,
7251                c10::string_view unary_post_op_algorithm)"""
7252
7253    def codegen(self, wrapper):
7254        # Parser the inputs and constant
7255        args = [x.codegen_reference() for x in self.inputs]
7256        const_args = []
7257        const_args.extend(self.codegen_const_args())
7258
7259        x = args[0]
7260        packed_weight = args[1]
7261        bias = args[2] if self.has_bias else const_args[0]
7262        w_scale, w_zp, other = args[-3], args[-2], args[-1]
7263        if self.x_scale_zp_are_tensors:
7264            assert len(args) >= 5
7265            x_scale, x_zp = args[-5], args[-4]
7266            (
7267                o_inv_scale,
7268                o_zp,
7269                output_dtype,
7270                other_scale,
7271                other_zp,
7272                binary_attr,
7273                alpha,
7274                unary_attr,
7275                unary_scalars,
7276                unary_algorithm,
7277            ) = const_args[-10:]
7278        else:
7279            assert len(const_args) >= 8
7280            (
7281                x_scale,
7282                x_zp,
7283                o_inv_scale,
7284                o_zp,
7285                output_dtype,
7286                other_scale,
7287                other_zp,
7288                binary_attr,
7289                alpha,
7290                unary_attr,
7291                unary_scalars,
7292                unary_algorithm,
7293            ) = const_args[-12:]
7294
7295        codegen_args = (
7296            x,
7297            x_scale,
7298            x_zp,
7299            packed_weight,
7300            w_scale,
7301            w_zp,
7302            bias,
7303            o_inv_scale,
7304            o_zp,
7305            output_dtype,
7306            other,
7307            other_scale,
7308            other_zp,
7309            binary_attr,
7310            alpha,
7311            unary_attr,
7312            unary_scalars,
7313            unary_algorithm,
7314        )
7315        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
7316            self.get_name(),
7317            self.python_kernel_name,
7318            self.cpp_kernel_name,
7319            codegen_args,
7320            self.cpp_op_schema,
7321            self.cpp_kernel_key,
7322            self.cpp_kernel_overload_name,
7323        )
7324        if isinstance(self.layout, Layout):
7325            self.codegen_size_asserts(wrapper)
7326
7327    @classmethod
7328    def create(
7329        cls,
7330        x: "TensorBox",
7331        x_scale: float,
7332        x_zp: int,
7333        weight: "TensorBox",  # packed_weight
7334        w_scale: "TensorBox",
7335        w_zp: "TensorBox",
7336        bias: "TensorBox",
7337        o_inv_scale: float,
7338        output_zero_point: int,
7339        output_dtype,
7340        other: "TensorBox",
7341        other_scale,
7342        other_zp,
7343        binary_attr,
7344        alpha,
7345        unary_attr,
7346        unary_scalars,
7347        unary_algorithm,
7348    ):
7349        (
7350            inputs,
7351            constant_args,
7352            kernel_layout,
7353            req_stride_order,
7354        ) = _prepare_linear_fusion_create(
7355            cls,
7356            x,
7357            weight,
7358            bias,
7359        )
7360
7361        if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox):
7362            x_scale.realize()
7363            x_zp.realize()
7364            inputs = inputs + [x_scale, x_zp]
7365            x_scale_zp_are_tensors = True
7366        else:
7367            assert isinstance(x_scale, float) and isinstance(x_zp, int)
7368            constant_args = constant_args + [x_scale, x_zp]
7369            x_scale_zp_are_tensors = False
7370        w_scale.realize()
7371        w_zp.realize()
7372        inputs = inputs + [w_scale, w_zp]
7373        if binary_attr == "sum":
7374            other = cls.require_stride_order(other, req_stride_order)
7375        inputs.append(other)
7376        constant_args = constant_args + [
7377            o_inv_scale,
7378            output_zero_point,
7379            output_dtype,
7380            other_scale,
7381            other_zp,
7382            binary_attr,
7383            alpha,
7384            unary_attr,
7385            may_convert_to_optional(unary_scalars),
7386            unary_algorithm,
7387        ]
7388
7389        if binary_attr == "sum":
7390            packed = QLinearPointwiseBinaryPT2E(
7391                layout=NoneLayout(other.get_device()),
7392                inputs=inputs,
7393                constant_args=constant_args,
7394                has_bias=(bias is not None),
7395                x_scale_zp_are_tensors=x_scale_zp_are_tensors,
7396            )
7397            mark_node_as_mutating(packed, other)
7398            # Return other since it has been inplace changed.
7399            return packed.inputs[-1]
7400
7401        if output_dtype is not None:
7402            assert output_dtype in [torch.float32, torch.bfloat16]
7403            # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout
7404            # if we set fp32_output, the output buf should be dtype float32 instead of uint8.
7405            kernel_layout.dtype = output_dtype
7406
7407        return QLinearPointwiseBinaryPT2E(
7408            layout=kernel_layout,
7409            inputs=inputs,
7410            constant_args=constant_args,
7411            has_bias=(bias is not None),
7412            x_scale_zp_are_tensors=x_scale_zp_are_tensors,
7413        )
7414
7415
7416@dataclasses.dataclass
7417class MutableBox(IRNode):
7418    """
7419    TensorBox / StorageBox allow in-place mutation of Tensors
7420    """
7421
7422    data: IRNode
7423
7424    def __getattr__(self, name):
7425        fn = getattr(self.data, name)
7426        if callable(fn):
7427            return fn
7428        raise AttributeError(f"{type(self.data).__name__}.{name} not callable")
7429
7430    def realize(self):
7431        return self.data.realize()
7432
7433    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
7434        return self.data.get_unbacked_symbol_uses()
7435
7436    def codegen_reference(self, writer=None):
7437        return self.data.codegen_reference(writer)
7438
7439    @property
7440    def layout(self):
7441        return self.data.get_layout()
7442
7443    def get_layout(self):
7444        return self.layout
7445
7446    def get_size(self):
7447        return self.data.get_size()
7448
7449    @property
7450    def dtype(self):
7451        return self.data.dtype
7452
7453    def __str__(self):
7454        if isinstance(self.data, MutableBox):
7455            line0 = f"{type(self).__name__}({type(self.data).__name__}("
7456            endl = "))"
7457            inner = self.data.data
7458        else:
7459            line0 = f"{type(self).__name__}("
7460            inner = self.data
7461            endl = ")"
7462
7463        lines = [
7464            line0,
7465            indent(str(inner)),
7466            endl,
7467        ]
7468        return "\n".join(lines)
7469
7470    __repr__ = __str__
7471
7472
7473class TensorBox(MutableBox):
7474    @staticmethod
7475    def create(data):
7476        return TensorBox(StorageBox(data))
7477
7478
7479class StorageBox(MutableBox):
7480    def is_input_buffer(self):
7481        if isinstance(self.data, (InputBuffer, ReinterpretView)):
7482            return self.data.get_name() in V.graph.graph_inputs
7483        return False
7484
7485    def is_module_buffer(self):
7486        return (
7487            isinstance(self.data, (ConstantBuffer))
7488            and self.data.get_name() in V.graph.constants
7489        )
7490
7491    def realize(self):
7492        if isinstance(
7493            self.data,
7494            (
7495                ComputedBuffer,
7496                InputsKernel,
7497                InputBuffer,
7498                ReinterpretView,
7499                TemplateBuffer,
7500            ),
7501        ):
7502            return self.data.get_name()
7503        assert isinstance(self.data, (Pointwise, Reduction, Scan)), type(self.data)
7504        origin_node = self.data.get_origin_node()
7505        traceback = self.data.get_traceback()
7506        self.data = ComputedBuffer(
7507            name=None,
7508            layout=FlexibleLayout(
7509                device=self.data.get_device(),
7510                dtype=self.data.get_dtype(),
7511                size=self.data.get_size(),
7512            ),
7513            data=self.data,
7514        )
7515        self.data.name = V.graph.register_buffer(self.data)
7516        self.data.origins = self.origins
7517        self.data.origin_node = origin_node
7518        self.data.traceback = traceback
7519        return self.data.name
7520
7521    def realize_hint(self):
7522        """
7523        Called on buffers we expect to be forced to realize later.
7524        """
7525        if (
7526            isinstance(self.data, (Pointwise, Reduction))
7527            and self.num_reads() > 1
7528            and self.is_pointwise_non_scalar_tensor_num_reads_larger_than_one()
7529        ):
7530            self.realize()
7531
7532    def has_exceeded_max_reads(self):
7533        return isinstance(self.data, Pointwise) and (
7534            self.num_reads() > config.realize_acc_reads_threshold
7535            or self.has_large_inner_fn()
7536        )
7537
7538    def mark_reuse(self, users):
7539        """
7540        A heuristic to decide if we should realize a tensor
7541        that is used multiple times.
7542        """
7543
7544        def should_realize_on_cpu(loops: Union[Pointwise, Reduction]):
7545            """
7546            The heuristic for realizing reused result of heavy ops on cpu
7547            """
7548            heavy_ops = ["exp"]  # a list of heavy ops
7549            fn_str = loops.inner_fn_str()
7550            return any((op + "(") in fn_str for op in heavy_ops)
7551
7552        if (
7553            users > 1
7554            and isinstance(self.data, (Pointwise, Reduction))
7555            and (
7556                self.num_reads() > config.realize_reads_threshold
7557                or self.has_large_inner_fn()
7558                or (is_cpu(self.data) and should_realize_on_cpu(self.data))
7559            )
7560        ):
7561            self.realize()
7562
7563    @cache_on_self
7564    def num_reads(self):
7565        data = self.data
7566        if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)):
7567            return 1
7568        if isinstance(data, ComputedBuffer):
7569            read_writes = data.get_read_writes()
7570        else:
7571            assert isinstance(data, (Pointwise, Reduction)), type(data)
7572            read_writes = ComputedBuffer(
7573                name=None,
7574                layout=FlexibleLayout(
7575                    device=data.get_device(),
7576                    dtype=data.get_dtype(),
7577                    size=data.get_size(),
7578                ),
7579                data=data,
7580            ).get_read_writes()
7581        return len(read_writes.reads)
7582
7583    @cache_on_self
7584    def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self):
7585        # Skip the check for non Pointwise instances
7586        return (
7587            (sum(read.index != 0 for read in self.data.get_reads()) > 1)
7588            if isinstance(self.data, Pointwise)
7589            and all(
7590                not isinstance(read, dependencies.StarDep)
7591                for read in self.data.get_reads()
7592            )
7593            else True
7594        )
7595
7596
7597@dataclasses.dataclass
7598class Subgraph(IRNode):
7599    name: str
7600    graph_module: torch.fx.GraphModule
7601    graph: Optional["GraphLowering"] = None
7602
7603
7604def _has_aliased_buffers(buffers):
7605    buffers = [
7606        buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer
7607        for buffer in buffers
7608    ]
7609    # assuming the same buffer is represented by the same IRNode object
7610    return len({id(buffer) for buffer in buffers}) < len(buffers)
7611
7612
7613@dataclasses.dataclass
7614class Conditional(ExternKernel):
7615    predicate: Optional[IRNode] = None
7616    operands: Optional[List[TensorBox]] = None
7617    true_subgraph: Optional[Subgraph] = None
7618    false_subgraph: Optional[Subgraph] = None
7619    outputs: Optional[List[MultiOutput]] = None
7620
7621    def __init__(
7622        self,
7623        predicate: IRNode,
7624        operands: List[TensorBox],
7625        true_subgraph: Subgraph,
7626        false_subgraph: Subgraph,
7627        layout: MultiOutputLayout,
7628    ):
7629        self.predicate = predicate
7630        self.operands = operands
7631        self.true_subgraph = true_subgraph
7632        self.false_subgraph = false_subgraph
7633
7634        inputs = []
7635        if not isinstance(predicate, ShapeAsConstantBuffer):
7636            inputs.append(predicate)
7637        inputs.extend(operands)
7638
7639        super().__init__(
7640            name=None,
7641            layout=layout,  # type: ignore[arg-type]
7642            inputs=inputs,  # type: ignore[list-item]
7643        )
7644
7645        self.name = V.graph.register_buffer(self)
7646
7647    @classmethod
7648    def create(
7649        cls,
7650        predicate: TensorBox,
7651        true_fn: Subgraph,
7652        false_fn: Subgraph,
7653        operands: List[TensorBox],
7654    ):
7655        predicate = cls.realize_input(predicate)
7656        operands = [cls.realize_input(x) for x in operands]
7657
7658        fx_operands = V.graph.current_node.args[-1]
7659        fake_operands = [x.meta["val"] for x in fx_operands]  # type: ignore[union-attr]
7660
7661        for subgraph in (true_fn, false_fn):
7662            if subgraph.graph is None:
7663                # create and lower subgraphs
7664                subgraph.graph = V.graph.make_subgraph(
7665                    gm=subgraph.graph_module,
7666                    example_inputs=fake_operands,
7667                    subgraph_name=subgraph.name,
7668                )
7669                with V.set_graph_handler(subgraph.graph):
7670                    subgraph.graph.run(*fake_operands)
7671
7672        true_outputs = true_fn.graph.graph_outputs  # type: ignore[union-attr]
7673        false_outputs = true_fn.graph.graph_outputs  # type: ignore[union-attr]
7674
7675        for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)):
7676            if _has_aliased_buffers(true_outputs):
7677                raise AssertionError(
7678                    "Output aliasing is currently not supported in compiled torch.cond. "
7679                    f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}"
7680                )
7681
7682        # make sure true and false outputs are structurally equivalent
7683        assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs)
7684        for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)):
7685            assert to.get_size() == fo.get_size(), (i, to, fo)
7686            assert to.get_stride() == fo.get_stride(), (i, to, fo)
7687            assert to.get_device() == fo.get_device(), (i, to, fo)
7688            assert to.get_dtype() == fo.get_dtype(), (i, to, fo)
7689            assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo)
7690
7691        if not isinstance(predicate, ShapeAsConstantBuffer):
7692            # use predicate device for consistent codegen-ing
7693            device = predicate.get_device()
7694        else:
7695            # predicate is not a Tensor: use first operand's device
7696            assert (
7697                len(operands) > 0
7698            ), "When predicate is not a Tensor, there must be at least one operand in torch.cond."
7699            device = operands[0].get_device()
7700
7701        conditional = Conditional(
7702            predicate=predicate,
7703            operands=operands,
7704            true_subgraph=true_fn,
7705            false_subgraph=false_fn,
7706            layout=MultiOutputLayout(device),
7707        )
7708
7709        outputs = [
7710            MultiOutput(
7711                FixedLayout(
7712                    device=output.get_device(),
7713                    dtype=output.get_dtype(),
7714                    size=output.get_size(),
7715                    stride=output.get_stride(),
7716                    offset=output.get_layout().offset,
7717                ),
7718                conditional,
7719                [(list, i)],
7720            )
7721            # as the true and false outputs are equivalent,
7722            # we can use either of them here as a "template"
7723            for i, output in enumerate(true_outputs)
7724        ]
7725
7726        conditional.outputs = outputs
7727        return outputs
7728
7729    def codegen(self, wrapper):
7730        wrapper.codegen_conditional(self)
7731
7732
7733@dataclasses.dataclass
7734class WhileLoop(ExternKernel):
7735    carried_inputs: Optional[List[TensorBox]] = None
7736    additional_inputs: Optional[List[TensorBox]] = None
7737    cond_subgraph: Optional[Subgraph] = None
7738    body_subgraph: Optional[Subgraph] = None
7739    outputs: Optional[List[MultiOutput]] = None
7740
7741    def __init__(
7742        self,
7743        carried_inputs: List[TensorBox],
7744        additional_inputs: List[TensorBox],
7745        cond_subgraph: Subgraph,
7746        body_subgraph: Subgraph,
7747        layout: MultiOutputLayout,
7748    ):
7749        self.carried_inputs = carried_inputs
7750        self.additional_inputs = additional_inputs
7751        self.cond_subgraph = cond_subgraph
7752        self.body_subgraph = body_subgraph
7753
7754        super().__init__(
7755            name=None,
7756            layout=layout,  # type: ignore[arg-type]
7757            inputs=carried_inputs + additional_inputs,  # type: ignore[list-item]
7758        )
7759
7760        self.name = V.graph.register_buffer(self)
7761
7762    @classmethod
7763    def create(
7764        cls,
7765        cond_fn: Subgraph,
7766        body_fn: Subgraph,
7767        carried_inputs: List[TensorBox],
7768        additional_inputs: List[TensorBox],
7769    ):
7770        carried_inputs = [cls.realize_input(x) for x in carried_inputs]
7771        additional_inputs = [cls.realize_input(x) for x in additional_inputs]
7772        all_inputs = carried_inputs + additional_inputs
7773
7774        fx_all_inputs = V.graph.current_node.args[-2] + V.graph.current_node.args[-1]  # type: ignore[operator]
7775        fake_all_inputs = [x.meta["val"] for x in fx_all_inputs]  # type: ignore[union-attr]
7776
7777        for subgraph in (cond_fn, body_fn):
7778            if subgraph.graph is None:
7779                # create and lower subgraphs
7780                subgraph.graph = V.graph.make_subgraph(
7781                    gm=subgraph.graph_module,
7782                    example_inputs=fx_all_inputs,  # type: ignore[arg-type]
7783                    subgraph_name=subgraph.name,
7784                )
7785                with V.set_graph_handler(subgraph.graph):
7786                    subgraph.graph.run(*fake_all_inputs)
7787
7788        cond_outputs = cond_fn.graph.graph_outputs  # type: ignore[union-attr]
7789        body_outputs = body_fn.graph.graph_outputs  # type: ignore[union-attr]
7790
7791        if _has_aliased_buffers(body_outputs):
7792            raise AssertionError(
7793                "Output aliasing is currently not supported in compiled torch.while_loop. "
7794                f"The outputs of the body_fn subgraph of torch.while_loop are aliased: {body_outputs}"
7795            )
7796
7797        # make sure cond_fn returns a boolean scalar Tensor
7798        assert len(cond_outputs) == 1, cond_outputs
7799        assert cond_outputs[0].get_dtype() == torch.bool, cond_outputs
7800        assert len(cond_outputs[0].get_size()) == 0, cond_outputs
7801
7802        assert (
7803            len(all_inputs) > 0
7804        ), "torch.while_loop is assumed to have at least one operand."
7805
7806        device = all_inputs[0].get_device()
7807
7808        # make sure carried_inputs and body outputs are structurally equivalent
7809        assert len(carried_inputs) == len(body_outputs), (carried_inputs, body_outputs)
7810        for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)):
7811            assert op.get_size() == bo.get_size(), (i, op, bo)
7812            assert op.get_stride() == bo.get_stride(), (i, op, bo)
7813            # assume all carried_inputs and outputs are on the same device
7814            # as the MultiOutputLayout below requires single device
7815            assert op.get_device() == bo.get_device() == device, (i, op, bo, device)
7816            assert op.get_dtype() == bo.get_dtype(), (i, op, bo)
7817            assert op.get_layout().offset == bo.get_layout().offset, (i, op, bo)
7818
7819        while_loop = WhileLoop(
7820            carried_inputs=carried_inputs,
7821            additional_inputs=additional_inputs,
7822            cond_subgraph=cond_fn,
7823            body_subgraph=body_fn,
7824            # asserted above that there is at least one operand
7825            layout=MultiOutputLayout(device),
7826        )
7827
7828        outputs = [
7829            MultiOutput(
7830                FixedLayout(
7831                    device=output.get_device(),
7832                    dtype=output.get_dtype(),
7833                    size=output.get_size(),
7834                    stride=output.get_stride(),
7835                    offset=output.get_layout().offset,
7836                ),
7837                while_loop,
7838                [(list, i)],
7839            )
7840            for i, output in enumerate(body_outputs)
7841        ]
7842
7843        for inp, out in zip(carried_inputs, outputs):
7844            if inp.get_name() in V.graph.graph_inputs:
7845                # if a carried input of the while_loop is a graph input,
7846                # it can be returned as is when the number of iterations
7847                # is zero. due to this, we can't (generally) reuse the
7848                # output buffers corresponding to the graph inputs, as
7849                # the inputs may end up being mutated.
7850                V.graph.never_reuse_buffers.add(out.get_name())
7851
7852        while_loop.outputs = outputs
7853        return outputs
7854
7855    def codegen(self, wrapper):
7856        wrapper.codegen_while_loop(self)
7857
7858
7859class EffectfulKernel(FallbackKernel):
7860    def __init__(
7861        self,
7862        layout,
7863        kernel,
7864        tensor_args,
7865        nontensor_args,
7866        unflatten_args,
7867        kwargs=None,
7868        *,
7869        unbacked_bindings=None,
7870    ):
7871        super().__init__(
7872            layout,
7873            kernel,
7874            tensor_args,
7875            nontensor_args,
7876            unflatten_args,
7877            kwargs=None,
7878            unbacked_bindings=unbacked_bindings,
7879        )
7880
7881        from torch._higher_order_ops.effects import get_effect_key
7882
7883        effect_type = get_effect_key(kernel, (*nontensor_args, *tensor_args), kwargs)
7884        assert effect_type is not None
7885        self.effect_type = effect_type
7886        self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None)
7887        V.graph.effectful_ops[effect_type] = self
7888
7889    def get_read_writes(self):
7890        read_writes = super().get_read_writes()
7891
7892        if self.prev_effect_buffer is not None:
7893            read_writes.reads.add(
7894                dependencies.StarDep(self.prev_effect_buffer.get_name())
7895            )
7896
7897        return read_writes
7898
7899    def has_side_effects(self):
7900        return True
7901
7902
7903@dataclasses.dataclass
7904class TorchBindObject(IRNode):
7905    name: str
7906    value: torch._C.ScriptObject
7907
7908    def get_name(self):
7909        return self.name
7910
7911    def get_device(self):
7912        return None  # is there a device??
7913
7914    def codegen_reference(self, writer=None):
7915        return self.name
7916
7917
7918class InterpreterShim(torch.fx.Interpreter):
7919    @staticmethod
7920    @functools.lru_cache(None)
7921    def _dummy_gm():
7922        return torch.fx.symbolic_trace(identity)
7923
7924    def __init__(self, graph, submodules):
7925        # call super() with a placeholder to avoid constructing a
7926        # GraphModule which is very expensive (it does codegen).
7927        super().__init__(self._dummy_gm(), garbage_collect_values=False)
7928        self.module = self  # type: ignore[assignment]
7929        self.graph = graph
7930        self.submodules = submodules
7931        self.extra_traceback = False
7932        self.fetch_attr = submodules.__getitem__
7933        self.current_node = None
7934
7935    def run_node(self, n: torch.fx.Node) -> Any:
7936        self.current_node = n
7937        return super().run_node(n)
7938
7939    def run(self, *args, **kwargs):
7940        with V.set_interpreter_handler(self):
7941            return super().run(*args, **kwargs)
7942
7943
7944class LoopBody:
7945    """
7946    Captures the body of a Loops subclass into an FX graph.  Persists any
7947    indexing simplifications and makes it easier to analyze loop bodies.
7948    """
7949
7950    def __init__(self, fn, args, var_ranges):
7951        super().__init__()
7952        self.var_ranges = var_ranges
7953        self.indexing_exprs = {}
7954        self.indexing_exprs_name = {}
7955        self.reads = []
7956        self.writes = []
7957        self.reads_name2expr = {}
7958        self.writes_name2expr = {}
7959        self.other = []
7960        self.submodules = {"get_index": self.get_index}
7961        self.subblocks = {}
7962        self.indirect_vars = []
7963        self.root_block = LoopBodyBlock(self, fn, args)
7964        self.indexing = None
7965
7966    @cache_on_self
7967    def get_nodes(self):
7968        all_graphs = itertools.chain(
7969            (self.root_block.graph,),
7970            (block.graph for block in self.subblocks.values()),
7971        )
7972        return [node for graph in all_graphs for node in graph.nodes]
7973
7974    @cache_on_self
7975    def bounds(self):
7976        # Doing a local import to avoid dumping all the code here
7977        from .bounds import BoundVars
7978
7979        return BoundVars(self)
7980
7981    def debug_str(self):
7982        lines = [f"var_ranges = {dict(self.var_ranges)}"]
7983        lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
7984        lines.extend(
7985            [
7986                block.debug_str(name)
7987                for name, block in itertools.chain(
7988                    [("body", self.root_block)], self.subblocks.items()
7989                )
7990            ]
7991        )
7992        return "\n".join(lines)
7993
7994    def add_index_expr(self, expr: sympy.Expr, category, buf_name):
7995        getattr(self, category).append(expr)
7996        if buf_name is not None:
7997            getattr(self, f"{category}_name2expr")[buf_name] = expr
7998        if expr not in self.indexing_exprs_name:
7999            name = f"index{len(self.indexing_exprs)}"
8000            self.indexing_exprs_name[expr] = name
8001            self.indexing_exprs[name] = expr
8002        return self.indexing_exprs_name[expr]
8003
8004    def add_submodule(self, block, prefix):
8005        """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
8006        if prefix[-1].isnumeric() and prefix not in self.submodules:
8007            name = prefix
8008        else:
8009            name = f"{prefix}{len(self.submodules)}"
8010        self.submodules[name] = block
8011        return name
8012
8013    def add_indirect(self, size):
8014        var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars))
8015        self.indirect_vars.append(var)
8016        return var
8017
8018    def replace_indirect(self, old, new):
8019        """Swap in a variable used in indirect indexing"""
8020        if str(old) == str(new):
8021            return
8022        assert self.indexing is not None
8023        self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
8024
8025    def get_index(self, name):
8026        assert self.indexing is not None
8027        return self.indexing[name]
8028
8029    def __call__(self, *indices):
8030        index = list(itertools.chain.from_iterable(indices))
8031        assert len(index) == len(self.var_ranges), (index, self.var_ranges)
8032        assert all(v not in self.var_ranges for v in index)
8033        replacements = dict(zip(self.var_ranges.keys(), index))
8034        self.indexing = {
8035            name: sympy_subs(expr, replacements)
8036            for name, expr in self.indexing_exprs.items()
8037        }
8038        result = self.root_block()
8039        self.indexing = None
8040        return result
8041
8042
8043class LoopBodyBlock:
8044    """
8045    Captures the body of a Loops subclass into an FX graph.
8046    In normal cases there will be a 1:1 mapping between LoopBody and
8047    LoopBodyBlock, hower in the case of ops.masked() the masked out
8048    operations will manifest as an extra LoopBodyBlock.
8049    """
8050
8051    def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]):
8052        self.body = body
8053
8054        def add_index(expr, category, buf_name=None):
8055            return tracer.create_proxy(
8056                "call_module",
8057                "get_index",
8058                (self.body.add_index_expr(expr, category, buf_name),),
8059                {},
8060            )
8061
8062        class CaptureIndexing(V.WrapperHandler):  # type: ignore[name-defined]
8063            self.name = "CaptureIndexing"
8064
8065            def load(self, name: str, index: sympy.Expr):
8066                index = add_index(index, "reads", name)
8067                return self._inner.load(name, index)
8068
8069            def store(self, name, index, value, mode=None):
8070                index = add_index(index, "writes", name)
8071                return self._inner.store(name, index, value, mode)
8072
8073            def store_reduction(self, name, index, value):
8074                index = add_index(index, "writes", name)
8075                return self._inner.store_reduction(name, index, value)
8076
8077            def reduction(self, dtype, src_dtype, reduction_type, value):
8078                result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
8079                if "welford" in reduction_type:
8080                    return tuple(result[i] for i in range(3))
8081                return result
8082
8083            def index_expr(self, index, dtype):
8084                if isinstance(index, (int, sympy.Integer)):
8085                    return self._inner.constant(int(index), dtype)
8086                index = add_index(index, "other")
8087                return self._inner.index_expr(index, dtype)
8088
8089            def check_bounds(self, index, size, lower, upper):
8090                index = add_index(index, "other")
8091                size = add_index(size, "other")
8092                return self._inner.check_bounds(index, size, lower, upper)
8093
8094            def bucketize(
8095                self,
8096                values,
8097                offsets_name: str,
8098                offsets_size: sympy.Expr,
8099                indexing_dtype: torch.dtype,
8100                right: bool,
8101            ):
8102                offsets_size = add_index(offsets_size, "other")
8103                return self._inner.bucketize(
8104                    values, offsets_name, offsets_size, indexing_dtype, right
8105                )
8106
8107            @staticmethod
8108            def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy):
8109                """
8110                Recursively capture the masked out body in another LoopBodyBlock
8111                """
8112
8113                subblock: LoopBodyBlock
8114
8115                def shim(mask, other):
8116                    return V.ops.masked(mask, subblock, other)
8117
8118                name = self.body.add_submodule(shim, "masked_subblock")
8119                subblock = LoopBodyBlock(self.body, masked_body, [])
8120                self.body.subblocks[name] = subblock
8121                return tracer.create_proxy(
8122                    "call_module", name, (mask_proxy, other_proxy), {}
8123                )
8124
8125            @staticmethod
8126            def scan(
8127                dtype_proxy,
8128                combine_fn: Callable[
8129                    [Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]
8130                ],
8131                value_proxy,
8132            ):
8133                def shim(dtypes, values):
8134                    return V.ops.scan(dtypes, combine_fn, values)
8135
8136                name = self.body.add_submodule(shim, "scan")
8137                result = tracer.create_proxy(
8138                    "call_module",
8139                    name,
8140                    (dtype_proxy, value_proxy),
8141                    {},
8142                )
8143                # Proxies are iterable, but some methods expect tuples/lists
8144                return tuple(result[i] for i in range(len(value_proxy)))
8145
8146            def frexp(self, value_proxy):
8147                result = self._inner.frexp(value_proxy)
8148                # Proxies are iterable, but some methods expect tuples/lists
8149                return (result[0], result[1])
8150
8151            @staticmethod
8152            def indirect_indexing(index_proxy, size, check=True):
8153                """
8154                Flow data from tensors into indexing formulas.
8155                Introduce a call_module to update the indexing.
8156                """
8157
8158                var = self.body.add_indirect(size)
8159
8160                def set_indirect(new_var):
8161                    self.body.replace_indirect(
8162                        var, V.ops.indirect_indexing(new_var, size, check)
8163                    )
8164
8165                tracer.create_proxy(
8166                    "call_module",
8167                    self.body.add_submodule(set_indirect, f"set_{var}"),
8168                    (index_proxy,),
8169                    {},
8170                )
8171                return var
8172
8173            @staticmethod
8174            def output(result):
8175                tracer.create_proxy("output", "output", (result,), {})
8176
8177        tracer = torch.fx.Tracer()
8178        tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
8179        proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
8180
8181        from .index_propagation import IndexPropagation
8182        from .sizevars import SimplifyIndexing
8183
8184        handler: Any = SimplifyIndexing(
8185            CaptureIndexing(proxy_ops), self.body.var_ranges
8186        )
8187        if config.constant_and_index_propagation:
8188            handler = IndexPropagation(handler, self.body.var_ranges)
8189
8190        with V.set_ops_handler(handler):
8191            # This indirection is just a cute way to get IndexPropagation to
8192            # unwrap the return value.
8193            ops.output(fn(*args))
8194        self.graph = tracer.graph
8195
8196    def __call__(self):
8197        graph = self.graph
8198        submodules = self.body.submodules
8199
8200        return InterpreterShim(graph, submodules).run(V.get_ops_handler())
8201
8202    def debug_str(self, name="block"):
8203        code = torch.fx.GraphModule(self.body.submodules, self.graph).code
8204        return re.sub(
8205            # strip `; del var0` suffixes to make output prettier
8206            r";[^\n]*",
8207            "",
8208            code.strip().replace("def forward(", f"def {name}("),
8209        )
8210
8211
8212class _CollectiveKernel(FallbackKernel):
8213    def should_allocate(self):
8214        return False
8215
8216    def has_side_effects(self):
8217        return True
8218
8219    # This is identical to FallbackKernel.set_cpp_kernel(), minus the
8220    # part that checks against input aliasing and mutation.
8221    def set_cpp_kernel(self, kernel):
8222        from .codegen.wrapper import get_cpp_op_schema
8223
8224        self.cpp_kernel_name = kernel._schema.name
8225        self.cpp_kernel_overload_name = kernel._schema.overload_name
8226        self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}"  # type: ignore[union-attr]
8227
8228        self.cpp_op_schema = get_cpp_op_schema(kernel)
8229        self.ordered_kwargs_for_cpp_kernel = [
8230            x.name for x in kernel._schema.arguments if x.kwarg_only
8231        ]
8232
8233    # NOTE: [In-Place Collective Safety]
8234    # Between the initiation and completion of an in-place collective, the
8235    # input buffers are subject to both volatile reads and volatile writes.
8236    # They must not be read, written to or reused by another kernel. To ensure
8237    # the constraints, we model collective -> wait_tensor as as two-step
8238    # mutation of the input buffers.
8239    @classmethod
8240    def create_inplace(
8241        cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
8242    ) -> None:
8243        cpp_kernel_name = kernel._name
8244        python_kernel_name = cpp_kernel_name.replace("::", ".")
8245        with V.graph.fake_mode:
8246            (
8247                example_output,
8248                tensor_args,
8249                non_tensor_args,
8250                unflatten_args,
8251                unbacked_bindings,
8252            ) = cls.process_kernel(kernel, inputs, *args, **kwargs)
8253        assert not unbacked_bindings, f"{kernel} {unbacked_bindings}"
8254        for tensor_arg in tensor_args:
8255            tensor_arg.realize()
8256
8257        packed = cls(
8258            NoneLayout(tensor_args[0].get_device()),
8259            kernel,
8260            tensor_args,
8261            non_tensor_args,
8262            unflatten_args,
8263        )
8264        packed.cpp_kernel_name = cpp_kernel_name
8265        packed.python_kernel_name = python_kernel_name
8266
8267        mark_node_as_mutating(packed, *pytree.tree_leaves(inputs))
8268
8269    # NOTE: [Out-of-Place Collective Safety]
8270    # Between the initiation and completion of an out-of-place collective:
8271    #
8272    # Input buffers:
8273    # - Are subject to volatile reads
8274    # - Can be read by another kernel
8275    # - Must not be written to or reused by another kernel
8276    #
8277    # Output buffers:
8278    # - Are subject to volatile writes
8279    # - Must not be read, written to or reused by another kernel
8280    #
8281    # To ensure the safety of input buffers without sacrificing read
8282    # availability, we add input buffers as read deps of wait_tensor kernels.
8283    #
8284    # To ensure the safety of output buffers, we model wait_tensor as a
8285    # mutation to the output buffer. Note we also assumes the user program being
8286    # correct and the output buffer is not consumed by kernels other than
8287    # wait_tensor.
8288    #
8289    # TODO(yifu): add a pre-grad pass to validate the correctness of collective
8290    # usage in the user program.
8291    @classmethod
8292    def create_out_of_place(
8293        cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
8294    ):
8295        cpp_kernel_name = kernel._name
8296        python_kernel_name = cpp_kernel_name.replace("::", ".")
8297        with V.graph.fake_mode:
8298            (
8299                example_output,
8300                tensor_args,
8301                non_tensor_args,
8302                unflatten_args,
8303                unbacked_bindings,
8304            ) = cls.process_kernel(kernel, inputs, *args, **kwargs)
8305        assert not unbacked_bindings, f"{kernel}, {unbacked_bindings}"
8306        for tensor_arg in tensor_args:
8307            tensor_arg.realize()
8308
8309        if isinstance(example_output, list):
8310            device = cls.find_device(tensor_args, example_output)
8311            packed = cls(
8312                MultiOutputLayout(device),
8313                kernel,
8314                tensor_args,
8315                non_tensor_args,
8316                unflatten_args,
8317            )
8318            packed.cpp_kernel_name = cpp_kernel_name
8319            packed.python_kernel_name = python_kernel_name
8320            packed.outputs = [
8321                MultiOutput(
8322                    cls.tensor_to_layout(tensor),
8323                    packed,
8324                    [(list, i)],
8325                )
8326                for i, tensor in enumerate(example_output)
8327            ]
8328            return packed.outputs
8329        else:
8330            packed = cls(
8331                cls.tensor_to_layout(example_output),
8332                kernel,
8333                tensor_args,
8334                non_tensor_args,
8335                unflatten_args,
8336            )
8337            packed.cpp_kernel_name = cpp_kernel_name
8338            packed.python_kernel_name = python_kernel_name
8339            packed.outputs = [packed]
8340            return packed
8341
8342
8343class _WaitKernel(_CollectiveKernel):
8344    def get_volatile_reads(self):
8345        inp = self.inputs[0]
8346        if isinstance(inp, _CollectiveKernel):
8347            # Out-of-place single-output
8348            return [inp.inputs[0]]
8349        elif isinstance(inp, MultiOutput):
8350            # This can be two things:
8351            # 1. Out-of-place multi-output coll
8352            # 2. In-place coll with inputs coming from another MultiOutput
8353            coll = inp.inputs[0]
8354            # Case 1
8355            if isinstance(coll, _CollectiveKernel):
8356                _, idx = inp.indices[0]
8357                return [coll.inputs[idx]]
8358            # Case 2
8359            return []
8360        else:
8361            # In-place requires no additional deps handling for volatile
8362            # reads since the inputs are mutated.
8363            return []
8364
8365    @classmethod
8366    def create_wait(cls, kernel, inp: TensorBox) -> None:
8367        with V.graph.fake_mode:
8368            (
8369                example_output,
8370                tensor_args,
8371                non_tensor_args,
8372                unflatten_args,
8373                unbacked_bindings,
8374            ) = cls.process_kernel(kernel, inp)
8375        assert not unbacked_bindings, f"{kernel} {unbacked_bindings}"
8376        packed = cls(
8377            NoneLayout(inp.get_device()),
8378            kernel,
8379            tensor_args,
8380            non_tensor_args,
8381            unflatten_args,
8382        )
8383
8384        mark_node_as_mutating(packed, inp)
8385
8386    def get_read_writes(self):
8387        read_writes = super().get_read_writes()
8388        # See [Out-of-Place Collective Safety].
8389        volatile_reads = self.get_volatile_reads()
8390        for vr in volatile_reads:
8391            read_writes.reads.add(dependencies.StarDep(vr.get_name()))
8392        return read_writes
8393
8394
8395# NB: recursive structure here reflects val_to_arg_str, avoid
8396# calling free_unbacked_symbols on "exotic" types that don't get pexpr
8397# treatment
8398def maybe_free_unbacked_symbols(s):
8399    if isinstance(s, (SymTypes, sympy.Expr)):
8400        # This branch should be impossible in return position
8401        return free_unbacked_symbols(s)
8402    elif isinstance(s, (tuple, list)):
8403        r = set()
8404        for t in s:
8405            r |= maybe_free_unbacked_symbols(t)
8406        return r
8407    elif isinstance(s, torch.Tensor):
8408        # This branch is impossible in constant-args position
8409        return free_unbacked_symbols(s)
8410    else:
8411        return set()
8412