1# Owner(s): ["module: inductor"] 2import functools 3import os 4import pickle 5import unittest 6from typing import List 7from unittest import mock 8 9import torch 10from torch._dynamo import reset 11from torch._dynamo.utils import counters 12from torch._inductor import config, metrics 13from torch._inductor.async_compile import AsyncCompile 14from torch._inductor.codecache import ( 15 cuda_compile_command, 16 CUDACodeCache, 17 FxGraphCachePickler, 18 FxGraphHashDetails, 19 PyCodeCache, 20 TensorMetadata, 21 TensorMetadataAndValues, 22) 23from torch._inductor.graph import GraphLowering 24from torch._inductor.runtime.runtime_utils import cache_dir 25from torch._inductor.test_case import run_tests, TestCase 26from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache 27from torch.testing._internal.common_cuda import SM80OrLater 28from torch.testing._internal.common_device_type import largeTensorTest 29from torch.testing._internal.common_utils import ( 30 instantiate_parametrized_tests, 31 parametrize, 32) 33from torch.testing._internal.inductor_utils import ( 34 GPU_TYPE, 35 HAS_CUDA, 36 HAS_GPU, 37 HAS_MULTIGPU, 38 requires_gpu, 39) 40from torch.utils._triton import has_triton 41 42 43try: 44 from .mock_cache import global_stats, patch_fbcode, PatchCaches 45except ImportError: 46 from mock_cache import global_stats, patch_fbcode, PatchCaches # @manual 47 48 49HAS_TRITON = has_triton() 50 51if HAS_TRITON: 52 import triton # @manual 53 54 from torch.testing._internal.triton_utils import add_kernel 55 56requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton") 57 58torch._dynamo.config.fake_tensor_cache_enabled = True 59torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True 60 61 62class MyModel(torch.nn.Module): 63 def __init__(self) -> None: 64 super().__init__() 65 self.fc1 = torch.nn.Linear(10, 10) 66 67 def forward(self, inp): 68 return self.fc1(inp) 69 70 71def _run_codecache_test(start_method): 72 with torch._inductor.config.patch( 73 worker_start_method=start_method, compile_threads=16 74 ): 75 AsyncCompile.warm_pool() 76 77 model = MyModel().to(device=GPU_TYPE) 78 model = torch.compile(model) 79 inp = torch.rand(10, 10).to(device=GPU_TYPE) 80 model(inp).sum().backward() 81 82 83@requires_gpu() 84def test_codecache_spawn(): 85 _run_codecache_test("spawn") 86 87 88@requires_gpu() 89def test_codecache_fork(): 90 _run_codecache_test("fork") 91 92 93class MyModelConv2d(torch.nn.Module): 94 def __init__(self, dim=512): 95 super().__init__() 96 self.conv1 = torch.nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False) 97 self.conv2 = torch.nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False) 98 99 def forward(self, x): 100 x = self.conv1(x) 101 torch._dynamo.graph_break() 102 x = self.conv2(x) 103 return x 104 105 106@instantiate_parametrized_tests 107class TestFxGraphCache(TestCase): 108 device_type = GPU_TYPE 109 110 def setUp(self): 111 super().setUp() 112 counters.clear() 113 PatchCaches.setUp() 114 115 def tearDown(self): 116 super().tearDown() 117 PatchCaches.tearDown() 118 119 def reset(self): 120 torch._dynamo.reset() 121 clear_inductor_caches() 122 123 @requires_triton() 124 @config.patch({"fx_graph_cache": True}) 125 @config.patch({"fx_graph_remote_cache": False}) 126 @parametrize("device", (GPU_TYPE, "cpu")) 127 @parametrize("dtype", (torch.float32, torch.bfloat16)) 128 @parametrize("dynamic", (False, True)) 129 def test_cache_load_function(self, device, dtype, dynamic): 130 """ 131 Verify that we can populate and load functions from the cache. 132 """ 133 if device == GPU_TYPE and not HAS_GPU: 134 raise unittest.SkipTest(f"requires {GPU_TYPE}") 135 if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: 136 raise unittest.SkipTest("requires SM80 or later") 137 138 def fn(x, y): 139 return (x * 2, y @ y) 140 141 a = torch.rand(25, dtype=dtype, device=device) 142 b = torch.rand(5, 5, dtype=dtype, device=device) 143 144 compiled_fn = torch.compile(fn, dynamic=dynamic) 145 146 # A first call should miss in the cache. 147 self.assertEqual(fn(a, b), compiled_fn(a, b)) 148 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 149 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 150 self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) 151 152 # A second call should hit. (First reset so in-memory guards 153 # don't prevent compilation). 154 for m in torch._inductor.codecache.PyCodeCache.cache.values(): 155 os.remove(m.__file__) 156 self.reset() 157 self.assertEqual(fn(a, b), compiled_fn(a, b)) 158 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 159 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) 160 self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1) 161 162 @requires_triton() 163 @config.patch({"fx_graph_remote_cache": True}) 164 @parametrize("device", (GPU_TYPE, "cpu")) 165 @parametrize("dtype", (torch.float32, torch.bfloat16)) 166 @parametrize("dynamic", (False, True)) 167 def test_remote_cache_load_function(self, device, dtype, dynamic): 168 from unittest.mock import patch 169 170 if device == GPU_TYPE and not HAS_GPU: 171 raise unittest.SkipTest(f"requires {GPU_TYPE}") 172 if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: 173 raise unittest.SkipTest("requires SM80 or later") 174 175 def fn(x, y): 176 return (x * 2, y @ y) 177 178 a = torch.rand(25, dtype=dtype, device=device) 179 b = torch.rand(5, 5, dtype=dtype, device=device) 180 181 with config.patch( 182 { 183 "fx_graph_remote_cache": True, 184 } 185 ), patch.dict(os.environ), PatchCaches(): 186 os.environ.pop("TRITON_CACHE_MANAGER", None) 187 for _ in range(4): 188 with fresh_inductor_cache(): 189 compiled_fn = torch.compile(fn, dynamic=dynamic) 190 self.assertEqual(fn(a, b), compiled_fn(a, b)) 191 reset() 192 193 global_stats.report() 194 self.assertEqual(global_stats.fx_graph.num_get_hit, 3) 195 self.assertEqual(global_stats.fx_graph.num_get_miss, 1) 196 self.assertEqual(global_stats.fx_graph.num_put, 1) 197 198 @requires_triton() 199 @config.patch({"fx_graph_cache": True}) 200 @config.patch({"fx_graph_remote_cache": False}) 201 @parametrize("device", (GPU_TYPE, "cpu")) 202 @parametrize("dtype", (torch.float32, torch.float64)) 203 @parametrize("dynamic", (False, True)) 204 def test_cache_load_model(self, device, dtype, dynamic): 205 """ 206 Verify that we can populate and load models from the cache. 207 """ 208 if device == GPU_TYPE and not HAS_GPU: 209 raise unittest.SkipTest(f"requires {GPU_TYPE}") 210 211 def fn(mod, x): 212 mod.zero_grad() 213 mod(x).sum().backward() 214 return [p.grad for p in mod.parameters()] 215 216 compiled_fn = torch.compile(fn, dynamic=dynamic) 217 218 mod = MyModelConv2d().to(device=device, dtype=dtype) 219 inp = torch.randn(2, 3, 16, 16, device=device, dtype=dtype) 220 221 # The first call should see all cache misses. 222 counters.clear() 223 grads1 = compiled_fn(mod, inp) 224 self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0) 225 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 226 227 # The second should see all hits. (First reset so in-memory guards 228 # don't prevent compilation). 229 counters.clear() 230 self.reset() 231 grads2 = compiled_fn(mod, inp) 232 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) 233 self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0) 234 235 # And the results should be the same. 236 self.assertEqual(grads1, grads2) 237 238 @largeTensorTest("64GB", device=GPU_TYPE) 239 @config.patch({"fx_graph_cache": True}) 240 @config.patch({"fx_graph_remote_cache": False}) 241 @parametrize("device", (GPU_TYPE,)) 242 @parametrize("dtype", (torch.float16, torch.bfloat16)) 243 def test_cache_load_with_guards_int32_bounds(self, device, dtype): 244 """ 245 Test caching the same graph, but under conditions that introduce guards 246 for tensor sizes < int32. 247 """ 248 if device == GPU_TYPE and not HAS_GPU: 249 raise unittest.SkipTest(f"requires {GPU_TYPE}") 250 if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: 251 raise unittest.SkipTest("requires CUDA SM80 or later") 252 253 def fn(x, y): 254 return (x + x, y + y) 255 256 compiled_fn = torch.compile(fn, dynamic=True) 257 258 # Iterate over different shapes, varying whether the total 259 # size is below or above int32. For each combination, we expect 260 # different guards around whether the symbolic sizes do or do 261 # not exceed int32. 262 shapes = ( 263 ((5, 6), (7, 8)), 264 ((5, 6), (47000, 47001)), 265 ((47000, 47001), (5, 6)), 266 ) 267 for a_shape, b_shape in shapes: 268 a = torch.rand(a_shape, device=device, dtype=dtype) 269 b = torch.rand(b_shape, device=device, dtype=dtype) 270 271 # AVOID a dynamo reset here. We expect guards to have been 272 # added that will be violated with the new shape. We should 273 # see a recompilation (along with a cache miss). 274 counters.clear() 275 res1 = compiled_fn(a, b) 276 self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0) 277 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 278 279 # A second call should hit. (Reset here to force compilation). 280 counters.clear() 281 self.reset() 282 res2 = compiled_fn(a, b) 283 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) 284 self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0) 285 286 self.assertEqual(res1, res2) 287 288 @config.patch({"fx_graph_cache": True}) 289 @config.patch({"fx_graph_remote_cache": False}) 290 @parametrize("device", (GPU_TYPE, "cpu")) 291 @parametrize("dtype", (torch.float32, torch.bfloat16)) 292 def test_cache_load_with_guards_static_bounds(self, device, dtype): 293 """ 294 Test caching the same graph, but under conditions that introduce guards 295 for static bounds. 296 """ 297 if device == GPU_TYPE and not HAS_GPU: 298 raise unittest.SkipTest(f"requires {GPU_TYPE}") 299 if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: 300 raise unittest.SkipTest("requires SM80 or later") 301 302 # See lowering; for all of the pooling operators, we always guard and 303 # make the height/width static. 304 def fn(x): 305 return torch.nn.functional.adaptive_avg_pool2d(x, [5, 7]) 306 307 compiled_fn = torch.compile(fn, dynamic=True) 308 309 # Iterate over different input shapes. Each new shape should cause 310 # a cache miss. 311 shapes = ((1, 64, 8, 9), (1, 64, 9, 10), (1, 64, 10, 11)) 312 for shape in shapes: 313 x = torch.rand(shape, device=device, dtype=dtype) 314 315 # AVOID a dynamo reset here. For each cache hit, we expect guards 316 # to have been added that will be violated with each new shape. 317 # We should see a recompilation (along with a cache miss). 318 counters.clear() 319 res1 = compiled_fn(x) 320 self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0) 321 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 322 323 # A second call should hit. 324 counters.clear() 325 self.reset() 326 res2 = compiled_fn(x) 327 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) 328 self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0) 329 330 self.assertEqual(res1, res2) 331 332 @config.patch({"fx_graph_cache": True}) 333 @config.patch({"fx_graph_remote_cache": False}) 334 @parametrize("device", (GPU_TYPE, "cpu")) 335 def test_constant_handling(self, device): 336 """ 337 Test that different constants are recognized correctly. 338 """ 339 if device == GPU_TYPE and not HAS_GPU: 340 raise unittest.SkipTest(f"requires {GPU_TYPE}") 341 342 def fn1(x): 343 return x + torch.tensor(list(range(0, 12)), device=device) 344 345 def fn2(x): 346 return x + torch.tensor(list(range(1, 13)), device=device) 347 348 a = torch.rand(12, device=device) 349 350 compiled_fn1 = torch.compile(fn1) 351 compiled_fn2 = torch.compile(fn2) 352 353 # A call to fn1 should miss in the cache. 354 self.assertEqual(fn1(a), compiled_fn1(a)) 355 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 356 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 357 358 # A call to fn2 should also miss (the constant is different) 359 self.assertEqual(fn2(a), compiled_fn2(a)) 360 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) 361 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 362 363 @requires_gpu() 364 @requires_triton() 365 @config.patch({"fx_graph_cache": True}) 366 @config.patch({"fx_graph_remote_cache": False}) 367 def test_higher_order_op_bypass(self): 368 """ 369 Verify that we bypass the cache when we have higher order ops. 370 """ 371 372 def fn(x, y): 373 output = torch.zeros_like(x) 374 n_elements = output.numel() 375 grid = lambda meta: ( # noqa: E731 376 triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 377 ) 378 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4) 379 return output 380 381 compiled_fn = torch.compile(fn, fullgraph=True) 382 383 x = torch.randn(4, device=GPU_TYPE) 384 y = torch.randn(4, device=GPU_TYPE) 385 compiled_fn(x, y) 386 387 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) 388 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 389 self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0) 390 391 @config.patch({"fx_graph_cache": True}) 392 @config.patch({"fx_graph_remote_cache": False}) 393 def test_generated_kernel_count(self): 394 """ 395 Test that we bump the generated_kernel_count metric on a cache hit. 396 """ 397 398 def fn(x, y): 399 return (x * y + y,) 400 401 a = torch.rand(5, 5) 402 b = torch.rand(5, 5) 403 404 compiled_fn = torch.compile(fn) 405 406 metrics.reset() 407 self.assertEqual(metrics.generated_kernel_count, 0) 408 409 # Verify the "miss" case. 410 self.assertEqual(fn(a, b), compiled_fn(a, b)) 411 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 412 self.assertEqual(metrics.generated_kernel_count, 1) 413 414 # Verify the "hit" case 415 self.reset() 416 self.assertEqual(fn(a, b), compiled_fn(a, b)) 417 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) 418 self.assertEqual(metrics.generated_kernel_count, 2) 419 420 @config.patch({"fx_graph_cache": True}) 421 @config.patch({"fx_graph_remote_cache": False}) 422 def test_inductor_counters(self): 423 """ 424 Test that we bump the inductor counters on a cache hit. 425 """ 426 compile_to_fn = GraphLowering.compile_to_fn 427 428 counter_name = "a_test_counter" 429 counter_incr = 7 430 431 def bump_counter(self): 432 # Mock that bumps some arbitrary test counter by a set amount, then calls 433 # the original GraphLowering.compile_to_fn. 434 counters["inductor"][counter_name] += counter_incr 435 return compile_to_fn(self) 436 437 with mock.patch.object(GraphLowering, "compile_to_fn", bump_counter): 438 439 def fn(a, b): 440 return torch.mm(a, b) 441 442 a = torch.rand(8, 32, device="cpu") 443 b = torch.rand(32, 8, device="cpu") 444 445 compiled_fn = torch.compile(fn) 446 447 # Verify the "miss" case. 448 counter_val = 2 449 counters["inductor"][counter_name] = counter_val 450 self.assertEqual(fn(a, b), compiled_fn(a, b)) 451 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 452 self.assertEqual( 453 counters["inductor"][counter_name], counter_val + counter_incr 454 ) 455 456 # Verify the "hit" case. 457 self.reset() 458 counter_val = 5 459 counters["inductor"][counter_name] = counter_val 460 self.assertEqual(fn(a, b), compiled_fn(a, b)) 461 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) 462 self.assertEqual( 463 counters["inductor"][counter_name], counter_val + counter_incr 464 ) 465 466 @config.patch({"fx_graph_cache": True}) 467 @config.patch({"fx_graph_remote_cache": False}) 468 def test_cache_clear(self): 469 """ 470 Test clearing the cache. 471 """ 472 473 def fn(x, y): 474 return (x * y,) 475 476 a = torch.rand(5, 5) 477 b = torch.rand(5, 5) 478 479 compiled_fn = torch.compile(fn) 480 481 # A first call should miss in the cache. 482 self.assertEqual(fn(a, b), compiled_fn(a, b)) 483 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 484 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 485 486 # A second call should hit. 487 counters.clear() 488 self.reset() 489 self.assertEqual(fn(a, b), compiled_fn(a, b)) 490 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) 491 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) 492 493 # Clear the cache; now we should miss. 494 counters.clear() 495 self.reset() 496 torch._inductor.codecache.FxGraphCache.clear() 497 self.assertEqual(fn(a, b), compiled_fn(a, b)) 498 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 499 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 500 501 @config.patch({"fx_graph_cache": True}) 502 @config.patch({"fx_graph_remote_cache": False}) 503 def test_cache_with_nt(self): 504 def gen_nt(r): 505 values = torch.randn(r, 16) 506 offsets = torch.tensor([0, 2, 3, 6, 13, r]) 507 return torch.nested.nested_tensor_from_jagged(values, offsets) 508 509 def fn(nt): 510 if nt.values().size(0) % 16 == 0: 511 return nt.sin() 512 return nt.cos() 513 514 inp1 = gen_nt(19) 515 inp2 = gen_nt(20) 516 517 counters.clear() 518 torch.compile(fn)(inp1) 519 torch.compile(fn)(inp2) 520 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 521 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 522 523 self.reset() 524 counters.clear() 525 torch.compile(fn)(inp1) 526 torch.compile(fn)(inp2) 527 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) 528 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) 529 530 @config.patch({"fx_graph_cache": True}) 531 @config.patch({"fx_graph_remote_cache": False}) 532 def test_cache_with_symint_non_arg_guard(self): 533 def fn(x, ref_id): 534 self_id = 22 535 if self_id == ref_id: 536 x = torch.mul(x, 1.0) 537 else: 538 x = torch.mul(x, 0) 539 return x 540 541 x = torch.ones(2) 542 543 counters.clear() 544 torch.compile(fn, fullgraph=True, dynamic=True)(x, 2) 545 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 546 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 547 548 self.reset() 549 counters.clear() 550 torch.compile(fn, fullgraph=True, dynamic=True)(x, 2) 551 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) 552 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) 553 554 @config.patch({"fx_graph_cache": True}) 555 @config.patch({"fx_graph_remote_cache": False}) 556 def test_cache_guard(self): 557 def f(x, val): 558 if val > 5: 559 return x.sin() 560 else: 561 return x.cos() 562 563 x = torch.ones(2) 564 a = torch.compile(f, dynamic=True)(x, 6) 565 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 566 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 567 568 self.reset() 569 counters.clear() 570 b = torch.compile(f, dynamic=True)(x, 4) 571 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 572 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) 573 574 self.assertNotEqual(a, b) 575 576 577class TestFxGraphCacheHashing(TestCase): 578 def test_tensor_constants(self): 579 """ 580 Test the hashing of tensor constants. 581 """ 582 data = FxGraphCachePickler.dumps(torch.tensor(list(range(9)))) 583 self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues) 584 585 def test_hash_fake_tensors(self): 586 """ 587 Test hashing (pickling) FakeTensors with various characteristics. 588 """ 589 with torch._subclasses.FakeTensorMode(): 590 # Verify that FakeTensors get pickled into a TensorMetadata: 591 data = FxGraphCachePickler.dumps(torch.randn(1)) 592 self.assertIsInstance(pickle.loads(data), TensorMetadata) 593 594 # Different shapes: 595 self.assertEqual( 596 FxGraphCachePickler.dumps(torch.randn(3)), 597 FxGraphCachePickler.dumps(torch.randn(3)), 598 ) 599 self.assertNotEqual( 600 FxGraphCachePickler.dumps(torch.randn(3)), 601 FxGraphCachePickler.dumps(torch.randn(4)), 602 ) 603 self.assertNotEqual( 604 FxGraphCachePickler.dumps(torch.randn(3)), 605 FxGraphCachePickler.dumps(torch.randn(3, 3)), 606 ) 607 608 self.assertEqual( 609 FxGraphCachePickler.dumps(torch.randn(3, 3)), 610 FxGraphCachePickler.dumps(torch.randn(3, 3)), 611 ) 612 self.assertNotEqual( 613 FxGraphCachePickler.dumps(torch.randn(3, 3)), 614 FxGraphCachePickler.dumps(torch.randn(3, 4)), 615 ) 616 self.assertNotEqual( 617 FxGraphCachePickler.dumps(torch.randn(3, 3)), 618 FxGraphCachePickler.dumps(torch.randn(4, 3)), 619 ) 620 621 # Different strides: 622 self.assertEqual( 623 FxGraphCachePickler.dumps(torch.randn(3, 3)), 624 FxGraphCachePickler.dumps( 625 torch.randn(3, 3).transpose(0, 1).transpose(0, 1) 626 ), 627 ) 628 self.assertNotEqual( 629 FxGraphCachePickler.dumps(torch.randn(3, 3)), 630 FxGraphCachePickler.dumps(torch.randn(3, 3).transpose(0, 1)), 631 ) 632 633 # Different storage offsets: 634 self.assertEqual( 635 FxGraphCachePickler.dumps(torch.randn(3)[1:]), 636 FxGraphCachePickler.dumps(torch.randn(3)[1:]), 637 ) 638 self.assertEqual( 639 FxGraphCachePickler.dumps(torch.randn(3)[1:]), 640 FxGraphCachePickler.dumps(torch.randn(2)), 641 ) 642 643 # Different dtypes: 644 self.assertEqual( 645 FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)), 646 FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)), 647 ) 648 self.assertNotEqual( 649 FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)), 650 FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float64)), 651 ) 652 653 # Different 'requires_grad': 654 self.assertEqual( 655 FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)), 656 FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)), 657 ) 658 self.assertNotEqual( 659 FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)), 660 FxGraphCachePickler.dumps(torch.randn(3, requires_grad=False)), 661 ) 662 663 # Different memory formats: 664 self.assertNotEqual( 665 FxGraphCachePickler.dumps(torch.randn(1, 2, 3, 4)), 666 FxGraphCachePickler.dumps( 667 torch.randn(1, 2, 3, 4).to(memory_format=torch.channels_last) 668 ), 669 ) 670 671 # Different devices: 672 self.assertEqual( 673 FxGraphCachePickler.dumps(torch.randn(3, device="meta")), 674 FxGraphCachePickler.dumps(torch.randn(3, device="meta")), 675 ) 676 self.assertNotEqual( 677 FxGraphCachePickler.dumps(torch.randn(3, device="meta")), 678 FxGraphCachePickler.dumps(torch.randn(3, device="cpu")), 679 ) 680 681 if HAS_MULTIGPU: 682 self.assertEqual( 683 FxGraphCachePickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:1")), 684 FxGraphCachePickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:1")), 685 ) 686 self.assertNotEqual( 687 FxGraphCachePickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:0")), 688 FxGraphCachePickler.dumps(torch.randn(3, device=f"{GPU_TYPE}:1")), 689 ) 690 691 def test_hash_kwargs(self): 692 """ 693 Test the special handling of the kwargs when hashing, i.e., 694 ordering of the kwargs dict and any set arguments. 695 """ 696 # Dict order of the kwargs should not affect hashes. 697 details1 = FxGraphHashDetails(None, [], {"a": 0, "z": 1}, []) 698 details2 = FxGraphHashDetails(None, [], {"z": 1, "a": 0}, []) 699 self.assertEqual( 700 FxGraphCachePickler.dumps(details1), 701 FxGraphCachePickler.dumps(details2), 702 ) 703 704 # Different kwarg values should affect hashes. 705 details1 = FxGraphHashDetails(None, [], {"a": 0}, []) 706 details2 = FxGraphHashDetails(None, [], {"a": 1}, []) 707 self.assertNotEqual( 708 FxGraphCachePickler.dumps(details1), 709 FxGraphCachePickler.dumps(details2), 710 ) 711 712 # Set order should not affect hashes. Sets are unordered, but 713 # sorting and creating a new set seems to change the order. 714 set1 = {"a", "b", "c", "d", "e", "f", "g"} 715 set2 = set(sorted(set1)) # noqa: C414 716 details1 = FxGraphHashDetails(None, [], {"a": set1}, []) 717 details2 = FxGraphHashDetails(None, [], {"a": set2}, []) 718 self.assertEqual( 719 FxGraphCachePickler.dumps(details1), 720 FxGraphCachePickler.dumps(details2), 721 ) 722 723 # But different set contents should affect hashes. 724 details1 = FxGraphHashDetails(None, [], {"a": {1, 2, 3}}, []) 725 details2 = FxGraphHashDetails(None, [], {"a": {1, 2}}, []) 726 self.assertNotEqual( 727 FxGraphCachePickler.dumps(details1), 728 FxGraphCachePickler.dumps(details2), 729 ) 730 731 def test_hash_config_changes(self): 732 """ 733 Test that different config settings affect hashes. 734 """ 735 with config.patch({"max_autotune": False}): 736 details1 = FxGraphHashDetails(None, [], {}, []) 737 details2 = FxGraphHashDetails(None, [], {}, []) 738 739 with config.patch({"max_autotune": True}): 740 details3 = FxGraphHashDetails(None, [], {}, []) 741 742 self.assertEqual( 743 FxGraphCachePickler.dumps(details1), 744 FxGraphCachePickler.dumps(details2), 745 ) 746 self.assertNotEqual( 747 FxGraphCachePickler.dumps(details1), 748 FxGraphCachePickler.dumps(details3), 749 ) 750 751 @unittest.skipIf(not HAS_CUDA, "Requires CUDA") 752 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 753 def test_cuda_compile_command(self): 754 cmd_no_extra_args: str = cuda_compile_command( 755 ["abc.cu", "def.cu"], "output", "so" 756 ) 757 assert "nvcc " in cmd_no_extra_args, cmd_no_extra_args 758 assert "abc.cu" in cmd_no_extra_args, cmd_no_extra_args 759 assert "def.cu" in cmd_no_extra_args, cmd_no_extra_args 760 assert "output" in cmd_no_extra_args, cmd_no_extra_args 761 cmd_extra_args: str = cuda_compile_command( 762 ["abc.cu", "def.cu"], "output", "so", ["-Wwhatever", "-nothing"] 763 ) 764 assert "nvcc " in cmd_extra_args, cmd_extra_args 765 assert " -Wwhatever" in cmd_extra_args, cmd_extra_args 766 assert " -nothing" in cmd_extra_args, cmd_extra_args 767 assert "abc.cu" in cmd_extra_args, cmd_extra_args 768 assert "def.cu" in cmd_extra_args, cmd_extra_args 769 assert "output " in cmd_extra_args, cmd_extra_args 770 with mock.patch("subprocess.check_output") as check_output_mock: 771 CUDACodeCache.compile("test123.cu", "so", ["-Wsomething"]) 772 check_output_mock.assert_called() 773 cmd_parts: List[str] = check_output_mock.call_args[0][0] 774 assert cmd_parts[0] == "nvcc", cmd_parts 775 assert "-Wsomething" in cmd_parts, cmd_parts 776 assert "-DNDEBUG" in cmd_parts, cmd_parts 777 778 779@instantiate_parametrized_tests 780class TestAutotuneCache(TestCase): 781 device_type = GPU_TYPE 782 783 def setUp(self): 784 super().setUp() 785 counters.clear() 786 PatchCaches.setUp() 787 788 def tearDown(self): 789 super().tearDown() 790 PatchCaches.tearDown() 791 792 def reset(self): 793 torch._dynamo.reset() 794 clear_inductor_caches() 795 796 @unittest.skipIf(not HAS_CUDA, "Requires CUDA") 797 @unittest.skipIf(not SM80OrLater, "Requires SM80+") 798 @config.patch({"fx_graph_cache": False}) 799 @config.patch({"fx_graph_remote_cache": False}) 800 @config.patch({"autotune_local_cache": False}) 801 @config.patch({"autotune_remote_cache": True}) 802 @config.patch({"max_autotune": True}) 803 @parametrize("fbcode", (False,) + (True,) * config.is_fbcode()) 804 def test_autotune_cache(self, fbcode: bool): 805 class Model(torch.nn.Module): 806 def forward(self, x, y, a, b): 807 return x + y, a + b 808 809 def f(x, y, a, b): 810 return Model()(x, y, a, b) 811 812 x = torch.randn(100, 100).cuda() 813 y = torch.randn(100, 100).cuda() 814 a = torch.randn(1000, 100).cuda() 815 b = torch.randn(1000, 100).cuda() 816 f_compiled = torch.compile(f, fullgraph=True) 817 818 with PatchCaches(), patch_fbcode(fbcode): 819 f_compiled(x, y, a, b) 820 821 self.assertEqual(global_stats.autotune.num_get_hit, 0) 822 self.assertEqual(global_stats.autotune.num_get_miss, 2) 823 self.assertEqual(global_stats.autotune.num_put, 2) 824 825 self.reset() 826 f_compiled(x, y, a, b) 827 828 global_stats.report() 829 self.assertEqual(global_stats.autotune.num_get_hit, 2) 830 self.assertEqual(global_stats.autotune.num_get_miss, 2) 831 self.assertEqual(global_stats.autotune.num_put, 2) 832 833 834class TestUtils(TestCase): 835 @config.patch({"fx_graph_remote_cache": False}) 836 def test_fresh_inductor_cache(self): 837 def fn(x, y): 838 return x + y 839 840 a = torch.rand(10) 841 b = torch.rand(10) 842 843 with fresh_inductor_cache(): 844 self.assertEqual(len(PyCodeCache.cache.keys()), 0) 845 res1 = torch.compile(fn)(a, b) 846 cache_dir1 = cache_dir() 847 848 torch._dynamo.reset() 849 with fresh_inductor_cache(): 850 self.assertEqual(len(PyCodeCache.cache.keys()), 0) 851 res2 = torch.compile(fn)(a, b) 852 cache_dir2 = cache_dir() 853 854 self.assertEqual(res1, res2) 855 self.assertNotEqual(cache_dir1, cache_dir2) 856 857 858if __name__ == "__main__": 859 run_tests() 860