1# Owner(s): ["module: inductor"] 2import os 3import sys 4import unittest 5 6import sympy 7 8import torch 9from torch._inductor.codegen.cpp import cexpr 10from torch._inductor.codegen.triton import texpr 11from torch._inductor.codegen.wrapper import pexpr 12from torch._inductor.runtime.benchmarking import benchmarker 13from torch._inductor.sizevars import SizeVarAllocator 14from torch._inductor.test_case import TestCase as InductorTestCase 15from torch._inductor.utils import run_and_get_triton_code 16from torch.testing._internal.common_utils import ( 17 instantiate_parametrized_tests, 18 parametrize, 19) 20from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU 21from torch.utils._sympy.functions import ( 22 FloorDiv, 23 ModularIndexing, 24 RoundDecimal, 25 RoundToInt, 26) 27 28 29DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" 30 31 32class TestIndexingSimplification(InductorTestCase): 33 def test_indexing_simplification(self): 34 sizevars = SizeVarAllocator() 35 i0 = sympy.Symbol("i0", integer=True) 36 i1 = sympy.Symbol("i1", integer=True) 37 i2 = sympy.Symbol("i2", integer=True) 38 r3 = sympy.Symbol("r3", integer=True) 39 40 var_ranges = {i0: 3136, i1: 64, i2: 32, r3: 3} 41 expr = ( 42 128 * i2 43 + ModularIndexing(i1, 1, 64) 44 + 64 * ModularIndexing(i1 + 64 * r3, 64, 2) 45 ) 46 # check that `i1//64` is removed when i1 is always less than 64, 47 # and the next simplificaton doesn't happen 48 self.assertEqual( 49 sizevars.simplify_with_ranges(expr, var_ranges), 50 i1 + 128 * i2 + 64 * ModularIndexing(r3, 1, 2), 51 ) 52 # all the modular indexing should be removed when the body cant be larger than the modulus 53 var_ranges[r3] = 2 54 self.assertEqual( 55 sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * r3 56 ) 57 # if there are negative terms in ModularIndexing base, we cannot replace it with FloorDiv 58 expr = ModularIndexing(i1 - 15, 1, 64) 59 self.assertEqual( 60 sizevars.simplify_with_ranges(expr, var_ranges), 61 ModularIndexing(i1 - 15, 1, 64), 62 ) 63 # small terms should be kept if the rest is not guaranteed to be divisible 64 self.assertEqual( 65 sizevars.simplify_with_ranges(FloorDiv(r3 + i2 + i1, 32), var_ranges), 66 FloorDiv(r3 + i2 + i1, 32), 67 ) 68 69 expr = ModularIndexing(2 * i2 + r3, 1, 64) 70 # modular indexing is removed if base is smaller than modulo 71 self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), 2 * i2 + r3) 72 73 # check the same thing but with symbolic divisor 74 self.assertEqual(FloorDiv(r3 * i0, r3), i0) 75 self.assertEqual(ModularIndexing(r3 * i0, r3, 10), ModularIndexing(i0, 1, 10)) 76 77 # (10*i) % 10 is always zero and should get optimized away 78 self.assertEqual( 79 ModularIndexing(i0 + i1 * 10, 1, 10), ModularIndexing(i0, 1, 10) 80 ) 81 82 # ((20*i)//2) % 10 is always zero and should get optimized away 83 self.assertEqual( 84 ModularIndexing(i0 + i1 * 20, 2, 10), ModularIndexing(i0, 2, 10) 85 ) 86 87 # the same things happens with symbolic divisor 88 self.assertEqual( 89 ModularIndexing(i0 + i1 * i2 * r3, i2, r3), ModularIndexing(i0, i2, r3) 90 ) 91 92 # if there are negative terms, we cannot optimize away zero terms due to https://github.com/openai/triton/issues/619 93 self.assertEqual( 94 ModularIndexing(-i0 + i1 * 20, 2, 10), ModularIndexing(-i0 + i1 * 20, 2, 10) 95 ) 96 self.assertEqual( 97 ModularIndexing(-15 + i1 * 20, 2, 10), ModularIndexing(-15 + i1 * 20, 2, 10) 98 ) 99 100 # Constant fold from divisor into base 101 self.assertEqual(ModularIndexing(i0 * 4, 2, 10), ModularIndexing(i0 * 2, 1, 10)) 102 self.assertEqual(FloorDiv(i0 * 4, 2), i0 * 2) 103 104 # Nested modular indexing is correctly simplified 105 var_ranges = {i1: 13, i2: 121} 106 expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784), 1, 28) 107 self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr) 108 expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784) + 1, 1, 28) 109 self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr) 110 var_ranges = {i2: 784} 111 expr = ModularIndexing(ModularIndexing(i2, 1, 28), 7, 4) 112 expected = FloorDiv(ModularIndexing(i2, 1, 28), 7) 113 self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expected) 114 expr = ModularIndexing(ModularIndexing(i2, 1, 28) + 1, 7, 4) 115 self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr) 116 117 def test_indexing_join(self): 118 sizevars = SizeVarAllocator() 119 i0 = sympy.Symbol("i0", integer=True) 120 i1 = sympy.Symbol("i1", integer=True) 121 i2 = sympy.Symbol("i2", integer=True) 122 123 # join two ModularIndexing calls into one larger one when possible 124 expr1 = ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4) 125 self.assertEqual( 126 sizevars.simplify_with_ranges(expr1, {}), ModularIndexing(i0, 1, 128) 127 ) 128 129 # it should also work with a scale 130 self.assertEqual( 131 sizevars.simplify_with_ranges(2 * expr1, {}), 132 2 * ModularIndexing(i0, 1, 128), 133 ) 134 135 # it should work when divisor is not 1 136 expr2 = ModularIndexing(i0, 3, 32) + 32 * ModularIndexing(i0, 32 * 3, 4) 137 simplified = sizevars.simplify_with_ranges(expr2, {}) 138 self.assertEqual(simplified, ModularIndexing(i0, 3, 128)) 139 self.assertEqual(expr2.subs({i0: 39485}), simplified.subs({i0: 39485})) 140 141 # it should not happen in this case as the modulus is wrong 142 expr3 = ModularIndexing(i0, 1, 30) + 32 * ModularIndexing(i0, 32, 4) 143 self.assertEqual(sizevars.simplify_with_ranges(expr3, {}), expr3) 144 145 # check that it also works with a modulus>1 146 expr4 = ModularIndexing(i0, 10, i1) + i1 * ModularIndexing(i0, i1 * 10, i2) 147 res0 = expr4.subs({i0: 24056, i1: 13, i2: 19}) 148 simplified = sizevars.simplify_with_ranges(expr4, {}) 149 res1 = simplified.subs({i0: 24056, i1: 13, i2: 19}) 150 self.assertEqual(res0, res1) 151 self.assertEqual(simplified, ModularIndexing(i0, 10, i1 * i2)) 152 153 # and also works with an offset 154 self.assertEqual( 155 sizevars.simplify_with_ranges(expr4 + 10, {}), 156 ModularIndexing(i0, 10, i1 * i2) + 10, 157 ) 158 159 # works for ModularIndexing + FloorDiv 160 expr5 = 197 * FloorDiv(i0, 197) + ModularIndexing(i0, 1, 197) 161 simplified = sizevars.simplify_with_ranges(expr5, {}) 162 self.assertEqual(simplified, i0) 163 self.assertEqual(expr5.subs({i0: 39485}), simplified.subs({i0: 39485})) 164 165 # works with a scale 166 self.assertEqual( 167 sizevars.simplify_with_ranges(2 * expr5, {}), 168 2 * i0, 169 ) 170 171 # divisor != 1 172 expr6 = 197 * FloorDiv(i0, 197 * 3) + ModularIndexing(i0, 3, 197) 173 simplified = sizevars.simplify_with_ranges(expr6, {}) 174 self.assertEqual(simplified, FloorDiv(i0, 3)) 175 self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485})) 176 177 def test_modular_indexing_pairs_merged(self): 178 sizevars = SizeVarAllocator() 179 x = sympy.Symbol("x", integer=True, positive=True) 180 a = 1024 181 b = 32 182 expr1 = ModularIndexing(x, 1, a) 183 expr2 = ModularIndexing(expr1, 1, b) 184 expected = ModularIndexing(x, 1, b) 185 186 actual = sizevars.combine_modular_indexing_pairs(expr2) 187 self.assertEqual(expected, actual) 188 self.assertNotEqual(expr2, actual) 189 190 def test_modular_indexing_pairs_not_merged(self): 191 sizevars = SizeVarAllocator() 192 x = sympy.Symbol("x", integer=True, positive=True) 193 a = 1024 194 b = 3 # pick a 'b' that we can not merge 195 expr1 = ModularIndexing(x, 1, a) 196 expr2 = ModularIndexing(expr1, 1, b) 197 198 actual = sizevars.combine_modular_indexing_pairs(expr2) 199 self.assertEqual(expr2, actual) 200 self.assertNotEqual(ModularIndexing(x, 1, b), actual) 201 202 def test_expand_floor_div_skipped(self): 203 sizevars = SizeVarAllocator() 204 x = sympy.Symbol("x", integer=True, positive=True) 205 y = sympy.Symbol("y", integer=True, positive=True) 206 207 expr = FloorDiv(x, 2) + FloorDiv(y, 3) 208 # The expression can not be simplified since there are multiple 209 # FloorDiv. We return False in that case 210 self.assertFalse(sizevars.expand_floor_div(expr)) 211 212 def test_expand_floor_div_applied(self): 213 sizevars = SizeVarAllocator() 214 x = sympy.Symbol("x", integer=True, positive=True) 215 y = sympy.Symbol("y", integer=True, positive=True) 216 217 expr = x * 5 + FloorDiv(y, 3) 218 actual, denominator = sizevars.expand_floor_div(expr) 219 self.assertNotEqual(expr, actual) 220 expected = FloorDiv(x * 15 + y, 3) 221 self.assertEqual(expected, FloorDiv(actual, denominator)) 222 223 @unittest.skipUnless(HAS_GPU, "Need GPU for this test") 224 def test_int8_unpack(self): 225 @torch.compile 226 def f(x): 227 first_elements = x >> 4 228 second_elements = x & 15 229 unpacked = torch.stack([first_elements, second_elements], dim=-1).view( 230 *x.size()[:-1], -1 231 ) 232 return unpacked * 2 233 234 x = torch.randint(0, 255, (2, 4096, 5504), dtype=torch.uint8, device=GPU_TYPE) 235 236 triton_code = run_and_get_triton_code(f, x) 237 # Make sure the 2 load uses simpified indexing rather than something like 238 # tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)), 239 self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),")) 240 if DO_PERF_TEST: 241 ms = benchmarker.benchmark_gpu(lambda: f(x)) 242 print(f"{ms=:.03f}") 243 244 245class ExprPrinterTests(InductorTestCase): 246 def test_print_pow(self): 247 s1 = sympy.Symbol("foo", integer=True) 248 s2 = sympy.Symbol("bar", integer=True) 249 s3 = sympy.Symbol("baz", integer=True) 250 251 common_cases = [ 252 # expr, result 253 # Test Pow directly. 254 ( 255 sympy.Pow(s1 + s2, 0), 256 lambda _, L: f"1{L}", 257 ), # note: simplified before _print_Pow 258 ] 259 260 gpu_cases = common_cases + [ 261 (sympy.Pow(s1 + s2, 2), lambda c, L: "(bar + foo)*(bar + foo)") 262 ] 263 cpu_cases = common_cases + [ 264 ( 265 sympy.Pow(s1 + s2, 2), 266 lambda c, L: "static_cast<int64_t>((bar + foo)*(bar + foo))", 267 ) 268 ] 269 for expr, result in gpu_cases: 270 self.assertEqual(texpr(expr), result(1, "")) 271 self.assertEqual(pexpr(expr), result(1, "")) 272 for expr, result in cpu_cases: 273 self.assertEqual( 274 cexpr(expr), 275 result(1.0, "LL") 276 if sys.platform in ["darwin", "win32"] 277 else result(1.0, "L"), 278 ) # 1.0 for FP div 279 280 def test_print_floor(self): 281 for integer in [True, False]: 282 s1 = sympy.Symbol("s1", integer=integer) 283 expr = sympy.floor(s1 / 2) 284 if integer: 285 self.assertEqual(pexpr(expr), "math.floor((1/2)*s1)") 286 self.assertEqual( 287 cexpr(expr), "static_cast<int64_t>(std::floor((1.0/2.0)*s1))" 288 ) 289 else: 290 self.assertExpectedInline(pexpr(expr), """math.floor((1/2)*s1)""") 291 self.assertExpectedInline( 292 texpr(expr), 293 """libdevice.floor((1/2)*s1).to(tl.int64)""", 294 ) 295 self.assertExpectedInline(cexpr(expr), """std::floor((1.0/2.0)*s1)""") 296 297 def test_print_ceil(self): 298 for integer in [True, False]: 299 s1 = sympy.Symbol("s1", integer=integer) 300 expr = sympy.ceiling(s1 / 2) 301 if integer: 302 self.assertExpectedInline(pexpr(expr), """math.ceil((1/2)*s1)""") 303 self.assertExpectedInline( 304 cexpr(expr), """static_cast<int64_t>(std::ceil((1.0/2.0)*s1))""" 305 ) 306 else: 307 self.assertExpectedInline(pexpr(expr), """math.ceil((1/2)*s1)""") 308 self.assertExpectedInline(cexpr(expr), """std::ceil((1.0/2.0)*s1)""") 309 310 def test_print_round(self): 311 expr = RoundToInt(sympy.Symbol("x", integer=True) / 2) 312 self.assertExpectedInline(pexpr(expr), """round((1/2)*x)""") 313 self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""") 314 self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""") 315 316 @parametrize("ndigits", [-1, 0, 1]) 317 def test_print_round_decimal(self, ndigits): 318 expr = RoundDecimal(sympy.Symbol("x", integer=True) / 2, ndigits) 319 self.assertEqual(pexpr(expr), f"round((1/2)*x, {ndigits})") 320 self.assertEqual( 321 cexpr(expr), 322 f"static_cast<double>(std::nearbyint(1e{ndigits} * ((1.0/2.0)*x)) * 1e{-ndigits})", 323 ) 324 self.assertEqual( 325 texpr(expr), 326 f"libdevice.nearbyint(1e{ndigits} * ((1/2)*x)) * 1e{-ndigits}", 327 ) 328 329 def test_print_floor_div(self): 330 s1 = sympy.Symbol("s1", integer=True) 331 s2 = sympy.Symbol("s2", integer=True) 332 expr = FloorDiv(s1, s2) 333 self.assertEqual(pexpr(expr), "(s1 // s2)") 334 self.assertEqual( 335 cexpr(expr), 336 "c10::div_floor_integer(static_cast<int64_t>(s1), static_cast<int64_t>(s2))", 337 ) 338 339 s1 = sympy.Symbol("s1", integer=True) 340 s2 = sympy.S(-1) 341 expr = FloorDiv(s1, s2) 342 self.assertEqual(pexpr(expr), "(-1)*s1") 343 self.assertEqual(cexpr(expr), "(-1LL)*s1") if sys.platform in [ 344 "darwin", 345 "win32", 346 ] else "(-1L)*s1" 347 348 def test_print_Min_Max(self): 349 cases = ( 350 (sympy.Min, "min", "<"), 351 (sympy.Max, "max", ">"), 352 ) 353 for f, s, cmp in cases: 354 x = sympy.Symbol("x", integer=True) 355 expr = f(-2, x) 356 self.assertEqual( 357 texpr(expr), f"((-2) * ((-2) {cmp}= (x)) + (x) * ((x) {cmp} (-2)))" 358 ) 359 self.assertEqual( 360 cexpr(expr), 361 f"std::{s}(static_cast<int64_t>(-2LL), static_cast<int64_t>(x))" 362 if sys.platform in ["darwin", "win32"] 363 else f"std::{s}(static_cast<int64_t>(-2L), static_cast<int64_t>(x))", 364 ) 365 366 expr = f(x, 2 * x, 3 * x) 367 self.assertEqual( 368 texpr(expr), 369 f"((x) * ((x) {cmp}= (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x))))) + (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) * ((((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) {cmp} (x)))", # noqa: B950 line too long 370 ) 371 self.assertEqual( 372 cexpr(expr), 373 f"std::{s}({{x, 2LL*x, 3LL*x}})" 374 if sys.platform in ["darwin", "win32"] 375 else f"std::{s}({{x, 2L*x, 3L*x}})", 376 ) 377 378 379instantiate_parametrized_tests(ExprPrinterTests) 380 381 382if __name__ == "__main__": 383 from torch._inductor.test_case import run_tests 384 385 if HAS_CPU or HAS_GPU: 386 run_tests("sympy") 387