xref: /aosp_15_r20/external/pytorch/torch/_inductor/sizevars.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import itertools
4import logging
5from typing import (
6    Any,
7    Callable,
8    cast,
9    Dict,
10    Iterable,
11    List,
12    Optional,
13    Sequence,
14    Set,
15    Tuple,
16    Union,
17)
18
19import sympy
20from sympy import Expr
21
22from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, ShapeEnv
23from torch.utils._sympy.functions import FloorDiv, ModularIndexing
24from torch.utils._sympy.symbol import symbol_is_type, SymT
25from torch.utils._sympy.value_ranges import bound_sympy, IntInfinity, ValueRanges
26
27from .runtime.runtime_utils import is_power_of_2
28from .utils import (
29    has_free_symbols,
30    sympy_index_symbol,
31    sympy_index_symbol_with_prefix,
32    sympy_subs,
33    VarRanges,
34)
35from .virtualized import V
36
37
38log = logging.getLogger(__name__)
39
40
41def evaluate_expr(
42    shape_env: ShapeEnv,
43    expr: Union[sympy.Basic, bool],
44    axioms: Optional[Tuple[sympy.Expr]] = None,
45    var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges[Any]]]] = None,
46) -> bool:
47    if expr in (True, False):
48        return bool(expr)
49
50    try:
51        simplified = shape_env._maybe_evaluate_static(
52            expr,
53            axioms=axioms,
54            var_to_range=var_to_range,
55        )
56        if simplified is not None:
57            return bool(simplified)
58    except Exception:
59        log.debug("Could not simplify  %s", expr, exc_info=True)
60
61    return False
62
63
64# This class is a little awkward, because ShapeEnv is doing most of the heavy
65# lifting and in some cases we should be directly passing through to ShapeEnv,
66# but there is some extra inductor logic that needs to be handled here
67class SizeVarAllocator:
68    def __init__(self, shape_env=None) -> None:
69        super().__init__()
70        if shape_env is None:
71            shape_env = ShapeEnv()
72        self.shape_env = shape_env
73        self.var_to_val = self.shape_env.var_to_val
74        self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements
75        # Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
76        # The basic idea is if we have some complicated sympy expression
77        # f(s0), we may choose to precompute it on the host and then replace
78        # all occurrences of that sympy expression with ps0, so that when we
79        # codegen we simply reference ps0 directly without repeating
80        # f(s0).  Unlike regular size variables, ps variables cannot be
81        # guarded upon; so if we are asked to guard on a Sympy expression
82        # which potentially could have already had a precomputed replacement
83        # on it, we are obligated to invert the precomputed replacements
84        # (inv_precomputed_replacements).
85        self.precomputed_replacements: Dict[Expr, sympy.Symbol] = {}
86        self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = {}
87        self.stride_vars = self.make_stride_vars_cache()
88        self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
89        self._simplify_loops = self.make_simplify_loops_cache()
90
91    def simplify(self, expr: Expr):
92        return sympy.expand(expr).xreplace(self.replacements)
93
94    def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]:
95        """
96        self._simplify_with_ranges() can be expensive, cache its results
97        """
98        cache: Dict[Tuple[Any, ...], Expr] = {}
99        replacement_count = len(self.replacements)
100
101        def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr:
102            nonlocal replacement_count
103            if replacement_count != len(self.replacements):
104                # new replacements invalidates cached results
105                cache.clear()
106                replacement_count = len(self.replacements)
107            key = (expr, *var_ranges.items())
108            result = cache.get(key, None)
109            if result is None:
110                result = self._simplify_with_ranges(expr, var_ranges)
111                cache[key] = result
112            return result
113
114        return simplify_with_ranges
115
116    def make_simplify_loops_cache(self):
117        """
118        self._simplify_with_ranges() can be expensive, cache its results
119        """
120        cache: Dict[Tuple[Any, ...], Any] = {}
121        replacement_count = len(self.replacements)
122
123        def simplify_loops(index_vars, sizes, index_formulas):
124            nonlocal replacement_count
125            if replacement_count != len(self.replacements):
126                # new replacements invalidates cached results
127                cache.clear()
128                replacement_count = len(self.replacements)
129            key = (*index_vars, *sizes, *index_formulas)
130            result = cache.get(key, None)
131            if result is None:
132                result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
133                cache[key] = result
134            return result
135
136        return simplify_loops
137
138    def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr:
139        """
140        Simplify indexing expression with knowledge of the ranges of
141        iteration variables.
142        """
143
144        expr = join_dimensions(self.simplify(expr))
145        original_expr = expr
146
147        var_to_range = dict(self.shape_env.var_to_range)
148        var_to_range.update(
149            {
150                k: ValueRanges(
151                    0, max(0, v - 1) if not has_free_symbols([v]) else IntInfinity()
152                )
153                for k, v in var_ranges.items()
154            }
155        )
156        for var in expr.free_symbols:
157            if var not in var_to_range:
158                var_to_range[var] = ValueRanges(0, IntInfinity())
159
160        var_to_range_tuple = cast(
161            Tuple[Tuple[sympy.Symbol, ValueRanges[sympy.Expr]]],
162            tuple(var_to_range.items()),
163        )
164
165        axioms = []
166        for var, upper_bound in var_ranges.items():
167            axioms.append(0 <= var)
168            axioms.append(var < upper_bound)
169        axioms = tuple(axioms) + self.shape_env.get_axioms()
170
171        def statically_known(expr):
172            evaluated = self.shape_env._maybe_evaluate_static(
173                expr,
174                axioms=axioms,
175                var_to_range=var_to_range_tuple,
176            )
177            return bool(evaluated)
178
179        def remove_zero_terms(base, divisor):
180            """Symbols smaller than the divisor are zero"""
181            if not statically_known(base >= 0):
182                return base
183
184            for v in base.free_symbols:
185                if v in var_ranges:
186                    # var smaller than divisor can be removed
187                    # if the rest is guaranteed to be multiple of divisor
188                    rest = sympy.Wild("_rest", exclude=[v])
189                    m = base.match(v + rest)
190                    if m and v not in m[rest].free_symbols:
191                        gcd = sympy.gcd(m[rest], divisor)
192                        if gcd == divisor:
193                            if statically_known(v < divisor):
194                                base = m[rest]
195            return base
196
197        def visit_indexing_div(base, divisor):
198            return FloorDiv(remove_zero_terms(base, divisor), divisor)
199
200        def visit_modular_indexing(base, divisor, modulus):
201            base = remove_zero_terms(base, divisor)
202
203            can_remove_mod = statically_known(base >= 0) and statically_known(
204                base < modulus * divisor
205            )
206
207            if can_remove_mod:
208                return FloorDiv(base, divisor)
209            return ModularIndexing(base, divisor, modulus)
210
211        if expr.has(ModularIndexing):
212            expr = expr.replace(
213                ModularIndexing(
214                    sympy.Wild("base", integer=True),
215                    sympy.Wild("divisor", integer=True),
216                    sympy.Wild("modulus", integer=True),
217                ),
218                visit_modular_indexing,
219            )
220
221        if expr.has(FloorDiv):
222            expr = expr.replace(
223                FloorDiv(
224                    sympy.Wild("base", integer=True),
225                    sympy.Wild("divisor", integer=True),
226                ),
227                visit_indexing_div,
228            )
229
230        if expr != original_expr:
231            return self._simplify_with_ranges(expr, var_ranges)
232        return expr
233
234    def _simplify_loops_impl(
235        self, index_vars: List[sympy.Symbol], sizes, index_formulas
236    ):
237        """
238        Try to remove as many axis from loop iterations as possible, by:
239            1) removing size==1 dimensions
240            2) fuse contiguous dimensions into a single loop
241            If channel_last = True, we will prevent the last dim fused with other dims
242        """
243        sizes = list(map(self.simplify, sizes))
244
245        strides = [
246            # index_formulas may contain boolean expressions (e.g. s0 < 10),
247            # for which "strides" don't make sense so we ignore them here.
248            # NOTE: These expressions may still block merging dims in the sound
249            # substitution test performed in can_merge_dims.
250            self.stride_vars(x, index_vars)
251            if isinstance(x, sympy.Expr)
252            else [0] * len(index_vars)
253            for x in index_formulas
254        ]
255        assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))
256
257        for i in range(len(sizes)):
258            if sizes[i] == 1:
259                # remove dim
260                sizes[i] = None
261
262        def can_merge_dims(a, b):
263            for k in range(len(strides)):
264                if self.simplify(strides[k][a] * sizes[a]) == self.simplify(
265                    strides[k][b]
266                ):
267                    # approximate test passed, try sound version
268                    va = index_vars[a]
269                    vb = index_vars[b]
270                    m1 = sympy_index_symbol("_merge_tester1")
271                    m2 = sympy_index_symbol("_merge_tester2")
272                    # NOTE: can't sub vb=0 here in case va * vb appears in the expression,
273                    # in which case both expr1 and expr2 would be zero!
274                    expr1 = sympy_subs(index_formulas[k], {va: m1 * sizes[a], vb: m2})
275                    expr2 = sympy_subs(index_formulas[k], {va: 0, vb: (m1 + m2)})
276                    if self.simplify(expr1) == self.simplify(expr2):
277                        continue
278                return False
279            return True
280
281        changed = True
282        while changed:
283            changed = False
284            for i, j in itertools.product(
285                reversed(range(len(sizes))), reversed(range(len(sizes)))
286            ):
287                if i == j or sizes[i] is None or sizes[j] is None:
288                    continue
289                if can_merge_dims(i, j):
290                    changed = True
291                    sizes[i] = sizes[i] * sizes[j]
292                    sizes[j] = None
293
294        def reindex(index):
295            it = list(reversed(index))
296            new_index = []
297            for size in sizes:
298                if size is None:
299                    new_index.append(sympy.Integer(0))
300                else:
301                    new_index.append(it.pop())
302            assert not it
303            return new_index
304
305        def prune(index):
306            assert len(index) == len(sizes)
307            return [i for i, s in zip(index, sizes) if s is not None]
308
309        return [x for x in sizes if x is not None], reindex, prune
310
311    # Note - [On Statically Known]
312    #
313    # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system
314    # operated by providing essentially a question, where the size hinted values were evaluated. If the condition was
315    # true, we add a guard and return True, otherwise, False.
316    #
317    # def maybe_guard_foo(args):
318    #   if size_hinted_check(args):
319    #       return False # No guard, no optim
320    #   guard(args) # Make a guard
321    #   return True # Safe to apply optimization
322    #
323    # The prior system incurred a guard, and green lit an optimization.
324    #
325    # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the
326    # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we
327    # return False.
328    #
329    # def maybe_guard_foo(args):
330    #   if all_static(args):
331    #       return True # Safe to apply optimization
332    #   else:
333    #       return False # No guard, no optim
334
335    # See Note - [On Statically Known]
336
337    def is_expr_static_and_true(self, expr: Union[sympy.Basic, bool]) -> bool:
338        return evaluate_expr(self.shape_env, expr)
339
340    def statically_known_equals(
341        self, left: Union[Expr, int], right: Union[Expr, int]
342    ) -> bool:
343        """
344        Returns a bool indicating if it is sound to optimize as if left and right are equal.
345        """
346        return self.is_expr_static_and_true(sympy.Eq(left, right))  # type: ignore[arg-type]
347
348    # See Note - [On Statically Known]
349    def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
350        """
351        Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
352        """
353        return len(left) == len(right) and all(
354            self.statically_known_equals(l, r) for l, r in zip(left, right)
355        )
356
357    # See Note - [On Statically Known]
358    def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool:
359        """
360        Returns a bool indicating if it is sound to optimize as if left is less than or equal to right.
361        """
362        expr = left <= right
363        return self.is_expr_static_and_true(expr)
364
365    # See Note - [On Statically Known]
366    def statically_known_geq(self, left: Expr, right: Union[Expr, int]) -> bool:
367        """
368        Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right.
369        """
370        expr = left >= right
371        return self.is_expr_static_and_true(expr)
372
373    # See Note - [On Statically Known]
374    def statically_known_lt(self, left: Expr, right: Union[Expr, int]) -> bool:
375        """
376        Returns a bool indicating if it is sound to optimize as if left is less than right.
377        """
378        expr = left < right
379        return self.is_expr_static_and_true(expr)
380
381    # See Note - [On Statically Known]
382    def statically_known_gt(self, left: Expr, right: Union[Expr, int]) -> bool:
383        """
384        Returns a bool indicating if it is sound to optimize as if left is greater than right.
385        """
386        expr = left > right
387        return self.is_expr_static_and_true(expr)
388
389    # See Note - [On Statically Known]
390    def statically_known_multiple_of(
391        self, numerator: Expr, denominator: Union[Expr, int]
392    ) -> bool:
393        """
394        Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
395        """
396        if free_unbacked_symbols(numerator) or free_unbacked_symbols(denominator):
397            return False
398        expr = sympy.Eq(numerator % denominator, 0)
399        return self.is_expr_static_and_true(expr)  # type: ignore[arg-type]
400
401    # See Note - [On Statically Known]
402    def statically_known_power_of_2(self, expr: Expr) -> bool:
403        """
404        Returns a bool indicating if x is known to be a power of 2.
405        """
406        return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr))
407
408    # The guard functions require you to ALREADY KNOW that a particular
409    # condition holds.  If you don't know (you want to guard on an expression
410    # being a particular value, and then get access to that value), use
411    # the evaluate functions.
412
413    def guard_equals(self, left: Expr, right: Expr) -> Expr:
414        if isinstance(left, Expr):
415            left = sympy_subs(left, self.inv_precomputed_replacements)  # type: ignore[arg-type]
416        if isinstance(right, Expr):
417            right = sympy_subs(right, self.inv_precomputed_replacements)  # type: ignore[arg-type]
418        assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
419        return left
420
421    def guard_leq(self, left: Expr, right: Expr) -> None:
422        return self.guard_lt(left, right + 1)
423
424    def guard_lt(self, left: Expr, right: Expr) -> None:
425        assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
426
427    def guarded_order(self, seq):
428        """
429        Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing.
430        """
431        seq = [*map(self.remove_precomputed_replacements, seq)]
432        seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)]
433        seq.sort()
434        order = [-1] * len(seq)
435        last_var = None
436        for new_index, (_, orig_index, var) in enumerate(seq):
437            order[orig_index] = new_index
438            if last_var is not None:
439                self.guard_leq(last_var, var)
440            last_var = var
441        return order
442
443    # The evaluate functions evaluate some symbolic sympy expression
444    # (NB: not necessarily an Expr) and return what the concrete result
445    # is, guarding on the expression being that result
446
447    # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b)
448    # as this will ensure that you actually have a sympy'ified expression,
449    # and will prevent you from incorrectly writing evaluate_expr(a == b)
450    # which does the wrong thing if a or b is a sympy expression
451    def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool:
452        assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left)
453        return self.shape_env.evaluate_expr(sympy.sympify(left))
454
455    def evaluate_min(self, left: Expr, right: Expr) -> Expr:
456        """return the smaller of left and right, and guard on that choice"""
457        if isinstance(left, Expr):
458            left = sympy_subs(left, self.inv_precomputed_replacements)  # type: ignore[arg-type]
459        if isinstance(right, Expr):
460            right = sympy_subs(right, self.inv_precomputed_replacements)  # type: ignore[arg-type]
461        try:
462            lv = self.size_hint(left)
463            rv = self.size_hint(right)
464        except TypeError:  # unbacked symints
465            if left == right or self.statically_known_leq(left, right):
466                return left
467            if self.statically_known_leq(right, left):
468                return right
469            gcd = sympy.gcd(left, right)
470            if left == gcd:  # handle `min(10*u0, u0)` etc
471                return left
472            if right == gcd:
473                return right
474            raise TypeError(
475                f"evaluate_min({left}, {right}) with unbacked symints"
476            ) from None
477        if lv <= rv:
478            self.guard_leq(left, right)
479            return left
480        else:
481            self.guard_leq(right, left)
482            return right
483
484    def evaluate_max(self, left: Expr, right: Expr) -> Expr:
485        """return the larger of left and right, and guard on that choice"""
486        # Always choose the opposite of eval min for consistency
487        # This means min(a, b) and max(a, b) produce the same guards
488        min_val = self.evaluate_min(left, right)
489        return right if min_val is left else left
490
491    def evaluate_static_shape(self, left: Union[Expr, int]) -> int:
492        if isinstance(left, int):
493            return left
494        right = self.size_hint(left)
495        self.guard_equals(left, sympy.Integer(right))
496        return int(right)
497
498    def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> List[int]:
499        return [self.evaluate_static_shape(x) for x in left]
500
501    def remove_precomputed_replacements(self, expr: Expr) -> Expr:
502        if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols):  # type: ignore[attr-defined]
503            return sympy_subs(expr, self.inv_precomputed_replacements)  # type: ignore[arg-type]
504        return expr
505
506    def symbolic_hint(self, expr: Union[Expr, int]) -> Union[Expr, int]:
507        if isinstance(expr, int):
508            return expr
509        # Substitute all hints into expr, but leave unbacked symints alone
510        expr = self.simplify(expr)
511        if not isinstance(expr, Expr):
512            assert isinstance(expr, int)
513            return expr
514        free_symbols = expr.free_symbols
515        if not free_symbols:
516            try:
517                return int(expr)  # type: ignore[return-value]
518            except TypeError:
519                return expr  # inf/nan/I
520        expr = self.remove_precomputed_replacements(expr)
521        return sympy_subs(expr, self.var_to_val)
522
523    def size_hint(
524        self, expr: Union[Expr, int], *, fallback: Optional[int] = None
525    ) -> int:
526        out = self.symbolic_hint(expr)
527        if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
528            # Use the provided heuristic fallback hint
529            unbacked_sym_vrs = {
530                s: self.shape_env.var_to_range.get(s, None) for s in out.free_symbols
531            }
532            if all(vr is not None for vr in unbacked_sym_vrs.values()):
533                hint_vr = bound_sympy(out, unbacked_sym_vrs)  # type: ignore[arg-type]
534                if isinstance(hint_vr.lower, (int, sympy.Integer)):
535                    fallback = max(fallback, int(hint_vr.lower))
536                if isinstance(hint_vr.upper, (int, sympy.Integer)):
537                    fallback = min(fallback, int(hint_vr.upper))
538            return fallback
539
540        try:
541            return int(out)
542        except Exception:
543            log.debug("failed on: %s", out)
544            raise
545
546    def size_hints(
547        self,
548        exprs: Iterable[Expr],
549        *,
550        fallback: Optional[int] = None,
551    ) -> Tuple[int, ...]:
552        return tuple(self.size_hint(x, fallback=fallback) for x in exprs)
553
554    def _lru_cache(self, fn, maxsize=None):
555        """
556        Wrapper around functools.lru_cache that clears when replacements
557        has been invalidated.
558        """
559        fn_cache = functools.lru_cache(maxsize)(fn)
560        prior_len = len(self.replacements)
561
562        @functools.wraps(fn)
563        def wrapper(*args, **kwargs):
564            nonlocal prior_len
565            if prior_len != len(self.replacements):
566                prior_len = len(self.replacements)
567                fn_cache.cache_clear()
568            return fn_cache(*args, **kwargs)
569
570        return wrapper
571
572    def make_stride_vars_cache(self):
573        cache = self._lru_cache(self._stride_vars)
574
575        def stride_vars(
576            index: Expr,
577            vars: Sequence[sympy.Symbol],
578            support_vars: Optional[Sequence[sympy.Symbol]] = None,
579        ) -> List[Expr]:
580            if not support_vars:
581                support_vars = vars
582            return cache(index, tuple(vars), tuple(support_vars))
583
584        return stride_vars
585
586    def _stride_vars(
587        self,
588        index: Expr,
589        vars: Sequence[sympy.Symbol],
590        support_vars: Sequence[sympy.Symbol],
591    ) -> List[Expr]:
592        """Convert an indexing expression back into strides
593
594        NOTE: This is only valid if the index is a standard strided offset
595        calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a
596        stride of -10 because the index wraps around after the first element
597
598        """
599        strides = []
600        index = self.simplify(index)
601        # remove any offset
602        index = index - sympy_subs(
603            index, {v: sympy.Integer(0) for v in support_vars if v != 0}
604        )
605        for i in range(len(vars)):
606            # drop all the other dims
607            index_dim = sympy_subs(
608                index,
609                {
610                    support_vars[j]: sympy.Integer(0)
611                    for j in range(len(support_vars))
612                    if vars[i] != support_vars[j] and support_vars[j] != 0
613                },
614            )
615            v = vars[i]
616            if v == 0:
617                strides.append(sympy.Integer(0))
618            else:
619                # TODO(jansel): should we use sympy.diff here?
620                strides.append(
621                    sympy_subs(index_dim, {v: sympy.Integer(1)})
622                    - sympy_subs(index_dim, {v: sympy.Integer(0)})
623                )
624        return strides
625
626    def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr:
627        """Extract offset part of an indexing expression"""
628        index = self.simplify(index)
629        return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0})
630
631    def stride_hints(
632        self,
633        index: Expr,
634        vars: Sequence[sympy.Symbol],
635        support_vars: Optional[Sequence[sympy.Symbol]] = None,
636    ) -> List[int]:
637        for v in index.free_symbols:
638            if symbol_is_type(v, SymT.INDIRECT):  # type: ignore[attr-defined]
639                index = sympy_subs(index, {v: 0})  # type: ignore[dict-item]
640        result = []
641        for s in self.stride_vars(index, vars, support_vars):
642            try:
643                result.append(self.size_hint(s))
644            except TypeError:
645                result.append(0)
646        return result
647
648    def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
649        strides = tuple(map(abs, self.stride_hints(index, vars)))
650        order = list(range(len(strides)))
651        order.sort(key=lambda x: (strides[x] == 0, strides[x]))
652        return order
653
654    def lookup_precomputed_size(self, expr: Expr) -> Expr:
655        if (
656            isinstance(expr, (int, sympy.Symbol, sympy.Number))
657            or expr.is_number
658            or expr.is_symbol
659        ):
660            return expr
661        expr = self.remove_precomputed_replacements(expr)
662        if expr not in self.precomputed_replacements:
663            sym = sympy_index_symbol_with_prefix(
664                SymT.PRECOMPUTED_SIZE, len(self.precomputed_replacements)
665            )
666            self.precomputed_replacements[expr] = sym
667            self.inv_precomputed_replacements[sym] = expr
668        return self.precomputed_replacements[expr]
669
670    def free_symbols(self) -> Set[sympy.Symbol]:
671        return set(self.var_to_val.keys()) - set(self.replacements.keys())
672
673    def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr:
674        """
675        A pair of special ModularIndexing can be combined.
676
677        E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b)
678        We can simplify this to ModuleIndexing(x, 1, b), if
679        1. x is non negative integer
680        2. a and b are positive integers
681        3. a is a multiple of b.
682        """
683
684        def _check_args(x, div, mod, is_first):
685            if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer):
686                return False
687            if div != 1:
688                return False
689            if mod <= 0:
690                return False
691
692            if is_first:
693                # first ModularIndexing should conatins a nested ModularIndex
694                if not isinstance(x, ModularIndexing):
695                    return False
696            else:
697                # second ModularIndexing should constains a non-negative
698                # symbol
699                if not isinstance(x, sympy.Symbol) or not self.statically_known_geq(
700                    x, 0
701                ):
702                    return False
703            return True
704
705        if isinstance(index, ModularIndexing):
706            x, div, mod = index.args
707
708            if not _check_args(x, div, mod, True):
709                return index
710
711            x2, div2, mod2 = x.args
712
713            if not _check_args(x2, div2, mod2, False):
714                return index
715
716            if mod2 % mod != 0:
717                return index
718
719            return ModularIndexing(x2, 1, mod)
720
721        return index
722
723    def expand_floor_div(
724        self, index: sympy.Expr
725    ) -> Union[bool, Tuple[sympy.Expr, sympy.Expr]]:
726        """
727        Expand the FloorDiv to the entire expression so that the expression may
728        be simplfied.
729
730        E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables
731        x1, x2, index expression 'x1 * 2b + x2' can be easily combined.
732        But index expression 'x1 * b + x2 // 2' can not.
733        By expanding the FloorDiv to the entire expression, we get
734        '(x1 * 2b + x2) // 2'. This transformation allows us to merge loops
735        for the numerator!
736
737        Return false if this optimization can be applied;
738        Return the new expression and the denominator otherwise.
739        The original expression will be equivalent to 'new_expression // denominator'
740        """
741        if not isinstance(index, sympy.Add):
742            return False
743        terms = index.args
744
745        if len(terms) < 2:
746            return False
747        floor_div_index = -1
748        varlist = []
749        factorlist = []
750        for idx, term in enumerate(terms):
751            if isinstance(term, sympy.Mul):
752                # For dynamic shape, term like '2*s1*x1' has 3 child nodes.
753                # - A integer for 2
754                # - A symbol for s1
755                # - A symbol for x1
756                # Skip for now.
757                if len(term.args) != 2:
758                    return False
759                factor, var = term.args
760                varlist.append(var)
761                factorlist.append(factor)
762                if not isinstance(factor, sympy.Integer) or not isinstance(
763                    var, sympy.Symbol
764                ):
765                    return False
766                # It's easier to reason about the correceness of the transformation
767                # for non-negative integers.
768                if not self.statically_known_geq(var, 0):
769                    return False
770            elif isinstance(term, FloorDiv):
771                var, factor = term.args
772                if not isinstance(factor, sympy.Integer) or not isinstance(
773                    var, sympy.Symbol
774                ):
775                    return False
776                if not self.statically_known_geq(var, 0):
777                    return False
778                if floor_div_index >= 0:
779                    # can not handle multi FloorDiv yet
780                    return False
781
782                floor_div_index = idx
783                varlist.append(var)
784                # this factor is denominator
785                factorlist.append(factor)
786            else:
787                return False
788
789        if floor_div_index < 0:
790            return False
791
792        # Construct the new expression and remember the denominator
793        denominator = factorlist[floor_div_index]
794        new_index = sympy.Integer(0)
795
796        for var, factor, idx in zip(varlist, factorlist, itertools.count()):
797            if idx == floor_div_index:
798                new_index += var
799            else:
800                new_index += (factor * denominator) * var
801
802        return new_index, denominator
803
804
805def join_dimensions(expr: Expr) -> Expr:
806    if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing):
807        return expr  # fast exit path
808    return _join_dimensions_cached(expr)
809
810
811@functools.lru_cache(256)
812def _join_dimensions_cached(expr: Expr) -> Expr:
813    """
814    ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
815    becomes
816    ModularIndexing(i0, 1, 128)
817    ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32)
818    becomes i0
819
820
821    This type of pattern can come from view operations
822    """
823    assert isinstance(expr, sympy.Add)
824
825    scale = sympy.Wild("scale", exclude=[0], integer=True)
826    base = sympy.Wild("base", integer=True)
827    divisor = sympy.Wild("divisor", integer=True)
828    mod1 = sympy.Wild("modulus", integer=True)
829    mod2 = sympy.Wild("modulus2", integer=True)
830    for term1 in expr.args:
831        m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
832        if m1:
833            for term2 in expr.args:
834                m2 = term2.match(
835                    m1[scale]
836                    * m1[mod1]
837                    * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2)
838                )
839                if m2 and term1 != term2:
840                    expr = join_dimensions(
841                        expr
842                        - term1
843                        - term2
844                        + m1[scale]
845                        * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2])
846                    )
847                    return expr
848    for term1 in expr.args:
849        m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
850        if m1:
851            for term2 in expr.args:
852                m2 = term2.match(
853                    m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1])
854                )
855                if m2 is not None:  # in case of success we get an empty dict here
856                    expr = join_dimensions(
857                        expr
858                        - term1
859                        - term2
860                        + m1[scale] * FloorDiv(m1[base], m1[divisor])
861                    )
862                    return expr
863    return expr
864
865
866class SimplifyIndexing(V.WrapperHandler):  # type: ignore[name-defined]
867    """
868    A wrapper around .virtualize.ops that uses var range information to
869    simplify ModularIndexing/FloorDiv.
870    """
871
872    def __init__(self, inner, var_ranges: VarRanges) -> None:
873        super().__init__(inner)
874        self.name = "SimplifyIndexing"
875        self._simplify: Callable[
876            [Expr], Expr
877        ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
878
879    def load(self, name: str, index: sympy.Expr):
880        return self._inner.load(name, self._simplify(index))
881
882    def store(self, name, index, value, mode=None):
883        return self._inner.store(name, self._simplify(index), value, mode=mode)
884
885    def store_reduction(self, name, index, value):
886        return self._inner.store_reduction(name, self._simplify(index), value)
887
888    def index_expr(self, index, dtype):
889        return self._inner.index_expr(self._simplify(index), dtype)
890
891    def check_bounds(self, index, size, lower, upper):
892        return self._inner.check_bounds(self._simplify(index), size, lower, upper)
893