xref: /aosp_15_r20/external/pytorch/test/inductor/test_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3from sympy import Symbol
4
5from torch._inductor.test_case import run_tests, TestCase
6from torch._inductor.utils import sympy_subs
7
8
9class TestUtils(TestCase):
10    def testSympySubs(self):
11        # integer and nonnegetaive attributes are preserved.
12        expr = Symbol("x")
13        result = sympy_subs(expr, {expr: "y"})
14        self.assertEqual(result.name, "y")
15        self.assertEqual(result.is_integer, None)
16        self.assertEqual(result.is_nonnegative, None)
17
18        expr = Symbol("x", integer=True, nonnegative=False)
19        result = sympy_subs(expr, {expr: "y"})
20        self.assertEqual(result.name, "y")
21        self.assertEqual(result.is_integer, True)
22        self.assertEqual(result.is_nonnegative, False)
23
24        # invalid replacement.
25        expr = Symbol("x", integer=True)
26        result = sympy_subs(expr, {Symbol("x"): Symbol("y")})
27        self.assertEqual(result.name, "x")
28
29        # valid replacement since properties match.
30        expr = Symbol("x", integer=True)
31        result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")})
32        self.assertEqual(result.name, "y")
33
34        # invalid replacement.
35        expr = Symbol("x", integer=None)
36        result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")})
37        self.assertEqual(result.name, "x")
38
39        # replaced cant be string
40        self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"})
41
42        # replaced can be an expression
43        expr = Symbol("x")
44        expr = abs(expr)
45        self.assertEqual(expr.is_integer, None)
46        self.assertEqual(expr.is_nonnegative, None)
47        # replace abs(x) with y
48        # propagte abs(x) sympy properties.
49        result = sympy_subs(expr, {expr: Symbol("y")})
50        self.assertEqual(result.name, "y")
51        self.assertEqual(result.is_integer, None)
52        self.assertEqual(result.is_nonnegative, None)
53
54
55if __name__ == "__main__":
56    run_tests()
57