1# Owner(s): ["module: inductor"] 2 3import contextlib 4import unittest 5 6import numpy as np 7 8import torch 9from torch import nn 10from torch._dynamo.testing import rand_strided 11from torch._dynamo.utils import same 12from torch._inductor import config as inductor_config, ir, metrics 13from torch._inductor.codegen.triton import TritonScheduling 14from torch._inductor.graph import GraphLowering 15from torch._inductor.scheduler import SchedulerNode 16from torch._inductor.test_case import run_tests, TestCase 17from torch._inductor.test_operators import realize 18from torch._inductor.utils import sympy_index_symbol 19from torch._inductor.virtualized import ops, V 20from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 21from torch.testing._internal.inductor_utils import HAS_CUDA 22from torch.utils._pytree import tree_map 23from torch.utils._sympy.functions import ModularIndexing 24 25 26if HAS_CUDA: 27 torch.set_default_device("cuda") 28 29 30class MockScheduler: 31 available_buffer_names = () 32 33 @staticmethod 34 def get_backend(cls, *args): 35 return TritonScheduling(cls) 36 37 38@inductor_config.patch(loop_ordering_after_fusion=True) 39class ImplDetailTest(TestCase): 40 _exit_stack = None 41 42 @classmethod 43 def setUpClass(cls): 44 super().setUpClass() 45 46 gm = torch.fx.symbolic_trace(lambda: 0) 47 graph = GraphLowering(gm) 48 graph.scheduler = MockScheduler 49 cls._exit_stack = contextlib.ExitStack() 50 cls._exit_stack.enter_context(V.set_graph_handler(graph)) 51 52 @classmethod 53 def tearDownClass(cls): 54 super().tearDownClass() 55 cls._exit_stack.close() 56 57 @staticmethod 58 def _get_snode_body_sym_prefix(snode): 59 body = snode._body 60 prefix = "" 61 62 for var in body.var_ranges: 63 prefix = str(var)[0] 64 break 65 66 assert prefix 67 return prefix 68 69 @staticmethod 70 def _create_computed_buffer_ax2(sizes=(32, 64), strides=None): 71 """ 72 Create a ComputedBuffer for 'a x 2' 73 """ 74 if strides is None: 75 strides = ir.FlexibleLayout.contiguous_strides(sizes) 76 77 box_a = ir.TensorBox.create( 78 ir.Buffer( 79 "a", ir.FixedLayout(torch.device("cuda"), torch.float32, sizes, strides) 80 ) 81 ) 82 box_a_loader = box_a.make_loader() 83 84 def inner_fn(index): 85 return box_a_loader(index) * 2 86 87 buf = ir.Pointwise.create( 88 device=box_a.get_device(), 89 dtype=box_a.get_dtype(), 90 inner_fn=inner_fn, 91 ranges=box_a.get_size(), 92 ) 93 buf.realize() 94 computed_buf = buf.data.data 95 computed_buf.decide_layout() 96 return computed_buf 97 98 def test_reorder_twice(self): 99 """ 100 This may happen in practice if we pick a order when fusing A and B. 101 Then we pick another order for AB when we fusion C into it. 102 103 E.g. happens for BertForMaskedLM. 104 """ 105 106 buf = self._create_computed_buffer_ax2() 107 snode = SchedulerNode(V.graph.scheduler, buf) 108 snode.apply_new_loop_order([1, 0]) 109 prefix1 = self._get_snode_body_sym_prefix(snode) 110 self.assertTrue(prefix1 == "z") 111 snode.apply_new_loop_order([1, 0]) 112 prefix2 = self._get_snode_body_sym_prefix(snode) 113 self.assertTrue(prefix2 == "z") 114 115 def test_reorder_and_merge_loops(self): 116 sizes = (1024, 2048) 117 strides = (1, 1024) 118 buf = self._create_computed_buffer_ax2(sizes, strides) 119 old_sizes, old_body = buf.simplify_and_reorder() 120 121 # Make sure loop reordering happens here 122 self.assertTrue(tuple(old_sizes[0]) == tuple(reversed(sizes)), f"{old_sizes=}") 123 new_body = old_body.merge_loops() 124 new_sizes = new_body.sizes 125 self.assertTrue(tuple(new_sizes[0]) == (np.prod(sizes),), f"{new_sizes=}") 126 127 def test_reorder_modular_indexing(self): 128 """ 129 There was a bug that we wrongly map i0 to the dimension with size 49 130 when reordering the loop and cause ModularIndexing get optimized away 131 as an no-op. 132 """ 133 134 def _create_computed_buffer(): 135 def inner_fn(index): 136 i0, i1, i2, i3 = index 137 return ops.load( 138 "primal", i3 + 49 * i2 + 2401 * ModularIndexing(i0, 1, 64) 139 ) 140 141 buf = ir.Pointwise.create( 142 device=torch.device("cuda"), 143 dtype=torch.float32, 144 inner_fn=inner_fn, 145 ranges=[128, 4, 49, 49], 146 ) 147 buf.realize() 148 cbuf = buf.data.data 149 cbuf.decide_layout() 150 return cbuf 151 152 buf = _create_computed_buffer() 153 _, body = buf.simplify_and_reorder() 154 new_body = body.reorder_iter_loops([1, 2, 3, 0]) 155 156 z0, z1, z2, z3 = (sympy_index_symbol(f"z{i}") for i in range(4)) 157 self.assertEqual(body.var_ranges, {z0: 128, z1: 4, z2: 49, z3: 49}) 158 self.assertEqual( 159 body.indexing_exprs["index0"], 160 z3 + 49 * z2 + 2401 * ModularIndexing(z0, 1, 64), 161 ) 162 self.assertEqual(new_body.var_ranges, {z0: 4, z1: 49, z2: 49, z3: 128}) 163 self.assertEqual( 164 new_body.indexing_exprs["index0"], 165 z2 + 49 * z1 + 2401 * ModularIndexing(z3, 1, 64), 166 ) 167 168 169@inductor_config.patch( 170 { 171 "benchmark_kernel": True, 172 "loop_ordering_after_fusion": True, 173 "triton.unique_kernel_names": True, 174 } 175) 176class LoopOrderingTest(TestCase): 177 def do_acc_test(self, f, *args, cast_fp8=True): 178 expect = f(*args) 179 actual = torch.compile(f)(*args) 180 181 if cast_fp8: 182 183 def _cast(x): 184 if isinstance(x, torch.Tensor) and x.dtype in ( 185 torch.float8_e5m2, 186 torch.float8_e4m3fn, 187 ): 188 return x.to(torch.float32) 189 return x 190 191 # Wordaround the issue that call allclose on fp8 tensor triggers error 192 # RuntimeError: "mul_cuda" not implemented for 'Float8_e4m3fn' 193 expect = tree_map(_cast, expect) 194 actual = tree_map(_cast, actual) 195 self.assertTrue(same(expect, actual, tol=1e-3)) 196 197 def setUp(self): 198 super().setUp() 199 metrics.reset() 200 201 def test_for_reordering_reindex(self): 202 """ 203 ComputedBuffer.iter_reoredering_reindex can cause some fusion 204 opportunitiies being skipped. 205 206 In this test case, Inductor generates 2 triton kernels before. 207 By removing ComputedBuffer.iter_reoredering_reindex, we can fuse those 208 two kernels into a single one. 209 """ 210 211 def f(x, y): 212 """ 213 Add a matmul since inductor may force layout for output. 214 """ 215 return (x.sum(dim=-1) + 1) @ y 216 217 A, B = 20, 30 218 # Make the first 2 dimension not able to merge on purpose so that 219 # ComputedBuffer.iter_reoredering_reindex will be updated. 220 x = rand_strided([A, A, B], [B, B * A + 300, 1], device="cuda") 221 y = torch.randn(A, A) 222 223 self.do_acc_test(f, x, y) 224 self.assertEqual(1, metrics.generated_kernel_count) 225 expected_num_bytes = 0 226 expected_num_bytes += A * A * B + A * A # for the fused reduction 227 expected_num_bytes += A * A * 3 # for matmul 228 expected_num_bytes *= x.itemsize 229 self.assertEqual(expected_num_bytes, metrics.num_bytes_accessed) 230 231 def test_apbt_realize(self): 232 M = 1024 233 N = 2048 234 235 def f(x, y): 236 """ 237 There will be 2 kernels being generated without loop ordering after fusion: 238 https://gist.github.com/shunting314/44df83f71de2c110232c50ac6638ed69 239 """ 240 x = realize(x * 2) 241 y = realize(y * 3) 242 return x + y 243 244 x = torch.randn(M, N) 245 y = torch.randn(N, M).t() 246 247 self.do_acc_test(f, x, y) 248 self.assertEqual(1, metrics.generated_kernel_count) 249 250 def test_sum_and_t(self): 251 N = 1024 252 253 def f(x): 254 return x.sum(dim=-1), x.t().contiguous() 255 256 x = torch.randn(N, N * 2) 257 self.do_acc_test(f, x) 258 self.assertEqual(1, metrics.generated_kernel_count) 259 260 def test_pw_outer_red(self): 261 def f(x): 262 x = realize(x + 1) 263 return x.sum(dim=[0, 1]) 264 265 # make the first 2 dimension small so we don't split the reduction 266 x = torch.randn(2, 4, 512) 267 self.do_acc_test(f, x) 268 self.assertEqual(1, metrics.generated_kernel_count) 269 270 def test_pw_outer_red_2(self): 271 """ 272 The pointwise kernel is a fused kernel 273 """ 274 275 def f(x): 276 x = realize(x + 1) 277 x = realize(x - 2) 278 x = realize(x * 3) 279 return x.sum(dim=[0, 1]) 280 281 # make the first 2 dimension small so we don't split the reduction 282 x = torch.randn(2, 4, 512) 283 self.do_acc_test(f, x) 284 self.assertEqual(1, metrics.generated_kernel_count) 285 286 @inductor_config.patch(split_reductions=False) 287 def test_different_reduction_order(self): 288 """ 289 We should not reorder loops in this case. Since reordering loops does 290 not help! 291 """ 292 293 def f(x): 294 return x.sum(dim=0), x.sum(dim=1) 295 296 x = torch.randn(1024, 2048) 297 self.do_acc_test(f, x) 298 self.assertEqual(2, metrics.generated_kernel_count) 299 self.assertEqual(0, metrics.num_loop_reordering) 300 301 def test_keep_fake_dep(self): 302 """ 303 In this model, there are fake dependencies (StarDep) between Scatter 304 and a following mutation kernel that computes the gradients of 305 the embedding tables. 306 307 When we do loop reordering for the mutation kernel, we re-analyze 308 the node's dependencies. But the analysis result does not contains 309 those fake dependencies. Have to add them back manually. 310 """ 311 V = 2048 312 hidden_size = 64 313 max_seqlen = 512 314 batch_size = 8 315 316 class Model(nn.Module): 317 def __init__(self): 318 super().__init__() 319 self.word_embeddings = nn.Embedding(V, hidden_size) 320 self.position_embeddings = nn.Embedding(max_seqlen, hidden_size) 321 self.layer_norm = nn.LayerNorm(hidden_size) 322 323 def forward(self, input_ids, labels, position_ids): 324 emb = self.word_embeddings(input_ids) + self.position_embeddings( 325 position_ids 326 ) 327 return self.layer_norm(emb) 328 329 m = Model() 330 331 @torch.compile 332 def f(*args): 333 m(*args).sum().backward() 334 335 input_ids = torch.randint(0, V, (batch_size, max_seqlen)) 336 labels = torch.randint(0, V, (batch_size, max_seqlen)) 337 position_ids = torch.arange(max_seqlen)[None, :] 338 # Make sure this line does not raise exceptions. If we miss 339 # fake dependencies after loop reordering, we may get exception that 340 # some buffer is used before being defined. 341 f(input_ids, labels, position_ids) 342 343 def test_different_broadcast_shapes(self): 344 def f(x, y, c): 345 return x + c, y + c 346 347 x = torch.randn(4, 256, 1024) 348 y = torch.randn(2, 512, 1024) 349 c = torch.randn(1024) 350 self.do_acc_test(f, x, y, c) 351 352 # The two kernels are not fused due to c is broadcasted 353 self.assertEqual(2, metrics.generated_kernel_count) 354 355 def test_view(self): 356 """ 357 Passing this test relies that we compare normalized MemoryDep. 358 Normlaization here means merging contiguous loops. 359 360 To make loop reordering work, we don't merge loops when creating 361 SchedulerNode. Thus we need explicitly normalize MemoryDep when 362 we check if two MemeoryDep matches. 363 """ 364 365 def f(x): 366 y = x.sin() 367 x = realize(x.view(10, 10)) 368 return x, y 369 370 x = torch.randn(100) 371 self.do_acc_test(f, x) 372 self.assertEqual(1, metrics.generated_kernel_count) 373 374 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+") 375 def test_fp8_cast_and_t(self): 376 """ 377 This test repros the not able to fuses issue in 378 https://github.com/pytorch/pytorch/issues/130015 379 for fp8 cast and transpose 380 """ 381 382 def f(x, scale): 383 x = x * scale 384 x = x.clamp(-1 * E4M3_MAX_POS, E4M3_MAX_POS) 385 x = x.to(torch.float8_e4m3fn) 386 x_t = x.t().contiguous().t() 387 return x, x_t 388 389 x = torch.randn(4096, 4096, dtype=torch.bfloat16) 390 scale = torch.Tensor([10.0]).cuda() 391 E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max 392 393 self.do_acc_test(f, x, scale) 394 self.assertEqual(1, metrics.generated_kernel_count) 395 396 397if __name__ == "__main__": 398 if HAS_CUDA: 399 run_tests() 400