1# Owner(s): ["module: inductor"] 2 3from sympy import Symbol 4 5from torch._inductor.test_case import run_tests, TestCase 6from torch._inductor.utils import sympy_subs 7 8 9class TestUtils(TestCase): 10 def testSympySubs(self): 11 # integer and nonnegetaive attributes are preserved. 12 expr = Symbol("x") 13 result = sympy_subs(expr, {expr: "y"}) 14 self.assertEqual(result.name, "y") 15 self.assertEqual(result.is_integer, None) 16 self.assertEqual(result.is_nonnegative, None) 17 18 expr = Symbol("x", integer=True, nonnegative=False) 19 result = sympy_subs(expr, {expr: "y"}) 20 self.assertEqual(result.name, "y") 21 self.assertEqual(result.is_integer, True) 22 self.assertEqual(result.is_nonnegative, False) 23 24 # invalid replacement. 25 expr = Symbol("x", integer=True) 26 result = sympy_subs(expr, {Symbol("x"): Symbol("y")}) 27 self.assertEqual(result.name, "x") 28 29 # valid replacement since properties match. 30 expr = Symbol("x", integer=True) 31 result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")}) 32 self.assertEqual(result.name, "y") 33 34 # invalid replacement. 35 expr = Symbol("x", integer=None) 36 result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")}) 37 self.assertEqual(result.name, "x") 38 39 # replaced cant be string 40 self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"}) 41 42 # replaced can be an expression 43 expr = Symbol("x") 44 expr = abs(expr) 45 self.assertEqual(expr.is_integer, None) 46 self.assertEqual(expr.is_nonnegative, None) 47 # replace abs(x) with y 48 # propagte abs(x) sympy properties. 49 result = sympy_subs(expr, {expr: Symbol("y")}) 50 self.assertEqual(result.name, "y") 51 self.assertEqual(result.is_integer, None) 52 self.assertEqual(result.is_nonnegative, None) 53 54 55if __name__ == "__main__": 56 run_tests() 57