1# mypy: allow-untyped-defs 2import functools 3import itertools 4import logging 5from typing import ( 6 Any, 7 Callable, 8 cast, 9 Dict, 10 Iterable, 11 List, 12 Optional, 13 Sequence, 14 Set, 15 Tuple, 16 Union, 17) 18 19import sympy 20from sympy import Expr 21 22from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, ShapeEnv 23from torch.utils._sympy.functions import FloorDiv, ModularIndexing 24from torch.utils._sympy.symbol import symbol_is_type, SymT 25from torch.utils._sympy.value_ranges import bound_sympy, IntInfinity, ValueRanges 26 27from .runtime.runtime_utils import is_power_of_2 28from .utils import ( 29 has_free_symbols, 30 sympy_index_symbol, 31 sympy_index_symbol_with_prefix, 32 sympy_subs, 33 VarRanges, 34) 35from .virtualized import V 36 37 38log = logging.getLogger(__name__) 39 40 41def evaluate_expr( 42 shape_env: ShapeEnv, 43 expr: Union[sympy.Basic, bool], 44 axioms: Optional[Tuple[sympy.Expr]] = None, 45 var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges[Any]]]] = None, 46) -> bool: 47 if expr in (True, False): 48 return bool(expr) 49 50 try: 51 simplified = shape_env._maybe_evaluate_static( 52 expr, 53 axioms=axioms, 54 var_to_range=var_to_range, 55 ) 56 if simplified is not None: 57 return bool(simplified) 58 except Exception: 59 log.debug("Could not simplify %s", expr, exc_info=True) 60 61 return False 62 63 64# This class is a little awkward, because ShapeEnv is doing most of the heavy 65# lifting and in some cases we should be directly passing through to ShapeEnv, 66# but there is some extra inductor logic that needs to be handled here 67class SizeVarAllocator: 68 def __init__(self, shape_env=None) -> None: 69 super().__init__() 70 if shape_env is None: 71 shape_env = ShapeEnv() 72 self.shape_env = shape_env 73 self.var_to_val = self.shape_env.var_to_val 74 self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements 75 # Maps of dynamic sizes that have to be precomputed on the host to the kernel args. 76 # The basic idea is if we have some complicated sympy expression 77 # f(s0), we may choose to precompute it on the host and then replace 78 # all occurrences of that sympy expression with ps0, so that when we 79 # codegen we simply reference ps0 directly without repeating 80 # f(s0). Unlike regular size variables, ps variables cannot be 81 # guarded upon; so if we are asked to guard on a Sympy expression 82 # which potentially could have already had a precomputed replacement 83 # on it, we are obligated to invert the precomputed replacements 84 # (inv_precomputed_replacements). 85 self.precomputed_replacements: Dict[Expr, sympy.Symbol] = {} 86 self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = {} 87 self.stride_vars = self.make_stride_vars_cache() 88 self.simplify_with_ranges = self.make_simplify_with_ranges_cache() 89 self._simplify_loops = self.make_simplify_loops_cache() 90 91 def simplify(self, expr: Expr): 92 return sympy.expand(expr).xreplace(self.replacements) 93 94 def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]: 95 """ 96 self._simplify_with_ranges() can be expensive, cache its results 97 """ 98 cache: Dict[Tuple[Any, ...], Expr] = {} 99 replacement_count = len(self.replacements) 100 101 def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr: 102 nonlocal replacement_count 103 if replacement_count != len(self.replacements): 104 # new replacements invalidates cached results 105 cache.clear() 106 replacement_count = len(self.replacements) 107 key = (expr, *var_ranges.items()) 108 result = cache.get(key, None) 109 if result is None: 110 result = self._simplify_with_ranges(expr, var_ranges) 111 cache[key] = result 112 return result 113 114 return simplify_with_ranges 115 116 def make_simplify_loops_cache(self): 117 """ 118 self._simplify_with_ranges() can be expensive, cache its results 119 """ 120 cache: Dict[Tuple[Any, ...], Any] = {} 121 replacement_count = len(self.replacements) 122 123 def simplify_loops(index_vars, sizes, index_formulas): 124 nonlocal replacement_count 125 if replacement_count != len(self.replacements): 126 # new replacements invalidates cached results 127 cache.clear() 128 replacement_count = len(self.replacements) 129 key = (*index_vars, *sizes, *index_formulas) 130 result = cache.get(key, None) 131 if result is None: 132 result = self._simplify_loops_impl(index_vars, sizes, index_formulas) 133 cache[key] = result 134 return result 135 136 return simplify_loops 137 138 def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr: 139 """ 140 Simplify indexing expression with knowledge of the ranges of 141 iteration variables. 142 """ 143 144 expr = join_dimensions(self.simplify(expr)) 145 original_expr = expr 146 147 var_to_range = dict(self.shape_env.var_to_range) 148 var_to_range.update( 149 { 150 k: ValueRanges( 151 0, max(0, v - 1) if not has_free_symbols([v]) else IntInfinity() 152 ) 153 for k, v in var_ranges.items() 154 } 155 ) 156 for var in expr.free_symbols: 157 if var not in var_to_range: 158 var_to_range[var] = ValueRanges(0, IntInfinity()) 159 160 var_to_range_tuple = cast( 161 Tuple[Tuple[sympy.Symbol, ValueRanges[sympy.Expr]]], 162 tuple(var_to_range.items()), 163 ) 164 165 axioms = [] 166 for var, upper_bound in var_ranges.items(): 167 axioms.append(0 <= var) 168 axioms.append(var < upper_bound) 169 axioms = tuple(axioms) + self.shape_env.get_axioms() 170 171 def statically_known(expr): 172 evaluated = self.shape_env._maybe_evaluate_static( 173 expr, 174 axioms=axioms, 175 var_to_range=var_to_range_tuple, 176 ) 177 return bool(evaluated) 178 179 def remove_zero_terms(base, divisor): 180 """Symbols smaller than the divisor are zero""" 181 if not statically_known(base >= 0): 182 return base 183 184 for v in base.free_symbols: 185 if v in var_ranges: 186 # var smaller than divisor can be removed 187 # if the rest is guaranteed to be multiple of divisor 188 rest = sympy.Wild("_rest", exclude=[v]) 189 m = base.match(v + rest) 190 if m and v not in m[rest].free_symbols: 191 gcd = sympy.gcd(m[rest], divisor) 192 if gcd == divisor: 193 if statically_known(v < divisor): 194 base = m[rest] 195 return base 196 197 def visit_indexing_div(base, divisor): 198 return FloorDiv(remove_zero_terms(base, divisor), divisor) 199 200 def visit_modular_indexing(base, divisor, modulus): 201 base = remove_zero_terms(base, divisor) 202 203 can_remove_mod = statically_known(base >= 0) and statically_known( 204 base < modulus * divisor 205 ) 206 207 if can_remove_mod: 208 return FloorDiv(base, divisor) 209 return ModularIndexing(base, divisor, modulus) 210 211 if expr.has(ModularIndexing): 212 expr = expr.replace( 213 ModularIndexing( 214 sympy.Wild("base", integer=True), 215 sympy.Wild("divisor", integer=True), 216 sympy.Wild("modulus", integer=True), 217 ), 218 visit_modular_indexing, 219 ) 220 221 if expr.has(FloorDiv): 222 expr = expr.replace( 223 FloorDiv( 224 sympy.Wild("base", integer=True), 225 sympy.Wild("divisor", integer=True), 226 ), 227 visit_indexing_div, 228 ) 229 230 if expr != original_expr: 231 return self._simplify_with_ranges(expr, var_ranges) 232 return expr 233 234 def _simplify_loops_impl( 235 self, index_vars: List[sympy.Symbol], sizes, index_formulas 236 ): 237 """ 238 Try to remove as many axis from loop iterations as possible, by: 239 1) removing size==1 dimensions 240 2) fuse contiguous dimensions into a single loop 241 If channel_last = True, we will prevent the last dim fused with other dims 242 """ 243 sizes = list(map(self.simplify, sizes)) 244 245 strides = [ 246 # index_formulas may contain boolean expressions (e.g. s0 < 10), 247 # for which "strides" don't make sense so we ignore them here. 248 # NOTE: These expressions may still block merging dims in the sound 249 # substitution test performed in can_merge_dims. 250 self.stride_vars(x, index_vars) 251 if isinstance(x, sympy.Expr) 252 else [0] * len(index_vars) 253 for x in index_formulas 254 ] 255 assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) 256 257 for i in range(len(sizes)): 258 if sizes[i] == 1: 259 # remove dim 260 sizes[i] = None 261 262 def can_merge_dims(a, b): 263 for k in range(len(strides)): 264 if self.simplify(strides[k][a] * sizes[a]) == self.simplify( 265 strides[k][b] 266 ): 267 # approximate test passed, try sound version 268 va = index_vars[a] 269 vb = index_vars[b] 270 m1 = sympy_index_symbol("_merge_tester1") 271 m2 = sympy_index_symbol("_merge_tester2") 272 # NOTE: can't sub vb=0 here in case va * vb appears in the expression, 273 # in which case both expr1 and expr2 would be zero! 274 expr1 = sympy_subs(index_formulas[k], {va: m1 * sizes[a], vb: m2}) 275 expr2 = sympy_subs(index_formulas[k], {va: 0, vb: (m1 + m2)}) 276 if self.simplify(expr1) == self.simplify(expr2): 277 continue 278 return False 279 return True 280 281 changed = True 282 while changed: 283 changed = False 284 for i, j in itertools.product( 285 reversed(range(len(sizes))), reversed(range(len(sizes))) 286 ): 287 if i == j or sizes[i] is None or sizes[j] is None: 288 continue 289 if can_merge_dims(i, j): 290 changed = True 291 sizes[i] = sizes[i] * sizes[j] 292 sizes[j] = None 293 294 def reindex(index): 295 it = list(reversed(index)) 296 new_index = [] 297 for size in sizes: 298 if size is None: 299 new_index.append(sympy.Integer(0)) 300 else: 301 new_index.append(it.pop()) 302 assert not it 303 return new_index 304 305 def prune(index): 306 assert len(index) == len(sizes) 307 return [i for i, s in zip(index, sizes) if s is not None] 308 309 return [x for x in sizes if x is not None], reindex, prune 310 311 # Note - [On Statically Known] 312 # 313 # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system 314 # operated by providing essentially a question, where the size hinted values were evaluated. If the condition was 315 # true, we add a guard and return True, otherwise, False. 316 # 317 # def maybe_guard_foo(args): 318 # if size_hinted_check(args): 319 # return False # No guard, no optim 320 # guard(args) # Make a guard 321 # return True # Safe to apply optimization 322 # 323 # The prior system incurred a guard, and green lit an optimization. 324 # 325 # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the 326 # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we 327 # return False. 328 # 329 # def maybe_guard_foo(args): 330 # if all_static(args): 331 # return True # Safe to apply optimization 332 # else: 333 # return False # No guard, no optim 334 335 # See Note - [On Statically Known] 336 337 def is_expr_static_and_true(self, expr: Union[sympy.Basic, bool]) -> bool: 338 return evaluate_expr(self.shape_env, expr) 339 340 def statically_known_equals( 341 self, left: Union[Expr, int], right: Union[Expr, int] 342 ) -> bool: 343 """ 344 Returns a bool indicating if it is sound to optimize as if left and right are equal. 345 """ 346 return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type] 347 348 # See Note - [On Statically Known] 349 def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool: 350 """ 351 Returns a bool indicating if it is sound to optimize as if left and right lists are equal. 352 """ 353 return len(left) == len(right) and all( 354 self.statically_known_equals(l, r) for l, r in zip(left, right) 355 ) 356 357 # See Note - [On Statically Known] 358 def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool: 359 """ 360 Returns a bool indicating if it is sound to optimize as if left is less than or equal to right. 361 """ 362 expr = left <= right 363 return self.is_expr_static_and_true(expr) 364 365 # See Note - [On Statically Known] 366 def statically_known_geq(self, left: Expr, right: Union[Expr, int]) -> bool: 367 """ 368 Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right. 369 """ 370 expr = left >= right 371 return self.is_expr_static_and_true(expr) 372 373 # See Note - [On Statically Known] 374 def statically_known_lt(self, left: Expr, right: Union[Expr, int]) -> bool: 375 """ 376 Returns a bool indicating if it is sound to optimize as if left is less than right. 377 """ 378 expr = left < right 379 return self.is_expr_static_and_true(expr) 380 381 # See Note - [On Statically Known] 382 def statically_known_gt(self, left: Expr, right: Union[Expr, int]) -> bool: 383 """ 384 Returns a bool indicating if it is sound to optimize as if left is greater than right. 385 """ 386 expr = left > right 387 return self.is_expr_static_and_true(expr) 388 389 # See Note - [On Statically Known] 390 def statically_known_multiple_of( 391 self, numerator: Expr, denominator: Union[Expr, int] 392 ) -> bool: 393 """ 394 Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. 395 """ 396 if free_unbacked_symbols(numerator) or free_unbacked_symbols(denominator): 397 return False 398 expr = sympy.Eq(numerator % denominator, 0) 399 return self.is_expr_static_and_true(expr) # type: ignore[arg-type] 400 401 # See Note - [On Statically Known] 402 def statically_known_power_of_2(self, expr: Expr) -> bool: 403 """ 404 Returns a bool indicating if x is known to be a power of 2. 405 """ 406 return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr)) 407 408 # The guard functions require you to ALREADY KNOW that a particular 409 # condition holds. If you don't know (you want to guard on an expression 410 # being a particular value, and then get access to that value), use 411 # the evaluate functions. 412 413 def guard_equals(self, left: Expr, right: Expr) -> Expr: 414 if isinstance(left, Expr): 415 left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] 416 if isinstance(right, Expr): 417 right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] 418 assert self.shape_env.evaluate_expr(sympy.Eq(left, right)) 419 return left 420 421 def guard_leq(self, left: Expr, right: Expr) -> None: 422 return self.guard_lt(left, right + 1) 423 424 def guard_lt(self, left: Expr, right: Expr) -> None: 425 assert self.shape_env.evaluate_expr(sympy.Lt(left, right)) 426 427 def guarded_order(self, seq): 428 """ 429 Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing. 430 """ 431 seq = [*map(self.remove_precomputed_replacements, seq)] 432 seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)] 433 seq.sort() 434 order = [-1] * len(seq) 435 last_var = None 436 for new_index, (_, orig_index, var) in enumerate(seq): 437 order[orig_index] = new_index 438 if last_var is not None: 439 self.guard_leq(last_var, var) 440 last_var = var 441 return order 442 443 # The evaluate functions evaluate some symbolic sympy expression 444 # (NB: not necessarily an Expr) and return what the concrete result 445 # is, guarding on the expression being that result 446 447 # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b) 448 # as this will ensure that you actually have a sympy'ified expression, 449 # and will prevent you from incorrectly writing evaluate_expr(a == b) 450 # which does the wrong thing if a or b is a sympy expression 451 def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool: 452 assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left) 453 return self.shape_env.evaluate_expr(sympy.sympify(left)) 454 455 def evaluate_min(self, left: Expr, right: Expr) -> Expr: 456 """return the smaller of left and right, and guard on that choice""" 457 if isinstance(left, Expr): 458 left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] 459 if isinstance(right, Expr): 460 right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] 461 try: 462 lv = self.size_hint(left) 463 rv = self.size_hint(right) 464 except TypeError: # unbacked symints 465 if left == right or self.statically_known_leq(left, right): 466 return left 467 if self.statically_known_leq(right, left): 468 return right 469 gcd = sympy.gcd(left, right) 470 if left == gcd: # handle `min(10*u0, u0)` etc 471 return left 472 if right == gcd: 473 return right 474 raise TypeError( 475 f"evaluate_min({left}, {right}) with unbacked symints" 476 ) from None 477 if lv <= rv: 478 self.guard_leq(left, right) 479 return left 480 else: 481 self.guard_leq(right, left) 482 return right 483 484 def evaluate_max(self, left: Expr, right: Expr) -> Expr: 485 """return the larger of left and right, and guard on that choice""" 486 # Always choose the opposite of eval min for consistency 487 # This means min(a, b) and max(a, b) produce the same guards 488 min_val = self.evaluate_min(left, right) 489 return right if min_val is left else left 490 491 def evaluate_static_shape(self, left: Union[Expr, int]) -> int: 492 if isinstance(left, int): 493 return left 494 right = self.size_hint(left) 495 self.guard_equals(left, sympy.Integer(right)) 496 return int(right) 497 498 def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> List[int]: 499 return [self.evaluate_static_shape(x) for x in left] 500 501 def remove_precomputed_replacements(self, expr: Expr) -> Expr: 502 if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined] 503 return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] 504 return expr 505 506 def symbolic_hint(self, expr: Union[Expr, int]) -> Union[Expr, int]: 507 if isinstance(expr, int): 508 return expr 509 # Substitute all hints into expr, but leave unbacked symints alone 510 expr = self.simplify(expr) 511 if not isinstance(expr, Expr): 512 assert isinstance(expr, int) 513 return expr 514 free_symbols = expr.free_symbols 515 if not free_symbols: 516 try: 517 return int(expr) # type: ignore[return-value] 518 except TypeError: 519 return expr # inf/nan/I 520 expr = self.remove_precomputed_replacements(expr) 521 return sympy_subs(expr, self.var_to_val) 522 523 def size_hint( 524 self, expr: Union[Expr, int], *, fallback: Optional[int] = None 525 ) -> int: 526 out = self.symbolic_hint(expr) 527 if not isinstance(out, (int, sympy.Integer)) and fallback is not None: 528 # Use the provided heuristic fallback hint 529 unbacked_sym_vrs = { 530 s: self.shape_env.var_to_range.get(s, None) for s in out.free_symbols 531 } 532 if all(vr is not None for vr in unbacked_sym_vrs.values()): 533 hint_vr = bound_sympy(out, unbacked_sym_vrs) # type: ignore[arg-type] 534 if isinstance(hint_vr.lower, (int, sympy.Integer)): 535 fallback = max(fallback, int(hint_vr.lower)) 536 if isinstance(hint_vr.upper, (int, sympy.Integer)): 537 fallback = min(fallback, int(hint_vr.upper)) 538 return fallback 539 540 try: 541 return int(out) 542 except Exception: 543 log.debug("failed on: %s", out) 544 raise 545 546 def size_hints( 547 self, 548 exprs: Iterable[Expr], 549 *, 550 fallback: Optional[int] = None, 551 ) -> Tuple[int, ...]: 552 return tuple(self.size_hint(x, fallback=fallback) for x in exprs) 553 554 def _lru_cache(self, fn, maxsize=None): 555 """ 556 Wrapper around functools.lru_cache that clears when replacements 557 has been invalidated. 558 """ 559 fn_cache = functools.lru_cache(maxsize)(fn) 560 prior_len = len(self.replacements) 561 562 @functools.wraps(fn) 563 def wrapper(*args, **kwargs): 564 nonlocal prior_len 565 if prior_len != len(self.replacements): 566 prior_len = len(self.replacements) 567 fn_cache.cache_clear() 568 return fn_cache(*args, **kwargs) 569 570 return wrapper 571 572 def make_stride_vars_cache(self): 573 cache = self._lru_cache(self._stride_vars) 574 575 def stride_vars( 576 index: Expr, 577 vars: Sequence[sympy.Symbol], 578 support_vars: Optional[Sequence[sympy.Symbol]] = None, 579 ) -> List[Expr]: 580 if not support_vars: 581 support_vars = vars 582 return cache(index, tuple(vars), tuple(support_vars)) 583 584 return stride_vars 585 586 def _stride_vars( 587 self, 588 index: Expr, 589 vars: Sequence[sympy.Symbol], 590 support_vars: Sequence[sympy.Symbol], 591 ) -> List[Expr]: 592 """Convert an indexing expression back into strides 593 594 NOTE: This is only valid if the index is a standard strided offset 595 calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a 596 stride of -10 because the index wraps around after the first element 597 598 """ 599 strides = [] 600 index = self.simplify(index) 601 # remove any offset 602 index = index - sympy_subs( 603 index, {v: sympy.Integer(0) for v in support_vars if v != 0} 604 ) 605 for i in range(len(vars)): 606 # drop all the other dims 607 index_dim = sympy_subs( 608 index, 609 { 610 support_vars[j]: sympy.Integer(0) 611 for j in range(len(support_vars)) 612 if vars[i] != support_vars[j] and support_vars[j] != 0 613 }, 614 ) 615 v = vars[i] 616 if v == 0: 617 strides.append(sympy.Integer(0)) 618 else: 619 # TODO(jansel): should we use sympy.diff here? 620 strides.append( 621 sympy_subs(index_dim, {v: sympy.Integer(1)}) 622 - sympy_subs(index_dim, {v: sympy.Integer(0)}) 623 ) 624 return strides 625 626 def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr: 627 """Extract offset part of an indexing expression""" 628 index = self.simplify(index) 629 return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0}) 630 631 def stride_hints( 632 self, 633 index: Expr, 634 vars: Sequence[sympy.Symbol], 635 support_vars: Optional[Sequence[sympy.Symbol]] = None, 636 ) -> List[int]: 637 for v in index.free_symbols: 638 if symbol_is_type(v, SymT.INDIRECT): # type: ignore[attr-defined] 639 index = sympy_subs(index, {v: 0}) # type: ignore[dict-item] 640 result = [] 641 for s in self.stride_vars(index, vars, support_vars): 642 try: 643 result.append(self.size_hint(s)) 644 except TypeError: 645 result.append(0) 646 return result 647 648 def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]: 649 strides = tuple(map(abs, self.stride_hints(index, vars))) 650 order = list(range(len(strides))) 651 order.sort(key=lambda x: (strides[x] == 0, strides[x])) 652 return order 653 654 def lookup_precomputed_size(self, expr: Expr) -> Expr: 655 if ( 656 isinstance(expr, (int, sympy.Symbol, sympy.Number)) 657 or expr.is_number 658 or expr.is_symbol 659 ): 660 return expr 661 expr = self.remove_precomputed_replacements(expr) 662 if expr not in self.precomputed_replacements: 663 sym = sympy_index_symbol_with_prefix( 664 SymT.PRECOMPUTED_SIZE, len(self.precomputed_replacements) 665 ) 666 self.precomputed_replacements[expr] = sym 667 self.inv_precomputed_replacements[sym] = expr 668 return self.precomputed_replacements[expr] 669 670 def free_symbols(self) -> Set[sympy.Symbol]: 671 return set(self.var_to_val.keys()) - set(self.replacements.keys()) 672 673 def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: 674 """ 675 A pair of special ModularIndexing can be combined. 676 677 E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b) 678 We can simplify this to ModuleIndexing(x, 1, b), if 679 1. x is non negative integer 680 2. a and b are positive integers 681 3. a is a multiple of b. 682 """ 683 684 def _check_args(x, div, mod, is_first): 685 if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer): 686 return False 687 if div != 1: 688 return False 689 if mod <= 0: 690 return False 691 692 if is_first: 693 # first ModularIndexing should conatins a nested ModularIndex 694 if not isinstance(x, ModularIndexing): 695 return False 696 else: 697 # second ModularIndexing should constains a non-negative 698 # symbol 699 if not isinstance(x, sympy.Symbol) or not self.statically_known_geq( 700 x, 0 701 ): 702 return False 703 return True 704 705 if isinstance(index, ModularIndexing): 706 x, div, mod = index.args 707 708 if not _check_args(x, div, mod, True): 709 return index 710 711 x2, div2, mod2 = x.args 712 713 if not _check_args(x2, div2, mod2, False): 714 return index 715 716 if mod2 % mod != 0: 717 return index 718 719 return ModularIndexing(x2, 1, mod) 720 721 return index 722 723 def expand_floor_div( 724 self, index: sympy.Expr 725 ) -> Union[bool, Tuple[sympy.Expr, sympy.Expr]]: 726 """ 727 Expand the FloorDiv to the entire expression so that the expression may 728 be simplfied. 729 730 E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables 731 x1, x2, index expression 'x1 * 2b + x2' can be easily combined. 732 But index expression 'x1 * b + x2 // 2' can not. 733 By expanding the FloorDiv to the entire expression, we get 734 '(x1 * 2b + x2) // 2'. This transformation allows us to merge loops 735 for the numerator! 736 737 Return false if this optimization can be applied; 738 Return the new expression and the denominator otherwise. 739 The original expression will be equivalent to 'new_expression // denominator' 740 """ 741 if not isinstance(index, sympy.Add): 742 return False 743 terms = index.args 744 745 if len(terms) < 2: 746 return False 747 floor_div_index = -1 748 varlist = [] 749 factorlist = [] 750 for idx, term in enumerate(terms): 751 if isinstance(term, sympy.Mul): 752 # For dynamic shape, term like '2*s1*x1' has 3 child nodes. 753 # - A integer for 2 754 # - A symbol for s1 755 # - A symbol for x1 756 # Skip for now. 757 if len(term.args) != 2: 758 return False 759 factor, var = term.args 760 varlist.append(var) 761 factorlist.append(factor) 762 if not isinstance(factor, sympy.Integer) or not isinstance( 763 var, sympy.Symbol 764 ): 765 return False 766 # It's easier to reason about the correceness of the transformation 767 # for non-negative integers. 768 if not self.statically_known_geq(var, 0): 769 return False 770 elif isinstance(term, FloorDiv): 771 var, factor = term.args 772 if not isinstance(factor, sympy.Integer) or not isinstance( 773 var, sympy.Symbol 774 ): 775 return False 776 if not self.statically_known_geq(var, 0): 777 return False 778 if floor_div_index >= 0: 779 # can not handle multi FloorDiv yet 780 return False 781 782 floor_div_index = idx 783 varlist.append(var) 784 # this factor is denominator 785 factorlist.append(factor) 786 else: 787 return False 788 789 if floor_div_index < 0: 790 return False 791 792 # Construct the new expression and remember the denominator 793 denominator = factorlist[floor_div_index] 794 new_index = sympy.Integer(0) 795 796 for var, factor, idx in zip(varlist, factorlist, itertools.count()): 797 if idx == floor_div_index: 798 new_index += var 799 else: 800 new_index += (factor * denominator) * var 801 802 return new_index, denominator 803 804 805def join_dimensions(expr: Expr) -> Expr: 806 if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): 807 return expr # fast exit path 808 return _join_dimensions_cached(expr) 809 810 811@functools.lru_cache(256) 812def _join_dimensions_cached(expr: Expr) -> Expr: 813 """ 814 ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4) 815 becomes 816 ModularIndexing(i0, 1, 128) 817 ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32) 818 becomes i0 819 820 821 This type of pattern can come from view operations 822 """ 823 assert isinstance(expr, sympy.Add) 824 825 scale = sympy.Wild("scale", exclude=[0], integer=True) 826 base = sympy.Wild("base", integer=True) 827 divisor = sympy.Wild("divisor", integer=True) 828 mod1 = sympy.Wild("modulus", integer=True) 829 mod2 = sympy.Wild("modulus2", integer=True) 830 for term1 in expr.args: 831 m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) 832 if m1: 833 for term2 in expr.args: 834 m2 = term2.match( 835 m1[scale] 836 * m1[mod1] 837 * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2) 838 ) 839 if m2 and term1 != term2: 840 expr = join_dimensions( 841 expr 842 - term1 843 - term2 844 + m1[scale] 845 * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2]) 846 ) 847 return expr 848 for term1 in expr.args: 849 m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) 850 if m1: 851 for term2 in expr.args: 852 m2 = term2.match( 853 m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1]) 854 ) 855 if m2 is not None: # in case of success we get an empty dict here 856 expr = join_dimensions( 857 expr 858 - term1 859 - term2 860 + m1[scale] * FloorDiv(m1[base], m1[divisor]) 861 ) 862 return expr 863 return expr 864 865 866class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined] 867 """ 868 A wrapper around .virtualize.ops that uses var range information to 869 simplify ModularIndexing/FloorDiv. 870 """ 871 872 def __init__(self, inner, var_ranges: VarRanges) -> None: 873 super().__init__(inner) 874 self.name = "SimplifyIndexing" 875 self._simplify: Callable[ 876 [Expr], Expr 877 ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) 878 879 def load(self, name: str, index: sympy.Expr): 880 return self._inner.load(name, self._simplify(index)) 881 882 def store(self, name, index, value, mode=None): 883 return self._inner.store(name, self._simplify(index), value, mode=mode) 884 885 def store_reduction(self, name, index, value): 886 return self._inner.store_reduction(name, self._simplify(index), value) 887 888 def index_expr(self, index, dtype): 889 return self._inner.index_expr(self._simplify(index), dtype) 890 891 def check_bounds(self, index, size, lower, upper): 892 return self._inner.check_bounds(self._simplify(index), size, lower, upper) 893