1# Owner(s): ["module: inductor"] 2import contextlib 3import re 4from unittest.mock import patch 5 6import functorch 7import torch 8import torch._inductor.config as config 9import torch.autograd 10from torch._inductor import metrics 11from torch._inductor.compile_fx import compile_fx, compile_fx_inner 12from torch._inductor.test_case import TestCase as InductorTestCase 13from torch._inductor.utils import run_and_get_code 14 15######################## 16# Explanation of Tests # 17######################## 18# These tests are all testing *memory accesses* of TorchInductor. 19# They are intended to be deterministic performance tests. 20# The expect tests are all measuring the number of memory bytes read/written by 21# the code that Inductor has generated 22# 23# If the test is failing because the number became smaller, feel free to lower it. 24# On the other hand, if the test is failing because the number became larger, 25# that means that your change is leading to *more* memory accesses on this test. 26# 27# That may still be aceeptable, but be aware that you are likely lowering 28# performance for that setting. 29# 30# Defines all the kernels for tests 31from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda 32 33 34if HAS_CUDA: 35 import triton 36 import triton.language as tl 37 38 from torch.testing._internal.triton_utils import add_kernel 39 40aten = torch.ops.aten 41 42 43def compile_but_use_eager(gm, example_inputs): 44 def inner_compile(gm, *args, **kwargs): 45 compile_fx_inner(gm, *args, **kwargs) 46 return gm 47 48 return compile_fx(gm, example_inputs, inner_compile=inner_compile) 49 50 51def count_numel(f, *args): 52 """ 53 Assumes all inputs are fp32 54 """ 55 metrics.reset() 56 torch.compile(f, backend=compile_but_use_eager)(*args) 57 print(metrics.nodes_num_elem) 58 return str(metrics.num_bytes_accessed // 4) 59 60 61def count_numel_train(f, *args): 62 """ 63 Assumes all inputs are fp32 64 """ 65 metrics.reset() 66 67 f = torch.compile(f, backend=compile_but_use_eager) 68 out = f(*args) 69 res = 0 70 for o in out: 71 res += o.mean() 72 res.backward() 73 print(metrics.nodes_num_elem) 74 return str(metrics.num_bytes_accessed // 4) 75 76 77DEVICE = "cuda" 78 79 80def T(*size, dtype=torch.float32, device=DEVICE, grad=False): 81 return torch.randn(size, dtype=dtype, device=device, requires_grad=grad) 82 83 84def TI(*size, mx=10, dtype=torch.int32, device=DEVICE): 85 return torch.randint(0, mx, size, dtype=dtype, device=device) 86 87 88class TestCase(InductorTestCase): 89 device = DEVICE 90 91 92class NumBytesMetricTests(TestCase): 93 """ 94 Primarily used for sanity testing that the num_bytes_accessed metrics is correct. 95 """ 96 97 def test_pointwise(self): 98 def f(x): 99 return x.cos() 100 101 inp = (T(10),) 102 self.assertExpectedInline(count_numel(f, *inp), """20""") 103 104 def f(x, y): 105 return x + y 106 107 inp = (T(10), T(10)) 108 self.assertExpectedInline(count_numel(f, *inp), """30""") 109 110 def f(x, y): 111 return x + y 112 113 inp = (T(10, 10), T(10)) 114 self.assertExpectedInline(count_numel(f, *inp), """210""") 115 116 def f(x): 117 return x + x 118 119 inp = (T(10),) 120 self.assertExpectedInline(count_numel(f, *inp), """20""") 121 122 def f(x): 123 return x + x.t() 124 125 inp = (T(10, 10),) 126 self.assertExpectedInline(count_numel(f, *inp), """200""") 127 128 def f(a, b, c): 129 return a.cos(), b.sin() + c.sin() 130 131 inp = (T(10), T(10), T(10)) 132 self.assertExpectedInline(count_numel(f, *inp), """50""") 133 134 def test_reduction(self): 135 def f(x): 136 return x.sum(dim=1) 137 138 inp = (T(10, 10),) 139 self.assertExpectedInline(count_numel(f, *inp), """110""") 140 141 def f(x): 142 return x.sum(dim=0) 143 144 inp = (T(10, 10),) 145 self.assertExpectedInline(count_numel(f, *inp), """110""") 146 147 def test_extern(self): 148 def f(x): 149 return torch.mm(x, x) 150 151 inp = (T(10, 10),) 152 self.assertExpectedInline(count_numel(f, *inp), """200""") 153 154 def f(a, b): 155 return torch.mm(a, b) 156 157 inp = (T(10, 10), T(10, 10)) 158 self.assertExpectedInline(count_numel(f, *inp), """300""") 159 160 def f(x): 161 x = x.cos() 162 x = torch.mm(x, x) 163 x = x.cos() 164 return x 165 166 inp = (T(10, 10),) 167 self.assertExpectedInline(count_numel(f, *inp), """600""") 168 169 def f(x): 170 a = x.cos() 171 b = x.sin() 172 x = torch.mm(a, b) 173 return x 174 175 inp = (T(10, 10),) 176 self.assertExpectedInline(count_numel(f, *inp), """600""") 177 178 def test_cat(self): 179 def f(a, b): 180 return torch.cat([a.sin(), b.sin()]) 181 182 inp = (T(10), T(10)) 183 self.assertExpectedInline(count_numel(f, *inp), """40""") 184 185 def f(a, b): 186 return torch.cat([a, b]) 187 188 inp = (T(10), T(10)) 189 self.assertExpectedInline(count_numel(f, *inp), """40""") 190 191 def f(a, b): 192 return torch.cat([a.cos(), b]) 193 194 inp = (T(10), T(10)) 195 self.assertExpectedInline(count_numel(f, *inp), """40""") 196 197 def f(a): 198 return torch.cat([a.cos(), a.sin()]) 199 200 inp = (T(10),) 201 self.assertExpectedInline(count_numel(f, *inp), """30""") 202 203 def f(a, b): 204 return torch.cat([torch.mm(a, a), b.sin()]) 205 206 inp = (T(10, 10), T(10, 10)) 207 self.assertExpectedInline(count_numel(f, *inp), """400""") 208 209 def f(a, b, c): 210 return torch.cat((a + 1, b + 2, c + 3)) + 10 211 212 inp = (T(10, 10), T(10, 10), T(10, 10)) 213 self.assertExpectedInline(count_numel(f, *inp), """600""") 214 215 def f(a, b, c, d, e): 216 return torch.cat((a + 1, b + 2, c + 3, d + 4, e + 5)) + 10 217 218 inp = [T(10, 10) for _ in range(5)] 219 self.assertExpectedInline(count_numel(f, *inp), """1000""") 220 221 def f(a, b): 222 return torch.cat([a.sum(dim=0), b.sum(dim=0)]) + 10 223 224 inp = [T(10, 10, 10), T(10, 10, 10)] 225 self.assertExpectedInline(count_numel(f, *inp), """2600""") 226 227 def test_cat_pointwise(self): 228 def f(a, b): 229 return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)]) 230 231 inp = (T(10, 10), T(10, 10)) 232 self.assertExpectedInline(count_numel(f, *inp), """400""") 233 234 def f(a, b): 235 return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)]).cos() 236 237 inp = (T(10, 10), T(10, 10)) 238 self.assertExpectedInline(count_numel(f, *inp), """680""") 239 240 # Should turn into pointwise even if only some of inputs are pointwise. 241 def f(a, b): 242 out = torch.cat([a.cos(), torch.mm(b, b)]) 243 return out.cos() 244 245 inp = (T(10, 10), T(10, 10)) 246 self.assertExpectedInline(count_numel(f, *inp), """600""") 247 248 # Should not turn into pointwise if all inputs are not pointwise 249 def f(a, b): 250 out = torch.cat([torch.mm(a, a), torch.mm(b, b)]) 251 return out.cos() 252 253 inp = (T(10, 10), T(10, 10)) 254 self.assertExpectedInline(count_numel(f, *inp), """800""") 255 256 def f(a, b): 257 out = torch.cat([a, b]) 258 return out.cos() 259 260 inp = (T(10, 10), T(10, 10)) 261 self.assertExpectedInline(count_numel(f, *inp), """400""") 262 263 def f(a, b): 264 b = b.cos() 265 return torch.cat([a, b]) 266 267 inp = (T(10, 10), T(10, 10)) 268 self.assertExpectedInline(count_numel(f, *inp), """400""") 269 270 def f(a, b): 271 a = a @ a 272 return torch.constant_pad_nd(torch.cat([a, b]), [2, 2], 0.5) 273 274 inp = (T(10, 10), T(10, 10)) 275 self.assertExpectedInline(count_numel(f, *inp), """680""") 276 277 @patch.object(config, "split_cat_fx_passes", False) 278 @patch.object( 279 config, 280 "pre_grad_fusion_options", 281 { 282 "batch_linear": {}, 283 "batch_linear_lhs": {}, 284 "batch_layernorm": {}, 285 "batch_tanh": {}, 286 "batch_relu": {}, 287 "batch_sigmoid": {}, 288 }, 289 ) 290 @patch.object(config, "post_grad_fusion_options", {}) 291 def test_cat_pointwise_many_complex_inputs(self): 292 def f(*inputs): 293 input = [torch.nn.functional.gelu(val) for val in inputs] 294 return torch.cat(input) + 10 295 296 inp = (T(10, 10) for _ in range(16)) 297 self.assertExpectedInline(count_numel(f, *inp), """6400""") 298 299 @patch.object(config, "split_cat_fx_passes", False) 300 @patch.object( 301 config, 302 "pre_grad_fusion_options", 303 { 304 "batch_linear": {}, 305 "batch_linear_lhs": {}, 306 "batch_layernorm": {}, 307 "batch_tanh": {}, 308 "batch_relu": {}, 309 "batch_sigmoid": {}, 310 }, 311 ) 312 @patch.object(config, "post_grad_fusion_options", {}) 313 def test_cat_pointwise_many_simple_inputs(self): 314 def f(*inputs): 315 input = [torch.nn.functional.relu(val) for val in inputs] 316 return torch.cat(input) + 10 317 318 inp = (T(10, 10) for _ in range(16)) 319 self.assertExpectedInline(count_numel(f, *inp), """9600""") 320 321 @patch.object(config, "max_pointwise_cat_inputs", 0) 322 def test_cat_pointwise_config_option(self): 323 def f(a, b): 324 return torch.cat([a + 1, b + 2]) + 3 325 326 inp = (T(10, 10), T(10, 10)) 327 self.assertExpectedInline(count_numel(f, *inp), """400""") 328 329 def test_index(self): 330 def f(a, b): 331 return a[b] 332 333 inp = (T(10), TI(10, mx=10)) 334 self.assertExpectedInline(count_numel(f, *inp), """30""") 335 336 337class FusionTests(TestCase): 338 """ 339 Tests that things can be fused into a single kernel 340 """ 341 342 def test_horizontal_reduction_pointwise(self): 343 def f(a): 344 b = a.sum(dim=1) 345 c = a.cos() 346 return b, c 347 348 inp = (T(10, 10),) 349 self.assertExpectedInline(count_numel(f, *inp), """210""") 350 351 def test_horizontal_reduction_reduction(self): 352 def f(a): 353 b = a.sum(dim=1) 354 c = a.amax(dim=1) 355 return b, c 356 357 inp = (T(10, 10),) 358 self.assertExpectedInline(count_numel(f, *inp), """120""") 359 360 def test_horizontal_reduction_pointwise2(self): 361 def f(a, b): 362 c = a.sum(dim=1) 363 b = b.cos() 364 return b + c 365 366 inp = (T(10, 10), T(10)) 367 self.assertExpectedInline(count_numel(f, *inp), """120""") 368 369 def test_horizontal_reduction_outer_pointwise(self): 370 def f(a, b): 371 c = a.sum(dim=0) 372 b = b.cos() 373 return b + c 374 375 inp = (T(10, 10), T(10)) 376 self.assertExpectedInline(count_numel(f, *inp), """120""") 377 378 def test_horizontal_sum_pw_broadcast(self): 379 def f(a, b): 380 a = a.sum(dim=1, keepdim=True) 381 b = b.cos() 382 return a * b 383 384 inp = (T(10, 10), T(10)) 385 self.assertExpectedInline(count_numel(f, *inp), """210""") 386 387 def test_vertical_sum_pw(self): 388 def f(a): 389 a = a.cos() 390 a = a.sum(dim=1) 391 return a.cos() 392 393 inp = (T(10, 10),) 394 self.assertExpectedInline(count_numel(f, *inp), """110""") 395 396 def test_norm_chain(self): 397 def f(a): 398 b = a.sum(dim=1, keepdim=True) 399 a = a * b 400 b = a.sum(dim=1, keepdim=True) 401 a = a * b 402 b = a.sum(dim=1, keepdim=True) 403 a = a * b 404 return a 405 406 inp = (T(10, 10),) 407 self.assertExpectedInline(count_numel(f, *inp), """200""") 408 409 def test_softmax_inner(self): 410 def f(a): 411 return torch.softmax(a, dim=1) 412 413 inp = (T(10, 10),) 414 self.assertExpectedInline(count_numel(f, *inp), """200""") 415 416 def test_layer_norm(self): 417 # TODO: Suboptimal! We shouldn't need to save normalization stats. 418 mod = torch.nn.LayerNorm(10, device=self.device) 419 420 def f(x): 421 return mod(x) 422 423 inp = (T(10, 10),) 424 with torch.no_grad(): 425 self.assertExpectedInline(count_numel(f, *inp), """220""") 426 427 def test_double_softmax(self): 428 def f(x): 429 x = torch.softmax(x, dim=1) 430 x = torch.softmax(x, dim=1) 431 return x 432 433 inp = (T(10, 10),) 434 self.assertExpectedInline(count_numel(f, *inp), """200""") 435 436 def test_softmax_backward(self): 437 def f(grad_out, out): 438 return aten._softmax_backward_data(grad_out, out, 1, torch.float32) 439 440 inp = (T(10, 10), T(10, 10)) 441 self.assertExpectedInline(count_numel(f, *inp), """300""") 442 443 def test_neighbor(self): 444 def f(a, b): 445 return ((a - b) ** 2).sum(dim=-1).amax(dim=1) 446 447 inp = (T(10, 1, 4), T(1, 10, 4)) 448 self.assertExpectedInline(count_numel(f, *inp), """90""") 449 450 def test_factory_reduction(self): 451 def f(): 452 a = torch.ones(10, device=self.device) 453 b = torch.ones(10, 10, device=self.device) 454 return (a + b).sum(dim=-1) 455 456 inp = () 457 self.assertExpectedInline(count_numel(f, *inp), """10""") 458 459 def test_index_pointwise(self): 460 def f(a, b): 461 return a[b].cos() 462 463 inp = (T(10, 10), TI(20, mx=10)) 464 self.assertExpectedInline(count_numel(f, *inp), """320""") 465 466 def test_index_reduction(self): 467 def f(a, b): 468 return a[b].cos().sum(dim=1) 469 470 inp = (T(10, 10), TI(20, mx=10)) 471 self.assertExpectedInline(count_numel(f, *inp), """140""") 472 473 def test_mutation_fusion(self): 474 def f(a, b, c): 475 a0 = a.add(c) 476 b0 = b.add(a0) 477 b.copy_(b0) 478 a.copy_(a0) 479 480 inp = (T(10, 10), T(10, 10), T(10, 10)) 481 self.assertExpectedInline(count_numel(f, *inp), """500""") 482 483 def test_reduction_pointwise_multi_level_reduction(self): 484 hidden_size = 4096 485 layer_norm = torch.nn.LayerNorm(hidden_size).cuda().float() 486 487 @torch.inference_mode() 488 def f(x, scale, amax_keep_dim): 489 x = layer_norm(x.to(dtype=torch.float)) 490 amax = torch.amax(torch.abs(x), keepdim=amax_keep_dim) 491 x_scaled = x * scale 492 y = torch.nn.functional.sigmoid(x_scaled) 493 return (y, amax) 494 495 inp = (T(4, 2048, hidden_size, dtype=torch.float), T(1, dtype=torch.float)) 496 497 # 2 kernels: 498 # kernel 1: (input = X, scale, LN scale, LN bias, output = LN_pointwise(X), first-level amax (split-reduction)) 499 # kernel 2: (input = first-level amax, output = final amax) 500 # scale (1) + X (4*2048*hidden_size) * 2 + LN scale (hidden_size) + LN bias (hidden_size) + amax (4 * 2048 * 2 + 1) 501 expected_numel = ( 502 1 + hidden_size * 2 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1 503 ) 504 self.assertExpectedInline(count_numel(f, *inp, True), str(expected_numel)) 505 self.assertExpectedInline(count_numel(f, *inp, False), str(expected_numel)) 506 507 def test_pointwise_multi_level_reduction(self): 508 # TODO: this can be optimized by having the first pointwise kernel leveraging block sizes 509 # of the first-level reduction kernel. 510 hidden_size = 4096 511 512 def f(x, scale, amax_keep_dim): 513 x = x * 1.1 514 amax = torch.amax(torch.abs(x), keepdim=amax_keep_dim) 515 x_scaled = x * scale 516 y = torch.nn.functional.sigmoid(x_scaled) 517 return (y, amax) 518 519 inp = (T(4, 2048, hidden_size, dtype=torch.float), T(1, dtype=torch.float)) 520 521 compiled_f = torch.compile(f) 522 compiled_f(*inp, True) 523 524 # 3 kernels: 525 # kernel 1: (input = X, scale, output = pointwise(X)) 526 # kernel 2: (input = X, output = first-level amax) 527 # kernel 3: (input = first-level amax, output = final amax) 528 # scale (1) + X (4*2048*hidden_size) * 3 + amax (num_splits * 2 + 1) 529 # num_splits depends on SM architectures. 530 expected_numel = 1 + 4 * 2048 * hidden_size * 3 + 1 531 actual_numel_amax_keep_dim = count_numel(f, *inp, True) 532 actual_numel_amax_no_keep_dim = count_numel(f, *inp, False) 533 self.assertEqual(actual_numel_amax_keep_dim, actual_numel_amax_no_keep_dim) 534 self.assertGreaterAlmostEqual(actual_numel_amax_keep_dim, str(expected_numel)) 535 536 537class SchedulerFusionTests(TestCase): 538 """ 539 Testing the fusion group creation heuristic (i.e. cases where we can't fuse 540 everything into a single kernel) 541 Disables inductor rematerialization for easier reasoning of tests. 542 """ 543 544 @classmethod 545 def setUpClass(cls): 546 super().setUpClass() 547 cls._stack = contextlib.ExitStack() 548 cls._stack.enter_context(patch.object(config, "realize_opcount_threshold", 0)) 549 550 @classmethod 551 def tearDownClass(cls): 552 cls._stack.close() 553 super().tearDownClass() 554 555 @patch.object(config, "pattern_matcher", False) 556 def test_fusion_choice1(self): 557 # Doesn't matter where we break fusion group here 558 def f(a): 559 c = a.cos() 560 d = torch.mm(c, c) 561 e = c.cos() 562 return d + e 563 564 inp = (T(10, 10),) 565 self.assertExpectedInline(count_numel(f, *inp), """700""") 566 567 @patch.object(config, "pattern_matcher", False) 568 def test_fusion_choice2(self): 569 # We should materialize e (it's smaller!) 570 # [c, e]: 210, [f]: 210, [d]: 200 571 def f(a): 572 c = a.cos() 573 d = torch.mm(c, c) 574 e = c.sum(dim=1) 575 f = d + e 576 return f 577 578 inp = (T(10, 10),) 579 self.assertExpectedInline(count_numel(f, *inp), """620""") 580 581 @patch.object(config, "pattern_matcher", False) 582 def test_fusion_choice3(self): 583 # We should materialize e. 584 # [c, e]: 300, [f]: 300, [d]: 200 585 def f(a): 586 c = a.cos() 587 d = torch.mm(c, c) 588 e = c + a 589 f = d + e 590 return f, e 591 592 inp = (T(10, 10),) 593 self.assertExpectedInline(count_numel(f, *inp), """800""") 594 595 @patch.object(config, "pattern_matcher", False) 596 def test_fusion_choice4_cpu(self): 597 # Fuse nodes with same number of elements and compatible orginal var ranges 598 # [buf0: {d0: 60, d1: 11}, buf1: {d0: 660}] -> buf0_buf1 599 def f(x, w): 600 o1 = x * w 601 output = o1 + 1.0 602 return output 603 604 inp = (T(2, 3, 10, 11, device="cpu"), T(11, device="cpu")) 605 self.assertExpectedInline(count_numel(f, *inp), """1331""") 606 607 # [buf0_buf1: {d0: 60, d1: 11}, buf2: {d0: 660}] -> buf0_buf1_buf2 608 def f(x, w1, w2): 609 o1 = x * w1 610 o2 = x * w2 611 output = o1 + o2 612 return output 613 614 inp = (T(2, 3, 10, 11, device="cpu"), T(11, device="cpu"), T(11, device="cpu")) 615 self.assertExpectedInline(count_numel(f, *inp), """1342""") 616 617 618class TilingTests(TestCase): 619 def test_tiling_simple(self): 620 def f(a, b): 621 return a + b.t() 622 623 inp = (T(10, 10), T(10, 10)) 624 self.assertExpectedInline(count_numel(f, *inp), """300""") 625 626 def f(a, b): 627 return a.t() + b 628 629 inp = (T(10, 10), T(10, 10)) 630 self.assertExpectedInline(count_numel(f, *inp), """300""") 631 632 def test_tiling_three(self): 633 def f(a, b, c): 634 return a + b.permute(1, 2, 0) + c.permute(2, 0, 1) 635 636 inp = (T(10, 10, 10), T(10, 10, 10), T(10, 10, 10)) 637 self.assertExpectedInline(count_numel(f, *inp), """4000""") 638 639 640class MinCutPartitioningTests(TestCase): 641 def test_partitioning_full_remat(self): 642 def f(x): 643 return x.cos().cos().cos() 644 645 inp = (T(10, grad=True),) 646 self.assertExpectedInline(count_numel_train(f, *inp), """50""") 647 648 def test_partitioning_partial_remat(self): 649 def f(a, b, c, d): 650 x = a + b + c + d 651 return x.cos().cos() 652 653 inp = (T(10, grad=True), T(10, grad=True), T(10, grad=True), T(10, grad=True)) 654 self.assertExpectedInline(count_numel_train(f, *inp), """90""") 655 656 def test_partitioning_dtype(self): 657 def f(x): 658 return (x < 0) * x 659 660 inp = (T(100, grad=True),) 661 self.assertExpectedInline(count_numel_train(f, *inp), """450""") 662 663 @patch.object(functorch.compile.config, "max_dist_from_bw", 1000) 664 def test_partitioning_unremat_bw(self): 665 def f(x): 666 return torch.mm(x, x.new_ones(x.shape)).tanh().tanh() 667 668 inp = (T(10, 10, grad=True),) 669 self.assertExpectedInline(count_numel_train(f, *inp), """1300""") 670 671 @patch.object(config, "pattern_matcher", False) 672 def test_partitioning_unremat_bw2(self): 673 def f(a): 674 a = torch.mm(a, a) 675 a = a + 1 676 b = a + 2 677 c = torch.mm(a, b) 678 return c 679 680 inp = (T(10, 10, grad=True),) 681 self.assertExpectedInline(count_numel_train(f, *inp), """2600""") 682 683 def test_partitioning_keops(self): 684 def f(a, b): 685 return (a * b).cos().sum(dim=1) 686 687 inp = (T(20, 1, grad=True), T(1, 20, grad=True)) 688 self.assertExpectedInline(count_numel_train(f, *inp), """220""") 689 690 def test_partitioning_cat(self): 691 def f(a, b): 692 a = torch.tanh(a) 693 return torch.cat([a, b]) 694 695 inp = (T(10, grad=True), T(10, grad=True)) 696 self.assertExpectedInline(count_numel_train(f, *inp), """70""") 697 698 def test_partitioning_with_view(self): 699 class Foo(torch.autograd.Function): 700 @staticmethod 701 def forward(ctx, x): 702 y = x.sin() 703 x = x.cos() 704 x = x.view(10, 10) 705 ctx.save_for_backward(x, y) 706 x = x.cos() 707 return x 708 709 @staticmethod 710 def backward(ctx, gradOut): 711 x, y = ctx.saved_tensors 712 return torch.mm(gradOut, x).view(100) * y 713 714 def f(a): 715 return Foo.apply(a) 716 717 inp = (T(100, grad=True),) 718 # We do not want to recompute the x.cos().view() chain, as it's 719 # materialized in backwards 720 self.assertExpectedInline(count_numel_train(f, *inp), """900""") 721 722 @patch.object(config, "pattern_matcher", False) 723 def test_partitioning_long_chain_add(self): 724 def f(x): 725 orig = x 726 for _ in range(2): 727 x = x * x 728 x = torch.mm(x, x) 729 x = x * 2 730 x = orig + x 731 orig = x 732 return x 733 734 inp = (T(10, 10, grad=True),) 735 self.assertExpectedInline(count_numel_train(f, *inp), """3900""") 736 737 738def unfusible(x): 739 # For the purpose of noop tests, we want inductor to fall back to 740 # eager mode, so, below we must use a aten operator that does not 741 # have decomposition nor lowering: 742 return aten._lazy_clone(x) 743 744 745class NoopTests(TestCase): 746 def test_noop_clones(self): 747 def f(a): 748 b = a.clone() 749 b = unfusible(b) 750 return b 751 752 inp = T(10) 753 self.assertExpectedInline(count_numel(f, inp), """20""") 754 755 def f(a): 756 b = a.clone() 757 c = unfusible(b) 758 return b, c 759 760 self.assertExpectedInline(count_numel(f, inp), """40""") 761 762 def test_noop_slice_scatter(self): 763 def f(a): 764 b = aten.slice_scatter(a, a) 765 c = unfusible(b) 766 return c 767 768 inp = T(10) 769 self.assertExpectedInline(count_numel(f, inp), """20""") 770 771 def test_noop_dtype_conversion(self): 772 def f(a): 773 b = torch.ops.prims.convert_element_type(a, torch.float32) 774 c = unfusible(b) 775 return c 776 777 inp = T(10) 778 self.assertExpectedInline(count_numel(f, inp), """20""") 779 780 def test_noop_device_conversion(self): 781 def f(a): 782 b = torch.ops.prims.device_put(a, "cuda") 783 c = unfusible(b) 784 return c 785 786 inp = T(10) 787 self.assertExpectedInline(count_numel(f, inp), """20""") 788 789 def test_noop_int_ops(self): 790 def f1(a): 791 b = torch.ceil(a) 792 c = unfusible(b) 793 return c 794 795 def f2(a): 796 d = torch.floor(a) 797 e = unfusible(d) 798 return e 799 800 def f3(a): 801 f = torch.round(a) 802 g = unfusible(f) 803 return g 804 805 def f4(a): 806 f = torch.pow(a, 1) 807 g = unfusible(f) 808 return g 809 810 inp = TI(10) 811 self.assertExpectedInline(count_numel(f1, inp), """20""") 812 self.assertExpectedInline(count_numel(f2, inp), """20""") 813 self.assertExpectedInline(count_numel(f3, inp), """20""") 814 self.assertExpectedInline(count_numel(f4, inp), """20""") 815 816 def test_noop_cat(self): 817 def f1(a): 818 b = torch.cat([a]) 819 return unfusible(b) 820 821 inp = T(10) 822 self.assertExpectedInline(count_numel(f1, inp), """20""") 823 824 def f2(a): 825 b = torch.cat([a]) 826 c = torch.cat([b]) 827 return c 828 829 self.assertExpectedInline(count_numel(f2, inp), """20""") 830 831 832class InplacingTests(TestCase): 833 def test_inplace_scatter(self): 834 def f(a, b): 835 a = a.cos() 836 a[b] = 1 837 return a 838 839 inp = (T(10), TI(2, mx=5)) 840 self.assertExpectedInline(count_numel(f, *inp), """26""") 841 842 def f(a, b): 843 out = aten.index_put(a, (b,), torch.tensor(1.0)) 844 return a.copy_(out) 845 846 inp = (T(10), TI(2, mx=5)) 847 self.assertExpectedInline(count_numel(f, *inp), """6""") 848 849 def f(a, b): 850 out = aten._unsafe_index_put(a, (b,), torch.tensor(1.0)) 851 return a.copy_(out) 852 853 inp = (T(10), TI(2, mx=5)) 854 self.assertExpectedInline(count_numel(f, *inp), """6""") 855 856 def test_inplace_scatter_noop_view(self): 857 def f(a, b): 858 a[:, b] = 1 859 return a 860 861 inp = (T(10, 10), TI(2, mx=5)) 862 self.assertExpectedInline(count_numel(f, *inp), """42""") 863 864 @requires_cuda 865 def test_inplace_triton_kernel_training(self): 866 @triton.jit 867 def sin_kernel( 868 in_ptr0, 869 out_ptr, 870 n_elements, 871 BLOCK_SIZE: "tl.constexpr", 872 ): 873 pid = tl.program_id(axis=0) 874 block_start = pid * BLOCK_SIZE 875 offsets = block_start + tl.arange(0, BLOCK_SIZE) 876 mask = offsets < n_elements 877 x = tl.load(in_ptr0 + offsets, mask=mask) 878 output = tl.sin(x) 879 tl.store(out_ptr + offsets, output, mask=mask) 880 881 def sin_triton(x, out): 882 n_elements = x.numel() 883 sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) 884 885 factory_op = torch.empty_like 886 887 class MySin(torch.autograd.Function): 888 @staticmethod 889 def forward(ctx, x): 890 out = factory_op(x) 891 sin_triton(x, out) 892 ctx.save_for_backward(out) 893 return out 894 895 @staticmethod 896 def backward(ctx, grad): 897 (saved,) = ctx.saved_tensors 898 out = factory_op(grad) 899 sin_triton(saved, out) 900 return out 901 902 def f(x): 903 return MySin.apply(x) 904 905 x = T(3, grad=True) 906 self.assertExpectedInline(count_numel_train(f, x), """9""") 907 908 @requires_cuda 909 def test_inplace_custom_op_training_two_mutated_inputs(self): 910 @torch.library.custom_op( 911 "_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"} 912 ) 913 def sin_cos( 914 x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor 915 ) -> None: 916 out_sin.copy_(x.sin()) 917 out_cos.copy_(x.cos()) 918 919 def f(x): 920 out0 = torch.empty_like(x) 921 out1 = torch.empty_like(x) 922 sin_cos(x, out0, out1) 923 return x.clone(), out0, out1 924 925 x = T(3, grad=True) 926 self.assertExpectedInline(count_numel(f, x), """21""") 927 928 @requires_cuda 929 def test_inplace_custom_op_training(self): 930 @torch.library.custom_op("_reinplacing::sin", mutates_args={"result"}) 931 def sin(x: torch.Tensor, result: torch.Tensor) -> None: 932 result.copy_(x.sin()) 933 934 factory_op = torch.empty_like 935 936 class MySin(torch.autograd.Function): 937 @staticmethod 938 def forward(ctx, x): 939 out = factory_op(x) 940 sin(x, out) 941 ctx.save_for_backward(out) 942 return out 943 944 @staticmethod 945 def backward(ctx, grad): 946 (saved,) = ctx.saved_tensors 947 out = factory_op(grad) 948 sin(saved, out) 949 return out 950 951 def f(x): 952 return MySin.apply(x) 953 954 x = T(3, grad=True) 955 self.assertExpectedInline(count_numel_train(f, x), """9""") 956 957 @requires_cuda 958 def test_inplace_custom_op(self): 959 with torch.library._scoped_library("mylib", "FRAGMENT") as m: 960 m.define("foo(Tensor x, Tensor(a!) out) -> ()") 961 962 def foo(x: torch.Tensor, out: torch.Tensor) -> None: 963 out.copy_(x.sin()) 964 965 m.impl("foo", foo, "CompositeExplicitAutograd") 966 967 def f(x, out): 968 torch.ops.mylib.foo(x, out) 969 torch.ops.mylib.foo(out, out) 970 torch.ops.mylib.foo(out, out) 971 return out 972 973 x = T(3) 974 out = T(3) 975 976 compiled_out, (code,) = run_and_get_code( 977 torch.compile(f, fullgraph=True), x, out 978 ) 979 self.assertEqual(compiled_out, x.sin().sin().sin()) 980 981 # Check that we are allocating the minimum number of intermediate buffers 982 matches = re.findall(r"empty_strided_\w+\(", code) 983 self.assertEqual(len(matches), 0) 984 985 self.assertExpectedInline(count_numel(f, x, out), """21""") 986 987 @requires_cuda 988 def test_inplace_custom_op_intermediate(self): 989 with torch.library._scoped_library("mylib", "FRAGMENT") as m: 990 m.define("foo(Tensor x, Tensor(a!) out) -> ()") 991 992 def foo(x: torch.Tensor, out: torch.Tensor) -> None: 993 out.copy_(x.sin()) 994 995 m.impl("foo", foo, "CompositeExplicitAutograd") 996 997 def f(x, out): 998 out = torch.empty_like(x) 999 torch.ops.mylib.foo(x, out) 1000 torch.ops.mylib.foo(out, out) 1001 torch.ops.mylib.foo(out, out) 1002 return out 1003 1004 x = T(3) 1005 out = T(3) 1006 1007 compiled_out, (code,) = run_and_get_code( 1008 torch.compile(f, fullgraph=True), x, out 1009 ) 1010 self.assertEqual(compiled_out, x.sin().sin().sin()) 1011 1012 # Check that we are allocating the minimum number of intermediate buffers 1013 matches = re.findall(r"empty_strided_\w+\(", code) 1014 self.assertEqual(len(matches), 1) 1015 1016 self.assertExpectedInline(count_numel(f, x, out), """21""") 1017 1018 @requires_cuda 1019 def test_inplace_custom_op_two_mutated_inputs(self): 1020 with torch.library._scoped_library("mylib", "FRAGMENT") as m: 1021 m.define("foo(Tensor q, Tensor(a!) k_cache, Tensor(b!) v_cache) -> Tensor") 1022 1023 def foo(q, k_cache, v_cache): 1024 k_cache.add_(1) 1025 v_cache.add_(1) 1026 return q + 1 1027 1028 m.impl("foo", foo, "CompositeExplicitAutograd") 1029 1030 q = T(3) 1031 k_cache = T(3) 1032 v_cache = torch.rand_like(k_cache) 1033 1034 def f(): 1035 x = 0 1036 for _ in range(2): 1037 x = x + torch.ops.mylib.foo(q, k_cache, v_cache) 1038 return x 1039 1040 compiled_out, (code,) = run_and_get_code( 1041 torch.compile(f, fullgraph=True), 1042 ) 1043 1044 # Check that we are allocating the minimum number of intermediate buffers 1045 matches = re.findall(r"empty_strided_\w+\(", code) 1046 self.assertEqual(len(matches), 1) 1047 1048 self.assertExpectedInline(count_numel(f), """39""") 1049 1050 @requires_cuda 1051 def test_inplace_triton_kernel_v1(self): 1052 def f(x: torch.Tensor, y: torch.Tensor): 1053 output = torch.zeros_like(x) 1054 n_elements = output.numel() 1055 grid = (n_elements,) 1056 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) 1057 return output 1058 1059 inp = (T(10), T(10)) 1060 self.assertExpectedInline(count_numel(f, *inp), """50""") 1061 1062 @requires_cuda 1063 def test_inplace_triton_kernel_v2(self): 1064 def f(x: torch.Tensor, y: torch.Tensor): 1065 output = torch.zeros_like(x) 1066 n_elements = output.numel() 1067 grid = (n_elements,) 1068 tmp = torch.add(x, 1) 1069 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) 1070 return output, tmp 1071 1072 inp = (T(10), T(10)) 1073 self.assertExpectedInline(count_numel(f, *inp), """70""") 1074 1075 @requires_cuda 1076 def test_inplace_triton_kernel_v3(self): 1077 def f(x: torch.Tensor, y: torch.Tensor): 1078 output = torch.zeros_like(x) 1079 n_elements = output.numel() 1080 grid = (n_elements,) 1081 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) 1082 x.add_(1) 1083 return output 1084 1085 inp = (T(10), T(10)) 1086 self.assertExpectedInline(count_numel(f, *inp), """80""") 1087 1088 @requires_cuda 1089 def test_inplace_triton_kernel_v4(self): 1090 def f(x: torch.Tensor, y: torch.Tensor): 1091 x_view = x.view(-1) 1092 output = torch.zeros_like(x) 1093 n_elements = output.numel() 1094 grid = (n_elements,) 1095 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) 1096 output2 = x_view.mul(2) 1097 return output, output2 1098 1099 inp = (T(10), T(10)) 1100 self.assertExpectedInline(count_numel(f, *inp), """70""") 1101 1102 @requires_cuda 1103 def test_inplace_triton_kernel_v5(self): 1104 def f(x: torch.Tensor, y: torch.Tensor): 1105 x_view = x.view(-1) 1106 output = torch.zeros_like(x) 1107 n_elements = output.numel() 1108 grid = (n_elements,) 1109 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) 1110 x_view.mul_(2) 1111 return output 1112 1113 inp = (T(10), T(10)) 1114 self.assertExpectedInline(count_numel(f, *inp), """80""") 1115 1116 @requires_cuda 1117 def test_inplace_triton_kernel_v6(self): 1118 def f(x: torch.Tensor, y: torch.Tensor): 1119 output = torch.zeros_like(x) 1120 n_elements = output.numel() 1121 grid = (n_elements,) 1122 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) 1123 return output 1124 1125 t = T(10) 1126 inp = (t, t.view(-1)) 1127 self.assertExpectedInline(count_numel(f, *inp), """50""") 1128 1129 def test_inplace_randperm_scatter(self): 1130 def scaled_index_add(x, y, scale_y): 1131 index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] 1132 out = x.index_add_(dim=0, source=y * scale_y, index=index) 1133 return out 1134 1135 inp = (T(10, 10), T(5, 10), T(10)) 1136 self.assertExpectedInline(count_numel(scaled_index_add, *inp), """250""") 1137 1138 1139# Test cases where we don't do the right thing yet. 1140class WouldBeNiceIfItWorked: 1141 def test_horizontal(self): 1142 def f(a): 1143 b = a.sum(dim=0) 1144 c = a.cos() 1145 return b, c 1146 1147 inp = (T(10, 10),) 1148 self.assertExpectedInline(count_numel(f, *inp), """210""") 1149 1150 # TODO: We aren't fusing outer dim softmaxes 1151 def test_softmax_outer(self): 1152 def f(a): 1153 return torch.softmax(a, dim=0) 1154 1155 inp = (T(10, 10),) 1156 self.assertExpectedInline(count_numel(f, *inp), """200""") 1157 1158 # TODO: The greedy fusion strategy results in suboptimal grouping 1159 @patch.object(config, "realize_opcount_threshold", 0) 1160 def test_fusion_choice4(self): 1161 def f(a, b, b2): 1162 c = a + b 1163 d = torch.mm(c, c) 1164 e = c + b + b2 1165 f = d + e + b2 1166 return f, e 1167 1168 inp = (T(10, 10), T(10, 10, dtype=torch.float16), T(10, 10)) 1169 self.assertExpectedInline(count_numel(f, *inp), """1000""") 1170 1171 # TODO: We materialize the intermediate if we don't unroll the reduction 1172 def test_neighbor(self): 1173 def f(a, b): 1174 return ((a - b) ** 2).sum(dim=-1).amax(dim=1) 1175 1176 inp = (T(10, 1, 8), T(1, 10, 8)) 1177 self.assertExpectedInline(count_numel(f, *inp), """170""") 1178 1179 1180if __name__ == "__main__": 1181 from torch._inductor.test_case import run_tests 1182 1183 if HAS_CUDA: 1184 run_tests(needs="filelock") 1185