1# Owner(s): ["module: inductor"] 2import contextlib 3 4import torch 5from torch._inductor.dependencies import MemoryDep 6from torch._inductor.graph import GraphLowering 7from torch._inductor.ir import Buffer, FixedLayout, Pointwise 8from torch._inductor.test_case import TestCase as InductorTestCase 9from torch._inductor.utils import sympy_index_symbol 10from torch._inductor.virtualized import ops, V 11from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU 12 13 14class TestDependencies(InductorTestCase): 15 def _create_buffer(self, name, shape, dtype=torch.float32): 16 return Buffer(name, FixedLayout(torch.device(GPU_TYPE), dtype, shape)) 17 18 def setUp(self): 19 super().setUp() 20 21 class DummyModule(torch.nn.Module): 22 def forward(self, x): 23 return x * 2 24 25 self._gm = torch.fx.symbolic_trace(DummyModule()) 26 self._graph = GraphLowering(self._gm) 27 28 self._stack = contextlib.ExitStack() 29 self._stack.enter_context(V.set_graph_handler(self._graph)) 30 31 def tearDown(self): 32 self._stack.close() 33 super().tearDown() 34 35 def test_bucketize_dependencies(self): 36 offsets = self._create_buffer("offsets", (1025,), torch.int32) 37 38 def inner_fn(index): 39 idx = index[0] 40 return ops.bucketize( 41 values=idx, 42 offsets_name=offsets.get_name(), 43 offsets_size=offsets.get_size()[0], 44 indexing_dtype=torch.int32, 45 right=True, 46 ) 47 48 pointwise = Pointwise.create( 49 device=torch.device(GPU_TYPE), 50 dtype=torch.int32, 51 inner_fn=inner_fn, 52 ranges=[1024 * 4], 53 ) 54 55 self.assertEqual(len(pointwise.get_reads()), 1) 56 57 def test_get_offset(self): 58 x = sympy_index_symbol("x") 59 y = sympy_index_symbol("y") 60 var_ranges = { 61 x: 1024, 62 y: 2048, 63 } 64 dep1 = MemoryDep( 65 "dep1", 66 x * 2048 + y, 67 list(var_ranges.keys()), 68 list(var_ranges.values()), 69 ) 70 dep2 = MemoryDep( 71 "dep2", 72 x * 2048 + y + 1024, 73 list(var_ranges.keys()), 74 list(var_ranges.values()), 75 ) 76 self.assertEqual(dep1.get_offset(), 0) 77 self.assertEqual(dep2.get_offset(), 1024) 78 79 def test_normalize_with_stride_order_equal(self): 80 x = sympy_index_symbol("x") 81 y = sympy_index_symbol("y") 82 var_ranges = { 83 x: 1024, 84 y: 2048, 85 } 86 87 loop_order1 = MemoryDep( 88 "access_the_same_buffer", 89 x * 2048 + y, 90 [x, y], 91 [1024, 2048], 92 ) 93 loop_order2 = MemoryDep( 94 "access_the_same_buffer", 95 x * 2048 + y, 96 [y, x], 97 [2048, 1024], 98 ) 99 self.assertTrue(loop_order1 != loop_order2) 100 normalized_loop_order1 = loop_order1.normalize_with_stride_order() 101 normalized_loop_order2 = loop_order2.normalize_with_stride_order() 102 self.assertTrue(normalized_loop_order1 == normalized_loop_order2) 103 104 def test_normalize_with_stride_order_unequal(self): 105 x = sympy_index_symbol("x") 106 y = sympy_index_symbol("y") 107 var_ranges = { 108 x: 1024, 109 y: 2048, 110 } 111 112 loop_order1 = MemoryDep( 113 "access_the_same_buffer", 114 x * 2048 + y, 115 [x, y], 116 [1024, 2048], 117 ) 118 loop_order2 = MemoryDep( 119 "access_the_same_buffer", 120 x * 2048 + y + 5, 121 [y, x], 122 [2048, 1024], 123 ) 124 self.assertTrue(loop_order1 != loop_order2) 125 normalized_loop_order1 = loop_order1.normalize_with_stride_order() 126 normalized_loop_order2 = loop_order2.normalize_with_stride_order() 127 # unequal due to different offset 128 self.assertTrue(normalized_loop_order1 != normalized_loop_order2) 129 130 131if __name__ == "__main__": 132 from torch._inductor.test_case import run_tests 133 134 if HAS_CPU and HAS_GPU: 135 run_tests("sympy") 136