1import logging 2from typing import Dict, Optional, Tuple, Type 3 4import sympy 5 6from torch.utils._sympy.functions import FloorDiv 7 8 9log = logging.getLogger(__name__) 10 11_MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = { 12 sympy.Eq: sympy.Eq, 13 sympy.Ne: sympy.Ne, 14 sympy.Ge: sympy.Le, 15 sympy.Gt: sympy.Lt, 16 sympy.Le: sympy.Ge, 17 sympy.Lt: sympy.Gt, 18} 19 20INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le) 21 22 23def mirror_rel_op(type: Type) -> Optional[Type[sympy.Rel]]: 24 return _MIRROR_REL_OP.get(type, None) 25 26 27# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side. 28# 29# Returns a tuple of: 30# 1. The simplified expression 31# 2. The expression on the right-hand side 32# 33# Returns 'None' if it can't reach a state where the only thing in the left 34# hand side is 'thing'. 35# 36# 'trials': number of times 'try_solve' will try to isolate 'thing' to the 37# left-hand side. 38# 39# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into 40# inequalities. 41def try_solve( 42 expr: sympy.Basic, 43 thing: sympy.Basic, 44 trials: int = 5, 45 floordiv_inequality: bool = True, 46) -> Optional[Tuple[sympy.Rel, sympy.Basic]]: 47 mirror = mirror_rel_op(type(expr)) 48 49 # Ignore unsupported expressions: 50 # - Those that are not relational operations 51 # - Those that don't have a mirror (just avoiding unexpected classes) 52 if not isinstance(expr, sympy.Rel) or mirror is None: 53 log.debug("expression with unsupported type: %s", type(expr)) 54 return None 55 56 lhs_has_thing = expr.lhs.has(thing) 57 rhs_has_thing = expr.rhs.has(thing) 58 59 # Give up when 'thing' appears on both sides of the relational expression. 60 # That is because, as is, we assume the thing we are trying to isolate is 61 # only on the right-hand side. 62 if lhs_has_thing and rhs_has_thing: 63 log.debug("thing (%s) found in both sides of expression: %s", thing, expr) 64 return None 65 66 # Try considering both LHS and RHS by mirroring the original expression: 67 # a < b ==> b > a 68 expressions = [] 69 70 # Add each version of 'expr' if 'thing' is in its left-hand side. 71 if lhs_has_thing: 72 expressions.append(expr) 73 if rhs_has_thing: 74 expressions.append(mirror(expr.rhs, expr.lhs)) 75 76 for e in expressions: 77 if e is None: 78 continue 79 80 assert isinstance(e, sympy.Rel) 81 82 for _ in range(trials): 83 trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality) 84 # Stop if there was no change in this trial. 85 if trial == e: 86 break 87 e = trial # type: ignore[assignment] 88 89 # Return if we were able to isolate 'thing' on the left-hand side. 90 if isinstance(e, sympy.Rel) and e.lhs == thing: 91 log.debug("solved: %s ---> %s", expr, e) 92 return e, e.rhs 93 94 return None 95 96 97def _try_isolate_lhs( 98 e: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool 99) -> sympy.Basic: 100 op = type(e) 101 102 if isinstance(e, sympy.Rel): 103 # Move any constants in the left-hand side to the right-hand side. 104 lhs_not_thing = ( 105 sum(a for a in e.lhs.args if not a.has(thing)) 106 if isinstance(e.lhs, sympy.Add) 107 else 0 108 ) 109 e = op(e.lhs - lhs_not_thing, e.rhs - lhs_not_thing) # type: ignore[attr-defined] 110 111 # Divide both sides by the factors that don't contain thing. 112 if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul): 113 lhs, rhs = e.args 114 other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)]) 115 116 # If we can't tell whether 'other' is negative or positive, we do nothing. 117 # That is because we don't know whether we have mirror the operation or not. 118 if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None): 119 # Divide both sides by 'other'. 120 lhs = lhs / other 121 rhs = rhs / other 122 123 # If 'e' is an inequality and 'other' is negative, we have to 124 # mirror the expression. 125 if isinstance(e, INEQUALITY_TYPES) and other.is_negative: 126 op = mirror_rel_op(op) # type: ignore[assignment] 127 128 assert op is not None 129 e = op(lhs, rhs) 130 131 ################################################################################ 132 # left-hand side is FloorDiv 133 ################################################################################ 134 # 135 # Given the expression: a // b op c 136 # where 'op' is a relational operation, these rules only work if: 137 # - b > 0 138 # - c is an integer 139 if ( 140 floordiv_inequality 141 and isinstance(e, sympy.Rel) 142 and isinstance(e.lhs, FloorDiv) 143 and e.lhs.divisor.is_positive 144 and e.rhs.is_integer 145 ): 146 # a // b == expr 147 # => a >= (b * expr) and a < (b * (expr + 1)) 148 if isinstance(e, sympy.Eq): 149 numerator, denominator = e.lhs.args 150 return sympy.And( 151 sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type] 152 sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type] 153 ) 154 # a // b != expr 155 # => a < (b * expr) or a >= (b * (expr + 1)) 156 if isinstance(e, sympy.Ne): 157 numerator, denominator = e.lhs.args 158 return sympy.Or( 159 sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type] 160 sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type] 161 ) 162 # The transformations below only work if b is positive. 163 # Note: we only have this information for constants. 164 # a // b > expr => a >= b * (expr + 1) 165 # a // b >= expr => a >= b * expr 166 if isinstance(e, (sympy.Gt, sympy.Ge)): 167 quotient = e.rhs if isinstance(e, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type] 168 return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type] 169 # a // b < expr => a < b * expr 170 # a // b <= expr => a < b * (expr + 1) 171 if isinstance(e, (sympy.Lt, sympy.Le)): 172 quotient = e.rhs if isinstance(e, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type] 173 return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type] 174 175 return e 176