1# Owner(s): ["module: inductor"] 2import os 3import unittest 4from typing import Callable, List, Optional 5 6import torch 7from torch import multiprocessing as mp, nn 8from torch._dynamo import reset 9from torch._dynamo.exc import BackendCompilerFailed 10from torch._dynamo.testing import rand_strided, reset_rng_state 11from torch._inductor import config 12from torch._inductor.autotune_process import ( 13 BenchmarkRequest, 14 CUDA_VISIBLE_DEVICES, 15 TuningProcessPool, 16) 17from torch._inductor.graph import GraphLowering 18from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout 19from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm 20from torch._inductor.select_algorithm import ( 21 AlgorithmSelectorCache, 22 TritonTemplateCaller, 23) 24from torch._inductor.test_case import run_tests, TestCase 25from torch._inductor.utils import fresh_inductor_cache, run_and_get_code 26from torch._inductor.virtualized import V 27from torch.fx.experimental.proxy_tensor import make_fx 28from torch.testing import FileCheck 29from torch.testing._internal.common_utils import ( 30 instantiate_parametrized_tests, 31 parametrize, 32 skipIfRocm, 33) 34from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA 35 36 37try: 38 from .mock_cache import global_stats, PatchCaches 39except ImportError: 40 from mock_cache import global_stats, PatchCaches # @manual 41 42 43torch.set_float32_matmul_precision("high") 44if HAS_CUDA: 45 torch.cuda.memory._set_allocator_settings("expandable_segments:False") 46 47_CUTLASS_DIR = os.path.join(os.path.dirname(__file__), "../../third_party/cutlass/") 48 49 50def _get_path_without_sccache() -> str: 51 """ 52 Get the PATH environment variable without sccache. 53 """ 54 path_envs = os.environ.get("PATH", "").split(":") 55 path_envs = [env for env in path_envs if "/opt/cache/bin" not in env] 56 return ":".join(path_envs) 57 58 59def benchmark_choice(choice, args, out, expected_out, timings): 60 result = choice.benchmark(*args, out=out) 61 if expected_out is not None: 62 torch.testing.assert_close(out, expected_out) 63 64 timings.copy_(torch.tensor(result)) 65 66 67class FailChoiceCaller(ChoiceCaller): 68 def benchmark(self, *args, out): 69 raise RuntimeError("This choice caller will always throw") 70 71 72@instantiate_parametrized_tests 73class TestMaxAutotune(TestCase): 74 def _create_buffer(self, name, shape): 75 return Buffer(name, FixedLayout(torch.device("cuda:0"), torch.float32, shape)) 76 77 def test_benchmark_choice_in_subproc(self): 78 gm = make_fx( 79 lambda: torch.zeros(2, 3) 80 )() # a dummy graph to construct the GraphLowering 81 graph = GraphLowering(gm) 82 83 # the graph handler is neede to create benchmark example value below 84 with V.set_graph_handler(graph): 85 buf1 = self._create_buffer("mat1", (2, 3)) 86 buf2 = self._create_buffer("mat2", (3, 2)) 87 buf3 = self._create_buffer("mat3", (2, 3)) 88 buf4 = self._create_buffer("mat4", (3, 2)) 89 90 layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2)) 91 92 mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1) 93 mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2) 94 mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3) 95 mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4) 96 97 out = AlgorithmSelectorCache.benchmark_example_value(layout) 98 # expected_out = (mat1 @ mat2) + (mat3 @ mat4) 99 expected_out = None 100 101 choice = aten_mm_plus_mm.bind((buf1, buf2, buf3, buf4), layout) 102 # use a tensor since the mutation to a python list in a sub process 103 # is not synced back to the parent process 104 timings = torch.zeros(3, dtype=torch.float32) 105 ctx = mp.get_context("spawn") 106 child = ctx.Process( 107 target=benchmark_choice, 108 args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings), 109 ) 110 child.start() 111 child.join() 112 self.assertEqual(0, child.exitcode) 113 print(f"timings is {timings}, out {out}, expected_out {expected_out}") 114 115 def test_benchmark_choice_fail_in_subproc(self): 116 gm = make_fx( 117 lambda: torch.zeros(2, 3) 118 )() # a dummy graph to construct the GraphLowering 119 graph = GraphLowering(gm) 120 121 # the graph handler is neede to create benchmark example value below 122 with V.set_graph_handler(graph): 123 buf1 = self._create_buffer("mat1", (2, 3)) 124 buf2 = self._create_buffer("mat2", (3, 2)) 125 buf3 = self._create_buffer("mat3", (2, 3)) 126 buf4 = self._create_buffer("mat4", (3, 2)) 127 128 layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2)) 129 130 mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1) 131 mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2) 132 mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3) 133 mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4) 134 135 out = AlgorithmSelectorCache.benchmark_example_value(layout) 136 expected_out = (mat1 @ mat2) + (mat3 @ mat4) 137 138 choice = FailChoiceCaller("fail_choice_caller", [], None) 139 140 # use a tensor since python list is not synced back 141 timings = torch.zeros(3, dtype=torch.float32) 142 ctx = mp.get_context("spawn") 143 child = ctx.Process( 144 target=benchmark_choice, 145 args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings), 146 ) 147 child.start() 148 child.join() 149 self.assertNotEqual(0, child.exitcode) 150 151 @parametrize("autotune_in_subproc", (True, False)) 152 @parametrize("autotune_multi_device", (True, False)) 153 def test_max_autotune_mm_plus_mm(self, autotune_in_subproc, autotune_multi_device): 154 """ 155 This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 . 156 With autotuning in subprocess, we don't crash anymore. 157 """ 158 m, n, k = 2048, 1536, 64 159 160 def mm_plus_mm(a, b, c, d): 161 return a @ b + c @ d 162 163 a = torch.randn(m, k).cuda() 164 b = torch.randn(k, n).cuda() 165 c = torch.randn(m, k).cuda() 166 d = torch.randn(k, n).cuda() 167 168 with config.patch( 169 { 170 "max_autotune": True, 171 "autotune_in_subproc": autotune_in_subproc, 172 "autotune_multi_device": autotune_multi_device, 173 } 174 ): 175 torch.compile(mm_plus_mm)(a, b, c, d) 176 177 @parametrize("dynamic", (False, True)) 178 def test_max_autotune_mm_plus_mm_zero_size_input(self, dynamic): 179 """ 180 Make sure autotuning mm_plus_mm with zero-size input works without crashes. 181 """ 182 m, n, k = 0, 1536, 64 183 184 def mm_plus_mm(a, b, c, d): 185 return a @ b + c @ d 186 187 a = torch.randn(m, k).cuda() 188 b = torch.randn(k, n).cuda() 189 c = torch.randn(m, k).cuda() 190 d = torch.randn(k, n).cuda() 191 192 with config.patch({"max_autotune": True}): 193 torch.compile(mm_plus_mm, dynamic=dynamic)(a, b, c, d) 194 195 @parametrize("dynamic", (False, True)) 196 def test_max_autotune_regular_mm(self, dynamic: bool): 197 """ 198 Make sure autotuning mm in sub processes work without crashes. 199 """ 200 201 def mm(a, b): 202 a = torch.sin(a) 203 return a @ b 204 205 a = torch.randn(100, 10).cuda() 206 b = torch.randn(10, 100).cuda() 207 208 with config.patch({"max_autotune": True, "autotune_in_subproc": True}): 209 torch.compile(mm, dynamic=dynamic)(a, b) 210 211 @parametrize("dynamic", (False, True)) 212 def test_max_autotune_regular_mm_zero_size_input(self, dynamic: bool): 213 """ 214 Make sure autotuning mm with zero-size input works without crashes. 215 """ 216 217 def mm(a, b): 218 a = torch.sin(a) 219 return a @ b 220 221 a = torch.randn(0, 10).cuda() 222 b = torch.randn(10, 100).cuda() 223 224 with config.patch({"max_autotune": True}): 225 torch.compile(mm, dynamic=dynamic)(a, b) 226 227 @skipIfRocm 228 def test_precompilation_threads(self): 229 import threading 230 from typing import Any, Dict 231 from unittest.mock import Mock, patch 232 233 class FakeChoiceCaller(ChoiceCaller): 234 def __init__(self) -> None: 235 super().__init__("none", [], Mock()) 236 self.thread_id = None 237 238 def precompile(self): 239 self.thread_id = threading.get_ident() 240 241 def call_name(self) -> str: 242 return None 243 244 def to_callable(self): 245 return None 246 247 def hash_key(self) -> str: 248 return str(hash(self)) 249 250 def output_node(self) -> "TensorBox": # noqa: F821 251 return None 252 253 fake_choices = [FakeChoiceCaller() for i in range(10)] 254 fake_lookup_result = dict.fromkeys(fake_choices, 0.123) 255 256 def no_lookup( 257 choices: List[ChoiceCaller], 258 op: str, 259 inputs: str, 260 benchmark: Callable[[Any], Dict[ChoiceCaller, float]], 261 ) -> Optional[Dict[ChoiceCaller, float]]: 262 if benchmark is not None: 263 return benchmark(choices) 264 265 asc = AlgorithmSelectorCache() 266 267 def fake_benchmark_fn(*args, **kwargs): 268 return fake_lookup_result 269 270 main_thread_id = threading.get_ident() 271 mock_debug_handler = Mock() 272 old_debug_handler = V.debug 273 try: 274 V.set_debug_handler(mock_debug_handler) 275 with patch.object(asc, "lookup", new=no_lookup): 276 with patch.object( 277 asc, "make_benchmark_fn", return_value=fake_benchmark_fn 278 ): 279 with config.patch( 280 { 281 "autotune_in_subproc": False, 282 "compile_threads": len(fake_choices), 283 } 284 ): 285 asc("test_call", fake_choices, [], Mock()) 286 for fake_choice in fake_choices: 287 assert ( 288 fake_choice.thread_id is not None 289 ), "Expected all ChoiceCaller's precompile method to have been called" 290 assert ( 291 fake_choice.thread_id != main_thread_id 292 ), "Expected all ChoiceCaller's precompile method to have been called on separate thread" 293 finally: 294 V.set_debug_handler(old_debug_handler) 295 296 @parametrize("dynamic", (False, True)) 297 def test_max_autotune_addmm(self, dynamic=False): 298 """ 299 Make sure autotuning addmm in sub processes work without crashes. 300 """ 301 302 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 303 304 def addmm(x, a, b): 305 return torch.addmm(x, a, b) 306 307 x = torch.randn(100).cuda() 308 a = torch.randn(100, 10).cuda() 309 b = torch.randn(10, 100).cuda() 310 with config.patch({"max_autotune": True, "autotune_in_subproc": True}): 311 Y_compiled = torch.compile(addmm, dynamic=dynamic)(x, a, b) 312 Y = addmm(x, a, b) 313 torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2) 314 315 @parametrize("dynamic", (False, True)) 316 def test_max_autotune_addmm_zero_size_input(self, dynamic): 317 """ 318 Make sure autotuning addmm with zero-size input works without crashes. 319 """ 320 321 def addmm(x, a, b): 322 return torch.addmm(x, a, b) 323 324 x = torch.randn(100).cuda() 325 a = torch.randn(0, 10).cuda() 326 b = torch.randn(10, 100).cuda() 327 with config.patch({"max_autotune": True}): 328 torch.compile(addmm, dynamic=dynamic)(x, a, b) 329 330 @skipIfRocm 331 def test_autotune_conv1x1(self): 332 # Assuming input has 3 channels and we want to produce 16 channels as output 333 conv1x1 = ( 334 torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1) 335 .to(memory_format=torch.channels_last) 336 .cuda() 337 ) 338 339 # Example input tensor: batch size = 4, channels = 3, height = 32, width = 32 340 # The memory format is set to `channels_last` 341 input_tensor = ( 342 torch.randn(4, 3, 32, 32) 343 .contiguous(memory_format=torch.channels_last) 344 .cuda() 345 ) 346 347 with config.patch( 348 {"max_autotune": True, "max_autotune_gemm_backends": "TRITON"} 349 ): 350 351 @torch.compile() 352 def foo(mod, x): 353 return mod(x) 354 355 with torch.no_grad(): 356 out, code = run_and_get_code(foo, conv1x1, input_tensor) 357 358 FileCheck().check_not("extern_kernels.convolution").run(code[0]) 359 self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0) 360 361 @skipIfRocm 362 def test_filled_cache_precompile(self): 363 def fn(a, b, c): 364 a = (a @ b) @ c 365 a, b, c = (t.to(torch.float16) for t in [a, b, c]) 366 return (a @ b) @ c 367 368 fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) 369 inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)] 370 from torch._dynamo.utils import counters 371 372 self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) 373 374 torch._dynamo.reset() 375 counters.clear() 376 377 fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) 378 self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) 379 380 @skipIfRocm 381 @fresh_inductor_cache() 382 @config.patch(max_autotune=True, max_fusion_size=2) 383 def test_jit_fusion_matches_aot_fusion(self): 384 # In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due 385 # to proximity, we want to make sure AOT-compile pass does the same. 386 # AOT could do fuse(buf2, buf4) instead if buf3 was pushed to the end 387 # of the V.graph.buffers list because fuse(buf2, buf4) would have a 388 # better proximity score than fuse(buf1, buf2). This scenario is possible 389 # since finalizing MultiTemplateBuffers needs to replace buffers. 390 def fn(x, number): 391 buf0 = x + x 392 buf1 = number.item() 393 buf2 = x * x 394 buf3 = x @ x # MultiTemplateBuffer 395 buf4 = x**2 396 return buf0, buf1, buf2, buf3, buf4 397 398 inputs = (torch.rand([256, 256], device="cuda"), torch.tensor(3, device="cuda")) 399 torch._export.aot_compile(fn, args=inputs) 400 401 @config.patch(autotune_local_cache=False, autotune_remote_cache=False) 402 @skipIfRocm 403 def test_precompilations(self): 404 def fn(a, b, c): 405 a = (a @ b) @ c 406 a, b, c = (t.to(torch.float16) for t in [a, b, c]) 407 return (a @ b) @ c 408 409 fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) 410 inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)] 411 412 torch.testing.assert_close(fn_c(*inputs), fn(*inputs), atol=1e-2, rtol=1e-2) 413 414 from torch._dynamo.utils import counters 415 416 self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 2) 417 418 def test_cat_addmm(self): 419 def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): 420 return torch.cat( 421 [ 422 torch.addmm(a, b, c), 423 torch.addmm(b, c, a), 424 ], 425 1, 426 ) 427 428 args = [ 429 torch.randn(4, 4, device="cuda"), 430 torch.randn(4, 4, device="cuda"), 431 torch.randn(4, 4, device="cuda"), 432 ] 433 with config.patch( 434 { 435 "max_autotune": True, 436 "max_autotune_gemm_backends": "Triton", 437 } 438 ): 439 expected = fn(*args) 440 actual = torch.compile(fn)(*args) 441 torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2) 442 443 def test_triton_template_with_epilogues_and_dynamic_shape(self): 444 def fn( 445 x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor, mul: torch.Tensor 446 ) -> torch.Tensor: 447 return ( 448 torch.nn.functional.relu( 449 torch.matmul(torch.transpose(x, 0, 1), torch.transpose(w, 0, 1)) 450 + bias 451 ) 452 * mul 453 ) 454 455 M0 = 5 456 M1 = 8 457 K = 4 458 N = 3 459 w = torch.rand(N, K).cuda().half() 460 b = torch.rand(N).cuda().half() 461 462 with config.patch( 463 { 464 "max_autotune": True, 465 "autotune_in_subproc": True, 466 "max_autotune_gemm_backends": "Triton", 467 } 468 ): 469 compiled_fn = torch.compile( 470 fn, fullgraph=True, dynamic=True, mode="max-autotune-no-cudagraphs" 471 ) 472 473 x0 = torch.rand(K, M0).cuda().half() 474 mul0 = torch.rand(M0, N).cuda().half() 475 y0 = compiled_fn(x0, w, b, mul0) 476 y0_expected = fn(x0, w, b, mul0) 477 torch.testing.assert_close(y0, y0_expected) 478 479 x1 = torch.rand(K, M1).cuda().half() 480 mul1 = torch.rand(M1, N).cuda().half() 481 y1 = compiled_fn(x1, w, b, mul1) 482 y1_expected = fn(x1, w, b, mul1) 483 torch.testing.assert_close(y1, y1_expected) 484 485 @config.patch( 486 benchmark_kernel=True, 487 fallback_random=True, 488 max_autotune_gemm=True, 489 ) 490 @parametrize("device", ("cpu", "cuda")) 491 def test_matmul_dropout(self, device): 492 def fwd(a, b): 493 x = a @ b 494 x = torch.nn.functional.dropout(x, 0.1) 495 return x 496 497 def fn(a, b): 498 x = fwd(a, b).sum() 499 x.backward() 500 return a.grad 501 502 N = 128 503 a = torch.randn(N, N, device=device, requires_grad=True) 504 b = torch.randn(N, N, device=device) 505 506 opt_fn = torch.compile(fn) 507 reset_rng_state() 508 ref = fn(a, b) 509 reset_rng_state() 510 act = opt_fn(a, b) 511 512 if N <= 8: 513 print(f"ref\n{ref}\nact\n{act}") 514 torch.testing.assert_close(ref, act, atol=1e-1, rtol=1e-1) 515 516 @config.patch( 517 max_autotune_gemm=True, 518 ) 519 @unittest.skipIf( 520 torch.cuda.device_count() < 2, "Need at least 2 devices for this test" 521 ) 522 def test_autotune_device_guard(self): 523 x = torch.randn(1024, 1024, device="cuda:1") 524 y = torch.randn(1024, 1024, device="cuda:1") 525 526 def f(x, y): 527 return x @ y 528 529 with fresh_inductor_cache(): 530 act = torch.compile(f)(x, y) 531 ref = f(x, y) 532 self.assertTrue(torch.allclose(act, ref, atol=4 * 1e-3, rtol=4 * 1e-3)) 533 534 @config.patch(max_autotune=True) 535 def test_empty_conv_input(self, kernel_size=3): 536 x = torch.randn(0, 256, 14, 14, device="cuda") 537 weight = torch.randn(256, 256, kernel_size, kernel_size, device="cuda") 538 539 def f(x, weight): 540 return torch.convolution( 541 x, 542 weight, 543 bias=None, 544 stride=[1, 1], 545 padding=[0, 0], 546 dilation=[1, 1], 547 transposed=False, 548 output_padding=[0, 0], 549 groups=1, 550 ) 551 552 opt_f = torch.compile(f) 553 ref = f(x, weight) 554 act = opt_f(x, weight) 555 self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3)) 556 557 @config.patch(max_autotune=True) 558 def test_empty_conv_input_with_1x1_kernel(self): 559 self.test_empty_conv_input(kernel_size=1) 560 561 @config.patch(max_autotune=True) 562 def test_conv1x1_with_free_symbols(self): 563 """ 564 Make sure there is no exception due to free symbols. 565 """ 566 conv = nn.Conv2d( 567 3, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False 568 ).to(device="cuda") 569 570 @torch.compile 571 def f(x, y, z): 572 h = y.nonzero().size(0) 573 w = z.nonzero().size(0) 574 x = x[:, :, :h, :w] 575 x = conv(x) 576 return x 577 578 x = torch.randn(4, 3, 224, 224).to( 579 memory_format=torch.channels_last, device="cuda" 580 ) 581 for _ in range(2): 582 y = torch.randint(0, 10, (224,)).to(device="cuda") 583 z = torch.randint(0, 10, (224,)).to(device="cuda") 584 f(x, y, z) 585 586 def test_conv3d(self): 587 fn = torch.nn.functional.conv3d 588 image = torch.randn([1, 3, 8, 16, 32]) 589 filt = torch.randn([3, 3, 7, 7, 7]) 590 591 with config.patch({"max_autotune": True}): 592 expected = fn(image, filt) 593 actual = torch.compile(fn)(image, filt) 594 torch.testing.assert_close(actual, expected, atol=6e-5, rtol=0.001) 595 596 @config.patch( 597 max_autotune=True, max_autotune_conv_backends="", layout_optimization=False 598 ) 599 def test_conv_backend(self): 600 m = torch.nn.Sequential( 601 torch.nn.Conv2d(3, 3, 1, 1), 602 ).cuda() 603 inp = torch.randn([2, 3, 16, 16]).cuda() 604 605 with self.assertRaises(BackendCompilerFailed) as context: 606 torch.compile(m)(inp) 607 608 self.assertIn("NoValidChoicesError", str(context.exception)) 609 610 def test_non_contiguous_input_mm(self): 611 """ 612 Make sure the triton template can work with non-contiguous inputs without crash. 613 Check https://github.com/pytorch/pytorch/issues/125437 for more details. 614 """ 615 x = rand_strided( 616 (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" 617 ) 618 y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda") 619 620 @torch.compile(mode="max-autotune") 621 def f(x, y): 622 return x @ y 623 624 ref = x @ y 625 act = f(x, y) 626 torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2) 627 628 def test_non_contiguous_input_addmm(self): 629 b = torch.randn((768), dtype=torch.bfloat16, device="cuda") 630 x = rand_strided( 631 (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" 632 ) 633 y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda") 634 635 @torch.compile(mode="max-autotune") 636 def f(x, y): 637 return torch.addmm(b, x, y) 638 639 ref = torch.addmm(b, x, y) 640 act = f(x, y) 641 torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2) 642 643 def test_non_contiguous_input_bmm(self): 644 x = rand_strided( 645 (1, 50257, 32768), (0, 1, 50304), dtype=torch.bfloat16, device="cuda" 646 ) 647 y = rand_strided( 648 (1, 32768, 768), (0, 768, 1), dtype=torch.bfloat16, device="cuda" 649 ) 650 651 @torch.compile(mode="max-autotune") 652 def f(x, y): 653 return torch.bmm(x, y) 654 655 ref = torch.bmm(x, y) 656 act = f(x, y) 657 torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2) 658 659 def test_non_contiguous_input_mm_plus_mm(self): 660 x1 = rand_strided((50257, 32768), (1, 50304), device="cuda") 661 y1 = rand_strided((32768, 768), (768, 1), device="cuda") 662 663 x2 = rand_strided((50257, 32768), (1, 50304), device="cuda") 664 y2 = rand_strided((32768, 768), (768, 1), device="cuda") 665 666 @torch.compile(mode="max-autotune") 667 def f(x1, y1, x2, y2): 668 return x1 @ y1 + x2 @ y2 669 670 ref = x1 @ y1 + x2 @ y2 671 act = f(x1, y1, x2, y2) 672 torch.testing.assert_close(act, ref, atol=1e-2, rtol=1e-2) 673 674 @config.patch( 675 max_autotune=True, 676 max_autotune_gemm_backends="", 677 autotune_fallback_to_aten=False, 678 ) 679 def test_no_valid_choices(self): 680 a = torch.zeros([2, 2], device="cuda") 681 b = torch.zeros([2, 2], device="cuda") 682 with self.assertRaises(BackendCompilerFailed) as context: 683 torch.compile(lambda a, b: a.matmul(b))(a, b) 684 self.assertIn("NoValidChoicesError", str(context.exception)) 685 686 @parametrize("multi_template", (True, False)) 687 @config.patch( 688 max_autotune=True, 689 max_autotune_gemm_backends="TRITON", 690 autotune_fallback_to_aten=False, 691 ) 692 def test_inf_timing(self, multi_template): 693 from unittest.mock import patch 694 695 lookup = AlgorithmSelectorCache.lookup 696 697 def mock_lookup(self, *args, **kwargs): 698 timings = lookup(self, *args, **kwargs) 699 return {choice: float("inf") for choice in timings.keys()} 700 701 a = torch.zeros([16, 16], device="cuda") 702 b = torch.zeros([16, 16], device="cuda") 703 with patch.object(AlgorithmSelectorCache, "lookup", mock_lookup), config.patch( 704 benchmark_epilogue_fusion=multi_template 705 ): 706 with self.assertRaises(BackendCompilerFailed) as context: 707 torch.compile(lambda a, b: a.matmul(b))(a, b) 708 self.assertIn("NoValidChoicesError", str(context.exception)) 709 710 711@instantiate_parametrized_tests 712class TestMaxAutotuneRemoteCache(TestCase): 713 def setUp(self): 714 super().setUp() 715 PatchCaches.setUp() 716 717 def tearDown(self): 718 super().tearDown() 719 PatchCaches.tearDown() 720 721 @skipIfRocm 722 @parametrize("dynamic", (False, True)) 723 def test_max_autotune_remote_caching(self, dynamic: bool): 724 from unittest.mock import patch 725 726 def mm(a, b): 727 a = torch.sin(a) 728 return a @ b 729 730 a = torch.randn(100, 10).cuda() 731 b = torch.randn(10, 100).cuda() 732 733 class Model(torch.nn.Module): 734 def forward(self, x, y): 735 return x + y 736 737 def f(x, y): 738 return Model()(x, y) 739 740 x = torch.randn(100, 100).cuda() 741 y = torch.randn(100, 100).cuda() 742 743 with config.patch( 744 { 745 "autotune_local_cache": False, 746 "autotune_remote_cache": True, 747 } 748 ), patch.dict(os.environ), PatchCaches(): 749 os.environ.pop("TRITON_CACHE_MANAGER", None) 750 with config.patch({"max_autotune": True}): 751 for _ in range(4): 752 with fresh_inductor_cache(): 753 torch.compile(mm, dynamic=dynamic)(a, b) 754 reset() 755 756 global_stats.report() 757 self.assertEqual(global_stats.autotune.num_get_hit, 3) 758 self.assertEqual(global_stats.autotune.num_get_miss, 1) 759 self.assertEqual(global_stats.autotune.num_put, 1) 760 761 global_stats.reset() 762 for _ in range(4): 763 with fresh_inductor_cache(): 764 torch.compile(f, dynamic=dynamic)(x, y) 765 reset() 766 global_stats.report() 767 self.assertEqual(global_stats.autotune.num_get_hit, 3) 768 self.assertEqual(global_stats.autotune.num_get_miss, 1) 769 self.assertEqual(global_stats.autotune.num_put, 1) 770 771 772class TestBenchmarkRequest(BenchmarkRequest): 773 def __init__( 774 self, value: float, multi_device: bool, parent_visible_devices: Optional[str] 775 ) -> None: 776 self.value = value 777 self.multi_device = multi_device 778 self.parent_visible_devices = parent_visible_devices 779 780 def benchmark( 781 self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None 782 ) -> float: 783 # Verify that the visible devices env var is set correctly. If multi-device 784 # auto-tuning is disabled, the visible devices should be unmanipulated from 785 # the parent process. If multi-device auto-tuning is enabled, the visible 786 # devices should be a _single_ valid device number. Note that we can't perform 787 # this validation directly from the test body because benchmarks execute in a 788 # separate process. If the check fails, however, the test will detect the 789 # failure by virtue of not receiving the expected result back. 790 visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES) 791 if not self.multi_device: 792 assert visible_devices == self.parent_visible_devices 793 else: 794 assert self.parent_visible_devices is not None 795 valid_devices = self.parent_visible_devices.split(",") 796 assert visible_devices in valid_devices 797 798 return self.value 799 800 801class TestTritonTemplateCaller(TritonTemplateCaller): 802 def __init__(self, bmreq: TestBenchmarkRequest): 803 self.bmreq = bmreq 804 805 def __str__(self) -> str: 806 return "test" 807 808 809class TestTuningProcess(TestCase): 810 def test_tuning_pool_crash(self): 811 # Use only one device/subprocess so we test the process restarts 812 # and is usable after a "crash". 813 with config.patch({"autotune_multi_device": False}): 814 tuning_pool = TuningProcessPool() 815 tuning_pool.initialize() 816 817 # First force the tuning process to "crash" by setting a bogus 818 # string for the expected visible devices. 819 bmreq = TestBenchmarkRequest(3.14, False, "invalid") 820 choice = TestTritonTemplateCaller(bmreq) 821 822 timings = tuning_pool.benchmark([choice]) 823 self.assertTrue(choice in timings) 824 self.assertEqual(timings[choice], float("inf")) 825 826 # Then send another request and make sure the sub-process 827 # has restarted and is operational. 'valid_devices' expected 828 # to be None because autotune_multi_device is off. 829 choice.bmreq.parent_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES) 830 831 timings = tuning_pool.benchmark([choice]) 832 self.assertTrue(choice in timings) 833 self.assertEqual(timings[choice], bmreq.value) 834 835 tuning_pool.terminate() 836 837 def test_tuning_pool_multiple_devices(self): 838 with config.patch({"autotune_multi_device": True}): 839 # Adapt the test to the available devices (and whether CUDA_VISIBLE_DEVICES 840 # is already set in the environment); use a subset of the available devices 841 # to ensure only the subset are visible to the sub-processes. 842 if CUDA_VISIBLE_DEVICES in os.environ: 843 visible_devices = os.environ[CUDA_VISIBLE_DEVICES].split(",") 844 else: 845 visible_devices = [str(d) for d in range(torch.cuda.device_count())] 846 847 parent_visible_devices = ",".join(visible_devices[-2:]) 848 os.environ[CUDA_VISIBLE_DEVICES] = parent_visible_devices 849 850 tuning_pool = TuningProcessPool() 851 tuning_pool.initialize() 852 853 choice1 = TestTritonTemplateCaller( 854 TestBenchmarkRequest(3.14, True, parent_visible_devices), 855 ) 856 choice2 = TestTritonTemplateCaller( 857 TestBenchmarkRequest(2.718, True, parent_visible_devices), 858 ) 859 860 timings = tuning_pool.benchmark([choice1, choice2]) 861 self.assertEqual(timings[choice1], choice1.bmreq.value) 862 self.assertEqual(timings[choice2], choice2.bmreq.value) 863 864 tuning_pool.terminate() 865 866 867if __name__ == "__main__": 868 from torch._inductor.utils import is_big_gpu 869 870 # Set env to make it work in CI. 871 if HAS_CUDA and HAS_CPU and is_big_gpu(0): 872 run_tests() 873