xref: /aosp_15_r20/external/pytorch/torch/_inductor/dependencies.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import abc
3import dataclasses
4import itertools
5import logging
6import re
7import typing
8from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
9from unittest.mock import patch
10
11import sympy
12
13import torch
14from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
15from torch.utils._ordered_set import OrderedSet
16
17from .codegen.common import index_prevent_reordering
18from .utils import (
19    get_dtype_size,
20    reduction_num_outputs,
21    sympy_index_symbol,
22    sympy_str,
23    sympy_subs,
24    VarRanges,
25)
26from .virtualized import OpsHandler, ReductionType, V
27
28
29log = logging.getLogger(__name__)
30is_indirect = re.compile(r"indirect|tmp").search
31
32
33class Dep(abc.ABC):
34    name: str
35    index: sympy.Expr
36
37    @abc.abstractmethod
38    def rename(self, renames: Dict[str, str]) -> "Dep":
39        pass
40
41    @abc.abstractmethod
42    def get_numel(self) -> sympy.Expr:
43        pass
44
45    @abc.abstractmethod
46    def numbytes_hint(self):
47        pass
48
49    @abc.abstractmethod
50    def has_unbacked_symbols(self) -> bool:
51        pass
52
53    @abc.abstractmethod
54    def is_contiguous(self) -> bool:
55        pass
56
57    def normalize_with_stride_order(self, prefix="t"):
58        return self
59
60
61@dataclasses.dataclass(frozen=True)
62class MemoryDep(Dep):
63    name: str
64    index: sympy.Expr
65    var_names: Tuple[sympy.Symbol, ...]
66    size: Tuple[sympy.Expr, ...]
67    mode: Optional[str] = None
68
69    def __repr__(self) -> str:
70        return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}, {self.mode})"
71
72    @property
73    def num_vars(self):
74        return len(self.var_names)
75
76    def decide_loop_order_to_match(self, other):
77        """
78        Can return None if not able to decide loop orders.
79        """
80        assert self.num_vars == other.num_vars
81
82        # ignore broadcast for now since broadcast causes extra 0 strides
83        # which makes it hard to decide the correct loop orders.
84        if self.num_vars != len(self.index.free_symbols):
85            return None
86        if other.num_vars != len(other.index.free_symbols):
87            return None
88
89        # bail out if any size is 0 or 1
90        # For size == 0, it's an empty tensor, any strides for that dimension
91        # are equivalent. Skip for simplicity and it may not matter that much.
92        #
93        # For size == 1, it cause cause tie for strides of different dimensions.
94        # Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder
95        # we can dependencies.index_vars_squeeze which should already sqeeuze
96        # the size == 1 dimensions.
97        if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)):
98            return None
99
100        # Extract strides for both expression
101        self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names)
102        other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names)
103
104        # Even if the shape contains no 0/1, some complex index expression may
105        # still have duplicate stride values. Here is an example:
106        # https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129
107        # We don't reorder the loop for these cases for now, but in theory
108        # we could improve the algorithm to detect the correct loop orders.
109        if len(set(self_strides)) != len(self_strides) or len(
110            set(other_strides)
111        ) != len(other_strides):
112            log.debug(
113                "unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s",
114                self,
115                other,
116                self_strides,
117                other_strides,
118            )
119            return None
120
121        # May hanppen if self and other are as follows
122        # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None)
123        # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None)
124        if set(self_strides) != set(other_strides):
125            return None
126
127        stride_to_index = {s: i for i, s in enumerate(self_strides)}
128        order = []
129        for s in other_strides:
130            order.append(stride_to_index[s])
131
132        assert set(order) == set(range(0, self.num_vars))
133        return order
134
135    def get_offset(self):
136        """
137        Return the offset by setting every variable to be 0.
138        """
139        return sympy_subs(self.index, dict.fromkeys(self.var_names, 0))
140
141    def normalize(self) -> "MemoryDep":
142        """
143        Normalize by merging loops. The different to normalize_with_stride_order is,
144        this method does not reorder loops while normalize_with_stride_order reorder
145        loops based on stride order.
146        """
147        return MemoryDep(
148            self.name,
149            *_RecordLoadStoreInner._normalize(self.index, self.ranges),  # type: ignore[arg-type]
150            self.mode,
151        )
152
153    def normalize_with_stride_order(self, prefix="t"):
154        r"""
155        Used to decide if two MemoryDep does not equal due to different loop orders.
156        More specifically, when dep1 and dep2 are not equal, we can normalize
157        both and check if they are equal after that. If yes, then the mismatch is
158        caused by different loop orders.
159        """
160        # import here to avoid circular import
161        from torch._inductor import ir
162
163        strides = V.graph.sizevars.stride_hints(self.index, self.var_names)
164
165        # pick a loop order with stride ordered decreasingly
166        order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
167        stride_reorder = ir.same_reorder(order)
168        sizes = self.size
169        var_names = self.var_names
170
171        new_reordered_sizes = stride_reorder(sizes)
172        new_reordered_var_names = stride_reorder(var_names)
173
174        new_simplified_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
175            new_reordered_var_names,
176            new_reordered_sizes,
177            index_prevent_reordering(
178                [self.index], new_reordered_var_names, new_reordered_sizes
179            ),
180        )
181
182        # now let's create new symbols with the passed in prefix
183        var_ranges, add_var = var_builder(prefix)
184        replacement = dict(
185            zip(
186                new_reordered_var_names,
187                reindex([add_var(x) for x in new_simplified_sizes]),
188            )
189        )
190        new_index = sympy_subs(sympy.expand(self.index), replacement)  # type: ignore[arg-type] # next PR
191
192        out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values()))  # type: ignore[arg-type]
193        return out
194
195    @property
196    def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]:
197        """{c0: 128, c1: 512, ...}"""
198        return dict(zip(self.var_names, self.size))
199
200    def get_numel(self) -> sympy.Expr:
201        if self.is_indirect():
202            numel = V.graph.get_numel(self.name)
203        else:
204            vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols)
205            numel = sympy.Integer(1)
206            for var, size in zip(self.var_names, self.size):
207                if var in vars:
208                    numel = numel * size
209        return numel  # type: ignore[return-value]
210
211    def rename(self, renames: Dict[str, str]) -> "MemoryDep":
212        if self.name in renames:
213            return MemoryDep(
214                renames[self.name],
215                self.index,
216                var_names=self.var_names,
217                size=self.size,
218                mode=self.mode,
219            )
220        return self
221
222    def numbytes_hint(self):
223        return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
224            V.graph.get_dtype(self.name)
225        )
226
227    def has_unbacked_symbols(self):
228        return len(free_unbacked_symbols(self.get_numel())) > 0
229
230    def is_contiguous(self) -> bool:
231        return isinstance(self.index, sympy.Symbol) and self.index in self.var_names
232
233    def stride1_for_last_dim(self, result_for_complex_expression=True) -> bool:
234        """
235        Whether the stride for the last dimension is 1.
236        """
237        # python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_masked_scatter_cuda_float16
238        # will exercise thru this corner case.
239        if len(self.var_names) == 0:
240            return True
241
242        terms = self.index.args if isinstance(self.index, sympy.Add) else [self.index]
243
244        last_sym = self.var_names[-1]
245        for term in terms:
246            if term is last_sym:
247                return True
248
249            # Having a >1 stride for the last dimension is bad for perf
250            # return False.
251            if (
252                isinstance(term, sympy.Mul)
253                and len(term.args) == 2
254                and term.args[1] is last_sym
255                and isinstance(term.args[0], (int, sympy.Integer))
256                and term.args[0] > 1
257            ):
258                return False
259
260        return result_for_complex_expression
261
262    def is_scalar(self) -> bool:
263        if isinstance(self.index, sympy.Symbol):
264            return self.index not in self.var_names and not self.is_indirect()
265        return isinstance(self.index, (int, sympy.Integer))
266
267    def is_indirect(self) -> bool:
268        return any(is_indirect(v.name) for v in self.index.free_symbols)  # type: ignore[attr-defined]
269
270
271@dataclasses.dataclass(frozen=True)
272class StarDep(Dep):
273    name: str
274    mode: Optional[str] = None
275
276    # depends on the entire buffer
277    @property
278    def index(self):
279        raise NotImplementedError("StarDep does not have an index")
280
281    def get_numel(self) -> sympy.Expr:
282        return V.graph.get_numel(self.name)  # type: ignore[return-value]
283
284    def rename(self, renames: Dict[str, str]) -> "StarDep":
285        if self.name in renames:
286            return StarDep(renames[self.name], self.mode)
287        return self
288
289    def numbytes_hint(self):
290        return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
291            V.graph.get_dtype(self.name)
292        )
293
294    def has_unbacked_symbols(self):
295        return len(free_unbacked_symbols(self.get_numel())) > 0
296
297    def is_contiguous(self) -> bool:
298        return False
299
300    def is_scalar(self) -> bool:
301        return False
302
303    def is_indirect(self) -> bool:
304        return False
305
306
307# Used for tracking mutation ordering
308# if A reads a buffer and B mutates it
309# B must be ordered after A
310#
311# This is useful for a variety of reasons.
312# For example, if A's read is never actually used, we can eliminate it.
313# Another case is if A's buffer ends up being fused away, we never need to
314# materialize that buffer
315@dataclasses.dataclass(frozen=True)
316class WeakDep(Dep):
317    # Fake dependency on unused buffer
318    name: str
319    # Buffer that is doing the mutation
320    mutating_buf: str
321
322    @property
323    def index(self):
324        raise NotImplementedError("WeakDep does not have an index")
325
326    def get_numel(self) -> sympy.Expr:
327        return sympy.Integer(1)
328
329    def rename(self, renames: Dict[str, str]) -> "WeakDep":
330        if self.name in renames:
331            return WeakDep(renames[self.name], self.mutating_buf)
332        return self
333
334    def numbytes_hint(self):
335        return 1  # Purely inserted for ordering, not an actual dep
336
337    def has_unbacked_symbols(self):
338        return False
339
340    def is_contiguous(self) -> bool:
341        return False
342
343
344@dataclasses.dataclass(frozen=True)
345class IndexExprDep:
346    index: sympy.Expr  # type: ignore[assignment]
347    var_names: Tuple[sympy.Symbol, ...]
348    size: Tuple[sympy.Expr, ...]
349
350
351@dataclasses.dataclass
352class ReadWrites:
353    reads: OrderedSet[Dep]
354    writes: OrderedSet[Dep]
355    index_exprs: OrderedSet[IndexExprDep]
356    range_vars: Optional[List[sympy.Expr]] = None
357    var_ranges: Optional[VarRanges] = None
358
359    def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
360        return ReadWrites(
361            OrderedSet(dep.rename(renames) for dep in self.reads),
362            OrderedSet(dep.rename(renames) for dep in self.writes),
363            self.index_exprs,
364            self.range_vars,
365            self.var_ranges,
366        )
367
368    def with_read(self, dep: Union[Dep, Set[Dep]]) -> "ReadWrites":
369        assert isinstance(dep, (WeakDep, StarDep, set))
370        if not isinstance(dep, set):
371            dep = {dep}
372        return ReadWrites(
373            OrderedSet.union(self.reads, dep),
374            self.writes,
375            self.index_exprs,
376            self.range_vars,
377            self.var_ranges,
378        )
379
380    def merge(self, other: "ReadWrites"):
381        reads = OrderedSet.union(self.reads, other.reads)
382        writes = OrderedSet.union(self.writes, other.writes)
383        index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs)
384        return ReadWrites(reads - writes, writes, index_exprs)
385
386    @staticmethod
387    def merge_list(read_writes: List["ReadWrites"]):
388        all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
389        all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
390        all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
391        return ReadWrites(all_reads, all_writes, all_index_exprs)
392
393    def remove_reads(self, rem_reads):
394        return ReadWrites(
395            self.reads - rem_reads,
396            self.writes,
397            self.index_exprs,
398            self.range_vars,
399            self.var_ranges,
400        )
401
402    def reads_and_writes(self):
403        return itertools.chain(self.reads, self.writes)
404
405    def buffer_names(self, ignore_integer_index=True):
406        """
407        Integer index is used for load_seed.
408        """
409        names: OrderedSet[str] = OrderedSet()
410        for dep in self.reads_and_writes():
411            if not isinstance(dep, MemoryDep):
412                continue
413            if not ignore_integer_index or not isinstance(
414                dep.index, (int, sympy.Integer)
415            ):
416                names.add(dep.name)
417        return names
418
419
420class _RecordLoadStoreInner(V.MockHandler):  # type: ignore[name-defined]
421    def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
422        super().__init__()
423        self._reads: OrderedSet[Dep] = OrderedSet()
424        self._writes: OrderedSet[MemoryDep] = OrderedSet()
425        self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet()
426        self._var_ranges: VarRanges = var_ranges
427        self._should_normalize: bool = normalize
428
429    @staticmethod
430    def drop_unused_symbols(index, var_names, sizes):
431        """
432        Reduction has last (reduced) dim in its sizes, but
433        downstream users won't.  Normalize this away.
434        """
435        if not isinstance(index, sympy.Expr):
436            # index can be an int
437            return
438        free_symbols = index.free_symbols
439        while var_names and var_names[-1] not in free_symbols:
440            var_names.pop()
441            sizes.pop()
442
443    @classmethod
444    def _normalize(
445        cls, index: sympy.Expr, var_ranges: VarRanges
446    ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]:
447        # Try to further simplify the indexes even if simplify_loops didn't
448        # convert it to the simplest form because of the interference from
449        # different indexing formulas.
450        index_vars = [*var_ranges.keys()]
451        sizes = tuple(var_ranges.values())  # type: ignore[assignment]
452        new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
453            index_vars,
454            sizes,
455            index_prevent_reordering([index], index_vars, sizes),
456        )
457
458        # assign new variables each dimension to deal with numbering mismatches
459        # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
460        new_vars, add_var = var_builder(canonicalization_prefix())
461        replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
462        index = sympy_subs(sympy.expand(index), replacement)
463
464        new_vars = [*new_vars.keys()]
465        new_sizes = [*new_sizes]
466        cls.drop_unused_symbols(index, new_vars, new_sizes)
467        return index, tuple(new_vars), tuple(new_sizes)  # type: ignore[arg-type]
468
469    def canonicalize(
470        self, index: sympy.Expr
471    ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]:
472        if not self._should_normalize:
473            sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()]
474            var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1]
475            sizes = [v for v in sizes if v != 1]
476
477            self.drop_unused_symbols(index, var_names, sizes)
478
479            return index, tuple(var_names), tuple(sizes)  # type: ignore[return-value, arg-type]
480        var_ranges = {
481            k: V.graph.sizevars.simplify(v)
482            for k, v in self._var_ranges.items()
483            # TODO(jansel): explore this further normalization
484            # if k in free_symbols
485        }
486        return self._normalize(index, var_ranges)
487
488    def load(self, name: str, index: sympy.Expr) -> str:
489        self._reads.add(MemoryDep(name, *self.canonicalize(index)))
490        return f"load({name}, {sympy_str(index)})"
491
492    def load_seed(self, name: str, index: int):
493        assert isinstance(index, int)
494        return self.load(name, sympy.Integer(index))
495
496    def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str:
497        self._writes.add(MemoryDep(name, *self.canonicalize(index), mode=mode))
498        return f"store({name}, {sympy_str(index)}, {value}, {mode})"
499
500    def store_reduction(self, name: str, index, value) -> str:
501        return self.store(name, index, f"store_reduction({value})")
502
503    def index_expr(self, index: sympy.Expr, dtype) -> str:
504        self._index_exprs.add(IndexExprDep(*self.canonicalize(index)))
505        return f"index_expr({sympy_str(index)}, {dtype})"
506
507    def bucketize(
508        self,
509        values,
510        offsets_name: str,
511        offsets_size: sympy.Expr,
512        indexing_dtype: torch.dtype,
513        right: bool,
514    ):
515        self._reads.add(StarDep(offsets_name))
516        return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})"
517
518
519class RecordLoadStore(V.KernelFormatterHandler):  # type: ignore[name-defined]
520    def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
521        parent_handler = _RecordLoadStoreInner(
522            var_ranges=var_ranges, normalize=normalize
523        )
524        super().__init__(parent_handler=parent_handler)
525
526
527# TODO: check call sites
528def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
529    cnt = itertools.count()
530    var_ranges: VarRanges = {}
531
532    def add_var(length: sympy.Expr) -> sympy.Symbol:
533        v = sympy_index_symbol(f"{prefix}{next(cnt)}")
534        var_ranges[v] = length
535        return v
536
537    return var_ranges, add_var
538
539
540def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
541    var_ranges, add_var = var_builder(prefix)
542    args: List[List[sympy.Symbol]] = []
543    for size in argsizes:
544        args.append(list(map(add_var, size)))
545    return args, var_ranges
546
547
548def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
549    from .ir import SqueezeView
550
551    var_ranges, add_var = var_builder(prefix)
552    args: List[List[sympy.Expr]] = []
553    new_sizes: List[List[sympy.Expr]] = []
554    for size in argsizes:
555        new_size, reindex = SqueezeView.squeezer(size)
556        new_sizes.append(new_size)
557        args.append(reindex(list(map(add_var, new_size))))
558    return args, var_ranges
559
560
561def extract_read_writes(
562    fn: Callable[..., Any],
563    *argsizes: Tuple[sympy.Expr, ...],
564    normalize: bool = False,
565    prefix: str = "d",
566    hidden_args=(),
567):
568    args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
569
570    from .loop_body import LoopBody, MemoryUsageType
571
572    if isinstance(fn, LoopBody):
573        # Fast path to avoid tracing when we already have a LoopBody
574        inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize)
575        name_to_index = fn.indexing_from_args([*args, *hidden_args])
576        if fn.indirect_vars:
577            # mimic the `tmpX` naming tracing gives us
578            repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)}
579            name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()}
580        for entry in fn.memory_usage[MemoryUsageType.LOAD]:
581            inner.load(entry.buffer_name, name_to_index[entry.index_name])
582        for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]:
583            inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name]))
584        for entry in fn.memory_usage[MemoryUsageType.STORE]:
585            inner.store(
586                entry.buffer_name, name_to_index[entry.index_name], None, entry.mode
587            )
588        for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
589            inner.store_reduction(
590                entry.buffer_name, name_to_index[entry.index_name], None
591            )
592        for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
593            inner.index_expr(name_to_index[entry.index_name], None)
594        for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]:
595            inner.bucketize(
596                None, entry.buffer_name, name_to_index[entry.index_name], None, None
597            )
598        # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
599    else:
600        # Slow path tracing the function
601        rw = RecordLoadStore(var_ranges, normalize=normalize)
602        with V.set_ops_handler(rw):
603            fn(*args, *hidden_args)
604        inner = rw.parent_handler
605
606    if normalize:
607        range_vars = []  # Number of vars could differ due to normalization
608    else:
609        range_vars = [*itertools.chain.from_iterable(args)]
610
611    return ReadWrites(
612        OrderedSet(inner._reads),
613        OrderedSet(inner._writes),
614        inner._index_exprs,
615        range_vars,
616        var_ranges,
617    )
618
619
620def extract_input_node_reduction_ranges(
621    input_node: "torch._inductor.ir.TensorBox",
622) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
623    """
624    Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
625    It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
626    In this case, reduction_sizes of the Reduction nodes need to be the same.
627    Otherwise returns (None, None).
628    """
629
630    from .ir import ComputedBuffer, Loops
631
632    if isinstance(input_node.data, ComputedBuffer):
633        # Input node has already been realized. Return its size and reduction_size.
634        size = input_node.get_size()
635        reduction_size = input_node.get_reduction_size()
636        if len(reduction_size) > 0:
637            return (size, reduction_size)
638        else:
639            return (None, None)
640
641    if not isinstance(input_node.data.data, Loops):  # type: ignore[attr-defined]
642        # Other IRNodes do not have reduction_ranges.
643        return (None, None)
644
645    # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes?
646    # The current method still uses reduction ranges from the dependent realized node, which is not ideal.
647    # Is there a way to check whether there are permutations inbetween?
648    reads = input_node.get_reads()
649    reduction_size = None
650    size = None
651    while reduction_size is None and len(reads) > 0:
652        seen: OrderedSet[str] = OrderedSet()
653        new_reads = []
654        for read in reads:
655            if not isinstance(read, MemoryDep):
656                continue
657            if read.name in seen:
658                continue
659            seen.add(read.name)
660            buffer = V.graph.try_get_buffer(read.name)
661            if buffer is None:
662                continue
663            op = buffer.get_defining_op()
664            if op is None:
665                continue
666
667            if isinstance(op, ComputedBuffer) and len(op.get_reduction_size()) > 0:
668                if reduction_size is None:
669                    reduction_size = op.get_reduction_size()
670                    size = op.get_size()
671                elif reduction_size != op.get_reduction_size() or size != op.get_size():
672                    return (None, None)
673            else:
674                new_reads.extend(op.get_reads())
675        if reads == new_reads:
676            return (size, reduction_size)
677        else:
678            reads = new_reads
679    return (size, reduction_size)
680
681
682def canonicalization_prefix():
683    return "c"
684
685
686# ops handler which computes all the free unbacked symbols for an IR
687class FreeUnbackedSymbolsOpsHandler:
688    symbols: OrderedSet[sympy.Symbol]
689
690    def __init__(self) -> None:
691        self.symbols = OrderedSet()
692
693    def __getattr__(self, name: str) -> Callable[..., Any]:
694        def inner(*args, **kwargs):
695            for a in itertools.chain(args, kwargs.values()):
696                if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
697                    self.symbols |= free_unbacked_symbols(a)
698
699        return inner
700
701    def indirect_indexing(
702        self, index_var, size, check=True, wrap_neg=True
703    ) -> sympy.Symbol:
704        assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean))
705        self.symbols |= free_unbacked_symbols(size)
706        return sympy_index_symbol(f"({str(index_var)})")
707
708    def frexp(self, x):
709        return (None,) * 2
710
711    def scan(self, dtypes, combine_fn, values):
712        return (None,) * len(values)
713
714    def sort(self, dtypes, values, stable, descending):
715        return (None,) * len(values)
716
717    def reduction(
718        self,
719        dtype: torch.dtype,
720        src_dtype: torch.dtype,
721        reduction_type: ReductionType,
722        value: Union[None, Tuple[None, ...]],
723    ) -> Union[None, Tuple[None, ...]]:
724        num_values = reduction_num_outputs(reduction_type)
725        return (None,) * num_values if num_values > 1 else None
726
727
728def _typecheck_FreeUnbackedSymbolsOpsHandler(
729    h: FreeUnbackedSymbolsOpsHandler,
730) -> OpsHandler[None]:
731    return h
732
733
734def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None):
735    from .ir import FlexibleLayout
736
737    args = [index, rindex] if rindex is not None else [index]
738    handler = FreeUnbackedSymbolsOpsHandler()
739    # NB: I cargo culted the allow_indexing patch here, I don't understand why
740    # people do this all over
741    with V.set_ops_handler(handler), patch.object(
742        FlexibleLayout, "allow_indexing", True
743    ):
744        fn(*args)
745    return handler.symbols
746