xref: /aosp_15_r20/external/pytorch/test/test_sympy_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: pt2"]
2
3import itertools
4import math
5import sys
6
7import sympy
8from typing import Callable, List, Tuple, Type
9from torch.testing._internal.common_device_type import skipIf
10from torch.testing._internal.common_utils import (
11    TEST_Z3,
12    instantiate_parametrized_tests,
13    parametrize,
14    run_tests,
15    TestCase,
16)
17from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd
18from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
19from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
20from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
21from torch.utils._sympy.interp import sympy_interp
22from torch.utils._sympy.singleton_int import SingletonInt
23from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
24from sympy.core.relational import is_ge, is_le, is_gt, is_lt
25import functools
26import torch.fx as fx
27
28
29
30UNARY_OPS = [
31    "reciprocal",
32    "square",
33    "abs",
34    "neg",
35    "exp",
36    "log",
37    "sqrt",
38    "floor",
39    "ceil",
40]
41BINARY_OPS = [
42    "truediv", "floordiv",
43    # "truncdiv",  # TODO
44    # NB: pow is float_pow
45    "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod"
46]
47
48UNARY_BOOL_OPS = ["not_"]
49BINARY_BOOL_OPS = ["or_", "and_"]
50COMPARE_OPS = ["eq", "ne", "lt", "gt", "le", "ge"]
51
52# a mix of constants, powers of two, primes
53CONSTANTS = [
54    -1,
55    0,
56    1,
57    2,
58    3,
59    4,
60    5,
61    8,
62    16,
63    32,
64    64,
65    100,
66    101,
67    2**24,
68    2**32,
69    2**37 - 1,
70    sys.maxsize - 1,
71    sys.maxsize,
72]
73# less constants for N^2 situations
74LESS_CONSTANTS = [-1, 0, 1, 2, 100]
75# SymPy relational types.
76RELATIONAL_TYPES = [sympy.Eq, sympy.Ne, sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le]
77
78
79def valid_unary(fn, v):
80    if fn == "log" and v <= 0:
81        return False
82    elif fn == "reciprocal" and v == 0:
83        return False
84    elif fn == "sqrt" and v < 0:
85        return False
86    return True
87
88
89def valid_binary(fn, a, b):
90    if fn == "pow" and (
91        # sympy will expand to x*x*... for integral b; don't do it if it's big
92        b > 4
93        # no imaginary numbers
94        or a <= 0
95        # 0**0 is undefined
96        or (a == b == 0)
97    ):
98        return False
99    elif fn == "pow_by_natural" and (
100        # sympy will expand to x*x*... for integral b; don't do it if it's big
101        b > 4
102        or b < 0
103        or (a == b == 0)
104    ):
105        return False
106    elif fn == "mod" and (a < 0 or b <= 0):
107        return False
108    elif (fn in ["div", "truediv", "floordiv"]) and b == 0:
109        return False
110    return True
111
112
113def generate_range(vals):
114    for a1, a2 in itertools.product(vals, repeat=2):
115        if a1 in [sympy.true, sympy.false]:
116            if a1 == sympy.true and a2 == sympy.false:
117                continue
118        else:
119            if a1 > a2:
120                continue
121        # ranges that only admit infinite values are not interesting
122        if a1 == sympy.oo or a2 == -sympy.oo:
123            continue
124        yield ValueRanges(a1, a2)
125
126
127class TestNumbers(TestCase):
128    def test_int_infinity(self):
129        self.assertIsInstance(int_oo, IntInfinity)
130        self.assertIsInstance(-int_oo, NegativeIntInfinity)
131        self.assertTrue(int_oo.is_integer)
132        # is tests here are for singleton-ness, don't use it for comparisons
133        # against numbers
134        self.assertIs(int_oo + int_oo, int_oo)
135        self.assertIs(int_oo + 1, int_oo)
136        self.assertIs(int_oo - 1, int_oo)
137        self.assertIs(-int_oo - 1, -int_oo)
138        self.assertIs(-int_oo + 1, -int_oo)
139        self.assertIs(-int_oo + (-int_oo), -int_oo)
140        self.assertIs(-int_oo - int_oo, -int_oo)
141        self.assertIs(1 + int_oo, int_oo)
142        self.assertIs(1 - int_oo, -int_oo)
143        self.assertIs(int_oo * int_oo, int_oo)
144        self.assertIs(2 * int_oo, int_oo)
145        self.assertIs(int_oo * 2, int_oo)
146        self.assertIs(-1 * int_oo, -int_oo)
147        self.assertIs(-int_oo * int_oo, -int_oo)
148        self.assertIs(2 * -int_oo, -int_oo)
149        self.assertIs(-int_oo * 2, -int_oo)
150        self.assertIs(-1 * -int_oo, int_oo)
151        self.assertIs(int_oo / 2, sympy.oo)
152        self.assertIs(-(-int_oo), int_oo)  # noqa: B002
153        self.assertIs(abs(int_oo), int_oo)
154        self.assertIs(abs(-int_oo), int_oo)
155        self.assertIs(int_oo ** 2, int_oo)
156        self.assertIs((-int_oo) ** 2, int_oo)
157        self.assertIs((-int_oo) ** 3, -int_oo)
158        self.assertEqual(int_oo ** -1, 0)
159        self.assertEqual((-int_oo) ** -1, 0)
160        self.assertIs(int_oo ** int_oo, int_oo)
161        self.assertTrue(int_oo == int_oo)
162        self.assertFalse(int_oo != int_oo)
163        self.assertTrue(-int_oo == -int_oo)
164        self.assertFalse(int_oo == 2)
165        self.assertTrue(int_oo != 2)
166        self.assertFalse(int_oo == sys.maxsize)
167        self.assertTrue(int_oo >= sys.maxsize)
168        self.assertTrue(int_oo >= 2)
169        self.assertTrue(int_oo >= -int_oo)
170
171    def test_relation(self):
172        self.assertIs(sympy.Add(2, int_oo), int_oo)
173        self.assertFalse(-int_oo > 2)
174
175    def test_lt_self(self):
176        self.assertFalse(int_oo < int_oo)
177        self.assertIs(min(-int_oo, -4), -int_oo)
178        self.assertIs(min(-int_oo, -int_oo), -int_oo)
179
180    def test_float_cast(self):
181        self.assertEqual(float(int_oo), math.inf)
182        self.assertEqual(float(-int_oo), -math.inf)
183
184    def test_mixed_oo_int_oo(self):
185        # Arbitrary choice
186        self.assertTrue(int_oo < sympy.oo)
187        self.assertFalse(int_oo > sympy.oo)
188        self.assertTrue(sympy.oo > int_oo)
189        self.assertFalse(sympy.oo < int_oo)
190        self.assertIs(max(int_oo, sympy.oo), sympy.oo)
191        self.assertTrue(-int_oo > -sympy.oo)
192        self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo)
193
194
195class TestValueRanges(TestCase):
196    @parametrize("fn", UNARY_OPS)
197    @parametrize("dtype", ("int", "float"))
198    def test_unary_ref(self, fn, dtype):
199        dtype = {"int": sympy.Integer, "float": sympy.Float}[dtype]
200        for v in CONSTANTS:
201            if not valid_unary(fn, v):
202                continue
203            with self.subTest(v=v):
204                v = dtype(v)
205                ref_r = getattr(ReferenceAnalysis, fn)(v)
206                r = getattr(ValueRangeAnalysis, fn)(v)
207                self.assertEqual(r.lower.is_integer, r.upper.is_integer)
208                self.assertEqual(r.lower, r.upper)
209                self.assertEqual(ref_r.is_integer, r.upper.is_integer)
210                self.assertEqual(ref_r, r.lower)
211
212    def test_pow_half(self):
213        ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5))
214
215    @parametrize("fn", BINARY_OPS)
216    @parametrize("dtype", ("int", "float"))
217    def test_binary_ref(self, fn, dtype):
218        to_dtype = {"int": sympy.Integer, "float": sympy.Float}
219        # Don't test float on int only methods
220        if dtype == "float" and fn in ["pow_by_natural", "mod"]:
221            return
222        dtype = to_dtype[dtype]
223        for a, b in itertools.product(CONSTANTS, repeat=2):
224            if not valid_binary(fn, a, b):
225                continue
226            a = dtype(a)
227            b = dtype(b)
228            with self.subTest(a=a, b=b):
229                r = getattr(ValueRangeAnalysis, fn)(a, b)
230                if r == ValueRanges.unknown():
231                    continue
232                ref_r = getattr(ReferenceAnalysis, fn)(a, b)
233
234                self.assertEqual(r.lower.is_integer, r.upper.is_integer)
235                self.assertEqual(ref_r.is_integer, r.upper.is_integer)
236                self.assertEqual(r.lower, r.upper)
237                self.assertEqual(ref_r, r.lower)
238
239    def test_mul_zero_unknown(self):
240        self.assertEqual(
241            ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()),
242            ValueRanges.wrap(0),
243        )
244        self.assertEqual(
245            ValueRangeAnalysis.mul(ValueRanges.wrap(0.0), ValueRanges.unknown()),
246            ValueRanges.wrap(0.0),
247        )
248
249    @parametrize("fn", UNARY_BOOL_OPS)
250    def test_unary_bool_ref_range(self, fn):
251        vals = [sympy.false, sympy.true]
252        for a in generate_range(vals):
253            with self.subTest(a=a):
254                ref_r = getattr(ValueRangeAnalysis, fn)(a)
255                unique = set()
256                for a0 in vals:
257                    if a0 not in a:
258                        continue
259                    with self.subTest(a0=a0):
260                        r = getattr(ReferenceAnalysis, fn)(a0)
261                        self.assertIn(r, ref_r)
262                        unique.add(r)
263                if ref_r.lower == ref_r.upper:
264                    self.assertEqual(len(unique), 1)
265                else:
266                    self.assertEqual(len(unique), 2)
267
268    @parametrize("fn", BINARY_BOOL_OPS)
269    def test_binary_bool_ref_range(self, fn):
270        vals = [sympy.false, sympy.true]
271        for a, b in itertools.product(generate_range(vals), repeat=2):
272            with self.subTest(a=a, b=b):
273                ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
274                unique = set()
275                for a0, b0 in itertools.product(vals, repeat=2):
276                    if a0 not in a or b0 not in b:
277                        continue
278                    with self.subTest(a0=a0, b0=b0):
279                        r = getattr(ReferenceAnalysis, fn)(a0, b0)
280                        self.assertIn(r, ref_r)
281                        unique.add(r)
282                if ref_r.lower == ref_r.upper:
283                    self.assertEqual(len(unique), 1)
284                else:
285                    self.assertEqual(len(unique), 2)
286
287    @parametrize("fn", UNARY_OPS)
288    def test_unary_ref_range(self, fn):
289        # TODO: bring back sympy.oo testing for float unary fns
290        vals = CONSTANTS
291        for a in generate_range(vals):
292            with self.subTest(a=a):
293                ref_r = getattr(ValueRangeAnalysis, fn)(a)
294                for a0 in CONSTANTS:
295                    if a0 not in a:
296                        continue
297                    if not valid_unary(fn, a0):
298                        continue
299                    with self.subTest(a0=a0):
300                        r = getattr(ReferenceAnalysis, fn)(sympy.Integer(a0))
301                        self.assertIn(r, ref_r)
302
303    # This takes about 4s for all the variants
304    @parametrize("fn", BINARY_OPS + COMPARE_OPS)
305    def test_binary_ref_range(self, fn):
306        # TODO: bring back sympy.oo testing for float unary fns
307        vals = LESS_CONSTANTS
308        for a, b in itertools.product(generate_range(vals), repeat=2):
309            # don't attempt pow on exponents that are too large (but oo is OK)
310            if fn == "pow" and b.upper > 4 and b.upper != sympy.oo:
311                continue
312            with self.subTest(a=a, b=b):
313                for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2):
314                    if a0 not in a or b0 not in b:
315                        continue
316                    if not valid_binary(fn, a0, b0):
317                        continue
318                    with self.subTest(a0=a0, b0=b0):
319                        ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
320                        r = getattr(ReferenceAnalysis, fn)(
321                            sympy.Integer(a0), sympy.Integer(b0)
322                        )
323                        if r.is_finite:
324                            self.assertIn(r, ref_r)
325
326
327class TestSympyInterp(TestCase):
328    @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
329    def test_interp(self, fn):
330        # SymPy does not implement truncation for Expressions
331        if fn in ("div", "truncdiv", "minimum", "maximum", "mod"):
332            return
333
334        is_integer = None
335        if fn == "pow_by_natural":
336            is_integer = True
337
338        x = sympy.Dummy('x', integer=is_integer)
339        y = sympy.Dummy('y', integer=is_integer)
340
341        vals = CONSTANTS
342        if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
343            vals = [True, False]
344        arity = 1
345        if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
346            arity = 2
347        symbols = [x]
348        if arity == 2:
349            symbols = [x, y]
350        for args in itertools.product(vals, repeat=arity):
351            if arity == 1 and not valid_unary(fn, *args):
352                continue
353            elif arity == 2 and not valid_binary(fn, *args):
354                continue
355            with self.subTest(args=args):
356                sargs = [sympy.sympify(a) for a in args]
357                sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
358                ref_r = getattr(ReferenceAnalysis, fn)(*sargs)
359                # Yes, I know this is a longwinded way of saying xreplace; the
360                # point is to test sympy_interp
361                r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr)
362                self.assertEqual(ref_r, r)
363
364    @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
365    def test_python_interp_fx(self, fn):
366        # These never show up from symbolic_shapes
367        if fn in ("log", "exp"):
368            return
369
370        # Sympy does not support truncation on symbolic shapes
371        if fn in ("truncdiv", "mod"):
372            return
373
374        vals = CONSTANTS
375        if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
376            vals = [True, False]
377
378        arity = 1
379        if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
380            arity = 2
381
382        is_integer = None
383        if fn == "pow_by_natural":
384            is_integer = True
385
386        x = sympy.Dummy('x', integer=is_integer)
387        y = sympy.Dummy('y', integer=is_integer)
388
389        symbols = [x]
390        if arity == 2:
391            symbols = [x, y]
392
393        for args in itertools.product(vals, repeat=arity):
394            if arity == 1 and not valid_unary(fn, *args):
395                continue
396            elif arity == 2 and not valid_binary(fn, *args):
397                continue
398            if fn == "truncdiv" and args[1] == 0:
399                continue
400            elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0):
401                continue
402            elif fn == "floordiv" and args[1] == 0:
403                continue
404            with self.subTest(args=args):
405                # Workaround mpf from symbol error
406                if fn == "minimum":
407                    sympy_expr = sympy.Min(x, y)
408                elif fn == "maximum":
409                    sympy_expr = sympy.Max(x, y)
410                else:
411                    sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
412
413                if arity == 1:
414                    def trace_f(px):
415                        return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr)
416                else:
417                    def trace_f(px, py):
418                        return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr)
419
420                gm = fx.symbolic_trace(trace_f)
421
422                self.assertEqual(
423                    sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr),
424                    gm(*args)
425                )
426
427
428def type_name_fn(type: Type) -> str:
429    return type.__name__
430
431def parametrize_relational_types(*types):
432    def wrapper(f: Callable):
433        return parametrize("op", types or RELATIONAL_TYPES, name_fn=type_name_fn)(f)
434    return wrapper
435
436
437class TestSympySolve(TestCase):
438    def _create_integer_symbols(self) -> List[sympy.Symbol]:
439        return sympy.symbols("a b c", integer=True)
440
441    def test_give_up(self):
442        from sympy import Eq, Ne
443
444        a, b, c = self._create_integer_symbols()
445
446        cases = [
447            # Not a relational operation.
448            a + b,
449            # 'a' appears on both sides.
450            Eq(a, a + 1),
451            # 'a' doesn't appear on neither side.
452            Eq(b, c + 1),
453            # Result is a 'sympy.And'.
454            Eq(FloorDiv(a, b), c),
455            # Result is a 'sympy.Or'.
456            Ne(FloorDiv(a, b), c),
457        ]
458
459        for case in cases:
460            e = try_solve(case, a)
461            self.assertEqual(e, None)
462
463    @parametrize_relational_types()
464    def test_noop(self, op):
465        a, b, _ = self._create_integer_symbols()
466
467        lhs, rhs = a, 42 * b
468        expr = op(lhs, rhs)
469
470        r = try_solve(expr, a)
471        self.assertNotEqual(r, None)
472
473        r_expr, r_rhs = r
474        self.assertEqual(r_expr, expr)
475        self.assertEqual(r_rhs, rhs)
476
477    @parametrize_relational_types()
478    def test_noop_rhs(self, op):
479        a, b, _ = self._create_integer_symbols()
480
481        lhs, rhs = 42 * b, a
482
483        mirror = mirror_rel_op(op)
484        self.assertNotEqual(mirror, None)
485
486        expr = op(lhs, rhs)
487
488        r = try_solve(expr, a)
489        self.assertNotEqual(r, None)
490
491        r_expr, r_rhs = r
492        self.assertEqual(r_expr, mirror(rhs, lhs))
493        self.assertEqual(r_rhs, lhs)
494
495    def _test_cases(self, cases: List[Tuple[sympy.Basic, sympy.Basic]], thing: sympy.Basic, op: Type[sympy.Rel], **kwargs):
496        for source, expected in cases:
497            r = try_solve(source, thing, **kwargs)
498
499            self.assertTrue(
500                (r is None and expected is None)
501                or (r is not None and expected is not None)
502            )
503
504            if r is not None:
505                r_expr, r_rhs = r
506                self.assertEqual(r_rhs, expected)
507                self.assertEqual(r_expr, op(thing, expected))
508
509    def test_addition(self):
510        from sympy import Eq
511
512        a, b, c = self._create_integer_symbols()
513
514        cases = [
515            (Eq(a + b, 0), -b),
516            (Eq(a + 5, b - 5), b - 10),
517            (Eq(a + c * b, 1), 1 - c * b),
518        ]
519
520        self._test_cases(cases, a, Eq)
521
522    @parametrize_relational_types(sympy.Eq, sympy.Ne)
523    def test_multiplication_division(self, op):
524        a, b, c = self._create_integer_symbols()
525
526        cases = [
527            (op(a * b, 1), 1 / b),
528            (op(a * 5, b - 5), (b - 5) / 5),
529            (op(a * b, c), c / b),
530        ]
531
532        self._test_cases(cases, a, op)
533
534    @parametrize_relational_types(*INEQUALITY_TYPES)
535    def test_multiplication_division_inequality(self, op):
536        a, b, _ = self._create_integer_symbols()
537        intneg = sympy.Symbol("neg", integer=True, negative=True)
538        intpos = sympy.Symbol("pos", integer=True, positive=True)
539
540        cases = [
541            # Divide/multiply both sides by positive number.
542            (op(a * intpos, 1), 1 / intpos),
543            (op(a / (5 * intpos), 1), 5 * intpos),
544            (op(a * 5, b - 5), (b - 5) / 5),
545            # 'b' is not strictly positive nor negative, so we can't
546            # divide/multiply both sides by 'b'.
547            (op(a * b, 1), None),
548            (op(a / b, 1), None),
549            (op(a * b * intpos, 1), None),
550        ]
551
552        mirror_cases = [
553            # Divide/multiply both sides by negative number.
554            (op(a * intneg, 1), 1 / intneg),
555            (op(a / (5 * intneg), 1), 5 * intneg),
556            (op(a * -5, b - 5), -(b - 5) / 5),
557        ]
558        mirror_op = mirror_rel_op(op)
559        assert mirror_op is not None
560
561        self._test_cases(cases, a, op)
562        self._test_cases(mirror_cases, a, mirror_op)
563
564    @parametrize_relational_types()
565    def test_floordiv(self, op):
566        from sympy import Eq, Ne, Gt, Ge, Lt, Le
567
568        a, b, c = sympy.symbols("a b c")
569        pos = sympy.Symbol("pos", positive=True)
570        integer = sympy.Symbol("integer", integer=True)
571
572        # (Eq(FloorDiv(a, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))),
573        # (Eq(FloorDiv(a + 5, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))),
574        # (Ne(FloorDiv(a, pos), integer), Or(Lt(a, integer * pos), Ge(a, (integer + 1) * pos))),
575
576        special_case = {
577            # 'FloorDiv' turns into 'And', which can't be simplified any further.
578            Eq: (Eq(FloorDiv(a, pos), integer), None),
579            # 'FloorDiv' turns into 'Or', which can't be simplified any further.
580            Ne: (Ne(FloorDiv(a, pos), integer), None),
581            Gt: (Gt(FloorDiv(a, pos), integer), (integer + 1) * pos),
582            Ge: (Ge(FloorDiv(a, pos), integer), integer * pos),
583            Lt: (Lt(FloorDiv(a, pos), integer), integer * pos),
584            Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos),
585        }[op]
586
587        cases: List[Tuple[sympy.Basic, sympy.Basic]] = [
588            # 'b' is not strictly positive
589            (op(FloorDiv(a, b), integer), None),
590            # 'c' is not strictly positive
591            (op(FloorDiv(a, pos), c), None),
592        ]
593
594        # The result might change after 'FloorDiv' transformation.
595        # Specifically:
596        #   - [Ge, Gt] => Ge
597        #   - [Le, Lt] => Lt
598        if op in (sympy.Gt, sympy.Ge):
599            r_op = sympy.Ge
600        elif op in (sympy.Lt, sympy.Le):
601            r_op = sympy.Lt
602        else:
603            r_op = op
604
605        self._test_cases([special_case, *cases], a, r_op)
606        self._test_cases([(special_case[0], None), *cases], a, r_op, floordiv_inequality=False)
607
608    def test_floordiv_eq_simplify(self):
609        from sympy import Eq, Lt, Le
610
611        a = sympy.Symbol("a", positive=True, integer=True)
612
613        def check(expr, expected):
614            r = try_solve(expr, a)
615            self.assertNotEqual(r, None)
616            r_expr, _ = r
617            self.assertEqual(r_expr, expected)
618
619        # (a + 10) // 3 == 3
620        # =====================================
621        # 3 * 3 <= a + 10         (always true)
622        #          a + 10 < 4 * 3 (not sure)
623        check(Eq(FloorDiv(a + 10, 3), 3), Lt(a, (3 + 1) * 3 - 10))
624
625        # (a + 10) // 2 == 4
626        # =====================================
627        # 4 * 2 <= 10 - a         (not sure)
628        #          10 - a < 5 * 2 (always true)
629        check(Eq(FloorDiv(10 - a, 2), 4), Le(a, -(4 * 2 - 10)))
630
631    @skipIf(not TEST_Z3, "Z3 not installed")
632    def test_z3_proof_floordiv_eq_simplify(self):
633        import z3
634        from sympy import Eq, Lt
635
636        a = sympy.Symbol("a", positive=True, integer=True)
637        a_ = z3.Int("a")
638
639        # (a + 10) // 3 == 3
640        # =====================================
641        # 3 * 3 <= a + 10         (always true)
642        #          a + 10 < 4 * 3 (not sure)
643        solver = z3.SolverFor("QF_NRA")
644
645        # Add assertions for 'a_'.
646        solver.add(a_ > 0)
647
648        expr = Eq(FloorDiv(a + 10, 3), 3)
649        r_expr, _ = try_solve(expr, a)
650
651        # Check 'try_solve' really returns the 'expected' below.
652        expected = Lt(a, (3 + 1) * 3 - 10)
653        self.assertEqual(r_expr, expected)
654
655        # Check whether there is an integer 'a_' such that the
656        # equation below is satisfied.
657        solver.add(
658            # expr
659            (z3.ToInt((a_ + 10) / 3.0) == 3)
660            !=
661            # expected
662            (a_ < (3 + 1) * 3 - 10)
663        )
664
665        # Assert that there's no such an integer.
666        # i.e. the transformation is sound.
667        r = solver.check()
668        self.assertEqual(r, z3.unsat)
669
670    def test_simple_floordiv_gcd(self):
671        x, y, z = sympy.symbols("x y z")
672
673        # positive tests
674        self.assertEqual(simple_floordiv_gcd(x, x), x)
675        self.assertEqual(simple_floordiv_gcd(128 * x, 2304), 128)
676        self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y, 2304), 128)
677        self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y + 8192 * z, 9216), 128)
678        self.assertEqual(simple_floordiv_gcd(49152 * x, 96 * x), 96 * x)
679        self.assertEqual(simple_floordiv_gcd(96 * x, 96 * x), 96 * x)
680        self.assertEqual(simple_floordiv_gcd(x * y, x), x)
681        self.assertEqual(simple_floordiv_gcd(384 * x * y, x * y), x * y)
682        self.assertEqual(simple_floordiv_gcd(256 * x * y, 8 * x), 8 * x)
683
684        # negative tests
685        self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1)
686
687
688class TestSingletonInt(TestCase):
689    def test_basic(self):
690        j1 = SingletonInt(1, coeff=1)
691        j1_copy = SingletonInt(1, coeff=1)
692        j2 = SingletonInt(2, coeff=1)
693        j1x2 = SingletonInt(1, coeff=2)
694
695        def test_eq(a, b, expected):
696            self.assertEqual(sympy.Eq(a, b), expected)
697            self.assertEqual(sympy.Ne(b, a), not expected)
698
699        # eq, ne
700        test_eq(j1, j1, True)
701        test_eq(j1, j1_copy, True)
702        test_eq(j1, j2, False)
703        test_eq(j1, j1x2, False)
704        test_eq(j1, sympy.Integer(1), False)
705        test_eq(j1, sympy.Integer(3), False)
706
707        def test_ineq(a, b, expected, *, strict=True):
708            greater = (sympy.Gt, is_gt) if strict else (sympy.Ge, is_ge)
709            less = (sympy.Lt, is_lt) if strict else (sympy.Le, is_le)
710
711            if isinstance(expected, bool):
712                # expected is always True
713                for fn in greater:
714                    self.assertEqual(fn(a, b), expected)
715                    self.assertEqual(fn(b, a), not expected)
716                for fn in less:
717                    self.assertEqual(fn(b, a), expected)
718                    self.assertEqual(fn(a, b), not expected)
719            else:
720                for fn in greater:
721                    with self.assertRaisesRegex(ValueError, expected):
722                        fn(a, b)
723                for fn in less:
724                    with self.assertRaisesRegex(ValueError, expected):
725                        fn(b, a)
726
727        # ge, le, gt, lt
728        for strict in (True, False):
729            _test_ineq = functools.partial(test_ineq, strict=strict)
730            _test_ineq(j1, sympy.Integer(0), True)
731            _test_ineq(j1, sympy.Integer(3), "indeterminate")
732            _test_ineq(j1, j2, "indeterminate")
733            _test_ineq(j1x2, j1, True)
734
735        # Special cases for ge, le, gt, lt:
736        for ge in (sympy.Ge, is_ge):
737            self.assertTrue(ge(j1, j1))
738            self.assertTrue(ge(j1, sympy.Integer(2)))
739            with self.assertRaisesRegex(ValueError, "indeterminate"):
740                ge(sympy.Integer(2), j1)
741        for le in (sympy.Le, is_le):
742            self.assertTrue(le(j1, j1))
743            self.assertTrue(le(sympy.Integer(2), j1))
744            with self.assertRaisesRegex(ValueError, "indeterminate"):
745                le(j1, sympy.Integer(2))
746
747        for gt in (sympy.Gt, is_gt):
748            self.assertFalse(gt(j1, j1))
749            self.assertFalse(gt(sympy.Integer(2), j1))
750            # it is only known to be that j1 >= 2, j1 > 2 is indeterminate
751            with self.assertRaisesRegex(ValueError, "indeterminate"):
752                gt(j1, sympy.Integer(2))
753
754        for lt in (sympy.Lt, is_lt):
755            self.assertFalse(lt(j1, j1))
756            self.assertFalse(lt(j1, sympy.Integer(2)))
757            with self.assertRaisesRegex(ValueError, "indeterminate"):
758                lt(sympy.Integer(2), j1)
759
760        # mul
761        self.assertEqual(j1 * 2, j1x2)
762        # Unfortunately, this doesn't not automatically simplify to 2*j1
763        # since sympy.Mul doesn't trigger __mul__ unlike the above.
764        self.assertIsInstance(sympy.Mul(j1, 2), sympy.core.mul.Mul)
765
766        with self.assertRaisesRegex(ValueError, "cannot be multiplied"):
767            j1 * j2
768
769        self.assertEqual(j1.free_symbols, set())
770
771
772instantiate_parametrized_tests(TestValueRanges)
773instantiate_parametrized_tests(TestSympyInterp)
774instantiate_parametrized_tests(TestSympySolve)
775
776
777if __name__ == "__main__":
778    run_tests()
779