xref: /aosp_15_r20/external/pytorch/test/inductor/test_indexing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import os
3import sys
4import unittest
5
6import sympy
7
8import torch
9from torch._inductor.codegen.cpp import cexpr
10from torch._inductor.codegen.triton import texpr
11from torch._inductor.codegen.wrapper import pexpr
12from torch._inductor.runtime.benchmarking import benchmarker
13from torch._inductor.sizevars import SizeVarAllocator
14from torch._inductor.test_case import TestCase as InductorTestCase
15from torch._inductor.utils import run_and_get_triton_code
16from torch.testing._internal.common_utils import (
17    instantiate_parametrized_tests,
18    parametrize,
19)
20from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
21from torch.utils._sympy.functions import (
22    FloorDiv,
23    ModularIndexing,
24    RoundDecimal,
25    RoundToInt,
26)
27
28
29DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
30
31
32class TestIndexingSimplification(InductorTestCase):
33    def test_indexing_simplification(self):
34        sizevars = SizeVarAllocator()
35        i0 = sympy.Symbol("i0", integer=True)
36        i1 = sympy.Symbol("i1", integer=True)
37        i2 = sympy.Symbol("i2", integer=True)
38        r3 = sympy.Symbol("r3", integer=True)
39
40        var_ranges = {i0: 3136, i1: 64, i2: 32, r3: 3}
41        expr = (
42            128 * i2
43            + ModularIndexing(i1, 1, 64)
44            + 64 * ModularIndexing(i1 + 64 * r3, 64, 2)
45        )
46        # check that `i1//64` is removed when i1 is always less than 64,
47        # and the next simplificaton doesn't happen
48        self.assertEqual(
49            sizevars.simplify_with_ranges(expr, var_ranges),
50            i1 + 128 * i2 + 64 * ModularIndexing(r3, 1, 2),
51        )
52        # all the modular indexing should be removed when the body cant be larger than the modulus
53        var_ranges[r3] = 2
54        self.assertEqual(
55            sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * r3
56        )
57        # if there are negative terms in ModularIndexing base, we cannot replace it with FloorDiv
58        expr = ModularIndexing(i1 - 15, 1, 64)
59        self.assertEqual(
60            sizevars.simplify_with_ranges(expr, var_ranges),
61            ModularIndexing(i1 - 15, 1, 64),
62        )
63        # small terms should be kept if the rest is not guaranteed to be divisible
64        self.assertEqual(
65            sizevars.simplify_with_ranges(FloorDiv(r3 + i2 + i1, 32), var_ranges),
66            FloorDiv(r3 + i2 + i1, 32),
67        )
68
69        expr = ModularIndexing(2 * i2 + r3, 1, 64)
70        # modular indexing is removed if base is smaller than modulo
71        self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), 2 * i2 + r3)
72
73        # check the same thing but with symbolic divisor
74        self.assertEqual(FloorDiv(r3 * i0, r3), i0)
75        self.assertEqual(ModularIndexing(r3 * i0, r3, 10), ModularIndexing(i0, 1, 10))
76
77        # (10*i) % 10 is always zero and should get optimized away
78        self.assertEqual(
79            ModularIndexing(i0 + i1 * 10, 1, 10), ModularIndexing(i0, 1, 10)
80        )
81
82        # ((20*i)//2) % 10 is always zero and should get optimized away
83        self.assertEqual(
84            ModularIndexing(i0 + i1 * 20, 2, 10), ModularIndexing(i0, 2, 10)
85        )
86
87        # the same things happens with symbolic divisor
88        self.assertEqual(
89            ModularIndexing(i0 + i1 * i2 * r3, i2, r3), ModularIndexing(i0, i2, r3)
90        )
91
92        # if there are negative terms, we cannot optimize away zero terms due to https://github.com/openai/triton/issues/619
93        self.assertEqual(
94            ModularIndexing(-i0 + i1 * 20, 2, 10), ModularIndexing(-i0 + i1 * 20, 2, 10)
95        )
96        self.assertEqual(
97            ModularIndexing(-15 + i1 * 20, 2, 10), ModularIndexing(-15 + i1 * 20, 2, 10)
98        )
99
100        # Constant fold from divisor into base
101        self.assertEqual(ModularIndexing(i0 * 4, 2, 10), ModularIndexing(i0 * 2, 1, 10))
102        self.assertEqual(FloorDiv(i0 * 4, 2), i0 * 2)
103
104        # Nested modular indexing is correctly simplified
105        var_ranges = {i1: 13, i2: 121}
106        expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784), 1, 28)
107        self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
108        expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784) + 1, 1, 28)
109        self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
110        var_ranges = {i2: 784}
111        expr = ModularIndexing(ModularIndexing(i2, 1, 28), 7, 4)
112        expected = FloorDiv(ModularIndexing(i2, 1, 28), 7)
113        self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expected)
114        expr = ModularIndexing(ModularIndexing(i2, 1, 28) + 1, 7, 4)
115        self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
116
117    def test_indexing_join(self):
118        sizevars = SizeVarAllocator()
119        i0 = sympy.Symbol("i0", integer=True)
120        i1 = sympy.Symbol("i1", integer=True)
121        i2 = sympy.Symbol("i2", integer=True)
122
123        # join two ModularIndexing calls into one larger one when possible
124        expr1 = ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
125        self.assertEqual(
126            sizevars.simplify_with_ranges(expr1, {}), ModularIndexing(i0, 1, 128)
127        )
128
129        # it should also work with a scale
130        self.assertEqual(
131            sizevars.simplify_with_ranges(2 * expr1, {}),
132            2 * ModularIndexing(i0, 1, 128),
133        )
134
135        # it should work when divisor is not 1
136        expr2 = ModularIndexing(i0, 3, 32) + 32 * ModularIndexing(i0, 32 * 3, 4)
137        simplified = sizevars.simplify_with_ranges(expr2, {})
138        self.assertEqual(simplified, ModularIndexing(i0, 3, 128))
139        self.assertEqual(expr2.subs({i0: 39485}), simplified.subs({i0: 39485}))
140
141        # it should not happen in this case as the modulus is wrong
142        expr3 = ModularIndexing(i0, 1, 30) + 32 * ModularIndexing(i0, 32, 4)
143        self.assertEqual(sizevars.simplify_with_ranges(expr3, {}), expr3)
144
145        # check that it also works with a modulus>1
146        expr4 = ModularIndexing(i0, 10, i1) + i1 * ModularIndexing(i0, i1 * 10, i2)
147        res0 = expr4.subs({i0: 24056, i1: 13, i2: 19})
148        simplified = sizevars.simplify_with_ranges(expr4, {})
149        res1 = simplified.subs({i0: 24056, i1: 13, i2: 19})
150        self.assertEqual(res0, res1)
151        self.assertEqual(simplified, ModularIndexing(i0, 10, i1 * i2))
152
153        # and also works with an offset
154        self.assertEqual(
155            sizevars.simplify_with_ranges(expr4 + 10, {}),
156            ModularIndexing(i0, 10, i1 * i2) + 10,
157        )
158
159        # works for ModularIndexing + FloorDiv
160        expr5 = 197 * FloorDiv(i0, 197) + ModularIndexing(i0, 1, 197)
161        simplified = sizevars.simplify_with_ranges(expr5, {})
162        self.assertEqual(simplified, i0)
163        self.assertEqual(expr5.subs({i0: 39485}), simplified.subs({i0: 39485}))
164
165        # works with a scale
166        self.assertEqual(
167            sizevars.simplify_with_ranges(2 * expr5, {}),
168            2 * i0,
169        )
170
171        # divisor != 1
172        expr6 = 197 * FloorDiv(i0, 197 * 3) + ModularIndexing(i0, 3, 197)
173        simplified = sizevars.simplify_with_ranges(expr6, {})
174        self.assertEqual(simplified, FloorDiv(i0, 3))
175        self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485}))
176
177    def test_modular_indexing_pairs_merged(self):
178        sizevars = SizeVarAllocator()
179        x = sympy.Symbol("x", integer=True, positive=True)
180        a = 1024
181        b = 32
182        expr1 = ModularIndexing(x, 1, a)
183        expr2 = ModularIndexing(expr1, 1, b)
184        expected = ModularIndexing(x, 1, b)
185
186        actual = sizevars.combine_modular_indexing_pairs(expr2)
187        self.assertEqual(expected, actual)
188        self.assertNotEqual(expr2, actual)
189
190    def test_modular_indexing_pairs_not_merged(self):
191        sizevars = SizeVarAllocator()
192        x = sympy.Symbol("x", integer=True, positive=True)
193        a = 1024
194        b = 3  # pick a 'b' that we can not merge
195        expr1 = ModularIndexing(x, 1, a)
196        expr2 = ModularIndexing(expr1, 1, b)
197
198        actual = sizevars.combine_modular_indexing_pairs(expr2)
199        self.assertEqual(expr2, actual)
200        self.assertNotEqual(ModularIndexing(x, 1, b), actual)
201
202    def test_expand_floor_div_skipped(self):
203        sizevars = SizeVarAllocator()
204        x = sympy.Symbol("x", integer=True, positive=True)
205        y = sympy.Symbol("y", integer=True, positive=True)
206
207        expr = FloorDiv(x, 2) + FloorDiv(y, 3)
208        # The expression can not be simplified since there are multiple
209        # FloorDiv. We return False in that case
210        self.assertFalse(sizevars.expand_floor_div(expr))
211
212    def test_expand_floor_div_applied(self):
213        sizevars = SizeVarAllocator()
214        x = sympy.Symbol("x", integer=True, positive=True)
215        y = sympy.Symbol("y", integer=True, positive=True)
216
217        expr = x * 5 + FloorDiv(y, 3)
218        actual, denominator = sizevars.expand_floor_div(expr)
219        self.assertNotEqual(expr, actual)
220        expected = FloorDiv(x * 15 + y, 3)
221        self.assertEqual(expected, FloorDiv(actual, denominator))
222
223    @unittest.skipUnless(HAS_GPU, "Need GPU for this test")
224    def test_int8_unpack(self):
225        @torch.compile
226        def f(x):
227            first_elements = x >> 4
228            second_elements = x & 15
229            unpacked = torch.stack([first_elements, second_elements], dim=-1).view(
230                *x.size()[:-1], -1
231            )
232            return unpacked * 2
233
234        x = torch.randint(0, 255, (2, 4096, 5504), dtype=torch.uint8, device=GPU_TYPE)
235
236        triton_code = run_and_get_triton_code(f, x)
237        # Make sure the 2 load uses simpified indexing rather than something like
238        # tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)),
239        self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),"))
240        if DO_PERF_TEST:
241            ms = benchmarker.benchmark_gpu(lambda: f(x))
242            print(f"{ms=:.03f}")
243
244
245class ExprPrinterTests(InductorTestCase):
246    def test_print_pow(self):
247        s1 = sympy.Symbol("foo", integer=True)
248        s2 = sympy.Symbol("bar", integer=True)
249        s3 = sympy.Symbol("baz", integer=True)
250
251        common_cases = [
252            # expr, result
253            # Test Pow directly.
254            (
255                sympy.Pow(s1 + s2, 0),
256                lambda _, L: f"1{L}",
257            ),  # note: simplified before _print_Pow
258        ]
259
260        gpu_cases = common_cases + [
261            (sympy.Pow(s1 + s2, 2), lambda c, L: "(bar + foo)*(bar + foo)")
262        ]
263        cpu_cases = common_cases + [
264            (
265                sympy.Pow(s1 + s2, 2),
266                lambda c, L: "static_cast<int64_t>((bar + foo)*(bar + foo))",
267            )
268        ]
269        for expr, result in gpu_cases:
270            self.assertEqual(texpr(expr), result(1, ""))
271            self.assertEqual(pexpr(expr), result(1, ""))
272        for expr, result in cpu_cases:
273            self.assertEqual(
274                cexpr(expr),
275                result(1.0, "LL")
276                if sys.platform in ["darwin", "win32"]
277                else result(1.0, "L"),
278            )  # 1.0 for FP div
279
280    def test_print_floor(self):
281        for integer in [True, False]:
282            s1 = sympy.Symbol("s1", integer=integer)
283            expr = sympy.floor(s1 / 2)
284            if integer:
285                self.assertEqual(pexpr(expr), "math.floor((1/2)*s1)")
286                self.assertEqual(
287                    cexpr(expr), "static_cast<int64_t>(std::floor((1.0/2.0)*s1))"
288                )
289            else:
290                self.assertExpectedInline(pexpr(expr), """math.floor((1/2)*s1)""")
291                self.assertExpectedInline(
292                    texpr(expr),
293                    """libdevice.floor((1/2)*s1).to(tl.int64)""",
294                )
295                self.assertExpectedInline(cexpr(expr), """std::floor((1.0/2.0)*s1)""")
296
297    def test_print_ceil(self):
298        for integer in [True, False]:
299            s1 = sympy.Symbol("s1", integer=integer)
300            expr = sympy.ceiling(s1 / 2)
301            if integer:
302                self.assertExpectedInline(pexpr(expr), """math.ceil((1/2)*s1)""")
303                self.assertExpectedInline(
304                    cexpr(expr), """static_cast<int64_t>(std::ceil((1.0/2.0)*s1))"""
305                )
306            else:
307                self.assertExpectedInline(pexpr(expr), """math.ceil((1/2)*s1)""")
308                self.assertExpectedInline(cexpr(expr), """std::ceil((1.0/2.0)*s1)""")
309
310    def test_print_round(self):
311        expr = RoundToInt(sympy.Symbol("x", integer=True) / 2)
312        self.assertExpectedInline(pexpr(expr), """round((1/2)*x)""")
313        self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""")
314        self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""")
315
316    @parametrize("ndigits", [-1, 0, 1])
317    def test_print_round_decimal(self, ndigits):
318        expr = RoundDecimal(sympy.Symbol("x", integer=True) / 2, ndigits)
319        self.assertEqual(pexpr(expr), f"round((1/2)*x, {ndigits})")
320        self.assertEqual(
321            cexpr(expr),
322            f"static_cast<double>(std::nearbyint(1e{ndigits} * ((1.0/2.0)*x)) * 1e{-ndigits})",
323        )
324        self.assertEqual(
325            texpr(expr),
326            f"libdevice.nearbyint(1e{ndigits} * ((1/2)*x)) * 1e{-ndigits}",
327        )
328
329    def test_print_floor_div(self):
330        s1 = sympy.Symbol("s1", integer=True)
331        s2 = sympy.Symbol("s2", integer=True)
332        expr = FloorDiv(s1, s2)
333        self.assertEqual(pexpr(expr), "(s1 // s2)")
334        self.assertEqual(
335            cexpr(expr),
336            "c10::div_floor_integer(static_cast<int64_t>(s1), static_cast<int64_t>(s2))",
337        )
338
339        s1 = sympy.Symbol("s1", integer=True)
340        s2 = sympy.S(-1)
341        expr = FloorDiv(s1, s2)
342        self.assertEqual(pexpr(expr), "(-1)*s1")
343        self.assertEqual(cexpr(expr), "(-1LL)*s1") if sys.platform in [
344            "darwin",
345            "win32",
346        ] else "(-1L)*s1"
347
348    def test_print_Min_Max(self):
349        cases = (
350            (sympy.Min, "min", "<"),
351            (sympy.Max, "max", ">"),
352        )
353        for f, s, cmp in cases:
354            x = sympy.Symbol("x", integer=True)
355            expr = f(-2, x)
356            self.assertEqual(
357                texpr(expr), f"((-2) * ((-2) {cmp}= (x)) + (x) * ((x) {cmp} (-2)))"
358            )
359            self.assertEqual(
360                cexpr(expr),
361                f"std::{s}(static_cast<int64_t>(-2LL), static_cast<int64_t>(x))"
362                if sys.platform in ["darwin", "win32"]
363                else f"std::{s}(static_cast<int64_t>(-2L), static_cast<int64_t>(x))",
364            )
365
366            expr = f(x, 2 * x, 3 * x)
367            self.assertEqual(
368                texpr(expr),
369                f"((x) * ((x) {cmp}= (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x))))) + (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) * ((((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) {cmp} (x)))",  # noqa: B950 line too long
370            )
371            self.assertEqual(
372                cexpr(expr),
373                f"std::{s}({{x, 2LL*x, 3LL*x}})"
374                if sys.platform in ["darwin", "win32"]
375                else f"std::{s}({{x, 2L*x, 3L*x}})",
376            )
377
378
379instantiate_parametrized_tests(ExprPrinterTests)
380
381
382if __name__ == "__main__":
383    from torch._inductor.test_case import run_tests
384
385    if HAS_CPU or HAS_GPU:
386        run_tests("sympy")
387