xref: /aosp_15_r20/external/pytorch/test/inductor/test_dependencies.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import contextlib
3
4import torch
5from torch._inductor.dependencies import MemoryDep
6from torch._inductor.graph import GraphLowering
7from torch._inductor.ir import Buffer, FixedLayout, Pointwise
8from torch._inductor.test_case import TestCase as InductorTestCase
9from torch._inductor.utils import sympy_index_symbol
10from torch._inductor.virtualized import ops, V
11from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
12
13
14class TestDependencies(InductorTestCase):
15    def _create_buffer(self, name, shape, dtype=torch.float32):
16        return Buffer(name, FixedLayout(torch.device(GPU_TYPE), dtype, shape))
17
18    def setUp(self):
19        super().setUp()
20
21        class DummyModule(torch.nn.Module):
22            def forward(self, x):
23                return x * 2
24
25        self._gm = torch.fx.symbolic_trace(DummyModule())
26        self._graph = GraphLowering(self._gm)
27
28        self._stack = contextlib.ExitStack()
29        self._stack.enter_context(V.set_graph_handler(self._graph))
30
31    def tearDown(self):
32        self._stack.close()
33        super().tearDown()
34
35    def test_bucketize_dependencies(self):
36        offsets = self._create_buffer("offsets", (1025,), torch.int32)
37
38        def inner_fn(index):
39            idx = index[0]
40            return ops.bucketize(
41                values=idx,
42                offsets_name=offsets.get_name(),
43                offsets_size=offsets.get_size()[0],
44                indexing_dtype=torch.int32,
45                right=True,
46            )
47
48        pointwise = Pointwise.create(
49            device=torch.device(GPU_TYPE),
50            dtype=torch.int32,
51            inner_fn=inner_fn,
52            ranges=[1024 * 4],
53        )
54
55        self.assertEqual(len(pointwise.get_reads()), 1)
56
57    def test_get_offset(self):
58        x = sympy_index_symbol("x")
59        y = sympy_index_symbol("y")
60        var_ranges = {
61            x: 1024,
62            y: 2048,
63        }
64        dep1 = MemoryDep(
65            "dep1",
66            x * 2048 + y,
67            list(var_ranges.keys()),
68            list(var_ranges.values()),
69        )
70        dep2 = MemoryDep(
71            "dep2",
72            x * 2048 + y + 1024,
73            list(var_ranges.keys()),
74            list(var_ranges.values()),
75        )
76        self.assertEqual(dep1.get_offset(), 0)
77        self.assertEqual(dep2.get_offset(), 1024)
78
79    def test_normalize_with_stride_order_equal(self):
80        x = sympy_index_symbol("x")
81        y = sympy_index_symbol("y")
82        var_ranges = {
83            x: 1024,
84            y: 2048,
85        }
86
87        loop_order1 = MemoryDep(
88            "access_the_same_buffer",
89            x * 2048 + y,
90            [x, y],
91            [1024, 2048],
92        )
93        loop_order2 = MemoryDep(
94            "access_the_same_buffer",
95            x * 2048 + y,
96            [y, x],
97            [2048, 1024],
98        )
99        self.assertTrue(loop_order1 != loop_order2)
100        normalized_loop_order1 = loop_order1.normalize_with_stride_order()
101        normalized_loop_order2 = loop_order2.normalize_with_stride_order()
102        self.assertTrue(normalized_loop_order1 == normalized_loop_order2)
103
104    def test_normalize_with_stride_order_unequal(self):
105        x = sympy_index_symbol("x")
106        y = sympy_index_symbol("y")
107        var_ranges = {
108            x: 1024,
109            y: 2048,
110        }
111
112        loop_order1 = MemoryDep(
113            "access_the_same_buffer",
114            x * 2048 + y,
115            [x, y],
116            [1024, 2048],
117        )
118        loop_order2 = MemoryDep(
119            "access_the_same_buffer",
120            x * 2048 + y + 5,
121            [y, x],
122            [2048, 1024],
123        )
124        self.assertTrue(loop_order1 != loop_order2)
125        normalized_loop_order1 = loop_order1.normalize_with_stride_order()
126        normalized_loop_order2 = loop_order2.normalize_with_stride_order()
127        # unequal due to different offset
128        self.assertTrue(normalized_loop_order1 != normalized_loop_order2)
129
130
131if __name__ == "__main__":
132    from torch._inductor.test_case import run_tests
133
134    if HAS_CPU and HAS_GPU:
135        run_tests("sympy")
136