1# Owner(s): ["oncall: pt2"] 2 3import itertools 4import math 5import sys 6 7import sympy 8from typing import Callable, List, Tuple, Type 9from torch.testing._internal.common_device_type import skipIf 10from torch.testing._internal.common_utils import ( 11 TEST_Z3, 12 instantiate_parametrized_tests, 13 parametrize, 14 run_tests, 15 TestCase, 16) 17from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd 18from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve 19from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges 20from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis 21from torch.utils._sympy.interp import sympy_interp 22from torch.utils._sympy.singleton_int import SingletonInt 23from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity 24from sympy.core.relational import is_ge, is_le, is_gt, is_lt 25import functools 26import torch.fx as fx 27 28 29 30UNARY_OPS = [ 31 "reciprocal", 32 "square", 33 "abs", 34 "neg", 35 "exp", 36 "log", 37 "sqrt", 38 "floor", 39 "ceil", 40] 41BINARY_OPS = [ 42 "truediv", "floordiv", 43 # "truncdiv", # TODO 44 # NB: pow is float_pow 45 "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" 46] 47 48UNARY_BOOL_OPS = ["not_"] 49BINARY_BOOL_OPS = ["or_", "and_"] 50COMPARE_OPS = ["eq", "ne", "lt", "gt", "le", "ge"] 51 52# a mix of constants, powers of two, primes 53CONSTANTS = [ 54 -1, 55 0, 56 1, 57 2, 58 3, 59 4, 60 5, 61 8, 62 16, 63 32, 64 64, 65 100, 66 101, 67 2**24, 68 2**32, 69 2**37 - 1, 70 sys.maxsize - 1, 71 sys.maxsize, 72] 73# less constants for N^2 situations 74LESS_CONSTANTS = [-1, 0, 1, 2, 100] 75# SymPy relational types. 76RELATIONAL_TYPES = [sympy.Eq, sympy.Ne, sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le] 77 78 79def valid_unary(fn, v): 80 if fn == "log" and v <= 0: 81 return False 82 elif fn == "reciprocal" and v == 0: 83 return False 84 elif fn == "sqrt" and v < 0: 85 return False 86 return True 87 88 89def valid_binary(fn, a, b): 90 if fn == "pow" and ( 91 # sympy will expand to x*x*... for integral b; don't do it if it's big 92 b > 4 93 # no imaginary numbers 94 or a <= 0 95 # 0**0 is undefined 96 or (a == b == 0) 97 ): 98 return False 99 elif fn == "pow_by_natural" and ( 100 # sympy will expand to x*x*... for integral b; don't do it if it's big 101 b > 4 102 or b < 0 103 or (a == b == 0) 104 ): 105 return False 106 elif fn == "mod" and (a < 0 or b <= 0): 107 return False 108 elif (fn in ["div", "truediv", "floordiv"]) and b == 0: 109 return False 110 return True 111 112 113def generate_range(vals): 114 for a1, a2 in itertools.product(vals, repeat=2): 115 if a1 in [sympy.true, sympy.false]: 116 if a1 == sympy.true and a2 == sympy.false: 117 continue 118 else: 119 if a1 > a2: 120 continue 121 # ranges that only admit infinite values are not interesting 122 if a1 == sympy.oo or a2 == -sympy.oo: 123 continue 124 yield ValueRanges(a1, a2) 125 126 127class TestNumbers(TestCase): 128 def test_int_infinity(self): 129 self.assertIsInstance(int_oo, IntInfinity) 130 self.assertIsInstance(-int_oo, NegativeIntInfinity) 131 self.assertTrue(int_oo.is_integer) 132 # is tests here are for singleton-ness, don't use it for comparisons 133 # against numbers 134 self.assertIs(int_oo + int_oo, int_oo) 135 self.assertIs(int_oo + 1, int_oo) 136 self.assertIs(int_oo - 1, int_oo) 137 self.assertIs(-int_oo - 1, -int_oo) 138 self.assertIs(-int_oo + 1, -int_oo) 139 self.assertIs(-int_oo + (-int_oo), -int_oo) 140 self.assertIs(-int_oo - int_oo, -int_oo) 141 self.assertIs(1 + int_oo, int_oo) 142 self.assertIs(1 - int_oo, -int_oo) 143 self.assertIs(int_oo * int_oo, int_oo) 144 self.assertIs(2 * int_oo, int_oo) 145 self.assertIs(int_oo * 2, int_oo) 146 self.assertIs(-1 * int_oo, -int_oo) 147 self.assertIs(-int_oo * int_oo, -int_oo) 148 self.assertIs(2 * -int_oo, -int_oo) 149 self.assertIs(-int_oo * 2, -int_oo) 150 self.assertIs(-1 * -int_oo, int_oo) 151 self.assertIs(int_oo / 2, sympy.oo) 152 self.assertIs(-(-int_oo), int_oo) # noqa: B002 153 self.assertIs(abs(int_oo), int_oo) 154 self.assertIs(abs(-int_oo), int_oo) 155 self.assertIs(int_oo ** 2, int_oo) 156 self.assertIs((-int_oo) ** 2, int_oo) 157 self.assertIs((-int_oo) ** 3, -int_oo) 158 self.assertEqual(int_oo ** -1, 0) 159 self.assertEqual((-int_oo) ** -1, 0) 160 self.assertIs(int_oo ** int_oo, int_oo) 161 self.assertTrue(int_oo == int_oo) 162 self.assertFalse(int_oo != int_oo) 163 self.assertTrue(-int_oo == -int_oo) 164 self.assertFalse(int_oo == 2) 165 self.assertTrue(int_oo != 2) 166 self.assertFalse(int_oo == sys.maxsize) 167 self.assertTrue(int_oo >= sys.maxsize) 168 self.assertTrue(int_oo >= 2) 169 self.assertTrue(int_oo >= -int_oo) 170 171 def test_relation(self): 172 self.assertIs(sympy.Add(2, int_oo), int_oo) 173 self.assertFalse(-int_oo > 2) 174 175 def test_lt_self(self): 176 self.assertFalse(int_oo < int_oo) 177 self.assertIs(min(-int_oo, -4), -int_oo) 178 self.assertIs(min(-int_oo, -int_oo), -int_oo) 179 180 def test_float_cast(self): 181 self.assertEqual(float(int_oo), math.inf) 182 self.assertEqual(float(-int_oo), -math.inf) 183 184 def test_mixed_oo_int_oo(self): 185 # Arbitrary choice 186 self.assertTrue(int_oo < sympy.oo) 187 self.assertFalse(int_oo > sympy.oo) 188 self.assertTrue(sympy.oo > int_oo) 189 self.assertFalse(sympy.oo < int_oo) 190 self.assertIs(max(int_oo, sympy.oo), sympy.oo) 191 self.assertTrue(-int_oo > -sympy.oo) 192 self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo) 193 194 195class TestValueRanges(TestCase): 196 @parametrize("fn", UNARY_OPS) 197 @parametrize("dtype", ("int", "float")) 198 def test_unary_ref(self, fn, dtype): 199 dtype = {"int": sympy.Integer, "float": sympy.Float}[dtype] 200 for v in CONSTANTS: 201 if not valid_unary(fn, v): 202 continue 203 with self.subTest(v=v): 204 v = dtype(v) 205 ref_r = getattr(ReferenceAnalysis, fn)(v) 206 r = getattr(ValueRangeAnalysis, fn)(v) 207 self.assertEqual(r.lower.is_integer, r.upper.is_integer) 208 self.assertEqual(r.lower, r.upper) 209 self.assertEqual(ref_r.is_integer, r.upper.is_integer) 210 self.assertEqual(ref_r, r.lower) 211 212 def test_pow_half(self): 213 ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) 214 215 @parametrize("fn", BINARY_OPS) 216 @parametrize("dtype", ("int", "float")) 217 def test_binary_ref(self, fn, dtype): 218 to_dtype = {"int": sympy.Integer, "float": sympy.Float} 219 # Don't test float on int only methods 220 if dtype == "float" and fn in ["pow_by_natural", "mod"]: 221 return 222 dtype = to_dtype[dtype] 223 for a, b in itertools.product(CONSTANTS, repeat=2): 224 if not valid_binary(fn, a, b): 225 continue 226 a = dtype(a) 227 b = dtype(b) 228 with self.subTest(a=a, b=b): 229 r = getattr(ValueRangeAnalysis, fn)(a, b) 230 if r == ValueRanges.unknown(): 231 continue 232 ref_r = getattr(ReferenceAnalysis, fn)(a, b) 233 234 self.assertEqual(r.lower.is_integer, r.upper.is_integer) 235 self.assertEqual(ref_r.is_integer, r.upper.is_integer) 236 self.assertEqual(r.lower, r.upper) 237 self.assertEqual(ref_r, r.lower) 238 239 def test_mul_zero_unknown(self): 240 self.assertEqual( 241 ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()), 242 ValueRanges.wrap(0), 243 ) 244 self.assertEqual( 245 ValueRangeAnalysis.mul(ValueRanges.wrap(0.0), ValueRanges.unknown()), 246 ValueRanges.wrap(0.0), 247 ) 248 249 @parametrize("fn", UNARY_BOOL_OPS) 250 def test_unary_bool_ref_range(self, fn): 251 vals = [sympy.false, sympy.true] 252 for a in generate_range(vals): 253 with self.subTest(a=a): 254 ref_r = getattr(ValueRangeAnalysis, fn)(a) 255 unique = set() 256 for a0 in vals: 257 if a0 not in a: 258 continue 259 with self.subTest(a0=a0): 260 r = getattr(ReferenceAnalysis, fn)(a0) 261 self.assertIn(r, ref_r) 262 unique.add(r) 263 if ref_r.lower == ref_r.upper: 264 self.assertEqual(len(unique), 1) 265 else: 266 self.assertEqual(len(unique), 2) 267 268 @parametrize("fn", BINARY_BOOL_OPS) 269 def test_binary_bool_ref_range(self, fn): 270 vals = [sympy.false, sympy.true] 271 for a, b in itertools.product(generate_range(vals), repeat=2): 272 with self.subTest(a=a, b=b): 273 ref_r = getattr(ValueRangeAnalysis, fn)(a, b) 274 unique = set() 275 for a0, b0 in itertools.product(vals, repeat=2): 276 if a0 not in a or b0 not in b: 277 continue 278 with self.subTest(a0=a0, b0=b0): 279 r = getattr(ReferenceAnalysis, fn)(a0, b0) 280 self.assertIn(r, ref_r) 281 unique.add(r) 282 if ref_r.lower == ref_r.upper: 283 self.assertEqual(len(unique), 1) 284 else: 285 self.assertEqual(len(unique), 2) 286 287 @parametrize("fn", UNARY_OPS) 288 def test_unary_ref_range(self, fn): 289 # TODO: bring back sympy.oo testing for float unary fns 290 vals = CONSTANTS 291 for a in generate_range(vals): 292 with self.subTest(a=a): 293 ref_r = getattr(ValueRangeAnalysis, fn)(a) 294 for a0 in CONSTANTS: 295 if a0 not in a: 296 continue 297 if not valid_unary(fn, a0): 298 continue 299 with self.subTest(a0=a0): 300 r = getattr(ReferenceAnalysis, fn)(sympy.Integer(a0)) 301 self.assertIn(r, ref_r) 302 303 # This takes about 4s for all the variants 304 @parametrize("fn", BINARY_OPS + COMPARE_OPS) 305 def test_binary_ref_range(self, fn): 306 # TODO: bring back sympy.oo testing for float unary fns 307 vals = LESS_CONSTANTS 308 for a, b in itertools.product(generate_range(vals), repeat=2): 309 # don't attempt pow on exponents that are too large (but oo is OK) 310 if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: 311 continue 312 with self.subTest(a=a, b=b): 313 for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): 314 if a0 not in a or b0 not in b: 315 continue 316 if not valid_binary(fn, a0, b0): 317 continue 318 with self.subTest(a0=a0, b0=b0): 319 ref_r = getattr(ValueRangeAnalysis, fn)(a, b) 320 r = getattr(ReferenceAnalysis, fn)( 321 sympy.Integer(a0), sympy.Integer(b0) 322 ) 323 if r.is_finite: 324 self.assertIn(r, ref_r) 325 326 327class TestSympyInterp(TestCase): 328 @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) 329 def test_interp(self, fn): 330 # SymPy does not implement truncation for Expressions 331 if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): 332 return 333 334 is_integer = None 335 if fn == "pow_by_natural": 336 is_integer = True 337 338 x = sympy.Dummy('x', integer=is_integer) 339 y = sympy.Dummy('y', integer=is_integer) 340 341 vals = CONSTANTS 342 if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: 343 vals = [True, False] 344 arity = 1 345 if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: 346 arity = 2 347 symbols = [x] 348 if arity == 2: 349 symbols = [x, y] 350 for args in itertools.product(vals, repeat=arity): 351 if arity == 1 and not valid_unary(fn, *args): 352 continue 353 elif arity == 2 and not valid_binary(fn, *args): 354 continue 355 with self.subTest(args=args): 356 sargs = [sympy.sympify(a) for a in args] 357 sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) 358 ref_r = getattr(ReferenceAnalysis, fn)(*sargs) 359 # Yes, I know this is a longwinded way of saying xreplace; the 360 # point is to test sympy_interp 361 r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr) 362 self.assertEqual(ref_r, r) 363 364 @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) 365 def test_python_interp_fx(self, fn): 366 # These never show up from symbolic_shapes 367 if fn in ("log", "exp"): 368 return 369 370 # Sympy does not support truncation on symbolic shapes 371 if fn in ("truncdiv", "mod"): 372 return 373 374 vals = CONSTANTS 375 if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: 376 vals = [True, False] 377 378 arity = 1 379 if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: 380 arity = 2 381 382 is_integer = None 383 if fn == "pow_by_natural": 384 is_integer = True 385 386 x = sympy.Dummy('x', integer=is_integer) 387 y = sympy.Dummy('y', integer=is_integer) 388 389 symbols = [x] 390 if arity == 2: 391 symbols = [x, y] 392 393 for args in itertools.product(vals, repeat=arity): 394 if arity == 1 and not valid_unary(fn, *args): 395 continue 396 elif arity == 2 and not valid_binary(fn, *args): 397 continue 398 if fn == "truncdiv" and args[1] == 0: 399 continue 400 elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0): 401 continue 402 elif fn == "floordiv" and args[1] == 0: 403 continue 404 with self.subTest(args=args): 405 # Workaround mpf from symbol error 406 if fn == "minimum": 407 sympy_expr = sympy.Min(x, y) 408 elif fn == "maximum": 409 sympy_expr = sympy.Max(x, y) 410 else: 411 sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) 412 413 if arity == 1: 414 def trace_f(px): 415 return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) 416 else: 417 def trace_f(px, py): 418 return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) 419 420 gm = fx.symbolic_trace(trace_f) 421 422 self.assertEqual( 423 sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), 424 gm(*args) 425 ) 426 427 428def type_name_fn(type: Type) -> str: 429 return type.__name__ 430 431def parametrize_relational_types(*types): 432 def wrapper(f: Callable): 433 return parametrize("op", types or RELATIONAL_TYPES, name_fn=type_name_fn)(f) 434 return wrapper 435 436 437class TestSympySolve(TestCase): 438 def _create_integer_symbols(self) -> List[sympy.Symbol]: 439 return sympy.symbols("a b c", integer=True) 440 441 def test_give_up(self): 442 from sympy import Eq, Ne 443 444 a, b, c = self._create_integer_symbols() 445 446 cases = [ 447 # Not a relational operation. 448 a + b, 449 # 'a' appears on both sides. 450 Eq(a, a + 1), 451 # 'a' doesn't appear on neither side. 452 Eq(b, c + 1), 453 # Result is a 'sympy.And'. 454 Eq(FloorDiv(a, b), c), 455 # Result is a 'sympy.Or'. 456 Ne(FloorDiv(a, b), c), 457 ] 458 459 for case in cases: 460 e = try_solve(case, a) 461 self.assertEqual(e, None) 462 463 @parametrize_relational_types() 464 def test_noop(self, op): 465 a, b, _ = self._create_integer_symbols() 466 467 lhs, rhs = a, 42 * b 468 expr = op(lhs, rhs) 469 470 r = try_solve(expr, a) 471 self.assertNotEqual(r, None) 472 473 r_expr, r_rhs = r 474 self.assertEqual(r_expr, expr) 475 self.assertEqual(r_rhs, rhs) 476 477 @parametrize_relational_types() 478 def test_noop_rhs(self, op): 479 a, b, _ = self._create_integer_symbols() 480 481 lhs, rhs = 42 * b, a 482 483 mirror = mirror_rel_op(op) 484 self.assertNotEqual(mirror, None) 485 486 expr = op(lhs, rhs) 487 488 r = try_solve(expr, a) 489 self.assertNotEqual(r, None) 490 491 r_expr, r_rhs = r 492 self.assertEqual(r_expr, mirror(rhs, lhs)) 493 self.assertEqual(r_rhs, lhs) 494 495 def _test_cases(self, cases: List[Tuple[sympy.Basic, sympy.Basic]], thing: sympy.Basic, op: Type[sympy.Rel], **kwargs): 496 for source, expected in cases: 497 r = try_solve(source, thing, **kwargs) 498 499 self.assertTrue( 500 (r is None and expected is None) 501 or (r is not None and expected is not None) 502 ) 503 504 if r is not None: 505 r_expr, r_rhs = r 506 self.assertEqual(r_rhs, expected) 507 self.assertEqual(r_expr, op(thing, expected)) 508 509 def test_addition(self): 510 from sympy import Eq 511 512 a, b, c = self._create_integer_symbols() 513 514 cases = [ 515 (Eq(a + b, 0), -b), 516 (Eq(a + 5, b - 5), b - 10), 517 (Eq(a + c * b, 1), 1 - c * b), 518 ] 519 520 self._test_cases(cases, a, Eq) 521 522 @parametrize_relational_types(sympy.Eq, sympy.Ne) 523 def test_multiplication_division(self, op): 524 a, b, c = self._create_integer_symbols() 525 526 cases = [ 527 (op(a * b, 1), 1 / b), 528 (op(a * 5, b - 5), (b - 5) / 5), 529 (op(a * b, c), c / b), 530 ] 531 532 self._test_cases(cases, a, op) 533 534 @parametrize_relational_types(*INEQUALITY_TYPES) 535 def test_multiplication_division_inequality(self, op): 536 a, b, _ = self._create_integer_symbols() 537 intneg = sympy.Symbol("neg", integer=True, negative=True) 538 intpos = sympy.Symbol("pos", integer=True, positive=True) 539 540 cases = [ 541 # Divide/multiply both sides by positive number. 542 (op(a * intpos, 1), 1 / intpos), 543 (op(a / (5 * intpos), 1), 5 * intpos), 544 (op(a * 5, b - 5), (b - 5) / 5), 545 # 'b' is not strictly positive nor negative, so we can't 546 # divide/multiply both sides by 'b'. 547 (op(a * b, 1), None), 548 (op(a / b, 1), None), 549 (op(a * b * intpos, 1), None), 550 ] 551 552 mirror_cases = [ 553 # Divide/multiply both sides by negative number. 554 (op(a * intneg, 1), 1 / intneg), 555 (op(a / (5 * intneg), 1), 5 * intneg), 556 (op(a * -5, b - 5), -(b - 5) / 5), 557 ] 558 mirror_op = mirror_rel_op(op) 559 assert mirror_op is not None 560 561 self._test_cases(cases, a, op) 562 self._test_cases(mirror_cases, a, mirror_op) 563 564 @parametrize_relational_types() 565 def test_floordiv(self, op): 566 from sympy import Eq, Ne, Gt, Ge, Lt, Le 567 568 a, b, c = sympy.symbols("a b c") 569 pos = sympy.Symbol("pos", positive=True) 570 integer = sympy.Symbol("integer", integer=True) 571 572 # (Eq(FloorDiv(a, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))), 573 # (Eq(FloorDiv(a + 5, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))), 574 # (Ne(FloorDiv(a, pos), integer), Or(Lt(a, integer * pos), Ge(a, (integer + 1) * pos))), 575 576 special_case = { 577 # 'FloorDiv' turns into 'And', which can't be simplified any further. 578 Eq: (Eq(FloorDiv(a, pos), integer), None), 579 # 'FloorDiv' turns into 'Or', which can't be simplified any further. 580 Ne: (Ne(FloorDiv(a, pos), integer), None), 581 Gt: (Gt(FloorDiv(a, pos), integer), (integer + 1) * pos), 582 Ge: (Ge(FloorDiv(a, pos), integer), integer * pos), 583 Lt: (Lt(FloorDiv(a, pos), integer), integer * pos), 584 Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos), 585 }[op] 586 587 cases: List[Tuple[sympy.Basic, sympy.Basic]] = [ 588 # 'b' is not strictly positive 589 (op(FloorDiv(a, b), integer), None), 590 # 'c' is not strictly positive 591 (op(FloorDiv(a, pos), c), None), 592 ] 593 594 # The result might change after 'FloorDiv' transformation. 595 # Specifically: 596 # - [Ge, Gt] => Ge 597 # - [Le, Lt] => Lt 598 if op in (sympy.Gt, sympy.Ge): 599 r_op = sympy.Ge 600 elif op in (sympy.Lt, sympy.Le): 601 r_op = sympy.Lt 602 else: 603 r_op = op 604 605 self._test_cases([special_case, *cases], a, r_op) 606 self._test_cases([(special_case[0], None), *cases], a, r_op, floordiv_inequality=False) 607 608 def test_floordiv_eq_simplify(self): 609 from sympy import Eq, Lt, Le 610 611 a = sympy.Symbol("a", positive=True, integer=True) 612 613 def check(expr, expected): 614 r = try_solve(expr, a) 615 self.assertNotEqual(r, None) 616 r_expr, _ = r 617 self.assertEqual(r_expr, expected) 618 619 # (a + 10) // 3 == 3 620 # ===================================== 621 # 3 * 3 <= a + 10 (always true) 622 # a + 10 < 4 * 3 (not sure) 623 check(Eq(FloorDiv(a + 10, 3), 3), Lt(a, (3 + 1) * 3 - 10)) 624 625 # (a + 10) // 2 == 4 626 # ===================================== 627 # 4 * 2 <= 10 - a (not sure) 628 # 10 - a < 5 * 2 (always true) 629 check(Eq(FloorDiv(10 - a, 2), 4), Le(a, -(4 * 2 - 10))) 630 631 @skipIf(not TEST_Z3, "Z3 not installed") 632 def test_z3_proof_floordiv_eq_simplify(self): 633 import z3 634 from sympy import Eq, Lt 635 636 a = sympy.Symbol("a", positive=True, integer=True) 637 a_ = z3.Int("a") 638 639 # (a + 10) // 3 == 3 640 # ===================================== 641 # 3 * 3 <= a + 10 (always true) 642 # a + 10 < 4 * 3 (not sure) 643 solver = z3.SolverFor("QF_NRA") 644 645 # Add assertions for 'a_'. 646 solver.add(a_ > 0) 647 648 expr = Eq(FloorDiv(a + 10, 3), 3) 649 r_expr, _ = try_solve(expr, a) 650 651 # Check 'try_solve' really returns the 'expected' below. 652 expected = Lt(a, (3 + 1) * 3 - 10) 653 self.assertEqual(r_expr, expected) 654 655 # Check whether there is an integer 'a_' such that the 656 # equation below is satisfied. 657 solver.add( 658 # expr 659 (z3.ToInt((a_ + 10) / 3.0) == 3) 660 != 661 # expected 662 (a_ < (3 + 1) * 3 - 10) 663 ) 664 665 # Assert that there's no such an integer. 666 # i.e. the transformation is sound. 667 r = solver.check() 668 self.assertEqual(r, z3.unsat) 669 670 def test_simple_floordiv_gcd(self): 671 x, y, z = sympy.symbols("x y z") 672 673 # positive tests 674 self.assertEqual(simple_floordiv_gcd(x, x), x) 675 self.assertEqual(simple_floordiv_gcd(128 * x, 2304), 128) 676 self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y, 2304), 128) 677 self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y + 8192 * z, 9216), 128) 678 self.assertEqual(simple_floordiv_gcd(49152 * x, 96 * x), 96 * x) 679 self.assertEqual(simple_floordiv_gcd(96 * x, 96 * x), 96 * x) 680 self.assertEqual(simple_floordiv_gcd(x * y, x), x) 681 self.assertEqual(simple_floordiv_gcd(384 * x * y, x * y), x * y) 682 self.assertEqual(simple_floordiv_gcd(256 * x * y, 8 * x), 8 * x) 683 684 # negative tests 685 self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1) 686 687 688class TestSingletonInt(TestCase): 689 def test_basic(self): 690 j1 = SingletonInt(1, coeff=1) 691 j1_copy = SingletonInt(1, coeff=1) 692 j2 = SingletonInt(2, coeff=1) 693 j1x2 = SingletonInt(1, coeff=2) 694 695 def test_eq(a, b, expected): 696 self.assertEqual(sympy.Eq(a, b), expected) 697 self.assertEqual(sympy.Ne(b, a), not expected) 698 699 # eq, ne 700 test_eq(j1, j1, True) 701 test_eq(j1, j1_copy, True) 702 test_eq(j1, j2, False) 703 test_eq(j1, j1x2, False) 704 test_eq(j1, sympy.Integer(1), False) 705 test_eq(j1, sympy.Integer(3), False) 706 707 def test_ineq(a, b, expected, *, strict=True): 708 greater = (sympy.Gt, is_gt) if strict else (sympy.Ge, is_ge) 709 less = (sympy.Lt, is_lt) if strict else (sympy.Le, is_le) 710 711 if isinstance(expected, bool): 712 # expected is always True 713 for fn in greater: 714 self.assertEqual(fn(a, b), expected) 715 self.assertEqual(fn(b, a), not expected) 716 for fn in less: 717 self.assertEqual(fn(b, a), expected) 718 self.assertEqual(fn(a, b), not expected) 719 else: 720 for fn in greater: 721 with self.assertRaisesRegex(ValueError, expected): 722 fn(a, b) 723 for fn in less: 724 with self.assertRaisesRegex(ValueError, expected): 725 fn(b, a) 726 727 # ge, le, gt, lt 728 for strict in (True, False): 729 _test_ineq = functools.partial(test_ineq, strict=strict) 730 _test_ineq(j1, sympy.Integer(0), True) 731 _test_ineq(j1, sympy.Integer(3), "indeterminate") 732 _test_ineq(j1, j2, "indeterminate") 733 _test_ineq(j1x2, j1, True) 734 735 # Special cases for ge, le, gt, lt: 736 for ge in (sympy.Ge, is_ge): 737 self.assertTrue(ge(j1, j1)) 738 self.assertTrue(ge(j1, sympy.Integer(2))) 739 with self.assertRaisesRegex(ValueError, "indeterminate"): 740 ge(sympy.Integer(2), j1) 741 for le in (sympy.Le, is_le): 742 self.assertTrue(le(j1, j1)) 743 self.assertTrue(le(sympy.Integer(2), j1)) 744 with self.assertRaisesRegex(ValueError, "indeterminate"): 745 le(j1, sympy.Integer(2)) 746 747 for gt in (sympy.Gt, is_gt): 748 self.assertFalse(gt(j1, j1)) 749 self.assertFalse(gt(sympy.Integer(2), j1)) 750 # it is only known to be that j1 >= 2, j1 > 2 is indeterminate 751 with self.assertRaisesRegex(ValueError, "indeterminate"): 752 gt(j1, sympy.Integer(2)) 753 754 for lt in (sympy.Lt, is_lt): 755 self.assertFalse(lt(j1, j1)) 756 self.assertFalse(lt(j1, sympy.Integer(2))) 757 with self.assertRaisesRegex(ValueError, "indeterminate"): 758 lt(sympy.Integer(2), j1) 759 760 # mul 761 self.assertEqual(j1 * 2, j1x2) 762 # Unfortunately, this doesn't not automatically simplify to 2*j1 763 # since sympy.Mul doesn't trigger __mul__ unlike the above. 764 self.assertIsInstance(sympy.Mul(j1, 2), sympy.core.mul.Mul) 765 766 with self.assertRaisesRegex(ValueError, "cannot be multiplied"): 767 j1 * j2 768 769 self.assertEqual(j1.free_symbols, set()) 770 771 772instantiate_parametrized_tests(TestValueRanges) 773instantiate_parametrized_tests(TestSympyInterp) 774instantiate_parametrized_tests(TestSympySolve) 775 776 777if __name__ == "__main__": 778 run_tests() 779