xref: /aosp_15_r20/external/pytorch/test/inductor/test_loop_ordering.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3import contextlib
4import unittest
5
6import numpy as np
7
8import torch
9from torch import nn
10from torch._dynamo.testing import rand_strided
11from torch._dynamo.utils import same
12from torch._inductor import config as inductor_config, ir, metrics
13from torch._inductor.codegen.triton import TritonScheduling
14from torch._inductor.graph import GraphLowering
15from torch._inductor.scheduler import SchedulerNode
16from torch._inductor.test_case import run_tests, TestCase
17from torch._inductor.test_operators import realize
18from torch._inductor.utils import sympy_index_symbol
19from torch._inductor.virtualized import ops, V
20from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
21from torch.testing._internal.inductor_utils import HAS_CUDA
22from torch.utils._pytree import tree_map
23from torch.utils._sympy.functions import ModularIndexing
24
25
26if HAS_CUDA:
27    torch.set_default_device("cuda")
28
29
30class MockScheduler:
31    available_buffer_names = ()
32
33    @staticmethod
34    def get_backend(cls, *args):
35        return TritonScheduling(cls)
36
37
38@inductor_config.patch(loop_ordering_after_fusion=True)
39class ImplDetailTest(TestCase):
40    _exit_stack = None
41
42    @classmethod
43    def setUpClass(cls):
44        super().setUpClass()
45
46        gm = torch.fx.symbolic_trace(lambda: 0)
47        graph = GraphLowering(gm)
48        graph.scheduler = MockScheduler
49        cls._exit_stack = contextlib.ExitStack()
50        cls._exit_stack.enter_context(V.set_graph_handler(graph))
51
52    @classmethod
53    def tearDownClass(cls):
54        super().tearDownClass()
55        cls._exit_stack.close()
56
57    @staticmethod
58    def _get_snode_body_sym_prefix(snode):
59        body = snode._body
60        prefix = ""
61
62        for var in body.var_ranges:
63            prefix = str(var)[0]
64            break
65
66        assert prefix
67        return prefix
68
69    @staticmethod
70    def _create_computed_buffer_ax2(sizes=(32, 64), strides=None):
71        """
72        Create a ComputedBuffer for 'a x 2'
73        """
74        if strides is None:
75            strides = ir.FlexibleLayout.contiguous_strides(sizes)
76
77        box_a = ir.TensorBox.create(
78            ir.Buffer(
79                "a", ir.FixedLayout(torch.device("cuda"), torch.float32, sizes, strides)
80            )
81        )
82        box_a_loader = box_a.make_loader()
83
84        def inner_fn(index):
85            return box_a_loader(index) * 2
86
87        buf = ir.Pointwise.create(
88            device=box_a.get_device(),
89            dtype=box_a.get_dtype(),
90            inner_fn=inner_fn,
91            ranges=box_a.get_size(),
92        )
93        buf.realize()
94        computed_buf = buf.data.data
95        computed_buf.decide_layout()
96        return computed_buf
97
98    def test_reorder_twice(self):
99        """
100        This may happen in practice if we pick a order when fusing A and B.
101        Then we pick another order for AB when we fusion C into it.
102
103        E.g. happens for BertForMaskedLM.
104        """
105
106        buf = self._create_computed_buffer_ax2()
107        snode = SchedulerNode(V.graph.scheduler, buf)
108        snode.apply_new_loop_order([1, 0])
109        prefix1 = self._get_snode_body_sym_prefix(snode)
110        self.assertTrue(prefix1 == "z")
111        snode.apply_new_loop_order([1, 0])
112        prefix2 = self._get_snode_body_sym_prefix(snode)
113        self.assertTrue(prefix2 == "z")
114
115    def test_reorder_and_merge_loops(self):
116        sizes = (1024, 2048)
117        strides = (1, 1024)
118        buf = self._create_computed_buffer_ax2(sizes, strides)
119        old_sizes, old_body = buf.simplify_and_reorder()
120
121        # Make sure loop reordering happens here
122        self.assertTrue(tuple(old_sizes[0]) == tuple(reversed(sizes)), f"{old_sizes=}")
123        new_body = old_body.merge_loops()
124        new_sizes = new_body.sizes
125        self.assertTrue(tuple(new_sizes[0]) == (np.prod(sizes),), f"{new_sizes=}")
126
127    def test_reorder_modular_indexing(self):
128        """
129        There was a bug that we wrongly map i0 to the dimension with size 49
130        when reordering the loop and cause ModularIndexing get optimized away
131        as an no-op.
132        """
133
134        def _create_computed_buffer():
135            def inner_fn(index):
136                i0, i1, i2, i3 = index
137                return ops.load(
138                    "primal", i3 + 49 * i2 + 2401 * ModularIndexing(i0, 1, 64)
139                )
140
141            buf = ir.Pointwise.create(
142                device=torch.device("cuda"),
143                dtype=torch.float32,
144                inner_fn=inner_fn,
145                ranges=[128, 4, 49, 49],
146            )
147            buf.realize()
148            cbuf = buf.data.data
149            cbuf.decide_layout()
150            return cbuf
151
152        buf = _create_computed_buffer()
153        _, body = buf.simplify_and_reorder()
154        new_body = body.reorder_iter_loops([1, 2, 3, 0])
155
156        z0, z1, z2, z3 = (sympy_index_symbol(f"z{i}") for i in range(4))
157        self.assertEqual(body.var_ranges, {z0: 128, z1: 4, z2: 49, z3: 49})
158        self.assertEqual(
159            body.indexing_exprs["index0"],
160            z3 + 49 * z2 + 2401 * ModularIndexing(z0, 1, 64),
161        )
162        self.assertEqual(new_body.var_ranges, {z0: 4, z1: 49, z2: 49, z3: 128})
163        self.assertEqual(
164            new_body.indexing_exprs["index0"],
165            z2 + 49 * z1 + 2401 * ModularIndexing(z3, 1, 64),
166        )
167
168
169@inductor_config.patch(
170    {
171        "benchmark_kernel": True,
172        "loop_ordering_after_fusion": True,
173        "triton.unique_kernel_names": True,
174    }
175)
176class LoopOrderingTest(TestCase):
177    def do_acc_test(self, f, *args, cast_fp8=True):
178        expect = f(*args)
179        actual = torch.compile(f)(*args)
180
181        if cast_fp8:
182
183            def _cast(x):
184                if isinstance(x, torch.Tensor) and x.dtype in (
185                    torch.float8_e5m2,
186                    torch.float8_e4m3fn,
187                ):
188                    return x.to(torch.float32)
189                return x
190
191            # Wordaround the issue that call allclose on fp8 tensor triggers error
192            #   RuntimeError: "mul_cuda" not implemented for 'Float8_e4m3fn'
193            expect = tree_map(_cast, expect)
194            actual = tree_map(_cast, actual)
195        self.assertTrue(same(expect, actual, tol=1e-3))
196
197    def setUp(self):
198        super().setUp()
199        metrics.reset()
200
201    def test_for_reordering_reindex(self):
202        """
203        ComputedBuffer.iter_reoredering_reindex can cause some fusion
204        opportunitiies being skipped.
205
206        In this test case, Inductor generates 2 triton kernels before.
207        By removing ComputedBuffer.iter_reoredering_reindex, we can fuse those
208        two kernels into a single one.
209        """
210
211        def f(x, y):
212            """
213            Add a matmul since inductor may force layout for output.
214            """
215            return (x.sum(dim=-1) + 1) @ y
216
217        A, B = 20, 30
218        # Make the first 2 dimension not able to merge on purpose so that
219        # ComputedBuffer.iter_reoredering_reindex will be updated.
220        x = rand_strided([A, A, B], [B, B * A + 300, 1], device="cuda")
221        y = torch.randn(A, A)
222
223        self.do_acc_test(f, x, y)
224        self.assertEqual(1, metrics.generated_kernel_count)
225        expected_num_bytes = 0
226        expected_num_bytes += A * A * B + A * A  # for the fused reduction
227        expected_num_bytes += A * A * 3  # for matmul
228        expected_num_bytes *= x.itemsize
229        self.assertEqual(expected_num_bytes, metrics.num_bytes_accessed)
230
231    def test_apbt_realize(self):
232        M = 1024
233        N = 2048
234
235        def f(x, y):
236            """
237            There will be 2 kernels being generated without loop ordering after fusion:
238              https://gist.github.com/shunting314/44df83f71de2c110232c50ac6638ed69
239            """
240            x = realize(x * 2)
241            y = realize(y * 3)
242            return x + y
243
244        x = torch.randn(M, N)
245        y = torch.randn(N, M).t()
246
247        self.do_acc_test(f, x, y)
248        self.assertEqual(1, metrics.generated_kernel_count)
249
250    def test_sum_and_t(self):
251        N = 1024
252
253        def f(x):
254            return x.sum(dim=-1), x.t().contiguous()
255
256        x = torch.randn(N, N * 2)
257        self.do_acc_test(f, x)
258        self.assertEqual(1, metrics.generated_kernel_count)
259
260    def test_pw_outer_red(self):
261        def f(x):
262            x = realize(x + 1)
263            return x.sum(dim=[0, 1])
264
265        # make the first 2 dimension small so we don't split the reduction
266        x = torch.randn(2, 4, 512)
267        self.do_acc_test(f, x)
268        self.assertEqual(1, metrics.generated_kernel_count)
269
270    def test_pw_outer_red_2(self):
271        """
272        The pointwise kernel is a fused kernel
273        """
274
275        def f(x):
276            x = realize(x + 1)
277            x = realize(x - 2)
278            x = realize(x * 3)
279            return x.sum(dim=[0, 1])
280
281        # make the first 2 dimension small so we don't split the reduction
282        x = torch.randn(2, 4, 512)
283        self.do_acc_test(f, x)
284        self.assertEqual(1, metrics.generated_kernel_count)
285
286    @inductor_config.patch(split_reductions=False)
287    def test_different_reduction_order(self):
288        """
289        We should not reorder loops in this case. Since reordering loops does
290        not help!
291        """
292
293        def f(x):
294            return x.sum(dim=0), x.sum(dim=1)
295
296        x = torch.randn(1024, 2048)
297        self.do_acc_test(f, x)
298        self.assertEqual(2, metrics.generated_kernel_count)
299        self.assertEqual(0, metrics.num_loop_reordering)
300
301    def test_keep_fake_dep(self):
302        """
303        In this model, there are fake dependencies (StarDep) between Scatter
304        and a following mutation kernel that computes the gradients of
305        the embedding tables.
306
307        When we do loop reordering for the mutation kernel, we re-analyze
308        the node's dependencies. But the analysis result does not contains
309        those fake dependencies. Have to add them back manually.
310        """
311        V = 2048
312        hidden_size = 64
313        max_seqlen = 512
314        batch_size = 8
315
316        class Model(nn.Module):
317            def __init__(self):
318                super().__init__()
319                self.word_embeddings = nn.Embedding(V, hidden_size)
320                self.position_embeddings = nn.Embedding(max_seqlen, hidden_size)
321                self.layer_norm = nn.LayerNorm(hidden_size)
322
323            def forward(self, input_ids, labels, position_ids):
324                emb = self.word_embeddings(input_ids) + self.position_embeddings(
325                    position_ids
326                )
327                return self.layer_norm(emb)
328
329        m = Model()
330
331        @torch.compile
332        def f(*args):
333            m(*args).sum().backward()
334
335        input_ids = torch.randint(0, V, (batch_size, max_seqlen))
336        labels = torch.randint(0, V, (batch_size, max_seqlen))
337        position_ids = torch.arange(max_seqlen)[None, :]
338        # Make sure this line does not raise exceptions. If we miss
339        # fake dependencies after loop reordering, we may get exception that
340        # some buffer is used before being defined.
341        f(input_ids, labels, position_ids)
342
343    def test_different_broadcast_shapes(self):
344        def f(x, y, c):
345            return x + c, y + c
346
347        x = torch.randn(4, 256, 1024)
348        y = torch.randn(2, 512, 1024)
349        c = torch.randn(1024)
350        self.do_acc_test(f, x, y, c)
351
352        # The two kernels are not fused due to c is broadcasted
353        self.assertEqual(2, metrics.generated_kernel_count)
354
355    def test_view(self):
356        """
357        Passing this test relies that we compare normalized MemoryDep.
358        Normlaization here means merging contiguous loops.
359
360        To make loop reordering work, we don't merge loops when creating
361        SchedulerNode. Thus we need explicitly normalize MemoryDep when
362        we check if two MemeoryDep matches.
363        """
364
365        def f(x):
366            y = x.sin()
367            x = realize(x.view(10, 10))
368            return x, y
369
370        x = torch.randn(100)
371        self.do_acc_test(f, x)
372        self.assertEqual(1, metrics.generated_kernel_count)
373
374    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
375    def test_fp8_cast_and_t(self):
376        """
377        This test repros the not able to fuses issue in
378        https://github.com/pytorch/pytorch/issues/130015
379        for fp8 cast and transpose
380        """
381
382        def f(x, scale):
383            x = x * scale
384            x = x.clamp(-1 * E4M3_MAX_POS, E4M3_MAX_POS)
385            x = x.to(torch.float8_e4m3fn)
386            x_t = x.t().contiguous().t()
387            return x, x_t
388
389        x = torch.randn(4096, 4096, dtype=torch.bfloat16)
390        scale = torch.Tensor([10.0]).cuda()
391        E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
392
393        self.do_acc_test(f, x, scale)
394        self.assertEqual(1, metrics.generated_kernel_count)
395
396
397if __name__ == "__main__":
398    if HAS_CUDA:
399        run_tests()
400