1# Owner(s): ["module: cuda"] 2 3import contextlib 4import ctypes 5import gc 6import json 7import os 8import pickle 9import random 10import subprocess 11import sys 12import tempfile 13import threading 14import unittest 15import warnings 16from copy import deepcopy 17from itertools import product 18from random import randint 19 20import psutil 21 22import torch 23import torch.cuda 24from torch import inf, nan 25from torch.cuda._memory_viz import ( 26 _profile_to_snapshot, 27 profile_plot, 28 segment_plot, 29 trace_plot, 30) 31from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast 32from torch.testing._internal.common_cuda import ( 33 _create_scaling_case, 34 _get_torch_cuda_version, 35 TEST_CUDNN, 36 TEST_MULTIGPU, 37) 38from torch.testing._internal.common_device_type import ( 39 instantiate_device_type_tests, 40 onlyCUDA, 41 onlyNativeDeviceTypes, 42) 43from torch.testing._internal.common_optimizers import ( 44 _get_optim_inputs_including_global_cliquey_kwargs, 45 optim_db, 46 optims, 47 TensorTracker, 48) 49from torch.testing._internal.common_utils import ( 50 EXPANDABLE_SEGMENTS, 51 freeze_rng_state, 52 gcIfJetson, 53 get_cycles_per_ms, 54 instantiate_parametrized_tests, 55 IS_ARM64, 56 IS_FBCODE, 57 IS_JETSON, 58 IS_LINUX, 59 IS_SANDCASTLE, 60 IS_WINDOWS, 61 load_tests, 62 NO_MULTIPROCESSING_SPAWN, 63 parametrize, 64 run_tests, 65 serialTest, 66 skipCUDAMemoryLeakCheckIf, 67 skipCUDANonDefaultStreamIf, 68 skipIfRocm, 69 slowTest, 70 subtest, 71 TemporaryFileName, 72 TEST_CUDA, 73 TEST_CUDA_GRAPH, 74 TEST_NUMPY, 75 TEST_WITH_ROCM, 76 TestCase, 77) 78from torch.utils.checkpoint import checkpoint_sequential 79from torch.utils.viz._cycles import observe_tensor_cycles 80 81 82# load_tests from common_utils is used to automatically filter tests for 83# sharding on sandcastle. This line silences flake warnings 84load_tests = load_tests 85 86try: 87 import torchvision.models # noqa: F401 88 from torchvision.models import resnet18 # noqa: F401 89 90 HAS_TORCHVISION = True 91except ImportError: 92 HAS_TORCHVISION = False 93skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 94 95TEST_CUDAMALLOCASYNC = TEST_CUDA and ( 96 torch.cuda.get_allocator_backend() == "cudaMallocAsync" 97) 98TEST_LARGE_TENSOR = TEST_CUDA 99TEST_MEDIUM_TENSOR = TEST_CUDA 100TEST_BF16 = False 101TEST_PYNVML = not torch.cuda._HAS_PYNVML 102if TEST_CUDA: 103 TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9 104 TEST_MEDIUM_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 6e9 105 TEST_BF16 = torch.cuda.is_bf16_supported() 106 107_cycles_per_ms = None 108 109 110@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 111@torch.testing._internal.common_utils.markDynamoStrictTest 112class TestCuda(TestCase): 113 _do_cuda_memory_leak_check = True 114 _do_cuda_non_default_stream = True 115 FIFTY_MIL_CYCLES = 50000000 116 117 def setUp(self): 118 super().setUp() 119 120 def tearDown(self): 121 super().tearDown() 122 123 @property 124 def expandable_segments(self): 125 return EXPANDABLE_SEGMENTS 126 127 def test_pinned_memory_with_cudaregister(self): 128 torch.cuda.memory._set_allocator_settings( 129 "pinned_use_cuda_host_register:True,pinned_num_register_threads:8" 130 ) 131 t = torch.ones(20) 132 self.assertFalse(t.is_pinned()) 133 try: 134 pinned_t = torch.ones(1 << 21).pin_memory() 135 self.assertTrue(pinned_t.is_pinned()) 136 pinned_t = torch.ones(1 << 24).pin_memory() 137 self.assertTrue(pinned_t.is_pinned()) 138 except RuntimeError as e: 139 # Some GPUs don't support same address space on host and device side 140 pass 141 142 def test_pinned_memory_with_cudaregister_multithread(self): 143 num_threads = 4 144 threads = [ 145 threading.Thread(target=self.test_pinned_memory_with_cudaregister) 146 for t in range(num_threads) 147 ] 148 for thread in threads: 149 thread.start() 150 for thread in threads: 151 thread.join() 152 153 def test_pinned_memory_empty_cache(self): 154 for alloc_settings in (True, False): 155 torch.cuda.memory._set_allocator_settings( 156 f"pinned_use_cuda_host_register:{alloc_settings}" 157 ) 158 try: 159 t = torch.ones(1024 * 1024, pin_memory=True) 160 self.assertTrue(t.is_pinned()) 161 del t 162 torch._C._host_emptyCache() 163 except RuntimeError as e: 164 # Some GPUs don't support same address space on host and device side 165 pass 166 167 def test_cudart_register(self): 168 t = torch.ones(20) 169 self.assertFalse(t.is_pinned()) 170 cudart = torch.cuda.cudart() 171 r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) 172 self.assertEqual(r, 0) 173 self.assertTrue(t.is_pinned()) 174 r = cudart.cudaHostUnregister(t.data_ptr()) 175 self.assertEqual(r, 0) 176 self.assertFalse(t.is_pinned()) 177 178 def test_memory_allocation(self): 179 gc.collect() 180 torch.cuda.empty_cache() 181 mem = None 182 size = 1 183 prev = 0 184 try: 185 prev = torch.cuda.memory_allocated() 186 mem = torch.cuda.caching_allocator_alloc(size) 187 self.assertGreater(torch.cuda.memory_allocated(), prev) 188 finally: 189 if mem is not None: 190 torch.cuda.caching_allocator_delete(mem) 191 self.assertEqual(torch.cuda.memory_allocated(), prev) 192 193 def test_check_error(self): 194 # Assert this call doesn't raise. 195 torch.cuda.check_error(0) 196 197 with self.assertRaisesRegex( 198 torch.cuda.CudaError, "out of memory|hipErrorOutOfMemory" 199 ): 200 torch.cuda.check_error(2) 201 202 def test_cuda_get_device_name(self): 203 # Testing the behaviour with None as an argument 204 current_device = torch.cuda.current_device() 205 current_device_name = torch.cuda.get_device_name(current_device) 206 device_name_None = torch.cuda.get_device_name(None) 207 self.assertEqual(current_device_name, device_name_None) 208 209 # Testing the behaviour for No argument 210 device_name_no_argument = torch.cuda.get_device_name() 211 self.assertEqual(current_device_name, device_name_no_argument) 212 213 def test_cuda_get_device_capability(self): 214 # Testing the behaviour with None as an argument 215 current_device = torch.cuda.current_device() 216 current_device_capability = torch.cuda.get_device_capability(current_device) 217 device_capability_None = torch.cuda.get_device_capability(None) 218 self.assertEqual(current_device_capability, device_capability_None) 219 220 # Testing the behaviour for No argument 221 device_capability_no_argument = torch.cuda.get_device_capability() 222 self.assertEqual(current_device_capability, device_capability_no_argument) 223 224 def test_out_of_memory(self): 225 tensor = torch.zeros(1024, device="cuda") 226 227 oom_regex = ( 228 "would exceed allowed memory" 229 if TEST_CUDAMALLOCASYNC 230 else "Tried to allocate 800000000.00 GiB" 231 ) 232 with self.assertRaisesRegex(RuntimeError, oom_regex): 233 torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device="cuda") 234 235 with self.assertRaisesRegex( 236 RuntimeError, "Tried to allocate more than 1EB memory" 237 ): 238 torch.empty( 239 1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="cuda" 240 ) 241 242 # ensure out of memory error doesn't disturb subsequent kernel 243 tensor.fill_(1) 244 self.assertTrue((tensor == 1).all()) 245 246 @unittest.skipIf( 247 TEST_CUDAMALLOCASYNC or IS_JETSON, "Segmentation fault (core dumped)" 248 ) 249 @serialTest() 250 def test_out_of_memory_retry(self): 251 torch.cuda.empty_cache() 252 total_memory = torch.cuda.get_device_properties(0).total_memory 253 oom_regex = ( 254 "would exceed allowed memory" 255 if TEST_CUDAMALLOCASYNC 256 else "Tried to allocate" 257 ) 258 size = int(total_memory * 0.5) 259 a = torch.empty(size, dtype=torch.int8, device="cuda") 260 with self.assertRaisesRegex(RuntimeError, oom_regex): 261 b = torch.empty(size, dtype=torch.int8, device="cuda") 262 del a 263 b = torch.empty(size, dtype=torch.int8, device="cuda") 264 del b 265 # We used a lot of memory here, clean up so we don't affect other tests too much 266 torch.cuda.empty_cache() 267 torch.cuda.reset_peak_memory_stats() 268 269 @serialTest() 270 def test_set_per_process_memory_fraction(self): 271 # test invalid fraction value. 272 with self.assertRaisesRegex(TypeError, "Invalid type"): 273 torch.cuda.set_per_process_memory_fraction(1) 274 with self.assertRaisesRegex(ValueError, "Invalid fraction value"): 275 torch.cuda.set_per_process_memory_fraction(-0.1) 276 with self.assertRaisesRegex(ValueError, "Invalid fraction value"): 277 torch.cuda.set_per_process_memory_fraction(2.0) 278 279 tensor = torch.zeros(1024, device="cuda") 280 torch.cuda.empty_cache() 281 total_memory = torch.cuda.get_device_properties(0).total_memory 282 torch.cuda.set_per_process_memory_fraction(0.5, 0) 283 284 # test 0.499 allocation is ok. 285 application = int(total_memory * 0.499) - torch.cuda.max_memory_reserved() 286 tmp_tensor = torch.empty(application, dtype=torch.int8, device="cuda") 287 del tmp_tensor 288 torch.cuda.empty_cache() 289 290 application = int(total_memory * 0.5) 291 # it will get OOM when try to allocate more than half memory. 292 oom_regex = ( 293 "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory" 294 ) 295 with self.assertRaisesRegex(RuntimeError, oom_regex): 296 torch.empty(application, dtype=torch.int8, device="cuda") 297 298 # ensure out of memory error doesn't disturb subsequent kernel 299 tensor.fill_(1) 300 self.assertTrue((tensor == 1).all()) 301 302 @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "uuid attribute not yet available") 303 def test_uuid(self): 304 uuid = torch.cuda.get_device_properties(0).uuid 305 self.assertEqual(len(str(uuid)), 36) # xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx 306 self.assertEqual(len(uuid.bytes), 16) 307 308 def test_copy_non_blocking(self): 309 def _test_copy_non_blocking(a, b): 310 event = torch.cuda.Event() 311 a.copy_(b, non_blocking=True) 312 event.record() 313 event.synchronize() 314 self.assertEqual(a, b) 315 316 # 10MB copies 317 x = torch.ones(10000000, dtype=torch.uint8).cuda() 318 y = torch.zeros(10000000, dtype=torch.uint8).pin_memory() 319 _test_copy_non_blocking(x, y) 320 321 x = torch.zeros(10000000, dtype=torch.uint8).pin_memory() 322 y = torch.ones(10000000, dtype=torch.uint8).cuda() 323 _test_copy_non_blocking(x, y) 324 325 # Test the case where the pinned data_ptr is not equal to the storage data_ptr. 326 x_base = torch.zeros(10000000, dtype=torch.uint8).pin_memory() 327 x = x_base[1:] 328 self.assertTrue(x.is_pinned()) 329 self.assertTrue(x_base.is_pinned()) 330 self.assertNotEqual(x_base.data_ptr(), x.data_ptr()) 331 self.assertEqual(x_base.storage().data_ptr(), x.storage().data_ptr()) 332 y = torch.ones(10000000 - 1, dtype=torch.uint8).cuda() 333 _test_copy_non_blocking(x, y) 334 335 def test_copy_non_blocking_type_conversion(self): 336 a = torch.ones(1, device="cuda") 337 b = torch.zeros(1, device="cpu", pin_memory=True) 338 c = torch.empty(1, device="cuda", dtype=torch.long) 339 torch.cuda._sleep(int(100 * get_cycles_per_ms())) 340 b.copy_(a, non_blocking=True) 341 c.copy_(b, non_blocking=True) 342 self.assertEqual(a, c, exact_dtype=False) 343 344 @serialTest() 345 def test_to_non_blocking(self): 346 stream = torch.cuda.current_stream() 347 348 def _test_to_non_blocking(a, non_blocking, dst): 349 torch.cuda.synchronize() 350 # Pushes an 0.1 second spin to stream so if the copy is non blocking, 351 # stream will almost surely be active when we query(). 352 torch.cuda._sleep(int(100 * get_cycles_per_ms())) 353 b = a.to(device=dst, non_blocking=non_blocking) 354 self.assertEqual(stream.query(), not non_blocking) 355 stream.synchronize() 356 self.assertEqual(a, b) 357 self.assertTrue(b.is_pinned() == (non_blocking and dst == "cpu")) 358 359 for dst, try_non_blocking in product(("cuda", "cpu"), (True, False)): 360 # Creates source on the opposite device from destination. 361 src = torch.randn( 362 1000000, 363 device="cuda" if dst == "cpu" else "cpu", 364 pin_memory=True if dst == "cuda" else False, 365 ) 366 _test_to_non_blocking(src, try_non_blocking, dst) 367 368 def test_to_cpu_blocking_by_default(self): 369 src = torch.randn(1000000, device="cuda") 370 torch.cuda.synchronize() 371 torch.cuda._sleep(int(100 * get_cycles_per_ms())) 372 dst = src.to(device="cpu") 373 self.assertEqual(torch.cuda.current_stream().query(), True) 374 self.assertEqual(src, dst) 375 self.assertFalse(dst.is_pinned()) 376 377 def test_serialization_array_with_storage(self): 378 x = torch.randn(5, 5).cuda() 379 y = torch.IntTensor(2, 5).fill_(0).cuda() 380 q = [x, y, x, y.storage()] 381 with tempfile.NamedTemporaryFile() as f: 382 torch.save(q, f) 383 f.seek(0) 384 q_copy = torch.load(f) 385 self.assertEqual(q_copy, q, atol=0, rtol=0) 386 q_copy[0].fill_(5) 387 self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0) 388 self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor)) 389 self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor)) 390 self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor)) 391 self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage)) 392 self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage)) 393 q_copy[1].fill_(10) 394 self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) 395 396 @unittest.skipIf( 397 TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async" 398 ) 399 @unittest.skipIf( 400 _get_torch_cuda_version() >= (12, 2), 401 "skipped as explicit workspace allocation is removed", 402 ) 403 def test_cublas_workspace_explicit_allocation(self): 404 a = torch.randn(7, 7, device="cuda", requires_grad=False) 405 default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024 # :4096:2:16:8 406 # different size (32 MiB) expected on Hopper GPU 407 if torch.cuda.get_device_capability() == (9, 0): 408 default_workspace_size = 4096 * 8 * 1024 409 410 def check_workspace_size(inp): 411 torch._C._cuda_clearCublasWorkspaces() 412 start = torch.cuda.memory_stats()["active_bytes.all.allocated"] 413 with torch.no_grad(): 414 torch.matmul(inp, inp) 415 finish = torch.cuda.memory_stats()["active_bytes.all.allocated"] 416 return finish - start 417 418 # check default 419 os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" 420 self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288) 421 422 # check default with bad user config 423 os.environ["CUBLAS_WORKSPACE_CONFIG"] = "-1" 424 self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288) 425 426 # check valid config 427 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":128:8:64:16:32:32" 428 self.assertTrue(abs(check_workspace_size(a) - (3072 * 1024)) < 524288) 429 430 torch._C._cuda_clearCublasWorkspaces() 431 432 def test_cublas_allow_tf32_get_set(self): 433 skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( 434 os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] 435 ) 436 if skip_tf32_cublas: 437 self.assertTrue(torch.backends.cuda.matmul.allow_tf32) 438 return 439 440 orig = torch.backends.cuda.matmul.allow_tf32 441 self.assertEqual(torch._C._get_cublas_allow_tf32(), orig) 442 torch.backends.cuda.matmul.allow_tf32 = not orig 443 self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig) 444 torch.backends.cuda.matmul.allow_tf32 = orig 445 446 def test_float32_matmul_precision_get_set(self): 447 orig = torch.get_float32_matmul_precision() 448 skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( 449 os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] 450 ) 451 # this is really just checking that the environment variable is respected during testing 452 # and not overwritten by another function that doesn't revert it to the intitial value 453 if not skip_tf32_cublas: 454 self.assertFalse(torch.backends.cuda.matmul.allow_tf32) 455 self.assertEqual(torch.get_float32_matmul_precision(), "highest") 456 else: 457 self.assertTrue(torch.backends.cuda.matmul.allow_tf32) 458 for p in ("medium", "high"): 459 torch.set_float32_matmul_precision(p) 460 self.assertEqual(torch.get_float32_matmul_precision(), p) 461 self.assertTrue(torch.backends.cuda.matmul.allow_tf32) 462 torch.set_float32_matmul_precision("highest") 463 self.assertEqual(torch.get_float32_matmul_precision(), "highest") 464 self.assertFalse(torch.backends.cuda.matmul.allow_tf32) 465 torch.set_float32_matmul_precision(orig) 466 467 def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self): 468 orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction 469 self.assertEqual( 470 torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), orig 471 ) 472 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = not orig 473 self.assertEqual( 474 torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig 475 ) 476 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig 477 478 def test_cublas_allow_bf16_reduced_precision_reduction_get_set(self): 479 orig = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction 480 self.assertEqual( 481 torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), orig 482 ) 483 torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = not orig 484 self.assertEqual( 485 torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig 486 ) 487 torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig 488 489 def test_cudnn_allow_tf32_get_set(self): 490 with torch.backends.cudnn.flags( 491 enabled=None, benchmark=None, deterministic=None, allow_tf32=False 492 ): 493 self.assertFalse(torch.backends.cudnn.allow_tf32) 494 with torch.backends.cudnn.flags( 495 enabled=None, benchmark=None, deterministic=None, allow_tf32=True 496 ): 497 self.assertTrue(torch.backends.cudnn.allow_tf32) 498 499 def test_type_conversions(self): 500 x = torch.randn(5, 5) 501 self.assertIsInstance(x.float(), torch.FloatTensor) 502 self.assertIsInstance(x.cuda().double(), torch.cuda.DoubleTensor) 503 self.assertIsInstance(x.cuda().float(), torch.cuda.FloatTensor) 504 self.assertIsInstance(x.cuda().float().cpu(), torch.FloatTensor) 505 self.assertIsInstance(x.cuda().float().cpu().int(), torch.IntTensor) 506 507 y = x.storage() 508 self.assertIsInstance(y.float(), torch.FloatStorage) 509 self.assertIsInstance(y.cuda().double(), torch.cuda.DoubleStorage) 510 self.assertIsInstance(y.cuda().float(), torch.cuda.FloatStorage) 511 self.assertIsInstance(y.cuda().float().cpu(), torch.FloatStorage) 512 self.assertIsInstance(y.cuda().float().cpu().int(), torch.IntStorage) 513 514 @unittest.skip("was disabled due to not enough memory, but actually it always fail") 515 def test_arithmetic_large_tensor(self): 516 x = torch.empty(2**30, device="cuda") 517 518 x.fill_(1) 519 self.assertEqual(x.sum(), 2**30) 520 521 x += 1 522 self.assertEqual(x.sum(), 2**31) 523 524 x.fill_(1) 525 x -= 0.5 526 self.assertEqual(x.sum(), 2**29) 527 528 x.fill_(1) 529 x *= 2 530 self.assertEqual(x.sum(), 2**31) 531 532 x.fill_(1) 533 x /= 2 534 self.assertEqual(x.sum(), 2**29) 535 536 def test_gather_bool(self): 537 t = torch.tensor([[False, True], [True, True]], device="cuda") 538 self.assertEqual( 539 torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]], device="cuda")), 540 torch.tensor([[False, False], [True, True]], device="cuda"), 541 ) 542 543 def test_torch_manual_seed_seeds_cuda_devices(self): 544 with freeze_rng_state(): 545 x = torch.zeros(4, 4).float().cuda() 546 torch.manual_seed(2) 547 self.assertEqual(torch.cuda.initial_seed(), 2) 548 x.uniform_() 549 torch.manual_seed(2) 550 y = x.clone().uniform_() 551 self.assertEqual(x, y) 552 self.assertEqual(torch.cuda.initial_seed(), 2) 553 554 def test_manual_seed(self): 555 with freeze_rng_state(): 556 x = torch.zeros(4, 4).float().cuda() 557 torch.cuda.manual_seed(2) 558 self.assertEqual(torch.cuda.initial_seed(), 2) 559 x.uniform_() 560 a = torch.bernoulli(torch.full_like(x, 0.5)) 561 torch.cuda.manual_seed(2) 562 y = x.clone().uniform_() 563 b = torch.bernoulli(torch.full_like(x, 0.5)) 564 self.assertEqual(x, y) 565 self.assertEqual(a, b) 566 self.assertEqual(torch.cuda.initial_seed(), 2) 567 568 def test_specify_improper_device_name(self): 569 import os 570 571 fname = "tempfile.pt" 572 try: 573 with self.assertRaisesRegex(RuntimeError, "Invalid device string"): 574 torch.save( 575 [torch.nn.Parameter(torch.randn(10, 10))], 576 fname, 577 _use_new_zipfile_serialization=True, 578 ) 579 torch.load(fname, "cuda0") 580 finally: 581 if os.path.exists(fname): 582 os.remove(fname) 583 584 def test_get_device_index(self): 585 from torch.cuda._utils import _get_device_index 586 587 with self.assertRaisesRegex(RuntimeError, "Invalid device string"): 588 _get_device_index("cuda0", optional=True) 589 590 with self.assertRaisesRegex(ValueError, "Expected a cuda device"): 591 cpu_device = torch.device("cpu") 592 _get_device_index(cpu_device, optional=True) 593 594 def test_serialization_array_with_empty(self): 595 x = [torch.randn(4, 4).cuda(), torch.cuda.FloatTensor()] 596 with tempfile.NamedTemporaryFile() as f: 597 torch.save(x, f) 598 f.seek(0) 599 x_copy = torch.load(f) 600 for original, copy in zip(x, x_copy): 601 self.assertEqual(copy, original) 602 self.assertIs(type(copy), type(original)) 603 self.assertEqual(copy.get_device(), original.get_device()) 604 605 @skipCUDANonDefaultStreamIf(True) 606 def test_streams(self): 607 default_stream = torch.cuda.current_stream() 608 user_stream = torch.cuda.Stream() 609 self.assertEqual(torch.cuda.current_stream(), default_stream) 610 self.assertNotEqual(default_stream, user_stream) 611 self.assertEqual(default_stream.cuda_stream, 0) 612 self.assertNotEqual(user_stream.cuda_stream, 0) 613 with torch.cuda.stream(user_stream): 614 self.assertEqual(torch.cuda.current_stream(), user_stream) 615 self.assertTrue(user_stream.query()) 616 tensor1 = torch.ByteTensor(5).pin_memory() 617 tensor2 = tensor1.cuda(non_blocking=True) + 1 618 default_stream.synchronize() 619 self.assertTrue(default_stream.query()) 620 621 def test_stream_event_repr(self): 622 s = torch.cuda.current_stream() 623 self.assertTrue("torch.cuda.Stream" in s.__repr__()) 624 e = torch.cuda.Event() 625 self.assertTrue("torch.cuda.Event" in e.__repr__()) 626 s.record_event(e) 627 self.assertTrue("torch.cuda.Event" in e.__repr__()) 628 629 def test_events(self): 630 stream = torch.cuda.current_stream() 631 event = torch.cuda.Event(enable_timing=True) 632 self.assertTrue(event.query()) 633 start_event = torch.cuda.Event(enable_timing=True) 634 stream.record_event(start_event) 635 torch.cuda._sleep(int(50 * get_cycles_per_ms())) 636 stream.record_event(event) 637 self.assertFalse(event.query()) 638 event.synchronize() 639 self.assertTrue(event.query()) 640 self.assertGreater(start_event.elapsed_time(event), 0) 641 642 def test_generic_stream_event(self): 643 stream = torch.Stream("cuda") 644 self.assertEqual(stream.device_index, torch.cuda.current_device()) 645 cuda_stream = torch.cuda.Stream( 646 stream_id=stream.stream_id, 647 device_index=stream.device_index, 648 device_type=stream.device_type, 649 ) 650 self.assertEqual(stream.stream_id, cuda_stream.stream_id) 651 self.assertNotEqual(stream.stream_id, torch.cuda.current_stream().stream_id) 652 653 event1 = torch.Event("cuda", enable_timing=True) 654 event2 = torch.Event("cuda", enable_timing=True) 655 self.assertEqual(event1.event_id, 0) 656 a = torch.randn(1000) 657 b = torch.randn(1000) 658 with torch.cuda.stream(cuda_stream): 659 a_cuda = a.to("cuda", non_blocking=True) 660 b_cuda = b.to("cuda", non_blocking=True) 661 self.assertEqual(stream.stream_id, torch.cuda.current_stream().stream_id) 662 event1.record(stream) 663 event1.synchronize() 664 self.assertTrue(event1.query()) 665 c_cuda = a_cuda + b_cuda 666 event2.record() 667 event2.synchronize() 668 self.assertTrue(event2.query()) 669 self.assertNotEqual(event1.event_id, event2.event_id) 670 self.assertEqual(c_cuda.cpu(), a + b) 671 self.assertTrue(event1.elapsed_time(event2) > 0) 672 673 def test_record_stream(self): 674 cycles_per_ms = get_cycles_per_ms() 675 676 t = torch.FloatTensor([1, 2, 3, 4]).pin_memory() 677 result = torch.cuda.FloatTensor(t.size()) 678 stream = torch.cuda.Stream() 679 ptr = [None] 680 681 # Performs the CPU->GPU copy in a background stream 682 def perform_copy(): 683 with torch.cuda.stream(stream): 684 tmp = t.cuda(non_blocking=True) 685 ptr[0] = tmp.data_ptr() 686 torch.cuda.current_stream().wait_stream(stream) 687 tmp.record_stream(torch.cuda.current_stream()) 688 torch.cuda._sleep(int(50 * cycles_per_ms)) # delay the copy 689 result.copy_(tmp) 690 691 perform_copy() 692 with torch.cuda.stream(stream): 693 tmp2 = torch.cuda.FloatTensor(t.size()) 694 tmp2.zero_() 695 self.assertNotEqual( 696 tmp2.data_ptr(), ptr[0], msg="allocation re-used to soon" 697 ) 698 699 self.assertEqual(result.tolist(), [1, 2, 3, 4]) 700 701 if not TEST_CUDAMALLOCASYNC: 702 # In the native allocator, we expect "tmp"'s side-stream-tagged block will be reused 703 # in that side stream after result.copy_(tmp) in the main stream finishes. 704 torch.cuda.current_stream().synchronize() 705 with torch.cuda.stream(stream): 706 tmp3 = torch.cuda.FloatTensor(t.size()) 707 self.assertEqual(tmp3.data_ptr(), ptr[0], msg="allocation not re-used") 708 709 def test_record_stream_on_shifted_view(self): 710 # See issue #27366 711 712 # This test detects unexpected block reallocation. For reliable test, 713 # the stream to allocate tensors is isolated. The allocator will not 714 # reuse free blocks which were allocated from another stream. 715 stream_alloc = torch.cuda.Stream() 716 with torch.cuda.stream(stream_alloc): 717 base = torch.cuda.FloatTensor([10, 10]) 718 719 # Record another stream on a shifted view tensor. 720 view = base[5:] 721 assert view.storage_offset() > 0 722 723 stream_record = torch.cuda.Stream() 724 with torch.cuda.stream(stream_record): 725 torch.cuda._sleep(int(50 * get_cycles_per_ms())) 726 727 view.record_stream(stream_record) 728 729 # Delete those tensors to make the block free soon. 730 data_ptr = base.data_ptr() 731 del base, view 732 733 # A new tensor should not be allocated to the block above. 734 stream_alloc.synchronize() 735 736 with torch.cuda.stream(stream_alloc): 737 try_realloc = torch.cuda.FloatTensor([10, 10]) 738 739 self.assertNotEqual(try_realloc.data_ptr(), data_ptr) 740 741 def test_noncontiguous_pinned_memory(self): 742 # See issue #3266 743 x = torch.arange(0, 10).view((2, 5)) 744 self.assertEqual(x.t(), x.t().pin_memory()) 745 746 def test_caching_pinned_memory(self): 747 cycles_per_ms = get_cycles_per_ms() 748 749 # check that allocations are re-used after deletion 750 t = torch.FloatTensor([1]).pin_memory() 751 ptr = t.data_ptr() 752 del t 753 t = torch.FloatTensor([1]).pin_memory() 754 self.assertEqual(t.data_ptr(), ptr, msg="allocation not reused") 755 756 # check that the allocation is not re-used if it's in-use by a copy 757 gpu_tensor = torch.cuda.FloatTensor([0]) 758 torch.cuda._sleep(int(1000 * cycles_per_ms)) # delay the copy by 1s 759 gpu_tensor.copy_(t, non_blocking=True) 760 del t 761 t = torch.FloatTensor([1]).pin_memory() 762 self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon") 763 self.assertEqual(list(gpu_tensor), [1]) 764 765 def test_caching_allocator_record_stream_oom(self): 766 """allocations delayed by a record_stream call should still be freed on 767 an out-of-memory in cuda_malloc_retry. see issue #19219""" 768 stream = torch.cuda.Stream() 769 770 with torch.cuda.stream(stream): 771 y = torch.zeros(40 * 1024 * 1024, device="cuda") 772 773 for _ in range(100): 774 x = torch.empty(40 * 1024 * 1024, device="cuda") 775 with torch.cuda.stream(stream): 776 y += x 777 # delays re-use of `x` until after all operations in `stream` 778 x.record_stream(stream) 779 del x 780 781 # we've made a mess by allocating up to the device capacity. free any 782 # cached blocks in case it affects future tests. 783 torch.cuda.empty_cache() 784 785 # Tests for historic illegal memory access, see #17040. 786 def test_reduction_gpu_memory_accessing(self): 787 x = torch.ones(512, 8, dtype=torch.float32, device="cuda") 788 torch.sum(x, 0) 789 790 def test_sum_fp16(self): 791 x = torch.zeros(10, device="cuda", dtype=torch.float16) 792 self.assertEqual(x.sum(), 0) 793 794 x = torch.ones(65504, device="cuda", dtype=torch.float16) 795 self.assertEqual(x.sum(), 65504) 796 self.assertEqual(x.sum(dtype=torch.float32), 65504) 797 798 x = torch.ones(65536, device="cuda", dtype=torch.float16) 799 self.assertEqual(x.sum(dtype=torch.float32), 65536) 800 801 a = torch.zeros(1203611).bernoulli_(0.0005) 802 x = a.to(device="cuda", dtype=torch.float16) 803 self.assertEqual(x.sum().item(), a.sum().item()) 804 805 a = torch.zeros(100, 121, 80).bernoulli_(0.0005) 806 x = a.to(device="cuda", dtype=torch.float16) 807 self.assertEqual(x.sum((0, 2)).float().cpu(), a.sum((0, 2))) 808 809 def test_mean_fp16(self): 810 x = torch.ones(65536, device="cuda", dtype=torch.float16) 811 self.assertEqual(x.mean(), 1) 812 813 x = torch.ones(65536, device="cuda", dtype=torch.float16) 814 self.assertEqual(x.mean(dtype=torch.float32), 1) 815 816 def test_prod_large(self): 817 # tests global reduction (should_global_reduce = true) in case of non-zero identity element 818 x = torch.ones(240000, device="cuda", dtype=torch.float32) 819 self.assertEqual(x.prod(), 1) 820 821 # test for complex types. Note 240k is divisible by 4 822 for dtype in [torch.cfloat, torch.cdouble]: 823 x = torch.ones(240000, device="cuda", dtype=dtype) * (0 + 1j) 824 self.assertEqual(x.prod(), 1) 825 826 def test_multinomial_ext(self): 827 # Test two corner cases from older PyTorch (Issue #4858) 828 freqs = torch.cuda.FloatTensor( 829 [ 830 0.0, 831 0.0, 832 0.0, 833 0.0, 834 0.0, 835 0.0, 836 0.0, 837 0.0, 838 0.0, 839 0.03178183361887932, 840 0.027680952101945877, 841 0.033176131546497345, 842 0.046052902936935425, 843 0.07742464542388916, 844 0.11543981730937958, 845 0.14148041605949402, 846 0.15784293413162231, 847 0.13180233538150787, 848 0.08271478116512299, 849 0.049702685326337814, 850 0.027557924389839172, 851 0.018125897273421288, 852 0.011851548217236996, 853 0.010252203792333603, 854 0.007422595750540495, 855 0.005372154992073774, 856 0.0045109698548913, 857 0.0036087757907807827, 858 0.0035267581697553396, 859 0.0018864056328311563, 860 0.0024605290964245796, 861 0.0022964938543736935, 862 0.0018453967059031129, 863 0.0010662291897460818, 864 0.0009842115687206388, 865 0.00045109697384759784, 866 0.0007791675161570311, 867 0.00020504408166743815, 868 0.00020504408166743815, 869 0.00020504408166743815, 870 0.00012302644609007984, 871 0.0, 872 0.00012302644609007984, 873 4.100881778867915e-05, 874 0.0, 875 0.0, 876 0.0, 877 0.0, 878 0.0, 879 0.0, 880 ] 881 ) 882 883 torch.cuda.manual_seed(11042) 884 sample = torch.multinomial(freqs, 1000, True) 885 self.assertNotEqual(freqs[sample].min(), 0) 886 887 p = torch.zeros(3421, 2, device="cuda", dtype=torch.float) 888 p[:, 1] = 1 889 torch.cuda.manual_seed(5214) 890 r = torch.multinomial(p, 1) 891 self.assertNotEqual(r.min().item(), 0) 892 893 # test corner case from Issue #13867 894 torch.cuda.manual_seed(33) 895 probs = torch.randn(1000000, device="cuda").clamp(min=0) * 3e-5 896 samples = probs.multinomial(1000000, replacement=True) 897 self.assertGreater(probs[samples].min().item(), 0) 898 899 def _spawn_test_multinomial_invalid_probs_cuda(self, probs): 900 import subprocess 901 902 try: 903 p = subprocess.Popen( 904 [ 905 sys.executable, 906 "-c", 907 f"""\ 908import sys 909import torch 910from torch import inf, nan 911try: 912 with torch.random.fork_rng(devices=[0]): 913 torch.multinomial(torch.tensor({probs}).to('cuda'), 2, replacement=True) 914 torch.cuda.synchronize() 915 sys.exit(-1) # Should not be reached 916except RuntimeError as e: 917 sys.exit(-2) 918""", 919 ], 920 stdout=subprocess.PIPE, 921 stderr=subprocess.PIPE, 922 universal_newlines=True, 923 ) 924 out, err = p.communicate(timeout=10) 925 p.wait(timeout=10) 926 except subprocess.TimeoutExpired as e: 927 p.kill() 928 out, err = p.communicate() 929 expected_messages = [ 930 "device-side assert triggered", # CUDA 931 "Assertion", # CUDA 932 "HSA_STATUS_ERROR_EXCEPTION", # ROCm 933 "Device-side assertion", # ROCm 934 ] 935 self.assertTrue(any(msg in out or msg in err for msg in expected_messages)) 936 937 @slowTest 938 @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts") 939 @unittest.skipIf( 940 NO_MULTIPROCESSING_SPAWN, 941 "Disabled for environments that \ 942 don't support multiprocessing with spawn start method", 943 ) 944 def test_multinomial_invalid_probs_cuda(self): 945 self._spawn_test_multinomial_invalid_probs_cuda([1.0, -1.0, 1.0]) 946 self._spawn_test_multinomial_invalid_probs_cuda([1.0, inf, 1.0]) 947 self._spawn_test_multinomial_invalid_probs_cuda([1.0, -inf, 1.0]) 948 self._spawn_test_multinomial_invalid_probs_cuda([1.0, 1.0, nan]) 949 950 @staticmethod 951 def _mute_init(): 952 os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno()) 953 954 def _spawn_method(self, method, arg): 955 ctx = torch.multiprocessing.get_context("spawn") 956 with ctx.Pool(1, initializer=self._mute_init) as pool: 957 errors = pool.map(method, [arg]) 958 for e in errors: 959 if "device-side assert triggered" not in str(e): 960 self.fail(e) 961 962 @staticmethod 963 def _test_index_bounds_cuda(idx): 964 x = torch.arange(10, device="cuda") 965 try: 966 y = x[torch.tensor([idx])] 967 return f"x[torch.tensor([{idx})]={y}" 968 except RuntimeError as err: 969 return err 970 971 @slowTest 972 @unittest.skipIf( 973 NO_MULTIPROCESSING_SPAWN, 974 "Disabled for environments that \ 975 don't support multiprocessing with spawn start method", 976 ) 977 @skipIfRocm 978 def test_index_out_of_bounds_exception_cuda(self): 979 test_method = TestCuda._test_index_bounds_cuda 980 # Test in-bound access works fine 981 self.assertEqual( 982 test_method(1), "x[torch.tensor([1)]=tensor([1], device='cuda:0')" 983 ) 984 # Test that indexing out of bounds causes assert 985 self._spawn_method(test_method, 11) 986 987 @slowTest 988 @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") 989 @serialTest() 990 def test_huge_index(self): 991 src = torch.empty(15000000, 45, device="cuda", dtype=torch.long).random_( 992 0, 2**22 993 ) 994 idx = torch.randperm(src.shape[0], device="cuda") 995 res = src[idx] 996 res_cpu = src.cpu()[idx.cpu()] 997 self.assertEqual(res.cpu(), res_cpu) 998 999 def test_randint_randomness_for_large_range(self) -> None: 1000 # For large ranges, randint generation is slightly different. This lead to a subtle bug where some Philox 1001 # offsets were not calculated correctly, resulting in reused random states. 1002 # See https://github.com/pytorch/pytorch/issues/125224 1003 size = 1_000_000 1004 high = 6_000_000_000 # Keep this above 2**32 1005 1006 def run(dev: torch.device) -> int: 1007 # Measure how many unique numbers are generated in 2 consecutive calls to randint. If random states are 1008 # reused, this will yield fewer unique numbers. 1009 gen = torch.Generator(device=dev) 1010 gen.manual_seed(0) 1011 t1 = torch.randint( 1012 0, high, [size], device=dev, generator=gen, dtype=torch.int64 1013 ) 1014 t2 = torch.randint( 1015 0, high, [size], device=dev, generator=gen, dtype=torch.int64 1016 ) 1017 return torch.stack([t1, t2]).unique().shape[0] 1018 1019 # Use CPU as reference. The results should not deviate too much. 1020 assert abs(run(torch.device("cuda")) - run(torch.device("cpu"))) < 10_000 1021 1022 @parametrize("dtype", [torch.float32, torch.double]) 1023 def test_random_no_reused_random_states(self, dtype: torch.dtype) -> None: 1024 # Test if random states do not overlap between consecutive rand/randn calls. 1025 # See https://github.com/pytorch/pytorch/issues/125224 1026 1027 def run(func, dev: torch.device, dtype: torch.dtype) -> int: 1028 # Measure how many unique numbers are generated in 2 consecutive calls. If random states are 1029 # reused, this will yield fewer unique numbers. 1030 size = 1000000 1031 gen = torch.Generator(device=dev) 1032 gen.manual_seed(0) 1033 t1 = func((size,), device=dev, generator=gen, dtype=dtype) 1034 t2 = func((size,), device=dev, generator=gen, dtype=dtype) 1035 return torch.stack([t1, t2]).unique().shape[0] 1036 1037 # Use CPU as reference. The results should not deviate too much. 1038 for func in [torch.rand, torch.randn]: 1039 deviation = abs( 1040 run(func, torch.device("cuda"), dtype) 1041 - run(func, torch.device("cpu"), dtype) 1042 ) 1043 assert deviation < 50_000, deviation 1044 1045 def test_min_max_inits(self): 1046 # Testing if THC_reduceAll received the correct index initialization. 1047 # This affects the result of THC_reduceAll operations at extreme values 1048 x = torch.cuda.ByteTensor([0]) 1049 y = torch.cuda.ByteTensor([255]) 1050 expected = torch.cuda.LongTensor([0])[0] 1051 1052 _, v = x.max(dim=0) 1053 self.assertEqual(v, expected) 1054 1055 _, v = y.min(dim=0) 1056 self.assertEqual(v, expected) 1057 1058 def test_nvtx(self): 1059 # Just making sure we can see the symbols 1060 torch.cuda.nvtx.range_push("foo") 1061 torch.cuda.nvtx.mark("bar") 1062 torch.cuda.nvtx.range_pop() 1063 range_handle = torch.cuda.nvtx.range_start("range_start") 1064 torch.cuda.nvtx.range_end(range_handle) 1065 1066 def test_bincount_ext(self): 1067 # ensure CUDA code coverage 1068 input_size = (100000,) 1069 w = torch.randn(input_size, dtype=torch.double, device="cuda") 1070 w_cpu = w.cpu() 1071 # test shared memory impl 1072 t = torch.randint(50, input_size, dtype=torch.int8, device="cuda") 1073 self.assertEqual(t.cpu().bincount(), t.bincount()) 1074 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) 1075 # test global memory impl 1076 # see `CUDAHistogramMemoryType` in SummaryOps.cu 1077 # 50000 * sizeof(int64_t) == 390 KiB, which should exceed smem of any known GPU 1078 t = torch.randint(50000, input_size, dtype=torch.int64, device="cuda") 1079 self.assertEqual(t.cpu().bincount(), t.bincount()) 1080 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) 1081 1082 t = torch.zeros([10], dtype=torch.int32, device="cuda") 1083 # 35488 * 65536 as int32 would cause overflow to negative value 1084 # giving negative bin offset 1085 t[0] = 35488 1086 counted = t.bincount(minlength=65536) 1087 self.assertEqual(torch.sum(counted), 10) 1088 1089 def test_tiny_half_norm_(self): 1090 a = torch.arange(25).cuda().float() 1091 a /= 100000000 1092 b = a.half() 1093 self.assertGreater(b.norm().item(), 0) 1094 1095 def test_norm_type_conversion(self): 1096 a = torch.ones(65536).cuda().half() 1097 self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536) 1098 1099 def test_cuda_memory_leak_detection_propagates_errors(self): 1100 with self.assertRaisesRegex( 1101 RuntimeError, r"The size of tensor a \(3\) must match" 1102 ): 1103 with self.assertLeaksNoCudaTensors(): 1104 x = torch.randn(3, 1, device="cuda") 1105 y = torch.randn(2, 1, device="cuda") 1106 z = x + y 1107 1108 @unittest.skipIf(not TEST_MEDIUM_TENSOR, "not enough memory") 1109 @serialTest() 1110 def test_cuda_kernel_loop_overflow(self): 1111 # Issue #24309: In extreme cases, the loop variable could overflow and continue 1112 # the kernel loop with a negative index, causing a RuntimeError (invalid write): 1113 x = torch.randn(1, 1, 1, 2**30 + 1, dtype=torch.float16, device="cuda") 1114 expected = x[0, 0, 0, 2**30] 1115 y = torch.nn.functional.avg_pool2d(x, kernel_size=1) 1116 torch.cuda.synchronize() 1117 self.assertEqual(y[0, 0, 0, 2**30], expected) 1118 1119 @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") 1120 @gcIfJetson 1121 @serialTest() 1122 def test_cuda_kernel_loop_overflow_large(self): 1123 # Make sure input.numel() > INT_MAX is handled: 1124 x = torch.randn(1, 1, 1, 2**31, dtype=torch.float16, device="cuda") 1125 with self.assertRaisesRegex(RuntimeError, "integer out of range"): 1126 y = torch.nn.functional.avg_pool2d(x, kernel_size=1) 1127 1128 # Issue #24309: In extreme cases, the loop variable could overflow and continue 1129 # the kernel loop with a negative index, causing a RuntimeError (invalid write): 1130 x = torch.randn(1, 1, 1, 2**31 - 1, dtype=torch.float16, device="cuda") 1131 expected = x[0, 0, 0, 2**31 - 2] 1132 y = torch.nn.functional.avg_pool2d(x, kernel_size=1) 1133 torch.cuda.synchronize() 1134 self.assertEqual(y[0, 0, 0, 2**31 - 2], expected) 1135 1136 # this might create a reference cycle on self... 1137 def _make_multiply_in_stream(self): 1138 class MultiplyInStream(torch.autograd.Function): 1139 @staticmethod 1140 def forward(ctx, x, val): 1141 ctx.val = val 1142 ctx.stream = torch.cuda.current_stream() 1143 return x * val 1144 1145 @staticmethod 1146 def backward(ctx, grad): 1147 self.assertEqual(torch.cuda.current_stream(), ctx.stream) 1148 # delays the operation in the background stream 1149 torch.cuda._sleep(1000 * 5000) 1150 return grad * ctx.val, None 1151 1152 return MultiplyInStream 1153 1154 @skipCUDANonDefaultStreamIf(True) 1155 def test_streaming_backwards_sync(self): 1156 default_stream = torch.cuda.current_stream() 1157 stream = torch.cuda.Stream() 1158 1159 MultiplyInStream = self._make_multiply_in_stream() 1160 1161 # Tests using grads outside the backward() stream context 1162 # See "Stream semantics of backward passes" on https://pytorch.org/docs/stable/notes/cuda.html 1163 x = torch.randn(5, 5, device="cuda", requires_grad=True) 1164 with torch.cuda.stream(stream): 1165 stream.wait_stream(default_stream) 1166 output = MultiplyInStream.apply(x, 2) 1167 output.sum().backward() 1168 # sync needed 1169 default_stream.wait_stream(stream) 1170 self.assertEqual(x.grad, torch.ones_like(x) * 2) 1171 self.assertEqual(torch.cuda.current_stream(), default_stream) 1172 1173 # Tests that using grads in the same stream context as backward() 1174 # is safe regardless what streams bwd ops ran on 1175 bwd_ambient_stream = torch.cuda.Stream() 1176 x = torch.randn(5, 5, device="cuda", requires_grad=True) 1177 with torch.cuda.stream(stream): 1178 stream.wait_stream(default_stream) 1179 output = MultiplyInStream.apply(x, 3) 1180 with torch.cuda.stream(bwd_ambient_stream): 1181 bwd_ambient_stream.wait_stream(stream) 1182 output.sum().backward() 1183 # x was first used on "stream" so its AccumulateGrad leaf should run on "stream". 1184 # The end of backward() should have synced "bwd_ambient_stream" with "stream" 1185 # so it should be safe to use x.grad here without any syncs. 1186 self.assertEqual(x.grad, torch.ones_like(x) * 3) 1187 self.assertEqual(torch.cuda.current_stream(), bwd_ambient_stream) 1188 1189 # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190 1190 @skipIfRocm(msg="flakey on ROCm https://github.com/pytorch/pytorch/issues/53190") 1191 def test_streaming_backwards_multiple_streams(self): 1192 MultiplyInStream = self._make_multiply_in_stream() 1193 1194 class StreamModel(torch.nn.Module): 1195 def __init__(self) -> None: 1196 super().__init__() 1197 self.event = torch.cuda.Event() 1198 self.stream0 = torch.cuda.Stream() 1199 self.stream1 = torch.cuda.Stream() 1200 1201 def forward(self, x, x_first_use_on_ambient): 1202 if x_first_use_on_ambient: 1203 x0 = x.clone() 1204 self.stream0.wait_stream(torch.cuda.current_stream()) 1205 self.stream1.wait_stream(torch.cuda.current_stream()) 1206 with torch.cuda.stream(self.stream0): 1207 if not x_first_use_on_ambient: 1208 x0 = x.clone() 1209 y0 = MultiplyInStream.apply(x0, 2) 1210 self.event.record(stream=torch.cuda.current_stream()) 1211 1212 with torch.cuda.stream(self.stream1): 1213 y1 = MultiplyInStream.apply(x, 3) 1214 self.stream1.wait_event(self.event) 1215 return y0 + y1 1216 1217 stream = torch.cuda.Stream() 1218 1219 for x_first_use_on_ambient in (True, False): 1220 # the out_of_place=False, iters=1 case stresses if proper syncs are inserted 1221 # when grads are initially None and stolen by backward ops. 1222 for out_of_place, iters in ((True, 1), (False, 1), (False, 5)): 1223 with torch.cuda.stream(stream): 1224 x = torch.randn(5, 5, device="cuda", requires_grad=True) 1225 model = StreamModel().cuda() 1226 x.register_hook( 1227 lambda grad: self.assertEqual( 1228 torch.cuda.current_stream(), 1229 stream if x_first_use_on_ambient else model.stream0, 1230 ) 1231 ) 1232 for p in model.parameters(): 1233 self.assertTrue(p.grad is None) 1234 for i in range(iters): 1235 loss = model(x, x_first_use_on_ambient).sum() 1236 if out_of_place: 1237 x_grad = torch.autograd.grad((loss,), (x,))[0] 1238 else: 1239 loss.backward() 1240 # See "Stream semantics of backward passes" on https://pytorch.org/docs/stable/notes/cuda.html 1241 torch.cuda.current_stream().wait_stream(stream) 1242 1243 if out_of_place: 1244 self.assertEqual(x_grad, torch.ones_like(x) * 5 * iters) 1245 else: 1246 self.assertEqual(x.grad, torch.ones_like(x) * 5 * iters) 1247 1248 def test_streaming_backwards_sync_graph_root(self): 1249 # This function tests if bwd ops running on a side stream properly sync with the GraphRoot. 1250 # The potential bug it targets is a race condition. The test uses multiple trials and 1251 # torch.cuda._sleep such that if the race condition exists, the test will almost certainly fail, 1252 # but there's a chance it may spuriously pass. Passing does not guarantee the backend is bug-free, 1253 # but failure does guarantee there is a bug. 1254 fwd_bwd_op_stream = torch.cuda.Stream() 1255 bwd_ambient_stream = torch.cuda.Stream() 1256 # We need these streams to be different otherwise the test is meaningless. 1257 self.assertTrue(fwd_bwd_op_stream != bwd_ambient_stream) 1258 1259 size = int(1e3) 1260 1261 a = torch.full((size,), 2.0, device="cuda", requires_grad=True) 1262 b = torch.full((size,), 3.0, device="cuda", requires_grad=True) 1263 1264 # I don't think we need any manual record_streams below. 1265 # a and b remain in scope for the entire test. 1266 # c and grad remain in scope for each iteration, and there's a full sync between iterations. 1267 for trial in range(5): 1268 torch.cuda.synchronize() 1269 a.grad = b.grad = None 1270 with torch.cuda.stream(fwd_bwd_op_stream): 1271 c = a * b 1272 1273 with torch.cuda.stream(bwd_ambient_stream): 1274 torch.cuda.synchronize() 1275 # Long-running dummy kernel on bwd_ambient_stream delays filling of grad 1276 torch.cuda._sleep(int(50 * get_cycles_per_ms())) 1277 # Fills grad on bwd_ambient_stream 1278 grad = torch.full((size,), float(trial + 1), device="cuda") 1279 1280 # Bwd ops still run on fwd_bwd_ops_stream, so the following will likely fail if 1281 # bwd ops don't sync with bwd_ambient_stream before consuming grad. 1282 torch.autograd.backward(tensors=c, grad_tensors=grad) 1283 1284 # See https://github.com/pytorch/pytorch/issues/47028 1285 # assertEquals below run on bwd_ambient_stream, so this test may also fail 1286 # if backward() fails to sync with bwd_ambient_stream at the end. 1287 # Synchronizing here works around the issue until a proper fix can be made. 1288 torch.cuda.synchronize() 1289 with torch.no_grad(): 1290 self.assertEqual(a.grad, grad * b) 1291 self.assertEqual(b.grad, grad * a) 1292 1293 def test_streaming_backwards_callback(self): 1294 # Tests if autograd callbacks sync properly with respect to leaf streams and 1295 # the user-facing stream surrounding backward(). If it fails, first suspect is 1296 # sync logic where "final_callbacks_" are called in torch/csrc/autograd/engine.cpp 1297 MultiplyInStream = self._make_multiply_in_stream() 1298 1299 size = int(1e3) 1300 a = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True) 1301 b = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True) 1302 1303 s0 = torch.cuda.Stream() 1304 s1 = torch.cuda.Stream() 1305 s2 = torch.cuda.Stream() 1306 1307 stash = [] 1308 1309 # sets up a nontrivial structure of leaf streams 1310 s0.wait_stream(torch.cuda.current_stream()) 1311 with torch.cuda.stream(s0): 1312 c = MultiplyInStream.apply(a, 2) 1313 1314 s1.wait_stream(torch.cuda.current_stream()) 1315 with torch.cuda.stream(s1): 1316 d = MultiplyInStream.apply(b, 3) 1317 s1.wait_stream(s0) 1318 e = c * d 1319 1320 def clone_leaf_grads(): 1321 stash.append(a.grad.clone()) 1322 stash.append(b.grad.clone()) 1323 1324 # Use a hook on e to install the callback 1325 e.register_hook( 1326 lambda grad: torch.autograd.Variable._execution_engine.queue_callback( 1327 clone_leaf_grads 1328 ) 1329 ) 1330 1331 s2.wait_stream(s1) 1332 with torch.cuda.stream(s2): 1333 e.sum().backward() 1334 # The autograd engine should sync s2 with all leaf streams then run the callback clone_leaf_grads on s2. 1335 # If those things happened properly, checking the values of the cloned grads on s2 should be safe: 1336 self.assertEqual(stash[0], torch.full_like(a, 6)) 1337 self.assertEqual(stash[1], torch.full_like(a, 6)) 1338 1339 @unittest.skipIf( 1340 TEST_WITH_ROCM, 1341 "In ROCm, kernel asserts are disabled due to performance overhead", 1342 ) 1343 def test_fixed_cuda_assert_async(self): 1344 with self.assertRaisesRegex( 1345 RuntimeError, "Boolean value of Tensor with no values is ambiguous" 1346 ): 1347 torch._assert_async(torch.tensor([], device="cuda")) 1348 with self.assertRaisesRegex( 1349 RuntimeError, 1350 "Boolean value of Tensor with more than one value is ambiguous", 1351 ): 1352 torch._assert_async(torch.tensor([0, 0], device="cuda")) 1353 1354 torch._assert_async(torch.tensor(1, device="cuda")) 1355 torch._assert_async(torch.tensor(0.1, device="cuda")) 1356 torch._assert_async(torch.tensor(-0.1, device="cuda")) 1357 torch._assert_async(torch.tensor(True, device="cuda")) 1358 torch._assert_async(torch.tensor(0 + 0.1j, device="cuda")) 1359 1360 fail_stmts = [ 1361 "torch._assert_async(torch.tensor(0, device='cuda'))", 1362 "torch._assert_async(torch.tensor(0.0, device='cuda'))", 1363 "torch._assert_async(torch.tensor(False, device='cuda'))", 1364 "torch._assert_async(torch.tensor(0 + 0j, device='cuda'))", 1365 ] 1366 1367 import subprocess 1368 1369 for stmt in fail_stmts: 1370 with self.subTest(stmt=stmt): 1371 r = subprocess.call( 1372 [ 1373 sys.executable, 1374 "-c", 1375 f"""\ 1376import torch 1377 1378{stmt} 1379torch.cuda.synchronize() 1380""", 1381 ] 1382 ) 1383 self.assertTrue(r != 0) 1384 1385 @unittest.skipIf(TEST_CUDAMALLOCASYNC, "FAIL") 1386 def test_cublas_multiple_threads_same_device(self): 1387 # Note, these parameters should be very carefully tuned 1388 # Too small number makes it hard for the racing condition 1389 # to happen, while too large number sometimes cause hang 1390 size = 1024 1391 num_threads = 2 1392 trials = 3 1393 test_iters = 100 1394 1395 weight = torch.ones((size, size), device="cuda") 1396 results = {} 1397 barrier = threading.Barrier(num_threads) 1398 1399 def _worker(t): 1400 my_stream = torch.cuda.Stream() 1401 # Hard sync so we don't need to worry about creating and using tensors 1402 # across streams or the fact that default streams are thread-local. 1403 # Those issues are not the target of this test. 1404 torch.cuda.synchronize() 1405 # Line up threads to increase likelihood of race conditions. 1406 barrier.wait() 1407 with torch.cuda.stream(my_stream): 1408 for i in range(test_iters): 1409 # If all threads are sharing the same cublas handle, 1410 # the following sequence may occur: 1411 # thread 0 calls cublasSetStream() 1412 # thread 1 calls cublasSetStream() 1413 # thread 0 launches its raw gemm, which it thinks is in 1414 # its own stream, but is actually in thread 1's stream. 1415 # thread 0 enqueues its div_, which IS is its own stream, 1416 # but actually now races with its gemm. 1417 results[t] = torch.mm(results[t], weight) 1418 results[t].div_(float(size)) 1419 torch.cuda.synchronize() 1420 1421 for _ in range(trials): 1422 for t in range(num_threads): 1423 results[t] = torch.ones((size, size), device="cuda") 1424 1425 threads = [ 1426 threading.Thread(target=_worker, args=(t,)) for t in range(num_threads) 1427 ] 1428 1429 for thread in threads: 1430 thread.start() 1431 for thread in threads: 1432 thread.join() 1433 1434 for t in range(num_threads): 1435 self.assertEqual(results[t].sum().item(), size * size) 1436 1437 # Test is flaky on Windows (https://github.com/pytorch/pytorch/issues/57401) 1438 @unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows (see issue 57401)") 1439 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 1440 @skipIfRocm 1441 def test_cudnn_multiple_threads_same_device(self): 1442 # This function is intended to test the lazy creation and reuse of per-thread 1443 # cudnn handles on each device in aten/src/ATen/cudnn/Handles.cpp. 1444 # Failure here likely indicates something wrong with that logic. 1445 weight = torch.ones((1, 1, 2, 2), device="cuda") 1446 1447 results = {} 1448 1449 num_threads = 2 1450 trials = 3 1451 test_iters = 1000 1452 barrier = threading.Barrier(num_threads) 1453 1454 with torch.backends.cudnn.flags(enabled=True): 1455 1456 def _worker(t): 1457 my_stream = torch.cuda.Stream() 1458 # Hard sync so we don't need to worry about creating and using tensors 1459 # across streams or the fact that default streams are thread-local. 1460 # Those issues are not the target of this test. 1461 torch.cuda.synchronize() 1462 # Line up threads to increase likelihood of race conditions. 1463 barrier.wait() 1464 with torch.cuda.stream(my_stream): 1465 for _ in range(test_iters): 1466 # If all threads are sharing the same cudnn handle, 1467 # the following sequence may occur: 1468 # thread 0 calls setCuDNNStreamToCurrent() 1469 # thread 1 calls setCuDNNStreamToCurrent() 1470 # thread 0 launches its raw convolution, which it thinks is in 1471 # its own stream, but is actually in thread 1's stream. 1472 # thread 0 enqueues its div_, which IS is its own stream, 1473 # but now races with its convolution. 1474 results[t] = torch.nn.functional.conv2d( 1475 results[t], weight, padding=0 1476 ) 1477 results[t].div_(4.0) 1478 torch.cuda.synchronize() 1479 1480 for _ in range(trials): 1481 for t in range(num_threads): 1482 results[t] = torch.ones((1, 1, 2048, 2048), device="cuda") 1483 1484 threads = [ 1485 threading.Thread(target=_worker, args=(t,)) 1486 for t in range(num_threads) 1487 ] 1488 1489 for thread in threads: 1490 thread.start() 1491 for thread in threads: 1492 thread.join() 1493 1494 for t in range(num_threads): 1495 self.assertEqual( 1496 results[t].sum().item(), 1497 (2048 - test_iters) * (2048 - test_iters), 1498 ) 1499 1500 def test_cusparse_multiple_threads_same_device(self): 1501 size = 1024 1502 num_threads = 2 1503 trials = 3 1504 test_iters = 500 1505 1506 def ones_sparse(size): 1507 a = torch.arange(size, device="cuda") 1508 indices = torch.cartesian_prod(a, a).t() 1509 values = torch.ones(size * size, device="cuda") 1510 return torch.sparse_coo_tensor(indices, values) 1511 1512 weight = ones_sparse(size) 1513 results = {} 1514 barrier = threading.Barrier(num_threads) 1515 1516 def _worker(t): 1517 my_stream = torch.cuda.Stream() 1518 # Hard sync so we don't need to worry about creating and using tensors 1519 # across streams or the fact that default streams are thread-local. 1520 # Those issues are not the target of this test. 1521 torch.cuda.synchronize() 1522 # Line up threads to increase likelihood of race conditions. 1523 barrier.wait() 1524 with torch.cuda.stream(my_stream): 1525 for i in range(test_iters): 1526 # If all threads are sharing the same cublas handle, 1527 # the following sequence may occur: 1528 # thread 0 calls cublasSetStream() 1529 # thread 1 calls cublasSetStream() 1530 # thread 0 launches its raw gemm, which it thinks is in 1531 # its own stream, but is actually in thread 1's stream. 1532 # thread 0 enqueues its div_, which IS is its own stream, 1533 # but actually now races with its gemm. 1534 results[t] = weight.mm(results[t]) 1535 results[t].div_(float(size)) 1536 torch.cuda.synchronize() 1537 1538 for _ in range(trials): 1539 for t in range(num_threads): 1540 results[t] = torch.ones((size, size), device="cuda") 1541 1542 threads = [ 1543 threading.Thread(target=_worker, args=(t,)) for t in range(num_threads) 1544 ] 1545 1546 for thread in threads: 1547 thread.start() 1548 for thread in threads: 1549 thread.join() 1550 1551 for t in range(num_threads): 1552 self.assertEqual(results[t].sum().item(), size * size) 1553 1554 @slowTest 1555 @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") 1556 @serialTest() 1557 def test_max_large_axis(self): 1558 x = torch.zeros(2**32, device="cuda", dtype=torch.int8) 1559 x[-1] = 1 1560 val, idx = x.max(0) 1561 self.assertEqual(val, 1) 1562 self.assertEqual(idx, x.shape[0] - 1) 1563 1564 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 1565 def test_to_numpy(self): 1566 self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy()) 1567 1568 def test_graph_is_current_stream_capturing(self): 1569 self.assertFalse(torch.cuda.is_current_stream_capturing()) 1570 1571 if TEST_CUDA and (not TEST_WITH_ROCM): 1572 s = torch.cuda.Stream() 1573 with torch.cuda.stream(s): 1574 g = torch.cuda.CUDAGraph() 1575 self.assertFalse(torch.cuda.is_current_stream_capturing()) 1576 g.capture_begin() 1577 self.assertTrue(torch.cuda.is_current_stream_capturing()) 1578 g.capture_end() 1579 1580 @unittest.skipIf( 1581 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1582 ) 1583 def test_graph_capture_simple(self): 1584 s = torch.cuda.Stream() 1585 1586 with torch.cuda.stream(s): 1587 a = torch.full((1000,), 1, device="cuda") 1588 g = torch.cuda.CUDAGraph() 1589 torch.cuda.empty_cache() 1590 g.capture_begin() 1591 b = a 1592 for _ in range(10): 1593 b = b + 1 1594 g.capture_end() 1595 torch.cuda.current_stream().wait_stream(s) 1596 1597 g.replay() 1598 1599 self.assertTrue(b.sum().item() == 11000.0) 1600 1601 @unittest.skipIf( 1602 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1603 ) 1604 def test_graphsafe_set_get_rng_state(self): 1605 # Define a function to create generator states, with optional graph registration 1606 def create_states(generator): 1607 """Initializes generator states and registers them with a CUDA graph if provided.""" 1608 # Ensure the CUDA generator is initialized 1609 torch.rand(1, device="cuda") 1610 generator.manual_seed(0) 1611 1612 # Save the current state of the generator 1613 old_state = generator.graphsafe_get_state() 1614 # Create and save a cloned state of the generator 1615 new_state = generator.clone_state() 1616 # Return the original generator and its two states 1617 return generator, old_state, new_state 1618 1619 def register_states_to_graph(generator_state, graph): 1620 generator, old_state, new_state = generator_state 1621 graph.register_generator_state(old_state) 1622 graph.register_generator_state(new_state) 1623 1624 # Define a function to perform specific RNG actions using the generator's states 1625 def perform_random_generation_steps(generator_state): 1626 generator, old_state, new_state = generator_state 1627 random_values = [] 1628 1629 # Generate random numbers with the new generator state 1630 generator.graphsafe_set_state(new_state) 1631 random_values.append(torch.rand(5, device="cuda", generator=generator)) 1632 1633 # Generate random numbers twice with the old generator state 1634 generator.graphsafe_set_state(old_state) 1635 random_values.extend( 1636 [torch.rand(5, device="cuda", generator=generator) for _ in range(2)] 1637 ) 1638 1639 return random_values 1640 1641 # Define a function to retrieve the final offsets of the original and new generator states 1642 def get_final_offsets_of_states(generator_state): 1643 generator, old_state, new_state = generator_state 1644 old_state_offset = old_state.get_offset() 1645 new_state_offset = new_state.get_offset() 1646 return old_state_offset, new_state_offset 1647 1648 # Set up and test a new CUDA generator 1649 generator = torch.Generator(device="cuda") 1650 generator_state = create_states(generator) 1651 1652 # Set up and test the default CUDA generator with a CUDA Graph 1653 g = torch.cuda.CUDAGraph() 1654 s = torch.cuda.Stream() 1655 default_generator = torch.cuda.default_generators[0] 1656 default_generator_state = create_states(default_generator) 1657 register_states_to_graph(default_generator_state, g) 1658 1659 # Perform random number generation within a CUDA graph 1660 with torch.cuda.stream(s): 1661 g.capture_begin() 1662 graphed_random_values = perform_random_generation_steps( 1663 default_generator_state 1664 ) 1665 g.capture_end() 1666 1667 # Synchronize the streams and replay the graph 1668 torch.cuda.current_stream().wait_stream(s) 1669 for _ in range(3): 1670 random_values = perform_random_generation_steps(generator_state) 1671 g.replay() 1672 offset = get_final_offsets_of_states(generator_state) 1673 graph_offset = get_final_offsets_of_states(default_generator_state) 1674 1675 # Compare the final offsets of states for both generators to ensure consistency 1676 self.assertTrue(offset == graph_offset) 1677 # Compare the states generated outside and inside the graph 1678 self.assertEqual(random_values, graphed_random_values) 1679 1680 @unittest.skipIf( 1681 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1682 ) 1683 def test_memory_stats_of_multiple_generators_and_graphs(self): 1684 # Function to clear CUDA cache and collect garbage 1685 def clear_cuda_cache(): 1686 gc.collect() 1687 torch.cuda.empty_cache() 1688 1689 # Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph. 1690 def simple_graph_task(graph): 1691 s = torch.cuda.Stream() 1692 with torch.cuda.stream(s): 1693 graph.capture_begin() 1694 torch.rand(1, device="cuda") 1695 graph.capture_end() 1696 torch.cuda.current_stream().wait_stream(s) 1697 graph.replay() # Replays the captured operations 1698 1699 def get_memory_stats(): 1700 stats = torch.cuda.memory_stats() 1701 num_blocks = stats["active.all.current"] 1702 total_size = stats["active_bytes.all.current"] 1703 return num_blocks, total_size 1704 1705 def test(num_graphs, num_generators): 1706 baseline = get_memory_stats() 1707 baseline_num_blocks, baseline_total_size = baseline 1708 1709 # Allocate CUDA graphs 1710 graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)] 1711 1712 # Allocate and manage generator states 1713 default_generator = torch.cuda.default_generators[0] 1714 generators = [default_generator.graphsafe_get_state()] 1715 1716 # Starts from 1 as one state is already added 1717 for _ in range(1, num_generators): 1718 generators.append(default_generator.clone_state()) 1719 1720 for graph in graphs: 1721 for generator_state in generators: 1722 graph.register_generator_state(generator_state) 1723 simple_graph_task(graph) 1724 1725 # Assert conditions after graph tasks 1726 num_blocks, total_size = get_memory_stats() 1727 # The allocated blocks should only be proportional to the number of generators 1728 expected_blocks_diff = 2 * num_generators 1729 expected_size_diff = 2 * 512 * num_generators # Each block's size is 512 1730 1731 self.assertTrue( 1732 (num_blocks - baseline_num_blocks) == expected_blocks_diff, 1733 "Unexpected number of active blocks.", 1734 ) 1735 self.assertTrue( 1736 (total_size - baseline_total_size) == expected_size_diff, 1737 "Unexpected total memory size.", 1738 ) 1739 1740 # Cleanup graphs and clear CUDA cache 1741 while graphs: 1742 graph = graphs.pop() 1743 del graph 1744 clear_cuda_cache() 1745 1746 # Assert that memory stats return to baseline after cleanup 1747 self.assertTrue( 1748 get_memory_stats() == baseline, 1749 "Memory stats do not match baseline after cleanup.", 1750 ) 1751 1752 # Running the test function with different parameters 1753 test(1, 1) 1754 test(3, 2) 1755 test(10, 20) 1756 1757 @unittest.skipIf( 1758 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1759 ) 1760 def test_graph_capture_reset_recapture(self): 1761 s = torch.cuda.Stream() 1762 1763 with torch.cuda.stream(s): 1764 a = torch.full((1000,), 1, device="cuda") 1765 g = torch.cuda.CUDAGraph() 1766 torch.cuda.empty_cache() 1767 g.capture_begin() 1768 b = a 1769 for _ in range(10): 1770 b = b + 1 1771 g.capture_end() 1772 torch.cuda.current_stream().wait_stream(s) 1773 1774 g.replay() 1775 1776 self.assertTrue(b.sum().item() == 11000.0) 1777 1778 g.reset() 1779 1780 with torch.cuda.stream(s): 1781 g.capture_begin() 1782 b.fill_(2.0) 1783 for _ in range(10): 1784 b = b + 2 1785 g.capture_end() 1786 torch.cuda.current_stream().wait_stream(s) 1787 1788 g.replay() 1789 self.assertTrue(b.sum().item() == 22000.0) 1790 1791 g.reset() 1792 del g 1793 1794 @unittest.skipIf( 1795 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1796 ) 1797 def test_graph_debugdump(self): 1798 torch.cuda.empty_cache() 1799 x = torch.randn(10240000, device="cuda") 1800 y = torch.rand_like(x) 1801 g = torch.cuda.CUDAGraph() 1802 g.enable_debug_mode() 1803 s0 = torch.cuda.Stream() 1804 s1 = torch.cuda.Stream() 1805 s0.wait_stream(torch.cuda.current_stream()) 1806 with torch.cuda.stream(s0): 1807 g.capture_begin() 1808 z = x + y 1809 with torch.cuda.stream(s1): 1810 s1.wait_stream(s0) 1811 w = z + y 1812 s0.wait_stream(s1) 1813 g.capture_end() 1814 s0.synchronize() 1815 torch.cuda.synchronize() 1816 with tempfile.TemporaryDirectory() as tempdir: 1817 g.debug_dump(os.path.join(tempdir, "out_multi_stream.dot")) 1818 1819 @unittest.skipIf( 1820 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1821 ) 1822 def test_graph_error(self): 1823 # We need to run this test in a separate thread as the error we trigger 1824 # puts the cuda context in a bad state 1825 script = """ 1826import torch 1827 1828g = torch.cuda.CUDAGraph() 1829try: 1830 g.capture_begin() 1831except RuntimeError as e: 1832 if "CUDA graphs must be captured on a non-default stream." in str(e): 1833 exit(0) 1834 else: 1835 exit(1) 1836exit(2) 1837""" 1838 try: 1839 a = subprocess.check_output( 1840 [sys.executable, "-c", script], 1841 stderr=subprocess.STDOUT, 1842 # On Windows, opening the subprocess with the default CWD makes `import torch` 1843 # fail, so just set CWD to this script's directory 1844 cwd=os.path.dirname(os.path.realpath(__file__)), 1845 ) 1846 except subprocess.CalledProcessError as e: 1847 if e.returncode == 1: 1848 self.assertTrue( 1849 False, 1850 "Error raise by starting capture without a stream is not the expected one", 1851 ) 1852 elif e.returncode == 2: 1853 self.assertTrue( 1854 False, 1855 "Error raised by starting capture without a stream was not caught", 1856 ) 1857 1858 @unittest.skipIf( 1859 (not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11, 1860 "CUDA >= 11.0 required for graphs", 1861 ) 1862 def test_graph_warn_if_has_zero_nodes(self): 1863 with warnings.catch_warnings(record=True) as caught: 1864 g = torch.cuda.CUDAGraph() 1865 s = torch.cuda.Stream() 1866 with torch.cuda.stream(s): 1867 g.capture_begin() 1868 g.capture_end() 1869 self.assertTrue( 1870 any("The CUDA Graph is empty" in str(w.message) for w in caught) 1871 ) 1872 1873 @unittest.skipIf( 1874 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1875 ) 1876 @unittest.skipIf( 1877 IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support" 1878 ) 1879 def test_graph_capture_oom(self): 1880 oom_regex = ( 1881 "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory" 1882 ) 1883 with self.assertRaisesRegex(RuntimeError, oom_regex): 1884 with torch.cuda.graph(torch.cuda.CUDAGraph()): 1885 torch.zeros(2**40, device="cuda") 1886 1887 @unittest.skipIf( 1888 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1889 ) 1890 @serialTest() 1891 def test_repeat_graph_capture_cublas_workspace_memory(self): 1892 (x, y, z) = 1024, 512, 64 1893 a = torch.rand((x, y), device="cuda") 1894 b = torch.rand((y, z), device="cuda") 1895 1896 # warmup 1897 torch.mm(a, b) 1898 1899 free_bytes_before, total_bytes = torch.cuda.mem_get_info() 1900 used_gb_before = (total_bytes - free_bytes_before) / 1e9 1901 1902 for i in range(100): 1903 torch_graph = torch.cuda.CUDAGraph() 1904 with torch.cuda.graph(torch_graph): 1905 torch.mm(a, b) 1906 torch_graph.replay() 1907 1908 free_bytes_after, _ = torch.cuda.mem_get_info() 1909 used_gb_after = (total_bytes - free_bytes_after) / 1e9 1910 1911 self.assertFalse(used_gb_before + 0.1 < used_gb_after) 1912 1913 @unittest.skipIf( 1914 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1915 ) 1916 def test_graph_rng_functional(self): 1917 ops_with_kwargs = ( 1918 (torch.nn.functional.dropout, {"p": 0.1}), 1919 (torch.nn.functional.rrelu, {"training": True}), 1920 ) 1921 size = 10000 1922 1923 def run(op, kwargs): 1924 a = torch.randn((size,), device="cuda", dtype=torch.float) 1925 1926 # Control 1927 torch.cuda.manual_seed(5) 1928 eager_out = a 1929 for _ in range(6): 1930 eager_out = op(eager_out, **kwargs) 1931 1932 graph_in = a.clone() 1933 stream = torch.cuda.Stream() 1934 stream.wait_stream(torch.cuda.current_stream()) 1935 with torch.cuda.stream(stream): 1936 torch.cuda.manual_seed(5) 1937 1938 g = torch.cuda.CUDAGraph() 1939 torch.cuda.empty_cache() 1940 g.capture_begin() 1941 graph_out = graph_in 1942 for _ in range(2): 1943 graph_out = op(graph_out, **kwargs) 1944 g.capture_end() 1945 torch.cuda.current_stream().wait_stream(stream) 1946 1947 # Runs a graphed->eager->graphed sequence of RNG ops. 1948 # replay() plays 2 invocations of the op, so the sequence has 6 1949 # invocations total, matching Control. 1950 # replay() reads from graph_in and writes to graph_out. 1951 g.replay() 1952 out = op(graph_out, **kwargs) 1953 out = op(out, **kwargs) 1954 graph_in.copy_(out) 1955 g.replay() 1956 1957 # If replay() updated RNG state correctly, graph_out 1958 # should now hold data equal to eager_out. 1959 try: 1960 self.assertEqual(eager_out, graph_out) 1961 except Exception as e: 1962 raise RuntimeError("Failed on ", op) from e 1963 1964 # Do the same operations varying seeds 1965 seeds = [6, 128, 9999] 1966 1967 for seed in seeds: 1968 torch.cuda.manual_seed(seed) 1969 graph_in.copy_(a) 1970 for _ in range(3): 1971 g.replay() 1972 1973 # If the random seed was not updated then the graph would 1974 # generate the same output as in previous check. 1975 try: 1976 self.assertNotEqual(eager_out, graph_out) 1977 except Exception as e: 1978 raise RuntimeError("Failed on ", op) from e 1979 1980 # Now repeat the same operations in non-graphed mode. 1981 torch.cuda.manual_seed(seed) 1982 for _ in range(3): 1983 eager_out.copy_(a) 1984 eager_out = op(eager_out, **kwargs) 1985 eager_out = op(eager_out, **kwargs) 1986 1987 # In the end, graph_out and eager_out must be equal 1988 # as they went under the same set of operations. 1989 try: 1990 self.assertEqual(eager_out, graph_out) 1991 except Exception as e: 1992 raise RuntimeError("Failed on ", op) from e 1993 1994 # We hold references to all tensors used across streams up til this sync, 1995 # so no need to call record_stream on those tensors. 1996 torch.cuda.synchronize() 1997 1998 for op, kwargs in ops_with_kwargs: 1999 run(op, kwargs) 2000 2001 @unittest.skipIf( 2002 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2003 ) 2004 def test_graph_rng_distributions(self): 2005 size = 10000 2006 input = torch.rand((size,), device="cuda", dtype=torch.float) 2007 alloc = torch.empty((size,), device="cuda", dtype=torch.float) 2008 2009 # Torch ops to test with sample args (tuple) and kwargs (dict) 2010 torch_with_args = ( 2011 ("bernoulli", (input.clone(),), {}), 2012 # multinomial uses some uncapturable CUDA calls. 2013 # TODO: reenable multinomial tests if/when the implementation is capturable. 2014 # ("multinomial", (input.clone(), size, True), {}), 2015 # ("multinomial", (input.clone(), size // 2, False), {}), 2016 # TODO: reenable normal test, where std is a device 2017 # tensor, when graph test failures are fixed 2018 # ("normal", (input.clone() + 1, input.clone()), {}), 2019 ("normal", (input.clone() + 1, 1.0), {}), 2020 ("poisson", (input.clone(),), {}), 2021 ("rand", (size,), {"device": "cuda", "dtype": torch.float}), 2022 ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}), 2023 ("randn", (size,), {"device": "cuda", "dtype": torch.float}), 2024 ) 2025 2026 # Tensor methods to test with sample args (tuple) 2027 tensor_with_args = ( 2028 ("bernoulli_", (input.clone(),)), 2029 ("cauchy_", ()), 2030 ("exponential_", ()), 2031 ("geometric_", (0.3,)), 2032 ("log_normal_", ()), 2033 ("normal_", ()), 2034 ("random_", ()), 2035 ("uniform_", ()), 2036 ) 2037 2038 def run(module, op, args, kwargs): 2039 torch.cuda.manual_seed(5) 2040 2041 # Each path runs a dummy op to increment the state a bit before creating controls. 2042 if module == "torch": 2043 dummy = getattr(torch, op)(*args, **kwargs) 2044 control1 = getattr(torch, op)(*args, **kwargs) 2045 control2 = getattr(torch, op)(*args, **kwargs) 2046 else: 2047 dummy = alloc.clone() 2048 control1 = alloc.clone() 2049 control2 = alloc.clone() 2050 getattr(dummy, op)(*args) 2051 getattr(control1, op)(*args) 2052 getattr(control2, op)(*args) 2053 2054 stream = torch.cuda.Stream() 2055 stream.wait_stream(torch.cuda.current_stream()) 2056 with torch.cuda.stream(stream): 2057 torch.cuda.manual_seed(5) 2058 2059 g = torch.cuda.CUDAGraph() 2060 torch.cuda.empty_cache() 2061 if module == "torch": 2062 g.capture_begin() 2063 t1 = getattr(torch, op)(*args, **kwargs) 2064 t2 = getattr(torch, op)(*args, **kwargs) 2065 g.capture_end() 2066 else: 2067 t1 = alloc.clone() 2068 t2 = alloc.clone() 2069 g.capture_begin() 2070 getattr(t1, op)(*args) 2071 getattr(t2, op)(*args) 2072 g.capture_end() 2073 torch.cuda.current_stream().wait_stream(stream) 2074 2075 if not TEST_CUDAMALLOCASYNC: 2076 # Makes sure values haven't been populated yet 2077 # (in other words, makes sure capture didn't actually run ops). 2078 # We can only try this with the native allocator, for which captured 2079 # addresses are already backed by cudaMalloced memory. 2080 # If we try it with cudaMallocAsync, CUDA won't event consider 2081 # the captured addresses allocated until replay(), and if we 2082 # access them before replay() we get IMAs. 2083 try: 2084 self.assertNotEqual(control1, t1) 2085 self.assertNotEqual(control2, t2) 2086 except Exception as e: 2087 raise RuntimeError("Failed on " + module + "." + op) from e 2088 2089 # Set a new seed to check if graph would use it 2090 for seed in [6, 314, 271]: 2091 torch.cuda.manual_seed(seed) 2092 # Runs a dummy op prelude, as for controls, to make sure replay() 2093 # picks up the dummy op's state increment. 2094 if module == "torch": 2095 dummy = getattr(torch, op)(*args, **kwargs) 2096 control1 = getattr(torch, op)(*args, **kwargs) 2097 control2 = getattr(torch, op)(*args, **kwargs) 2098 else: 2099 getattr(dummy, op)(*args) 2100 getattr(control1, op)(*args) 2101 getattr(control2, op)(*args) 2102 2103 torch.cuda.manual_seed(seed) 2104 if module == "torch": 2105 dummy = getattr(torch, op)(*args, **kwargs) 2106 else: 2107 getattr(dummy, op)(*args) 2108 2109 # see above comment on TEST_CUDAMALLOCASYNC 2110 if not TEST_CUDAMALLOCASYNC: 2111 t1.copy_(alloc) 2112 t2.copy_(alloc) 2113 2114 # Runs RNG ops that fill t1 and t2. 2115 g.replay() 2116 2117 try: 2118 self.assertEqual(control1, t1) 2119 self.assertEqual(control2, t2) 2120 except Exception as e: 2121 raise RuntimeError("Failed on " + module + "." + op) from e 2122 2123 # We hold references to all tensors used across streams up til this sync, 2124 # so no need to call record_stream on those tensors. 2125 torch.cuda.synchronize() 2126 2127 for op_with_args in torch_with_args: 2128 run("torch", *op_with_args) 2129 2130 for meth_with_args in tensor_with_args: 2131 # Adds an empty dict for kwargs, which none of the Tensor methods use 2132 run("Tensor", *(meth_with_args + ({},))) 2133 2134 @unittest.skipIf( 2135 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2136 ) 2137 def test_graph_two_successive(self): 2138 torch.cuda.empty_cache() 2139 2140 size = 1000 2141 kSmallBuffer = 2097152 2142 2143 def func_with_temps(t, val): 2144 x = t.clone() + val 2145 y = t.clone() + val 2146 return x + y 2147 2148 s = torch.cuda.Stream() 2149 2150 for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): 2151 g0 = torch.cuda.CUDAGraph() 2152 g1 = torch.cuda.CUDAGraph() 2153 2154 a = torch.ones((size,), device="cuda") 2155 2156 s.wait_stream(torch.cuda.current_stream()) 2157 with torch.cuda.stream(s): 2158 g0_args = ( 2159 (torch.cuda.graph_pool_handle(),) 2160 if share_mem == "via graph_pool_handle()" 2161 else () 2162 ) 2163 g0.capture_begin(*g0_args) 2164 b = a.clone() 2165 for _ in range(5): 2166 b = func_with_temps(b, 1) 2167 g0.capture_end() 2168 2169 g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args 2170 g1.capture_begin(*g1_args) 2171 for _ in range(5): 2172 b = func_with_temps(b, 1) 2173 g1.capture_end() 2174 torch.cuda.current_stream().wait_stream(s) 2175 2176 # mixes unrelated eager ops with replays 2177 c = a.clone() 2178 for _ in range(2): 2179 c = func_with_temps(c, 3) 2180 g0.replay() 2181 for _ in range(2): 2182 c = func_with_temps(c, 3) 2183 g1.replay() 2184 for _ in range(2): 2185 c = func_with_temps(c, 3) 2186 2187 self.assertEqual(b.sum().item(), size * 3070) 2188 self.assertEqual(c.sum().item(), size * 442) 2189 2190 if not TEST_CUDAMALLOCASYNC: 2191 # These stat checks are specific to the native allocator. 2192 if share_mem != "Don't share": 2193 self.assertEqual( 2194 reserved_no_sharing # noqa: F821 2195 - torch.cuda.memory_stats()["reserved_bytes.all.current"], 2196 kSmallBuffer, 2197 ) 2198 else: 2199 reserved_no_sharing = torch.cuda.memory_stats()[ 2200 "reserved_bytes.all.current" 2201 ] 2202 2203 del a, b, c, g0, g1 2204 # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. 2205 torch.cuda.synchronize() 2206 torch.cuda.empty_cache() 2207 2208 @unittest.skipIf( 2209 (not TEST_CUDA_GRAPH) 2210 or IS_WINDOWS 2211 or ( # appears to still be broken on Windows as of 11.4+ 2212 torch.version.cuda 2213 and int(torch.version.cuda.split(".")[0]) == 11 2214 and int(torch.version.cuda.split(".")[1]) < 4 2215 ), 2216 "Graph bindings disallow concurrent replay for CUDA < 11.4, see " 2217 + "https://github.com/pytorch/pytorch/pull/57556", 2218 ) 2219 @unittest.skipIf( 2220 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2221 ) 2222 def test_graph_concurrent_replay(self): 2223 torch.cuda.empty_cache() 2224 2225 size = 1000000 # largeish to help expose race conditions 2226 2227 def func_with_temps(t, val): 2228 x = t.clone() + val 2229 y = t.clone() + val 2230 return x + y 2231 2232 s = torch.cuda.Stream() 2233 2234 for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): 2235 g0 = torch.cuda.CUDAGraph() 2236 g1 = torch.cuda.CUDAGraph() 2237 2238 s0 = torch.cuda.Stream() 2239 s1 = torch.cuda.Stream() 2240 2241 a = torch.ones((size,), device="cuda") 2242 2243 s.wait_stream(torch.cuda.current_stream()) 2244 with torch.cuda.stream(s): 2245 g0_args = ( 2246 (torch.cuda.graph_pool_handle(),) 2247 if share_mem == "via graph_pool_handle()" 2248 else () 2249 ) 2250 g0.capture_begin(*g0_args) 2251 b = a.clone() 2252 for _ in range(5): 2253 b = func_with_temps(b, 1) 2254 g0.capture_end() 2255 2256 g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args 2257 g1.capture_begin(*g1_args) 2258 c = a.clone() 2259 for _ in range(5): 2260 c = func_with_temps(c, 2) 2261 g1.capture_end() 2262 2263 # To reproduce data corruption, I need g0 and g1's kernels to run concurrently. 2264 # But replay() (especially cudaGraphLaunch) can incur significant CPU overhead. 2265 # The following pattern helps align device-side execution of g0 and g1's kernels. 2266 torch.cuda.synchronize() 2267 with torch.cuda.stream(s0): 2268 torch.cuda._sleep(1000000) 2269 s1.wait_stream(s0) 2270 g0.replay() 2271 with torch.cuda.stream(s1): 2272 g1.replay() 2273 torch.cuda.current_stream().wait_stream(s0) 2274 torch.cuda.current_stream().wait_stream(s1) 2275 2276 if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"): 2277 # If we used the native allocator and shared mempools, 2278 # we expect the concurrent replays corrupted each other. 2279 self.assertNotEqual(b.sum().item(), size * 94) 2280 self.assertNotEqual(c.sum().item(), size * 156) 2281 else: 2282 # If we EITHER 2283 # - used the native allocator without sharing mempools, OR 2284 # - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe 2285 # we don't expect memory corruption. 2286 self.assertEqual(b.sum().item(), size * 94) 2287 self.assertEqual(c.sum().item(), size * 156) 2288 2289 del a, b, c, g0, g1 2290 # Tensors used across streams (a, b, c) were held until just now, so no need to call record_stream on them. 2291 torch.cuda.synchronize() 2292 torch.cuda.empty_cache() 2293 2294 @unittest.skipIf( 2295 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2296 ) 2297 def test_graph_three_successive(self): 2298 torch.cuda.empty_cache() 2299 2300 size = 1000 2301 2302 s = torch.cuda.Stream() 2303 2304 for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): 2305 a = torch.ones((size,), device="cuda") 2306 2307 g0 = torch.cuda.CUDAGraph() 2308 g1 = torch.cuda.CUDAGraph() 2309 g2 = torch.cuda.CUDAGraph() 2310 2311 s.wait_stream(torch.cuda.current_stream()) 2312 with torch.cuda.stream(s): 2313 g0_args = ( 2314 (torch.cuda.graph_pool_handle(),) 2315 if share_mem == "via graph_pool_handle()" 2316 else () 2317 ) 2318 g0.capture_begin(*g0_args) 2319 b = a.clone() 2320 c = b + 1 2321 d = b + 2 2322 g0.capture_end() 2323 2324 args = (g0.pool(),) if share_mem == "via pool()" else g0_args 2325 2326 g1.capture_begin(*args) 2327 e = c + 3 2328 del c 2329 g1.capture_end() 2330 2331 g2.capture_begin(*args) 2332 f = d + 4 2333 g2.capture_end() 2334 torch.cuda.current_stream().wait_stream(s) 2335 2336 # Tests that replaying in capture order is valid 2337 g0.replay() 2338 g1.replay() 2339 g2.replay() 2340 2341 self.assertEqual(e.sum().item(), size * 5) 2342 self.assertEqual(f.sum().item(), size * 7) 2343 2344 # Tests that replaying as g0, g2, g1 is only valid if they don't share a pool 2345 g0.replay() 2346 g2.replay() 2347 g1.replay() 2348 2349 expect_corruption = (not TEST_CUDAMALLOCASYNC) and ( 2350 share_mem != "Don't share" 2351 ) 2352 # If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f. 2353 # We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3". 2354 self.assertEqual( 2355 e.sum().item(), size * (7 + 3) if expect_corruption else size * 5 2356 ) 2357 self.assertEqual(f.sum().item(), size * 7) 2358 2359 del a, b, d, e, f, g0, g1, g2 2360 # Tensors used across streams (a, e, f) were held until just now, so no need to call record_stream on them. 2361 torch.cuda.synchronize() 2362 torch.cuda.empty_cache() 2363 2364 @unittest.skipIf( 2365 (not TEST_CUDA_GRAPH) or TEST_CUDAMALLOCASYNC, 2366 "CUDA >= 11.0 or ROCM >= 5.3 required for graphs", 2367 ) 2368 def test_graph_memory_stats_and_use_result_after_destroy_graph(self): 2369 kSmallSize = 1048576 2370 kSmallBuffer = 2097152 2371 kLargeBuffer = 20971520 2372 kMinLargeAlloc = 10485760 2373 kRoundLarge = 2097152 2374 2375 elem = 4 2376 2377 # this was annoying to write but stresses the expectations pretty rigorously 2378 cases = ( 2379 (512 // elem, 1, kSmallBuffer, kSmallBuffer, "small_pool"), 2380 (kSmallSize // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"), 2381 ((kSmallSize + 512) // elem, 1, kLargeBuffer, kLargeBuffer, "large_pool"), 2382 ( 2383 (kMinLargeAlloc - 512) // elem, 2384 2, 2385 2 * kLargeBuffer, 2386 kLargeBuffer, 2387 "large_pool", 2388 ), 2389 ( 2390 (kMinLargeAlloc + 512) // elem, 2391 3, 2392 3 2393 * ( 2394 kRoundLarge 2395 * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge) 2396 ), 2397 kRoundLarge * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge), 2398 "large_pool", 2399 ), 2400 ) 2401 2402 stats_to_check = ("segment.", "reserved_bytes.", "active.", "active_bytes.") 2403 2404 gc.collect() 2405 torch.cuda.empty_cache() 2406 2407 s = torch.cuda.Stream() 2408 2409 for ( 2410 numel, 2411 delta_cudaMallocs, 2412 delta_cudaMalloc_bytes, 2413 delta_cudaMalloc_bytes_post_del_g, 2414 pool_string, 2415 ) in cases: 2416 if pool_string == "small_pool": 2417 delta_active_blocks = 3 # one from "b" plus a sneaky two from CUDAGraph's one-element rng seed and offset holders 2418 delta_active_bytes = ( 2419 numel * elem + 1024 2420 ) # + 1024 for CUDAGraph's rng seed and offset holders each 2421 else: 2422 delta_active_blocks = 1 # We only check the large pool, which isn't affected by rng offset holder 2423 delta_active_bytes = numel * elem 2424 2425 g = torch.cuda.CUDAGraph() 2426 s.wait_stream(torch.cuda.current_stream()) 2427 with torch.cuda.stream(s): 2428 # Allocation stat estimates assume input is created on the same stream as capture_begin() 2429 # (in other words, the same stream silo as the rng offset holder, which is not allocated from the 2430 # capture's private pool). 2431 a = torch.ones((numel,), device="cuda") 2432 2433 precapture_stats = torch.cuda.memory_stats() 2434 2435 g.capture_begin() 2436 b = a.clone() 2437 for _ in range(5): 2438 b = b.clone() + 1 2439 g.capture_end() 2440 torch.cuda.current_stream().wait_stream(s) 2441 2442 gc.collect() 2443 2444 postcapture_stats = torch.cuda.memory_stats() 2445 2446 expecteds = ( 2447 delta_cudaMallocs, 2448 delta_cudaMalloc_bytes, 2449 delta_active_blocks, 2450 delta_active_bytes, 2451 ) 2452 # Double checks replay and stats before and after a call to empty_cache 2453 for i in range(2): 2454 for stat, expected in zip(stats_to_check, expecteds): 2455 stat = stat + pool_string + ".current" 2456 current = postcapture_stats[stat] - precapture_stats[stat] 2457 2458 # There will only ever be one expandable segment in each of the small and large pools. The way the 2459 # bookeeping is done in the allocator means that we never increment the number of segments. 2460 if self.expandable_segments and "segment" in stat: 2461 expected = 0 2462 # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an 2463 # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is 2464 # smaller than the page size 2465 if ( 2466 self.expandable_segments 2467 and "reserved" in stat 2468 and (numel == cases[3][0] or numel == cases[4][0]) 2469 ): 2470 expected = 2 * kLargeBuffer 2471 2472 self.assertEqual( 2473 current, 2474 expected, 2475 "Pre to post capture delta of " 2476 + stat 2477 + f" = {current}, expected = {expected}, numel = {numel}", 2478 ) 2479 2480 g.replay() 2481 self.assertEqual(b.sum().item(), 6 * numel) 2482 if i == 0: 2483 torch.cuda.empty_cache() 2484 2485 del g 2486 gc.collect() 2487 torch.cuda.empty_cache() 2488 postdel_stats = torch.cuda.memory_stats() 2489 2490 # Uses graph result b after graph has been deleted 2491 self.assertEqual(b.sum().item(), 6 * numel) 2492 2493 # b should be the only live reference remaining from the graph's private pool 2494 expecteds = (1, delta_cudaMalloc_bytes_post_del_g, 1, numel * elem) 2495 for stat, expected in zip(stats_to_check, expecteds): 2496 stat = stat + pool_string + ".current" 2497 current = postdel_stats[stat] - precapture_stats[stat] 2498 2499 # There will only ever be one expandable segment in each of the small and large pools. The way the 2500 # bookeeping is done in the allocator means that we never increment the number of segments. 2501 if self.expandable_segments and "segment" in stat: 2502 expected = 0 2503 # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an 2504 # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is 2505 # smaller than the page size 2506 if ( 2507 self.expandable_segments 2508 and "reserved" in stat 2509 and numel == cases[3][0] 2510 ): 2511 expected = 2 * kLargeBuffer 2512 if ( 2513 self.expandable_segments 2514 and "reserved" in stat 2515 and numel == cases[4][0] 2516 ): 2517 expected = kLargeBuffer 2518 2519 self.assertEqual( 2520 current, 2521 expected, 2522 "Pre capture to post graph delete delta of " 2523 + stat 2524 + f" = {current}, expected = {expected}, numel = {numel}", 2525 ) 2526 2527 # del a, b before the next case is essential, otherwise overwriting a and b in the next case 2528 # can throw off its allocation/deallocation counts. 2529 del a, b 2530 # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. 2531 torch.cuda.synchronize() 2532 torch.cuda.empty_cache() 2533 2534 @unittest.skipIf( 2535 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2536 ) 2537 def test_graph_record_stream(self): 2538 # Makes sure graph capture defers attempting to reclaim allocations used across streams. See 2539 # "Q. Why skip process_events if a capture might be underway?" in c10/cuda/CUDACachingAllocator.cpp 2540 torch.cuda.empty_cache() 2541 2542 potential_problem = torch.zeros((3,), device="cuda") 2543 a = torch.zeros((3,), device="cuda") 2544 s0 = torch.cuda.Stream() 2545 s1 = torch.cuda.Stream() 2546 s2 = torch.cuda.Stream() 2547 g = torch.cuda.CUDAGraph() 2548 2549 torch.cuda.synchronize() 2550 with torch.cuda.stream(s0): 2551 potential_problem.record_stream(s0) 2552 torch.cuda._sleep(TestCuda.FIFTY_MIL_CYCLES) 2553 potential_problem.fill_(1.0) 2554 del potential_problem 2555 2556 with torch.cuda.stream(s1): 2557 g.capture_begin() 2558 # potential_problem's allocation should still be outstanding. if DeviceCachingAllocator::malloc 2559 # mistakenly calls process_events, it will trigger cudaEventQueries on potential_problem's end-of-life 2560 # event, which will cause the capture to error. 2561 b = a.clone() 2562 2563 # Let's also see what happens if we record_stream on a tensor during capture. 2564 s2.wait_stream(s1) 2565 with torch.cuda.stream(s2): 2566 b.fill_(1.0) 2567 b.record_stream(s2) # dummy record_stream 2568 del b 2569 s1.wait_stream(s2) 2570 g.capture_end() 2571 torch.cuda.synchronize() 2572 2573 # dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event. 2574 c = torch.zeros((3,), device="cuda") 2575 2576 @skipIfRocm 2577 @unittest.skipIf( 2578 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2579 ) 2580 # If this test is the first in the process to try cudnn rnns with dropout, it'll initialize 2581 # DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior 2582 # as a memory leak unless we skip the leak check. 2583 @skipCUDAMemoryLeakCheckIf(True) 2584 @serialTest() 2585 def test_graph_cudnn_dropout(self): 2586 # Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp. 2587 # In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should 2588 # avoid syncing noncapturing streams with captured events or vice versa. 2589 torch.cuda.empty_cache() 2590 2591 model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda() 2592 x = torch.ones(100, 192, 512, device="cuda") 2593 2594 y = model(x) 2595 2596 g = torch.cuda.CUDAGraph() 2597 s = torch.cuda.Stream() 2598 s.wait_stream(torch.cuda.current_stream()) 2599 with torch.cuda.stream(s): 2600 g.capture_begin() 2601 y = model(x) 2602 g.capture_end() 2603 torch.cuda.current_stream().wait_stream(s) 2604 2605 g.replay() 2606 2607 y = model(x) 2608 2609 @unittest.skipIf( 2610 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2611 ) 2612 @parametrize( 2613 "with_amp,cache_enabled,allow_unused_input", 2614 [ 2615 subtest((False, False, True), decorators=[skipIfRocm]), 2616 subtest((True, False, True), decorators=[skipIfRocm]), 2617 subtest((True, True, True), decorators=[unittest.expectedFailure]), 2618 subtest((False, False, False), decorators=[unittest.expectedFailure]), 2619 ], 2620 name_fn=lambda x, y, z: "{}{}{}".format( 2621 {True: "with_amp", False: "without_amp"}[x], 2622 {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", 2623 {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], 2624 ), 2625 ) 2626 @serialTest() 2627 def test_graph_make_graphed_callables( 2628 self, with_amp, cache_enabled, allow_unused_input 2629 ): 2630 torch.manual_seed(5) 2631 torch.cuda.manual_seed(5) 2632 2633 N, D_in, H, D_out = 640, 4096, 2048, 1024 2634 2635 class MLP1(torch.nn.Module): 2636 def __init__(self, D_in: int, H: int, D_out: int): 2637 super().__init__() 2638 self.net_1 = torch.nn.Sequential( 2639 torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) 2640 ).cuda() 2641 self.net_2 = torch.nn.Sequential( 2642 torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) 2643 ).cuda() 2644 2645 def forward(self, input_dict: dict): 2646 x = input_dict["x"] 2647 return self.net_2(self.net_1(x)) 2648 2649 class MLP2(torch.nn.Module): 2650 def __init__(self, D_in: int, H: int, D_out: int): 2651 super().__init__() 2652 self.net_1 = torch.nn.Sequential( 2653 torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) 2654 ).cuda() 2655 self.net_2 = torch.nn.Sequential( 2656 torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) 2657 ).cuda() 2658 2659 def forward(self, x): 2660 return self.net_2(self.net_1(x)) 2661 2662 class ParameterlessModule(torch.nn.Module): 2663 def forward(self, x): 2664 idx = ( 2665 torch.arange(x.size(0), device=x.device) 2666 .view(-1, 1) 2667 .repeat(1, x.size(1)) 2668 ) 2669 return {"output": torch.gather(x, 0, idx)} 2670 2671 models = [] 2672 for _ in range(2): 2673 model_section1 = MLP1(D_in, H, H).cuda() 2674 model_section2 = MLP2(H, H, D_out).cuda() 2675 model_section3 = ParameterlessModule().cuda() 2676 models.append( 2677 torch.nn.Sequential(model_section1, model_section2, model_section3) 2678 ) 2679 2680 model_graphed = models[0] 2681 model_control = models[1] 2682 2683 model_graphed.load_state_dict(model_control.state_dict()) 2684 2685 opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1) 2686 opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1) 2687 2688 x = torch.randn(N, D_in, device="cuda") 2689 h = torch.randn(N, H, device="cuda", requires_grad=True) 2690 h2 = torch.randn(N, D_out, device="cuda", requires_grad=True) 2691 unused_input = torch.randn(N, H, device="cuda", requires_grad=True) 2692 y_pred = torch.randn(N, D_out, device="cuda", requires_grad=True) 2693 y = torch.randn(N, D_out, device="cuda") 2694 2695 loss_fn_control = torch.nn.functional.mse_loss 2696 relu_control = torch.nn.functional.relu 2697 2698 # This is a good stress test. It graphs four callables: two Modules and two python functions. 2699 with torch.amp.autocast( 2700 device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled 2701 ): 2702 ( 2703 model_graphed[0], 2704 model_graphed[1], 2705 model_graphed[2], 2706 relu_graphed, 2707 loss_fn_graphed, 2708 ) = torch.cuda.make_graphed_callables( 2709 ( 2710 model_graphed[0], 2711 model_graphed[1], 2712 model_graphed[2], 2713 relu_control, 2714 loss_fn_control, 2715 ), 2716 ( 2717 ({"x": x, "unused_input": unused_input},), 2718 (h,), 2719 (h2,), 2720 (y_pred,), 2721 (y_pred, y), 2722 ), 2723 allow_unused_input=allow_unused_input, 2724 ) 2725 2726 real_inputs = [torch.rand_like(x) for _ in range(10)] 2727 real_targets = [torch.rand_like(y) for _ in range(10)] 2728 2729 for m, opt, relu, loss_fn in zip( 2730 (model_graphed, model_control), 2731 (opt_graphed, opt_control), 2732 (relu_graphed, relu_control), 2733 (loss_fn_graphed, loss_fn_control), 2734 ): 2735 # Resets RNC states before iterations for graphed and ungraphed models, 2736 # so dropout math should be bitwise identical for both. 2737 torch.manual_seed(5) 2738 torch.cuda.manual_seed(5) 2739 for data, target in zip(real_inputs, real_targets): 2740 opt.zero_grad(set_to_none=True) 2741 with torch.amp.autocast( 2742 device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled 2743 ): 2744 y_pred = m({"x": data, "unused_input": unused_input})["output"] 2745 y_pred = relu(y_pred) 2746 loss = loss_fn(y_pred, target) 2747 loss.backward() 2748 opt.step() 2749 2750 for p, pc in zip(model_graphed.parameters(), model_control.parameters()): 2751 self.assertEqual(p, pc) 2752 2753 # We graphed the models in training mode. Eval should still run ungraphed. 2754 model_graphed.eval() 2755 model_control.eval() 2756 self.assertEqual( 2757 model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) 2758 ) 2759 2760 @unittest.skipIf( 2761 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2762 ) 2763 @parametrize( 2764 "with_amp,cache_enabled,allow_unused_input", 2765 [ 2766 subtest((False, False, True), decorators=[skipIfRocm]), 2767 subtest((True, False, True), decorators=[skipIfRocm]), 2768 subtest((True, True, True), decorators=[unittest.expectedFailure]), 2769 subtest((False, False, False), decorators=[skipIfRocm]), 2770 ], 2771 name_fn=lambda x, y, z: "{}{}{}".format( 2772 {True: "with_amp", False: "without_amp"}[x], 2773 {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", 2774 {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], 2775 ), 2776 ) 2777 @serialTest() 2778 def test_graph_make_graphed_callables_parameterless_nograd_module( 2779 self, with_amp, cache_enabled, allow_unused_input 2780 ): 2781 torch.manual_seed(5) 2782 torch.cuda.manual_seed(5) 2783 2784 N, D_in, H, D_out = 640, 4096, 2048, 1024 2785 2786 class ParameterlessModule(torch.nn.Module): 2787 def forward(self, input_dict: dict): 2788 x = input_dict["x"] 2789 idx = ( 2790 torch.arange(x.size(0), device=x.device) 2791 .view(-1, 1) 2792 .repeat(1, x.size(1)) 2793 ) 2794 return {"output": torch.gather(x, 0, idx)} 2795 2796 models = [] 2797 for _ in range(2): 2798 model_section1 = ParameterlessModule().cuda() 2799 models.append(torch.nn.Sequential(model_section1)) 2800 2801 model_graphed = models[0] 2802 model_control = models[1] 2803 2804 model_graphed.load_state_dict(model_control.state_dict()) 2805 2806 x = torch.randn(N, D_in, device="cuda", requires_grad=False) 2807 unused_input = torch.randn(N, H, device="cuda", requires_grad=False) 2808 y_pred = torch.randn(N, D_in, device="cuda", requires_grad=False) 2809 y = torch.randn(N, D_in, device="cuda") 2810 2811 # This is a good stress test. It graphs four callables: two Modules and two python functions. 2812 with torch.amp.autocast( 2813 device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled 2814 ): 2815 model_graphed[0] = torch.cuda.make_graphed_callables( 2816 model_graphed[0], 2817 ({"x": x, "unused_input": unused_input},), 2818 allow_unused_input=allow_unused_input, 2819 ) 2820 2821 real_inputs = [torch.rand_like(x, requires_grad=True) for _ in range(10)] 2822 real_targets = [torch.rand_like(y) for _ in range(10)] 2823 2824 for m in (model_graphed, model_control): 2825 # Resets RNC states before iterations for graphed and ungraphed models, 2826 # so dropout math should be bitwise identical for both. 2827 torch.manual_seed(5) 2828 torch.cuda.manual_seed(5) 2829 for data, _ in zip(real_inputs, real_targets): 2830 with torch.amp.autocast( 2831 device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled 2832 ): 2833 out = m({"x": data, "unused_input": unused_input})["output"] 2834 2835 # We graphed the models in training mode. Eval should still run ungraphed. 2836 model_graphed.eval() 2837 model_control.eval() 2838 self.assertEqual( 2839 model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) 2840 ) 2841 2842 @unittest.skipIf( 2843 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2844 ) 2845 def test_graph_make_graphed_callables_same_pool(self): 2846 torch.manual_seed(5) 2847 torch.cuda.manual_seed(5) 2848 models = [] 2849 num_models = 3 2850 for _ in range(num_models): 2851 models.append( 2852 torch.nn.Sequential( 2853 torch.nn.Linear(32, 128), 2854 torch.nn.ReLU(), 2855 torch.nn.Linear(128, 128), 2856 ).cuda() 2857 ) 2858 # we will reuse the same pool for all graph captures 2859 mempool = torch.cuda.graph_pool_handle() 2860 graphed_models = [] 2861 for model in models: 2862 x = torch.randn([64, 32], device="cuda") 2863 graphed_model = deepcopy(model) 2864 graphed_model = torch.cuda.make_graphed_callables( 2865 graphed_model, (x,), pool=mempool 2866 ) 2867 graphed_models.append(graphed_model) 2868 2869 for model, graphed_model in zip(models, graphed_models): 2870 x = torch.randn([64, 32], device="cuda") 2871 y = model(x) 2872 yg = graphed_model(x) 2873 l = y.norm() 2874 lg = yg.norm() 2875 l.backward() 2876 lg.backward() 2877 2878 self.assertEqual(y, yg) 2879 self.assertEqual(l, lg) 2880 for p, pg in zip(model.parameters(), graphed_model.parameters()): 2881 self.assertEqual(p, pg) 2882 self.assertEqual(p.grad, pg.grad) 2883 self.assertNotEqual(p.data_ptr(), pg.data_ptr()) 2884 self.assertNotEqual(p.grad.data_ptr(), pg.grad.data_ptr()) 2885 2886 def _test_graphed_optimizer( 2887 self, steps_warmup, steps_train, optimizer_ctor, kwargs 2888 ): 2889 for actually_do_graphs in (True, False): 2890 params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + [ 2891 torch.randn((), device="cuda") 2892 ] 2893 params_control = [p.clone().requires_grad_() for p in params] 2894 params_graphed = [p.clone().requires_grad_() for p in params] 2895 2896 grads = [ 2897 [torch.randn_like(p) for p in params] 2898 for _ in range(steps_warmup + steps_train) 2899 ] 2900 2901 # Control (capturable=False) 2902 2903 opt = optimizer_ctor(params_control, capturable=False, **kwargs) 2904 2905 for i in range(steps_warmup + steps_train): 2906 for j, p in enumerate(params_control): 2907 p.grad = grads[i][j] 2908 opt.step() 2909 2910 # capturable=True 2911 2912 opt = optimizer_ctor(params_graphed, capturable=True, **kwargs) 2913 2914 for i in range(steps_warmup): 2915 for j, p in enumerate(params_graphed): 2916 p.grad = grads[i][j] 2917 opt.step() 2918 2919 if actually_do_graphs: 2920 g = torch.cuda.CUDAGraph() 2921 with torch.cuda.graph(g): 2922 opt.step() 2923 2924 for i in range(steps_train): 2925 if actually_do_graphs: 2926 for j, p in enumerate(params_graphed): 2927 p.grad.copy_(grads[i + steps_warmup][j]) 2928 g.replay() 2929 else: 2930 # Passing capturable=True to the constructor and running without graphs should still be 2931 # numerically correct, even if it's not ideal for performance. 2932 for j, p in enumerate(params_graphed): 2933 p.grad = grads[i + steps_warmup][j] 2934 opt.step() 2935 2936 for p_control, p_graphed in zip(params_control, params_graphed): 2937 self.assertEqual(p_control, p_graphed) 2938 2939 @unittest.skipIf( 2940 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2941 ) 2942 def test_graph_optims_with_explicitly_capturable_param_groups(self): 2943 # mimicking `_test_graphed_optimizer` maladroitly to pass two param_groups to optimizer.__init__ 2944 n_warmup, n_replay = 3, 2 2945 for optimizer, second_param_group_capturable in product( 2946 ( 2947 torch.optim.Adam, 2948 torch.optim.AdamW, 2949 torch.optim.ASGD, 2950 torch.optim.Adamax, 2951 torch.optim.NAdam, 2952 torch.optim.RAdam, 2953 torch.optim.Adadelta, 2954 torch.optim.RMSprop, 2955 torch.optim.Rprop, 2956 ), 2957 (True, False), 2958 ): 2959 ref_p1, param1 = ( 2960 torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) 2961 ) 2962 ref_p2, param2 = ( 2963 torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) 2964 ) 2965 grads1, grads2 = ( 2966 [torch.randn_like(param1) for _ in range(n_warmup + n_replay)] 2967 for _ in range(2) 2968 ) 2969 ref_grads1, ref_grads2 = ( 2970 [t.clone() for t in tensors] for tensors in (grads1, grads2) 2971 ) 2972 params = [ 2973 {"params": [param1], "capturable": True}, 2974 {"params": [param2], "capturable": second_param_group_capturable}, 2975 ] 2976 opt = optimizer(params) 2977 opt_ = optimizer( 2978 [ 2979 {"params": [ref_p1], "capturable": False}, 2980 {"params": [ref_p2], "capturable": False}, 2981 ] 2982 ) 2983 2984 for i in range(n_warmup + n_replay): 2985 ref_p1.grad = ref_grads1[i] 2986 ref_p2.grad = ref_grads2[i] 2987 opt_.step() 2988 2989 for i in range(n_warmup): 2990 param1.grad = grads1[i] 2991 param2.grad = grads2[i] 2992 opt.step() 2993 2994 g = torch.cuda.CUDAGraph() 2995 if not second_param_group_capturable: 2996 with self.assertRaisesRegex(RuntimeError, "Attempting CUDA graph"): 2997 with torch.cuda.graph(g): 2998 opt.step() 2999 else: 3000 with torch.cuda.graph(g): 3001 opt.step() 3002 3003 for i in range(n_replay): 3004 param1.grad.copy_(grads1[n_warmup + i]) 3005 param2.grad.copy_(grads2[n_warmup + i]) 3006 g.replay() 3007 self.assertEqual(ref_p1, param1) 3008 self.assertEqual(ref_p2, param2) 3009 3010 @unittest.skipIf( 3011 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 3012 ) 3013 def test_cuda_graph_error_options(self): 3014 def fn(): 3015 x = torch.zeros([2000], device="cuda") 3016 y = x + x + x 3017 return y 3018 3019 mem = None 3020 3021 def raw_malloc(): 3022 global mem 3023 mem = None 3024 stream = torch.cuda.Stream() 3025 try: 3026 with torch.cuda.stream(stream): 3027 mem = torch.cuda.caching_allocator_alloc(1024) 3028 except BaseException: 3029 if mem is None: 3030 return 3031 try: 3032 torch.cuda.caching_allocator_delete(mem) 3033 mem = None 3034 return None 3035 except BaseException: 3036 pass 3037 3038 def throws_on_cuda_event(capture_error_mode): 3039 graph = torch.cuda.CUDAGraph() 3040 torch.cuda.synchronize() 3041 stream = torch.cuda.Stream() 3042 stream.wait_stream(torch.cuda.current_stream()) 3043 with torch.cuda.stream(stream): 3044 fn() 3045 stream.synchronize() 3046 torch.cuda.current_stream().wait_stream(stream) 3047 torch.cuda.synchronize() 3048 try: 3049 with torch.cuda.graph( 3050 graph, stream=stream, capture_error_mode=capture_error_mode 3051 ): 3052 out = fn() 3053 thread = threading.Thread(target=raw_malloc) 3054 thread.start() 3055 thread.join() 3056 except Exception: 3057 if mem is not None: 3058 torch.cuda.caching_allocator_delete(mem) 3059 return True 3060 3061 return False 3062 3063 self.assertFalse(throws_on_cuda_event("thread_local")) 3064 self.assertFalse(throws_on_cuda_event("relaxed")) 3065 3066 # Exception would Corrupt Process and make other tests fail 3067 # self.assertTrue(throws_on_cuda_event("global")) 3068 3069 @unittest.skipIf( 3070 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 3071 ) 3072 def test_cuda_graph_allocator_propagates_stream(self): 3073 segments = torch.cuda.memory_snapshot() 3074 existing_pools = {s["segment_pool_id"] for s in segments} 3075 x = torch.randn(10240000, device="cuda") 3076 y = torch.rand_like(x) 3077 g = torch.cuda.CUDAGraph() 3078 s0 = torch.cuda.Stream() 3079 s1 = torch.cuda.Stream() 3080 s0.wait_stream(torch.cuda.current_stream()) 3081 with torch.cuda.stream(s0): 3082 g.capture_begin() 3083 z = x + y 3084 with torch.cuda.stream(s1): 3085 s1.wait_stream(s0) 3086 w = z + y 3087 s0.wait_stream(s1) 3088 with torch.cuda.stream(s0): 3089 g.capture_end() 3090 segments = torch.cuda.memory_snapshot() 3091 x = [ 3092 s["segment_pool_id"] 3093 for s in segments 3094 if s["segment_pool_id"] not in existing_pools 3095 ] 3096 self.assertEqual(len(x), 2) 3097 self.assertEqual(x[0], x[1]) 3098 3099 def test_batch_norm_gather_stats(self): 3100 input = torch.randn(1, 3, 3, 3, device="cuda") 3101 mean, invstd = torch.batch_norm_gather_stats( 3102 input, 3103 mean=torch.ones(2, 3, device="cuda"), 3104 invstd=torch.ones(2, 3, device="cuda"), 3105 running_mean=None, 3106 running_var=None, 3107 momentum=0.1, 3108 eps=1e-5, 3109 count=2, 3110 ) 3111 self.assertEqual(mean, torch.ones(3, device="cuda")) 3112 self.assertEqual(invstd, torch.ones(3, device="cuda")) 3113 3114 def test_matmul_memory_use(self): 3115 def get_max_used(): 3116 torch.cuda.synchronize() 3117 val = torch.cuda.max_memory_allocated() 3118 torch.cuda.reset_peak_memory_stats() 3119 return val 3120 3121 a = torch.rand(1, 32, 32, device="cuda") 3122 b = torch.rand(24, 32, 1, device="cuda") 3123 3124 get_max_used() 3125 3126 torch.matmul(a, b) 3127 3128 matmul_mem = get_max_used() 3129 3130 a = a.expand(24, 32, 32) 3131 torch.matmul(a, b) 3132 3133 matmul_expand_mem = get_max_used() 3134 3135 torch.bmm(a, b) 3136 3137 bmm_mem = get_max_used() 3138 3139 self.assertEqual(matmul_expand_mem, matmul_mem) 3140 self.assertEqual(bmm_mem, matmul_mem) 3141 3142 @unittest.skipIf(not TEST_WITH_ROCM, "ROCm-only test") 3143 def test_rocm_backward_pass_guard(self): 3144 # The test exercises a ROCm-specific feature. 3145 3146 class MyFunction(torch.autograd.Function): 3147 @staticmethod 3148 def forward(ctx, tensor, constant): 3149 self.assertFalse(torch._C._rocm_is_backward_pass()) 3150 ctx.constant = constant 3151 return tensor * constant 3152 3153 @staticmethod 3154 def backward(ctx, grad_output): 3155 self.assertTrue(torch._C._rocm_is_backward_pass()) 3156 return grad_output * ctx.constant, None 3157 3158 class MyModule(torch.nn.Module): 3159 def __init__(self) -> None: 3160 super().__init__() 3161 self.a = torch.nn.Parameter(torch.randn(())) 3162 3163 def forward(self, x): 3164 return MyFunction.apply(x, self.a) 3165 3166 model = MyModule() 3167 criterion = torch.nn.MSELoss(reduction="sum") 3168 optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) 3169 3170 x = torch.randn(5, 5) 3171 result = model(x) 3172 loss = criterion(result, x) 3173 optimizer.zero_grad() 3174 loss.backward() 3175 optimizer.step() 3176 3177 def test_matmul_device_mismatch(self): 3178 cpu = torch.rand((10, 10)) 3179 cuda = cpu.cuda() 3180 with self.assertRaisesRegex( 3181 RuntimeError, "Expected all tensors to be on the same device" 3182 ): 3183 cpu @ cuda 3184 with self.assertRaisesRegex( 3185 RuntimeError, "Expected all tensors to be on the same device" 3186 ): 3187 cuda @ cpu 3188 3189 for s, m1, m2 in product((cpu, cuda), repeat=3): 3190 if s.device == m1.device == m2.device: 3191 torch.addmm(s, m1, m2) 3192 else: 3193 with self.assertRaisesRegex( 3194 RuntimeError, "Expected all tensors to be on the same device" 3195 ): 3196 torch.addmm(s, m1, m2) 3197 3198 @unittest.skipIf(TEST_MULTIGPU, "Testing on one GPU is sufficient") 3199 def test_lazy_init(self): 3200 """Validate that no CUDA calls are made during `import torch` call""" 3201 3202 def check_output(script: str) -> str: 3203 return ( 3204 subprocess.check_output([sys.executable, "-c", script]) 3205 .decode("ascii") 3206 .strip() 3207 ) 3208 3209 VISIBLE_DEVICES = ( 3210 "HIP_VISIBLE_DEVICES" if TEST_WITH_ROCM else "CUDA_VISIBLE_DEVICES" 3211 ) 3212 test_script = f"import os; import torch;os.environ['{VISIBLE_DEVICES}']='32';print(torch.cuda.device_count())" 3213 rc = check_output(test_script) 3214 self.assertEqual(rc, "0") 3215 if not TEST_WITH_ROCM: 3216 # Check that `cuInit` was not called during the import 3217 # By using ctypes and calling cuDeviceCountGet() and expect CUDA_ERROR_NOT_INITIALIZED == 3 3218 # See https://github.com/pytorch/pytorch/issues/116276 for more details 3219 libcuda_name = "libcuda.so.1" if not IS_WINDOWS else "nvcuda.dll" 3220 cuda_driver_api_call = ( 3221 f"ctypes.CDLL('{libcuda_name}').cuDeviceGetCount(ctypes.byref(x))" 3222 ) 3223 rc = check_output( 3224 f"import torch; import ctypes;x=ctypes.c_int(-1);print({cuda_driver_api_call})" 3225 ) 3226 self.assertEqual(rc, "3") 3227 3228 @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing") 3229 def test_hip_device_count(self): 3230 """Validate device_count works with both CUDA/HIP visible devices""" 3231 test_script = """\ 3232import torch 3233import os 3234print(f"{torch.cuda.device_count()}") 3235""" 3236 custom_envs = [ 3237 {"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None}, 3238 {"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"}, 3239 {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"}, 3240 ] 3241 3242 for env_config in custom_envs: 3243 env = os.environ.copy() 3244 for key, value in env_config.items(): 3245 if value is None: 3246 env.pop(key, None) 3247 else: 3248 env[key] = value 3249 r = ( 3250 subprocess.check_output([sys.executable, "-c", test_script], env=env) 3251 .decode("ascii") 3252 .strip() 3253 ) 3254 self.assertEqual("1", r) 3255 3256 @unittest.skipIf(not TEST_MULTIGPU, "requires multiple devices") 3257 def test_device_count_not_cached_pre_init(self): 3258 visible_devices = ( 3259 "HIP_VISIBLE_DEVICES" if torch.version.hip else "CUDA_VISIBLE_DEVICES" 3260 ) 3261 test_script = f"""\ 3262import torch 3263import os 3264r1 = torch.cuda.device_count() 3265os.environ['{visible_devices}'] = '0' 3266r2 = torch.cuda.device_count() 3267torch.empty(10, device='cuda') 3268print(f"{{r1}}, {{r2}}") 3269""" 3270 3271 r = ( 3272 subprocess.check_output([sys.executable, "-c", test_script]) 3273 .decode("ascii") 3274 .strip() 3275 ) 3276 3277 x = torch.cuda.device_count() 3278 self.assertEqual(f"{x}, 1", r) 3279 3280 @unittest.skip("Disabling as USE_CUFILE=0 by default in builds") 3281 def test_gds_fails_in_ci(self): 3282 if IS_WINDOWS or TEST_WITH_ROCM: 3283 error_msg = "is not supported on this platform" 3284 else: 3285 error_msg = "cuFileHandleRegister failed" 3286 with TemporaryFileName() as f: 3287 with self.assertRaisesRegex(RuntimeError, error_msg): 3288 file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR) 3289 3290 3291@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 3292@torch.testing._internal.common_utils.markDynamoStrictTest 3293class TestCudaMallocAsync(TestCase): 3294 @unittest.skipIf( 3295 TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3296 ) 3297 def test_memory_snapshot(self): 3298 try: 3299 torch.cuda.memory.empty_cache() 3300 torch.cuda.memory._record_memory_history("state", stacks="python") 3301 # make x the second block in a segment 3302 torch.rand(2 * 311, 411, device="cuda") 3303 unused = torch.rand(310, 410, device="cuda") 3304 x = torch.rand(311, 411, device="cuda") 3305 3306 # create a bunch of tensors that all will tile into the 3307 # same segment to exercise the history merging code 3308 # 512B is the minimum block size, 3309 # so we allocate all the tensors to this size to make sure 3310 # they tile evenly 3311 tensors = [torch.rand(128, device="cuda") for _ in range(1000)] 3312 while tensors: 3313 del tensors[randint(0, len(tensors) - 1)] 3314 3315 # exercise the history trimming code 3316 torch.rand(128 * 5, device="cuda") 3317 3318 ss = torch.cuda.memory._snapshot() 3319 found_it = False 3320 for seg in ss["segments"]: 3321 self.assertTrue("frames" in seg) 3322 for b in seg["blocks"]: 3323 if b["requested_size"] == 311 * 411 * 4: 3324 self.assertTrue("test_cuda" in b["frames"][0]["filename"]) 3325 found_it = True 3326 self.assertEqual(x.untyped_storage().data_ptr(), b["address"]) 3327 self.assertTrue(found_it) 3328 3329 if not IS_WINDOWS: 3330 with tempfile.NamedTemporaryFile() as f: 3331 torch.cuda.memory._save_segment_usage(f.name) 3332 with open(f.name) as f2: 3333 self.assertTrue("test_cuda.py" in f2.read()) 3334 del unused 3335 del x 3336 torch.cuda.empty_cache() 3337 ss = torch.cuda.memory._snapshot() 3338 self.assertTrue( 3339 ss["device_traces"][0][-1]["action"] 3340 in ("segment_free", "segment_unmap") 3341 ) 3342 3343 finally: 3344 torch.cuda.memory._record_memory_history(None) 3345 3346 @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding") 3347 def test_direct_traceback(self): 3348 from torch._C._profiler import gather_traceback, symbolize_tracebacks 3349 3350 c = gather_traceback(True, True, True) 3351 (r,) = symbolize_tracebacks([c]) 3352 r = str(r) 3353 self.assertTrue("test_cuda.py" in r) 3354 self.assertTrue("unwind" in r) 3355 3356 @unittest.skipIf( 3357 TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3358 ) 3359 @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3360 def test_memory_snapshot_with_cpp(self): 3361 try: 3362 torch.cuda.memory.empty_cache() 3363 torch.cuda.memory._record_memory_history("state", stacks="all") 3364 x = torch.rand(311, 411, device="cuda") 3365 3366 ss = torch.cuda.memory._snapshot()["segments"] 3367 found_it = False 3368 for seg in ss: 3369 for b in seg["blocks"]: 3370 if b["requested_size"] == 311 * 411 * 4: 3371 self.assertTrue("::rand" in str(b["frames"])) 3372 found_it = True 3373 self.assertTrue(found_it) 3374 3375 finally: 3376 torch.cuda.memory._record_memory_history(None) 3377 3378 @skipIfRocm 3379 def test_memory_profiler_viz(self): 3380 with torch.profiler.profile( 3381 with_stack=True, profile_memory=True, record_shapes=True 3382 ) as prof: 3383 x = torch.rand(128, 128, device="cuda") 3384 x * x + x * x 3385 plot = profile_plot(prof) 3386 plot = json.dumps(_profile_to_snapshot(prof)) 3387 self.assertTrue("test_cuda.py" in plot) 3388 self.assertTrue("test_memory_profiler_viz" in plot) 3389 self.assertTrue("category" in plot) 3390 3391 @unittest.skipIf( 3392 TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3393 ) 3394 @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3395 def test_cycles(self): 3396 fired = False 3397 3398 def observer(html): 3399 nonlocal fired 3400 fired = True 3401 self.assertTrue("torch.Tensor" in html) 3402 self.assertTrue("test_cuda" in html) 3403 self.assertTrue("cell_contents" in html) 3404 3405 disarm = observe_tensor_cycles(observer) 3406 3407 def noop(): 3408 pass 3409 3410 try: 3411 3412 def create(): 3413 x = torch.empty(3, 4, device="cuda") 3414 3415 def foo(p): 3416 if p: 3417 return foo(not p) 3418 else: 3419 return x 3420 3421 return foo 3422 3423 create() 3424 gc.collect() 3425 # the callback has to run outside of the collect 3426 # call so it doesn't actual fire until the next 3427 # method call after a gc.collect 3428 noop() 3429 self.assertTrue(fired) 3430 finally: 3431 disarm() 3432 3433 @unittest.skipIf( 3434 TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3435 ) 3436 @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3437 def test_memory_plots(self): 3438 for context, stacks in ( 3439 ("all", "all" if IS_LINUX else "python"), 3440 ("all", "python"), 3441 (None, "python"), 3442 ): 3443 try: 3444 torch.cuda.memory.empty_cache() 3445 torch.cuda.memory._record_memory_history( 3446 "all", context=context, stacks=stacks 3447 ) 3448 3449 def run(): 3450 x = torch.rand(128, 128, device="cuda") 3451 x * x + x * x 3452 3453 run() 3454 cpp = stacks == "all" 3455 record_context = context is not None 3456 ss = torch.cuda.memory._snapshot() 3457 3458 tplot = trace_plot(ss) 3459 splot = segment_plot(ss) 3460 text = json.dumps(ss) 3461 3462 self.assertTrue(record_context == ("test_memory_plots" in text)) 3463 self.assertTrue(cpp == ("::rand" in text)) 3464 self.assertTrue(str(128 * 128 * 4) in text) 3465 3466 finally: 3467 torch.cuda.memory._record_memory_history(None) 3468 3469 @unittest.skipIf( 3470 TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3471 ) 3472 @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3473 def test_memory_plots_free_stack(self): 3474 for context in ["alloc", "all", "state"]: 3475 try: 3476 torch.cuda.memory.empty_cache() 3477 torch.cuda.memory._record_memory_history(context=context) 3478 x = None 3479 3480 def thealloc(): 3481 nonlocal x 3482 x = torch.rand(3, 4, device="cuda") 3483 3484 def thefree(): 3485 nonlocal x 3486 del x 3487 3488 thealloc() 3489 thefree() 3490 ss = json.dumps(torch.cuda.memory._snapshot()) 3491 self.assertTrue(("thefree" in ss) == (context == "all")) 3492 self.assertTrue(("thealloc" in ss) == (context != "state")) 3493 finally: 3494 torch.cuda.memory._record_memory_history(None) 3495 3496 @unittest.skipIf( 3497 TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3498 ) 3499 @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3500 def test_memory_plots_history_context(self): 3501 try: 3502 torch.cuda.memory.empty_cache() 3503 x = None 3504 3505 def should_capture1(): 3506 nonlocal x 3507 x = torch.rand(4, 4, device="cuda") 3508 3509 def should_not_capture(): 3510 nonlocal x 3511 x = torch.rand(3, 4, device="cuda") 3512 3513 def should_capture2(): 3514 nonlocal x 3515 x = torch.rand(4, 4, device="cuda") 3516 3517 # Recording with context and python call stacks should capture the call stack. 3518 torch.cuda.memory._record_memory_history(context="all", stacks="python") 3519 should_capture1() 3520 # Recording with context=None should not capture the call stack. 3521 torch.cuda.memory._record_memory_history(context=None) 3522 should_not_capture() 3523 # Recording with context and python call stacks should capture the call stack. 3524 torch.cuda.memory._record_memory_history(context="all", stacks="python") 3525 should_capture2() 3526 3527 ss = json.dumps(torch.cuda.memory._snapshot()) 3528 self.assertTrue("should_capture1" in ss) 3529 self.assertTrue("should_not_capture" not in ss) 3530 self.assertTrue("should_capture2" in ss) 3531 finally: 3532 torch.cuda.memory._record_memory_history(None) 3533 3534 @unittest.skipIf( 3535 TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3536 ) 3537 @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3538 def test_memory_plots_free_segment_stack(self): 3539 for context in ["alloc", "all", "state"]: 3540 try: 3541 torch.cuda.memory.empty_cache() 3542 torch.cuda.memory._record_memory_history(context=context) 3543 x = torch.rand(3, 4, device="cuda") 3544 del x 3545 torch.cuda.memory.empty_cache() 3546 3547 ss = json.dumps(torch.cuda.memory._snapshot()) 3548 self.assertTrue(("empty_cache" in ss) == (context == "all")) 3549 finally: 3550 torch.cuda.memory._record_memory_history(None) 3551 3552 @unittest.skipIf( 3553 TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3554 ) 3555 def test_memory_snapshot_script(self): 3556 try: 3557 torch.cuda.memory.empty_cache() 3558 torch.cuda.memory._record_memory_history("state", stacks="python") 3559 3560 @torch.jit.script 3561 def foo(): 3562 return torch.rand(311, 411, device="cuda") 3563 3564 x = foo() 3565 3566 ss = torch.cuda.memory._snapshot()["segments"] 3567 found_it = False 3568 for seg in ss: 3569 for b in seg["blocks"]: 3570 if b["requested_size"] == 311 * 411 * 4: 3571 self.assertTrue(b["frames"][0]["name"] == "foo") 3572 found_it = True 3573 self.assertTrue(found_it) 3574 3575 finally: 3576 torch.cuda.memory._record_memory_history(None) 3577 3578 def test_max_split_expandable(self): 3579 torch.cuda.memory.empty_cache() 3580 mb = 1024 * 1024 3581 _, all_memory = torch.cuda.memory.mem_get_info() 3582 total_allowed = 120 * mb 3583 fraction_allowed = total_allowed / all_memory 3584 assert int(fraction_allowed * all_memory) == total_allowed 3585 torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed) 3586 3587 def alloc(n): 3588 return torch.ones(n * mb, dtype=torch.int8, device="cuda") 3589 3590 torch.cuda.memory._set_allocator_settings( 3591 "expandable_segments:False,max_split_size_mb:40" 3592 ) 3593 a = alloc(40) 3594 torch.cuda.memory._set_allocator_settings( 3595 "expandable_segments:True,max_split_size_mb:40" 3596 ) 3597 b = alloc(40) 3598 torch.cuda.memory._set_allocator_settings( 3599 "expandable_segments:False,max_split_size_mb:40" 3600 ) 3601 c = alloc(40) 3602 with self.assertRaises(torch.OutOfMemoryError): 3603 alloc(40) 3604 del a, b, c 3605 # force release_cached_blocks to run with some expandable segments in the free list 3606 alloc(120) 3607 3608 def test_garbage_collect_expandable(self): 3609 torch.cuda.memory.empty_cache() 3610 mb = 1024 * 1024 3611 _, all_memory = torch.cuda.memory.mem_get_info() 3612 total_allowed = 120 * mb 3613 fraction_allowed = total_allowed / all_memory 3614 assert int(fraction_allowed * all_memory) == total_allowed 3615 torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed) 3616 3617 def alloc(n): 3618 return torch.ones(n * mb, dtype=torch.int8, device="cuda") 3619 3620 torch.cuda.memory._set_allocator_settings( 3621 "expandable_segments:False,garbage_collection_threshold:0.5" 3622 ) 3623 a = alloc(40) 3624 torch.cuda.memory._set_allocator_settings( 3625 "expandable_segments:True,garbage_collection_threshold:0.5" 3626 ) 3627 b = alloc(40) 3628 del a, b 3629 # causes GC to run. The expandable segment block will be split 3630 # so GC would not attempt to free it anyway, but this at least makes sure 3631 # expandable_segment blocks can be in the free list when this is called. 3632 alloc(80) 3633 3634 def test_allocator_settings(self): 3635 def power2_div(size, div_factor): 3636 pow2 = 1 3637 while pow2 < size: 3638 pow2 = pow2 * 2 3639 if pow2 == size: 3640 return pow2 3641 step = pow2 / 2 / div_factor 3642 ret = pow2 / 2 3643 while ret < size: 3644 ret = ret + step 3645 return ret 3646 3647 torch.cuda.memory.empty_cache() 3648 key_allocated = ( 3649 "active_bytes.all.allocated" 3650 if not TEST_CUDAMALLOCASYNC 3651 else "allocated_bytes.all.current" 3652 ) 3653 key_requested = "requested_bytes.all.allocated" 3654 3655 nelems = 21 * 1024 * 1024 3656 nbytes = 4 * nelems # floats are 4 bytes 3657 3658 nelems_big = 100 * 1024 * 1024 3659 nbytes_big = 4 * nelems_big # floats are 4 bytes 3660 3661 start_mem = torch.cuda.memory_stats()[key_allocated] 3662 torch.cuda.memory._set_allocator_settings("") 3663 x = torch.rand(nelems, device="cuda") 3664 3665 # test roundup_power2_divisions single value syntax 3666 reg_mem = torch.cuda.memory_stats()[key_allocated] 3667 start_requested = torch.cuda.memory_stats()[key_requested] 3668 torch.cuda.memory._set_allocator_settings("roundup_power2_divisions:4") 3669 y = torch.rand(nelems, device="cuda") 3670 3671 pow2_div4_mem = torch.cuda.memory_stats()[key_allocated] 3672 current_requested = torch.cuda.memory_stats()[key_requested] 3673 3674 self.assertTrue(reg_mem - start_mem == nbytes) 3675 if not TEST_CUDAMALLOCASYNC: 3676 # not supported with the cudaMallocAsync backend 3677 self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4)) 3678 self.assertTrue(current_requested - start_requested == nbytes) 3679 3680 torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5") 3681 torch.cuda.memory._set_allocator_settings( 3682 "garbage_collection_threshold:0.5,max_split_size_mb:40" 3683 ) 3684 3685 # should have reset the power2 divisions now 3686 torch.cuda.memory.empty_cache() 3687 start_mem = torch.cuda.memory_stats()[key_allocated] 3688 z = torch.rand(nelems, device="cuda") 3689 reg_mem = torch.cuda.memory_stats()[key_allocated] 3690 self.assertTrue(reg_mem - start_mem == nbytes) 3691 3692 # roundup_power2_divisions knob array syntax 3693 torch.cuda.memory.empty_cache() 3694 torch.cuda.memory._set_allocator_settings( 3695 "garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,128:2,256:2,512:2,1024:1,>:1]" 3696 ) 3697 start_mem = torch.cuda.memory_stats()[key_allocated] 3698 w = torch.rand(nelems, device="cuda") 3699 3700 pow2_div8_mem = torch.cuda.memory_stats()[key_allocated] 3701 if not TEST_CUDAMALLOCASYNC: 3702 # not supported with the cudaMallocAsync backend 3703 self.assertTrue(pow2_div8_mem - start_mem == power2_div(nbytes, 8)) 3704 3705 torch.cuda.memory.empty_cache() 3706 start_mem = torch.cuda.memory_stats()[key_allocated] 3707 v = torch.rand(nelems_big, device="cuda") 3708 3709 pow2_div2_mem = torch.cuda.memory_stats()[key_allocated] 3710 if not TEST_CUDAMALLOCASYNC: 3711 # not supported with the cudaMallocAsync backend 3712 self.assertTrue(pow2_div2_mem - start_mem == power2_div(nbytes_big, 2)) 3713 3714 torch.cuda.memory.empty_cache() 3715 torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:True") 3716 start_mem = torch.cuda.memory_stats()[key_allocated] 3717 w = torch.rand(nelems, device="cuda") 3718 reg_mem = torch.cuda.memory_stats()[key_allocated] 3719 self.assertTrue(reg_mem - start_mem == nbytes) 3720 3721 with self.assertRaises(RuntimeError): 3722 torch.cuda.memory._set_allocator_settings("foo:1,bar:2") 3723 3724 with self.assertRaises(RuntimeError): 3725 torch.cuda.memory._set_allocator_settings( 3726 "garbage_collection_threshold:1.2" 3727 ) 3728 3729 with self.assertRaises(RuntimeError): 3730 torch.cuda.memory._set_allocator_settings("max_split_size_mb:2") 3731 3732 with self.assertRaises(RuntimeError): 3733 torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:none") 3734 3735 with self.assertRaises(RuntimeError): 3736 torch.cuda.memory._set_allocator_settings( 3737 "pinned_use_cuda_host_register:none" 3738 ) 3739 3740 with self.assertRaises(RuntimeError): 3741 torch.cuda.memory._set_allocator_settings( 3742 "pinned_num_register_threads:none" 3743 ) 3744 3745 with self.assertRaises(RuntimeError): 3746 torch.cuda.memory._set_allocator_settings( 3747 "pinned_num_register_threads:1024" 3748 ) 3749 3750 @parametrize("max_split_size_mb_setting", [False, True]) 3751 def test_raises_oom(self, max_split_size_mb_setting): 3752 if max_split_size_mb_setting: 3753 # CudaCachingAllocator does early return when searching available blocks 3754 # if max_split_size_mb is not set 3755 # Setting this triggers more parts of the code 3756 torch.cuda.memory._set_allocator_settings("max_split_size_mb:1024") 3757 torch.cuda.memory.empty_cache() 3758 with self.assertRaises(torch.cuda.OutOfMemoryError): 3759 torch.empty(1024 * 1024 * 1024 * 1024, device="cuda") 3760 3761 @unittest.skipIf( 3762 not (IS_LINUX and os.uname().machine == "x86_64"), "cpp traces only on linux" 3763 ) 3764 @unittest.skipIf( 3765 TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3766 ) 3767 def test_cpp_memory_snapshot_pickle(self): 3768 from torch.utils.cpp_extension import load_inline 3769 3770 source = """ 3771 #include <torch/csrc/cuda/memory_snapshot.h> 3772 py::object do_snapshot() { 3773 std::string data = torch::cuda::_memory_snapshot_pickled(); 3774 return py::bytes(data); 3775 } 3776 void record(bool e, bool ctx) { 3777 torch::cuda::_record_memory_history(e, ctx, 10, ctx, ctx); 3778 } 3779 """ 3780 m = load_inline( 3781 name="snapshot", cpp_sources=[source], functions=["do_snapshot", "record"] 3782 ) 3783 for ctx in (False, True): 3784 try: 3785 m.record(True, ctx) 3786 3787 @torch.jit.script 3788 def the_script_fn(): 3789 return torch.rand(311, 411, device="cuda") 3790 3791 def run(): 3792 t = the_script_fn() 3793 return pickle.loads(m.do_snapshot()) 3794 3795 mem = run() 3796 found = False 3797 for s in mem["segments"]: 3798 for b in s["blocks"]: 3799 if b["state"] == "active_allocated": 3800 if b["requested_size"] == 311 * 411 * 4: 3801 if ctx: 3802 frame_text = str(b["frames"]) 3803 # C++ frame 3804 self.assertTrue("::rand" in frame_text) 3805 # script frame 3806 self.assertTrue("the_script_fn" in frame_text) 3807 # python frame 3808 self.assertTrue("case.py" in frame_text) 3809 found = True 3810 last_action = mem["device_traces"][0][-1] 3811 self.assertTrue(last_action["action"] == "alloc") 3812 self.assertTrue(last_action["size"] == 311 * 411 * 4) 3813 self.assertTrue(found) 3814 finally: 3815 m.record(False, False) 3816 3817 @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled") 3818 def test_notifies_oom(self): 3819 x = False 3820 3821 def cb(device, alloc, device_alloc, device_free): 3822 nonlocal x 3823 x = True 3824 3825 torch._C._cuda_attach_out_of_memory_observer(cb) 3826 with self.assertRaises(torch.cuda.OutOfMemoryError): 3827 torch.empty(1024 * 1024 * 1024 * 1024, device="cuda") 3828 self.assertTrue(x) 3829 3830 def test_allocator_fuzz(self): 3831 # fuzz 3832 state = random.getstate() 3833 random.seed(123) 3834 N = 10000 3835 try: 3836 mem = [] 3837 total = 0 3838 c = 0 3839 3840 def alloc(): 3841 nonlocal total, c 3842 b = random.randrange(2 * 1024 * 1024 // 4, 20 * 1024 * 1024 // 4) 3843 mem.append((c, torch.full((b,), c, dtype=torch.int32, device="cuda"))) 3844 c += 1 3845 total += b 3846 3847 def free(): 3848 nonlocal total 3849 idx = random.randrange(0, len(mem)) 3850 v, x = mem.pop(idx) 3851 assert torch.all(v == x) 3852 total -= x.numel() 3853 3854 choices = [alloc, free, torch.cuda.memory.empty_cache] 3855 for i in range(N): 3856 while total >= 1024 * 1024 * 1024 / (4 * 10): 3857 free() 3858 (action,) = random.choices(choices, weights=[1, 1 if mem else 0, 0.1]) 3859 action() 3860 finally: 3861 random.setstate(state) 3862 3863 @unittest.skipIf(TEST_PYNVML, "pynvml is not available") 3864 def test_nvml_get_handler(self): 3865 if not torch.version.hip: 3866 self.assertTrue(torch.cuda._get_pynvml_handler() is not None) 3867 else: 3868 self.assertTrue(torch.cuda._get_amdsmi_handler() is not None) 3869 3870 @unittest.skipIf(TEST_PYNVML, "pynvml is not available") 3871 def test_temperature(self): 3872 self.assertTrue(0 <= torch.cuda.temperature() <= 150) 3873 3874 @unittest.skipIf(TEST_PYNVML, "pynvml is not available") 3875 def test_power_draw(self): 3876 self.assertTrue(torch.cuda.power_draw() >= 0) 3877 3878 @unittest.skipIf(TEST_PYNVML, "pynvml is not available") 3879 def test_clock_speed(self): 3880 self.assertTrue(torch.cuda.clock_rate() >= 0) 3881 3882 3883MIN_BLOCK_SIZE = 512 3884SMALL_SIZE = 1048576 3885SMALL_BUFFER = 2097152 3886LARGE_BUFFER = 20971520 3887 3888 3889def get_cudagraph_segments(pool_id): 3890 segments = torch.cuda.memory_snapshot() 3891 return [segment for segment in segments if segment["segment_pool_id"] == pool_id] 3892 3893 3894def get_all_cudagraph_segments(): 3895 segments = torch.cuda.memory_snapshot() 3896 return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)] 3897 3898 3899def cudagraphify(fn, inputs, pool=None): 3900 if not TEST_CUDA_GRAPH: 3901 raise unittest.SkipTest("cuda graph test is skipped") 3902 3903 torch.cuda.synchronize() 3904 stream = torch.cuda.Stream() 3905 stream.wait_stream(torch.cuda.current_stream()) 3906 with torch.cuda.stream(stream): 3907 fn(*inputs) 3908 stream.synchronize() 3909 torch.cuda.current_stream().wait_stream(stream) 3910 torch.cuda.synchronize() 3911 3912 graph = torch.cuda.CUDAGraph() 3913 with torch.cuda.graph(graph, stream=stream, pool=pool): 3914 static_outputs = fn(*inputs) 3915 3916 return graph, static_outputs 3917 3918 3919def int8_cuda(size): 3920 return torch.ones([size], device="cuda", dtype=torch.uint8) 3921 3922 3923def live_blocks(pool_id): 3924 blocks = 0 3925 seg = get_cudagraph_segments(pool_id) 3926 for segment in get_cudagraph_segments(pool_id): 3927 for block in segment["blocks"]: 3928 blocks += block["state"] == "active_allocated" 3929 return blocks 3930 3931 3932def tensor_metadata(x): 3933 return { 3934 "nbytes": x.untyped_storage().nbytes(), 3935 "data_ptr": x.untyped_storage().data_ptr(), 3936 "size": x.shape, 3937 "stride": x.stride(), 3938 "dtype": x.dtype, 3939 "device": x.device, 3940 "storage_offset": x.storage_offset(), 3941 } 3942 3943 3944def reconstruct_from_tensor_metadata(metadata): 3945 s = torch._C._construct_storage_from_data_pointer( 3946 metadata["data_ptr"], metadata["device"], metadata["nbytes"] 3947 ) 3948 t = torch.empty([0], device=metadata["device"], dtype=metadata["dtype"]) 3949 t.set_( 3950 source=s, 3951 storage_offset=metadata["storage_offset"], 3952 size=metadata["size"], 3953 stride=metadata["stride"], 3954 ) 3955 return t 3956 3957 3958@unittest.skipIf(not TEST_CUDA or TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI") 3959@torch.testing._internal.common_utils.markDynamoStrictTest 3960class TestBlockStateAbsorption(TestCase): 3961 @property 3962 def expandable_segments(self): 3963 return EXPANDABLE_SEGMENTS 3964 3965 def checkCheckpointedBlock(self, before_block, after_block): 3966 for field in ("size", "state"): 3967 self.assertEqual(before_block[field], after_block[field]) 3968 3969 def checkCheckpointedState(self, before_segments, after_segments): 3970 # after may contain additional segments, but all of the segments in before 3971 # should be exactly equivalent to after 3972 after_ptr_to_segment = { 3973 segment["address"]: segment for segment in after_segments 3974 } 3975 3976 for before_segment in before_segments: 3977 self.assertTrue(before_segment["address"] in after_ptr_to_segment) 3978 after_segment = after_ptr_to_segment[before_segment["address"]] 3979 3980 for field in ( 3981 "device", 3982 "total_size", 3983 "allocated_size", 3984 "active_size", 3985 "segment_type", 3986 "segment_pool_id", 3987 ): 3988 self.assertEqual(before_segment[field], after_segment[field]) 3989 3990 self.assertEqual( 3991 len(before_segment["blocks"]), len(after_segment["blocks"]) 3992 ) 3993 for before_block, after_block in zip( 3994 before_segment["blocks"], after_segment["blocks"] 3995 ): 3996 self.checkCheckpointedBlock(before_block, after_block) 3997 3998 @staticmethod 3999 def setCheckpointPoolState( 4000 device, state, stale_storages_ptr, storages_deleters=None 4001 ): 4002 stale_storages_ptr = [t.untyped_storage()._cdata for t in stale_storages_ptr] 4003 storages_deleters = ( 4004 [] 4005 if not storages_deleters 4006 else [t.untyped_storage()._cdata for t in storages_deleters] 4007 ) 4008 torch._C._cuda_setCheckpointPoolState( 4009 device, state, stale_storages_ptr, storages_deleters 4010 ) 4011 4012 def checkFunction(self, fn, inputs, pool=None): 4013 graph, outputs = cudagraphify(fn, inputs, pool=pool) 4014 4015 pool_id = graph.pool() 4016 device = outputs[0].device.index 4017 4018 segments_before_checkpoint = get_cudagraph_segments(pool_id) 4019 4020 state = torch._C._cuda_getCheckpointState(device, pool_id) 4021 self.setCheckpointPoolState(device, state, [], []) 4022 4023 self.checkCheckpointedState( 4024 segments_before_checkpoint, get_cudagraph_segments(pool_id) 4025 ) 4026 4027 def setUp(self): 4028 super().setUp() 4029 self.segment_length = len(get_all_cudagraph_segments()) 4030 4031 def tearDown(self): 4032 torch.cuda.synchronize() 4033 gc.collect() 4034 torch.cuda.empty_cache() 4035 4036 self.assertEqual(len(get_all_cudagraph_segments()), self.segment_length) 4037 4038 super().tearDown() 4039 4040 def test_simple(self): 4041 def foo(): 4042 x = torch.zeros([SMALL_SIZE * 8], device="cuda", dtype=torch.uint8) 4043 x = x + x 4044 x1 = int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) 4045 y = int8_cuda(SMALL_SIZE) + x1 4046 z = int8_cuda(SMALL_SIZE) 4047 return x, y, z 4048 4049 self.checkFunction(foo, []) 4050 4051 def test_allocated_in_middle_of_segment(self): 4052 def foo(): 4053 small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] 4054 return small_buffers[5].add_(2) 4055 4056 self.checkFunction(foo, []) 4057 4058 def test_multiple_middle_allocations(self): 4059 def foo(): 4060 small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] 4061 return small_buffers[5], small_buffers[8] 4062 4063 self.checkFunction(foo, []) 4064 4065 def test_middle_allocations_contiguous(self): 4066 def foo(): 4067 small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] 4068 return small_buffers[5], small_buffers[6] 4069 4070 self.checkFunction(foo, []) 4071 4072 def test_additional_free_following_checkpoint(self): 4073 def foo(): 4074 return (int8_cuda(MIN_BLOCK_SIZE),) 4075 4076 def foo2(): 4077 return (int8_cuda(MIN_BLOCK_SIZE),) 4078 4079 graph, outputs = cudagraphify(foo, []) 4080 pool_id = graph.pool() 4081 4082 segments_before_checkpoint = get_cudagraph_segments(pool_id) 4083 4084 state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) 4085 4086 graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool()) 4087 4088 self.setCheckpointPoolState(outputs[0].device.index, state, outputs2, []) 4089 4090 del outputs2 4091 4092 self.checkCheckpointedState( 4093 segments_before_checkpoint, get_cudagraph_segments(pool_id) 4094 ) 4095 4096 # TODO: re-enable 4097 # def test_additional_free_error(self): 4098 # def foo(): 4099 # return int8_cuda(MIN_BLOCK_SIZE), 4100 4101 # def foo2(): 4102 # return int8_cuda(MIN_BLOCK_SIZE), 4103 4104 # graph, outputs = cudagraphify(foo, []) 4105 # pool_id = graph.pool() 4106 4107 # segments_before_checkpoint = get_cudagraph_segments(pool_id) 4108 4109 # state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) 4110 4111 # graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool()) 4112 # with self.assertRaisesRegex(Exception, "being manually freed must be passed"): 4113 # self.setCheckpointPoolState(outputs[0].device.index, state, [], []) 4114 4115 def test_tensor_dies_after_checkpoint(self): 4116 def foo(): 4117 return int8_cuda(MIN_BLOCK_SIZE), int8_cuda(MIN_BLOCK_SIZE) 4118 4119 graph, outputs = cudagraphify(foo, []) 4120 pool_id = graph.pool() 4121 device = outputs[0].device.index 4122 4123 segments_before_checkpoint = get_cudagraph_segments(pool_id) 4124 state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) 4125 4126 output_data_ptrs = [output.data_ptr() for output in outputs] 4127 4128 del outputs 4129 4130 self.setCheckpointPoolState(device, state, [], []) 4131 4132 self.assertEqual(live_blocks(pool_id), 2) 4133 torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[0]) 4134 self.assertEqual(live_blocks(pool_id), 1) 4135 torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[1]) 4136 self.assertEqual(live_blocks(pool_id), 0) 4137 4138 def test_assigning_back_deleter_fns_to_tensor(self): 4139 def foo(x): 4140 return ( 4141 int8_cuda(SMALL_BUFFER) + x, 4142 int8_cuda(SMALL_BUFFER) + x, 4143 int8_cuda(LARGE_BUFFER) + x, 4144 ) 4145 4146 inp = torch.tensor([1], device="cuda") 4147 graph, outputs = cudagraphify(foo, [inp]) 4148 pool_id = graph.pool() 4149 graph.replay() 4150 4151 device = outputs[0].device.index 4152 4153 for i in range(len(outputs)): 4154 self.assertTrue(outputs[i].mean(dtype=torch.float) == 2) 4155 4156 state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) 4157 4158 output_ptrs = [output.untyped_storage().data_ptr() for output in outputs] 4159 ten_metadata = [tensor_metadata(t) for t in outputs] 4160 4161 self.assertEqual(live_blocks(pool_id), 3) 4162 4163 del outputs 4164 4165 self.assertEqual(live_blocks(pool_id), 0) 4166 4167 reconstructed_tensors = [ 4168 reconstruct_from_tensor_metadata(metadata) for metadata in ten_metadata 4169 ] 4170 4171 for i in range(len(reconstructed_tensors)): 4172 self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 2) 4173 4174 inp.add_(1) 4175 graph.replay() 4176 4177 for i in range(len(reconstructed_tensors)): 4178 self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 3) 4179 4180 self.setCheckpointPoolState( 4181 device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]] 4182 ) 4183 4184 self.assertEqual(live_blocks(pool_id), 3) 4185 4186 reconstructed_tensors[0] = None 4187 self.assertEqual(live_blocks(pool_id), 2) 4188 4189 reconstructed_tensors[1] = None 4190 self.assertEqual(live_blocks(pool_id), 1) 4191 4192 # should not change, we did not pass it in to swap data ptrs 4193 reconstructed_tensors[2] = None 4194 self.assertEqual(live_blocks(pool_id), 1) 4195 4196 torch._C._cuda_cudaCachingAllocator_raw_delete(output_ptrs[2]) 4197 4198 self.assertEqual(live_blocks(pool_id), 0) 4199 4200 @skipIfNoTorchVision 4201 def test_resnet(self): 4202 import torchvision 4203 4204 m = torchvision.models.resnet50() 4205 m.eval() 4206 m = m.cuda() 4207 4208 inp = torch.rand([1, 3, 255, 255], device="cuda") 4209 self.checkFunction(m, [inp]) 4210 4211 def test_check_pool_live_allocations(self): 4212 def foo(): 4213 return torch.ones([4], device="cuda") 4214 4215 pool = torch.cuda.graph_pool_handle() 4216 graph, outputs = cudagraphify(foo, [], pool=pool) 4217 4218 index = outputs[0].device.index 4219 4220 def check(live_dps): 4221 return torch._C._cuda_checkPoolLiveAllocations(index, pool, live_dps) 4222 4223 self.assertTrue(check({outputs[0].data_ptr()})) 4224 4225 self.assertFalse(check({outputs[0].data_ptr(), 0})) 4226 self.assertFalse(check(set())) 4227 4228 del outputs 4229 self.assertTrue(check(set())) 4230 4231 def test_allocate_in_thread_to_pool(self): 4232 def foo(): 4233 return torch.rand([4], device="cuda") 4234 4235 pool = torch.cuda.graph_pool_handle() 4236 graph, outputs = cudagraphify(foo, [], pool=pool) 4237 device = outputs[0].device.index 4238 del outputs 4239 4240 @contextlib.contextmanager 4241 def _use_cuda_memory_pool_manager(device, mem_pool): 4242 """ 4243 Context manager to use cuda graph pool for new allocations. If you use this manager 4244 all cudagraph tensors in use should be reflected in the allocator or they will be overwritten. 4245 existing_graph should already have been used in a capture, and the mem_pool must already exist. 4246 """ 4247 torch.cuda.synchronize() 4248 stream = torch.cuda.Stream() 4249 stream.wait_stream(torch.cuda.current_stream()) 4250 stream_context = torch.cuda.stream(stream) 4251 stream_context.__enter__() 4252 torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool) 4253 try: 4254 yield 4255 finally: 4256 torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool) 4257 torch._C._cuda_releasePool(device, mem_pool) 4258 stream_context.__exit__(None, None, None) 4259 4260 segments = get_cudagraph_segments(pool) 4261 self.assertEqual(len(get_cudagraph_segments(pool)), 1) 4262 4263 def use_pool(): 4264 def alloc_three(): 4265 a = int8_cuda(LARGE_BUFFER) 4266 b = int8_cuda(LARGE_BUFFER) 4267 c = a + b 4268 4269 with _use_cuda_memory_pool_manager(device, pool): 4270 # three allocations 4271 for _ in range(10): 4272 alloc_three() 4273 4274 # three more allocations not in pool 4275 alloc_three() 4276 4277 def no_pool(): 4278 # two allocations 4279 for _ in range(10): 4280 a = int8_cuda(LARGE_BUFFER) 4281 b = int8_cuda(LARGE_BUFFER) 4282 del a, b 4283 4284 graph_thread = threading.Thread(target=use_pool) 4285 no_graph_thread = threading.Thread(target=no_pool) 4286 graph_thread.start() 4287 no_graph_thread.start() 4288 4289 graph_thread.join() 4290 no_graph_thread.join() 4291 4292 self.assertEqual( 4293 len(get_cudagraph_segments(pool)), 2 if self.expandable_segments else 4 4294 ) 4295 4296 del graph 4297 4298 torch.cuda.synchronize() 4299 gc.collect() 4300 torch.cuda.empty_cache() 4301 4302 self.assertEqual(len(get_cudagraph_segments(pool)), 0) 4303 4304 def test_no_triton_on_import(self): 4305 """Test that Trition is not imported on first GPU use""" 4306 script = "import sys; import torch; torch.rand(2, device='cuda'); print('triton' in sys.modules)" 4307 4308 rc = ( 4309 subprocess.check_output( 4310 [sys.executable, "-c", script], 4311 # On Windows, opening the subprocess with the default CWD makes `import torch` 4312 # fail, so just set CWD to this script's directory 4313 cwd=os.path.dirname(os.path.realpath(__file__)), 4314 ) 4315 .strip() 4316 .decode("ascii") 4317 ) 4318 self.assertEqual(rc, "False", "Triton was imported when importing torch!") 4319 4320 4321@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 4322class TestMemPool(TestCase): 4323 def test_mempool_id(self): 4324 pool1 = torch.cuda.graph_pool_handle() 4325 pool2 = torch.cuda.MemPool().id 4326 4327 # first value of id in a user created pool is always zero 4328 self.assertEqual(pool1[0] == 0, pool2[0] == 0) 4329 4330 # each call to torch.cuda.graph_pool_handle() or torch.cuda.MemPool() 4331 # increments the id 4332 self.assertTrue(abs(pool2[1] - pool1[1]) > 0) 4333 4334 def test_mempool_with_allocator(self): 4335 pool = torch.cuda.MemPool() 4336 4337 # MemPool doesn't have an allocator by default 4338 self.assertEqual(pool.allocator, None) 4339 4340 from torch.utils.cpp_extension import load_inline 4341 4342 dummy_allocator_source = """ 4343 #include <torch/extension.h> 4344 #include <ATen/cuda/Exceptions.h> 4345 #include <cuda_runtime_api.h> 4346 4347 extern "C" { 4348 C10_EXPORT int called_dummy_alloc = 0; 4349 C10_EXPORT int called_dummy_free = 0; 4350 4351 // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865 4352 C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { 4353 called_dummy_alloc = 123; 4354 void* ptr; 4355 C10_CUDA_CHECK(cudaMallocManaged(&ptr, size)); 4356 return ptr; 4357 } 4358 4359 C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) { 4360 called_dummy_free = 321; 4361 C10_CUDA_CHECK(cudaFree(ptr)); 4362 } 4363 } 4364 """ 4365 dummy_allocator_libname = "dummy_allocator" 4366 dummy_allocator = load_inline( 4367 name=dummy_allocator_libname, 4368 cpp_sources=dummy_allocator_source, 4369 is_python_module=False, 4370 keep_intermediates=False, 4371 verbose=True, 4372 with_cuda=True, 4373 ) 4374 allocator = torch.cuda.memory.CUDAPluggableAllocator( 4375 dummy_allocator, 4376 "dummy_alloc", 4377 "dummy_free", 4378 ) 4379 pool = torch.cuda.MemPool(allocator.allocator()) 4380 4381 # pool should point to the same allocator as the one passed into it 4382 self.assertEqual(allocator.allocator(), pool.allocator) 4383 4384 # no allocations happened yet, so called_dummy_alloc should be 0 4385 alloc_lib = ctypes.CDLL(dummy_allocator) 4386 called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc") 4387 self.assertEqual(called_dummy_alloc.value, 0) 4388 4389 with torch.cuda.use_mem_pool(pool): 4390 out = torch.randn(1, device="cuda") 4391 4392 # called_dummy_alloc should be 123 if dummy_alloc was used to allocate 4393 # out tensor 4394 self.assertEqual(called_dummy_alloc.value, 123) 4395 4396 def test_mempool_context(self): 4397 active_pool = torch.cuda.MemPoolContext.active_pool() 4398 4399 # there is no active pool if none was made active 4400 self.assertEqual(active_pool, None) 4401 4402 pool = torch.cuda.MemPool() 4403 ctx = torch.cuda.MemPoolContext(pool) 4404 active_pool = torch.cuda.MemPoolContext.active_pool() 4405 4406 # pool was made active 4407 self.assertEqual(active_pool, pool) 4408 4409 del ctx 4410 active_pool = torch.cuda.MemPoolContext.active_pool() 4411 4412 # ctx was deleted, so active pool is the previous one 4413 self.assertEqual(active_pool, None) 4414 4415 def test_mempool_multithread(self): 4416 pool_ids = [] 4417 active_pool_ids = [] 4418 4419 def create_mempool_and_make_active(): 4420 pool = torch.cuda.MemPool() 4421 pool_ids.extend([pool.id]) 4422 4423 ctx = torch.cuda.MemPoolContext(pool) 4424 active_pool = torch.cuda.MemPoolContext.active_pool() 4425 active_pool_ids.extend([active_pool.id]) 4426 del ctx 4427 4428 num_threads = 4 4429 threads = [ 4430 threading.Thread(target=create_mempool_and_make_active) 4431 for t in range(num_threads) 4432 ] 4433 for thread in threads: 4434 thread.start() 4435 for thread in threads: 4436 thread.join() 4437 4438 # each thread should create a unique mempool, since 4439 # mempool id creation is atomic 4440 self.assertEqual(len(set(pool_ids)), 4) 4441 4442 # each thread should have different active mempool, since 4443 # the pointer to the mempool is thread local 4444 self.assertEqual(len(set(active_pool_ids)), 4) 4445 4446 4447@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 4448@torch.testing._internal.common_utils.markDynamoStrictTest 4449class TestCudaOptims(TestCase): 4450 # These tests will be instantiate with instantiate_device_type_tests 4451 # to apply the new OptimizerInfo structure. 4452 4453 @onlyCUDA 4454 @unittest.skipIf( 4455 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs" 4456 ) 4457 @optims( 4458 [optim for optim in optim_db if optim.has_capturable_arg], 4459 dtypes=[torch.float32], 4460 ) 4461 def test_graph_optims(self, device, dtype, optim_info): 4462 optim_cls = optim_info.optim_cls 4463 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 4464 device, dtype, optim_info, skip=("differentiable",) 4465 ) 4466 4467 steps_warmup = 3 4468 steps_train = 2 4469 4470 for optim_input in all_optim_inputs: 4471 kwargs = optim_input.kwargs 4472 4473 # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam 4474 # and torch.optim.adamw 4475 kwargs["lr"] = 0.1 4476 4477 for actually_do_graphs in (True, False): 4478 params = [ 4479 torch.randn((i + 5, i + 5), device=device) for i in range(2) 4480 ] + [torch.randn((), device=device)] 4481 params_control = [p.clone().requires_grad_() for p in params] 4482 params_graphed = [p.clone().requires_grad_() for p in params] 4483 4484 grads = [ 4485 [torch.randn_like(p) for p in params] 4486 for _ in range(steps_warmup + steps_train) 4487 ] 4488 4489 # Control (capturable=False) 4490 kwargs["capturable"] = False 4491 4492 opt = optim_cls(params_control, **kwargs) 4493 for i in range(steps_warmup + steps_train): 4494 for j, p in enumerate(params_control): 4495 p.grad = grads[i][j] 4496 opt.step() 4497 4498 # capturable=True 4499 kwargs["capturable"] = True 4500 opt = optim_cls(params_graphed, **kwargs) 4501 4502 for i in range(steps_warmup): 4503 for j, p in enumerate(params_graphed): 4504 p.grad = grads[i][j] 4505 opt.step() 4506 4507 if actually_do_graphs: 4508 g = torch.cuda.CUDAGraph() 4509 with torch.cuda.graph(g): 4510 opt.step() 4511 4512 for i in range(steps_train): 4513 if actually_do_graphs: 4514 for j, p in enumerate(params_graphed): 4515 p.grad.copy_(grads[i + steps_warmup][j]) 4516 g.replay() 4517 else: 4518 # Passing capturable=True to the constructor and running without graphs should still be 4519 # numerically correct, even if it's not ideal for performance. 4520 for j, p in enumerate(params_graphed): 4521 p.grad = grads[i + steps_warmup][j] 4522 opt.step() 4523 4524 for p_control, p_graphed in zip(params_control, params_graphed): 4525 self.assertEqual(p_control, p_graphed) 4526 4527 @onlyCUDA 4528 @unittest.skipIf( 4529 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 4530 ) 4531 @optims( 4532 [ 4533 optim 4534 for optim in optim_db 4535 if "fused" in optim.supported_impls and "cuda" in optim.supports_fused_on 4536 ], 4537 dtypes=[torch.float32], 4538 ) 4539 def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info): 4540 optim_cls = optim_info.optim_cls 4541 4542 steps_warmup = 3 4543 steps_train = 2 4544 4545 optim_inputs = optim_info.optim_inputs_func(device=device) 4546 4547 for optim_input in optim_inputs: 4548 kwargs = optim_input.kwargs 4549 kwargs["fused"] = True 4550 4551 for actually_do_graphs in ( 4552 (True, False) if optim_info.has_capturable_arg else (True,) 4553 ): 4554 params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)] 4555 params_control = [p.clone().requires_grad_() for p in params] 4556 params_graphed = [p.clone().requires_grad_() for p in params] 4557 4558 # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. 4559 grads = [ 4560 [torch.randn_like(p) for p in params] 4561 for _ in range(steps_warmup + steps_train) 4562 ] 4563 with torch.no_grad(): 4564 grads_control = [[g.clone() for g in gs] for gs in grads] 4565 grads_graphed = [[g.clone() for g in gs] for gs in grads] 4566 4567 # Gradient Scaler 4568 scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) 4569 with torch.no_grad(): 4570 scaler_for_control._lazy_init_scale_growth_tracker(device) 4571 4572 scaler_for_graphed = torch.cuda.amp.GradScaler() 4573 scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) 4574 with torch.no_grad(): 4575 scaler_for_graphed._lazy_init_scale_growth_tracker(device) 4576 4577 # Control (capturable=False) 4578 if optim_info.has_capturable_arg: 4579 kwargs["capturable"] = False 4580 opt = optim_cls(params_control, **kwargs) 4581 4582 for i in range(steps_warmup + steps_train): 4583 for j, p in enumerate(params_control): 4584 p.grad = grads_control[i][j] 4585 scaler_for_control.step(opt) 4586 scaler_for_control.update() 4587 4588 # capturable=True 4589 if optim_info.has_capturable_arg: 4590 kwargs["capturable"] = True 4591 opt = optim_cls(params_graphed, **kwargs) 4592 4593 for i in range(steps_warmup): 4594 for j, p in enumerate(params_graphed): 4595 p.grad = grads_graphed[i][j] 4596 scaler_for_graphed.step(opt) 4597 scaler_for_graphed.update() 4598 4599 if actually_do_graphs: 4600 g = torch.cuda.CUDAGraph() 4601 with torch.cuda.graph(g): 4602 scaler_for_graphed.step(opt) 4603 scaler_for_graphed.update() 4604 4605 for i in range(steps_train): 4606 if actually_do_graphs: 4607 for j, p in enumerate(params_graphed): 4608 p.grad.copy_(grads_graphed[i + steps_warmup][j]) 4609 g.replay() 4610 else: 4611 # Passing capturable=True to the constructor and running without graphs should still be 4612 # numerically correct, even if it's not ideal for performance. 4613 for j, p in enumerate(params_graphed): 4614 p.grad = grads_graphed[i + steps_warmup][j] 4615 scaler_for_graphed.step(opt) 4616 scaler_for_graphed.update() 4617 4618 for p_control, p_graphed in zip(params_control, params_graphed): 4619 self.assertEqual(p_control, p_graphed) 4620 4621 @onlyNativeDeviceTypes 4622 @optims( 4623 [optim for optim in optim_db if "fused" in optim.supported_impls], 4624 dtypes=[torch.float32], 4625 ) 4626 def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info): 4627 device = device.split(":")[0] 4628 if device not in optim_info.supports_fused_on: 4629 self.skipTest( 4630 f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" 4631 ) 4632 optim_inputs = optim_info.optim_inputs_func(device=device) 4633 optim_cls = optim_info.optim_cls 4634 for optim_input in optim_inputs: 4635 for _separate_unscale in (True, False): 4636 kwargs = optim_input.kwargs 4637 kwargs["fused"] = True 4638 torch.manual_seed(20) 4639 ( 4640 mod_control, 4641 mod_scaling, 4642 opt_control, 4643 opt_scaling, 4644 data, 4645 loss_fn, 4646 _, 4647 ) = _create_scaling_case( 4648 optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, device=device 4649 ) 4650 optimizer_kwargs = deepcopy(kwargs) 4651 optimizer_kwargs["fused"] = False 4652 if "lr" not in kwargs: 4653 # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr 4654 optimizer_kwargs["lr"] = 1.0 4655 opt_control = optim_cls(mod_control.parameters(), **optimizer_kwargs) 4656 scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0) 4657 scaler_control = torch.amp.GradScaler(device, init_scale=128.0) 4658 tracker = TensorTracker() 4659 for input, target in data: 4660 opt_control.zero_grad() 4661 with torch.autocast(device_type=device, dtype=torch.half): 4662 output_control = mod_control(input) 4663 loss_control = loss_fn(output_control, target) 4664 scaler_control.scale(loss_control).backward() 4665 scaler_control.step(opt_control) 4666 scaler_control.update() 4667 4668 opt_scaling.zero_grad() 4669 with torch.autocast(device_type=device, dtype=torch.half): 4670 output_scaling = mod_scaling(input) 4671 loss_scaling = loss_fn(output_scaling, target) 4672 scaler_scaling.scale(loss_scaling).backward() 4673 if _separate_unscale: 4674 scaler_scaling.unscale_(opt_scaling) 4675 scaler_scaling.step(opt_scaling) 4676 scaler_scaling.update() 4677 4678 tracker.add(loss_control) 4679 tracker.pop_check_set(loss_scaling, self) 4680 for param_control, param_scaling in zip( 4681 mod_control.parameters(), mod_scaling.parameters() 4682 ): 4683 tracker.add(param_control.grad) 4684 tracker.pop_check_set(param_scaling.grad, self) 4685 tracker.add(param_control) 4686 tracker.pop_check_set(param_scaling, self) 4687 4688 state_control, state_scaling = ( 4689 opt_control.state[param_control], 4690 opt_scaling.state[param_scaling], 4691 ) 4692 4693 for k in state_control: 4694 actual = state_scaling[k] 4695 if k == "step": 4696 actual = actual.squeeze() 4697 tracker.add(state_control[k]) 4698 tracker.pop_check_set(actual, self) 4699 4700 @onlyCUDA 4701 @parametrize("in_place_unscale", [False, True]) 4702 @optims( 4703 [optim for optim in optim_db if "cuda" in optim.supports_fused_on], 4704 dtypes=[torch.float32], 4705 ) 4706 def test_grad_scaler_with_preset_grad_scale( 4707 self, device, dtype, optim_info, in_place_unscale 4708 ): 4709 weight = torch.ones((5, 5), device="cuda", requires_grad=True) 4710 weight.grad = torch.full_like(weight, fill_value=15) 4711 opt = optim_info.optim_cls([weight], lr=0.1, fused=True) 4712 scaler = torch.amp.GradScaler(init_scale=5) 4713 4714 # simulate scaling a loss 4715 scaler.scale(torch.ones(5)) 4716 4717 if in_place_unscale: 4718 scaler.unscale_(opt) 4719 # the gradient should have been divided in-place 4720 self.assertEqual(weight.grad, torch.full_like(weight, fill_value=3)) 4721 4722 # the user sets a `grad_scale` value which should be fused with the optimizer step 4723 opt.grad_scale = torch.Tensor([3]).cuda() 4724 scaler.step(opt) 4725 4726 # check that the user's grad_scale was respected (i.e. the gradient was divided by 5 * 3) 4727 self.assertEqual(weight.grad, torch.full_like(weight, fill_value=1)) 4728 4729 @onlyCUDA 4730 @unittest.skipIf( 4731 not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 4732 ) 4733 @parametrize("foreach, fused", [(False, False), (True, False), (False, True)]) 4734 @optims( 4735 [ 4736 optim 4737 for optim in optim_db 4738 if "foreach" in optim.supported_impls and "cuda" in optim.supports_fused_on 4739 ], 4740 dtypes=[torch.float32], 4741 ) 4742 def test_graph_grad_scaling(self, device, dtype, optim_info, foreach, fused): 4743 torch.cuda.empty_cache() 4744 4745 scaler = torch.amp.GradScaler(device="cuda", init_scale=4.0) 4746 g = torch.cuda.CUDAGraph() 4747 s = torch.cuda.Stream() 4748 4749 weight = torch.ones((100,), device="cuda", requires_grad=True) 4750 opt = optim_info.optim_cls([weight], lr=0.1, foreach=foreach, fused=fused) 4751 static_input = torch.ones_like(weight) 4752 static_grad = torch.ones_like(weight) 4753 4754 # warmup 4755 s = torch.cuda.Stream() 4756 s.wait_stream(torch.cuda.current_stream()) 4757 with torch.cuda.stream(s): 4758 loss = (weight.half() * static_input).sum() 4759 scaler.scale(loss).backward() 4760 torch.cuda.current_stream().wait_stream(s) 4761 4762 opt.zero_grad(set_to_none=True) 4763 4764 # capture 4765 with torch.cuda.stream(s): 4766 g.capture_begin() 4767 loss = (weight.half() * static_input).sum() 4768 scaler.scale(loss).backward() 4769 g.capture_end() 4770 4771 input_vals = [5, 20000, 5, 40000] 4772 # If the scale gets updated properly, these are the scale, growth tracker, 4773 # and grad values we expect. 4774 expected_scales = [4, 2, 2, 1] 4775 expected_growth_trackers = [1, 0, 1, 0] 4776 expected_grad_vals = [5 * 4, float("inf"), 5 * 2, float("inf")] 4777 4778 for data, scale, growth_tracker, grad_val in zip( 4779 input_vals, expected_scales, expected_growth_trackers, expected_grad_vals 4780 ): 4781 static_input.fill_(data) 4782 g.replay() 4783 self.assertEqual(weight.grad, torch.full_like(weight.grad, grad_val)) 4784 scaler.step(opt) 4785 scaler.update() 4786 self.assertEqual(scaler._scale, scale) 4787 self.assertEqual(scaler._growth_tracker, growth_tracker) 4788 4789 4790@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 4791class TestGDS(TestCase): 4792 def _get_tmp_dir_fs_type(self): 4793 my_path = os.path.realpath("/tmp") 4794 root_type = "" 4795 for part in psutil.disk_partitions(): 4796 if part.mountpoint == "/": 4797 root_type = part.fstype 4798 continue 4799 if part.mountpoint == my_path: 4800 return part.fstype 4801 return root_type 4802 4803 @unittest.skip("Disabling as USE_CUFILE=0 by default in builds") 4804 def test_gds_read_write_tensors(self): 4805 if self._get_tmp_dir_fs_type() not in ("ext4", "xfs"): 4806 self.skipTest("GPUDirect Storage requires ext4/xfs for local filesystem") 4807 src1 = torch.randn(1024, device="cuda") 4808 src2 = torch.randn(2, 1024, device="cuda") 4809 torch.cuda.gds._gds_register_buffer(src1.untyped_storage()) 4810 torch.cuda.gds._gds_register_buffer(src2.untyped_storage()) 4811 dest1 = torch.empty(1024, device="cuda") 4812 dest2 = torch.empty(2, 1024, device="cuda") 4813 with TemporaryFileName() as f: 4814 file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR) 4815 file.save_storage(src1.untyped_storage(), offset=0) 4816 file.save_storage(src2.untyped_storage(), offset=src1.nbytes) 4817 file.load_storage(dest1.untyped_storage(), offset=0) 4818 file.load_storage(dest2.untyped_storage(), offset=src1.nbytes) 4819 self.assertEqual(src1, dest1) 4820 self.assertEqual(src2, dest2) 4821 torch.cuda.gds._gds_deregister_buffer(src1.untyped_storage()) 4822 torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage()) 4823 4824 4825@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 4826class TestCudaAutocast(TestAutocast): 4827 def setUp(self): 4828 super().setUp() 4829 self.autocast_lists = AutocastTestLists(torch.device("cuda:0")) 4830 4831 def tearDown(self): 4832 del self.autocast_lists 4833 super().tearDown() 4834 4835 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4836 def test_autocast_torch_fp16(self): 4837 with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4838 for op_with_args in self.autocast_lists.torch_fp16: 4839 skip_test = False 4840 op, args = op_with_args[0], op_with_args[1] 4841 if len(op_with_args) == 3: 4842 skip_test = op_with_args[2] # TEST_WITH_ROCM 4843 if not skip_test: 4844 self._run_autocast_outofplace( 4845 op, args, torch.float16, device="cuda", amp_dtype=torch.float16 4846 ) 4847 4848 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4849 def test_autocast_torch_bf16(self): 4850 with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4851 for op_with_args in self.autocast_lists.torch_fp16: 4852 skip_test = False 4853 op, args = op_with_args[0], op_with_args[1] 4854 if len(op_with_args) == 3: 4855 skip_test = op_with_args[2] # TEST_WITH_ROCM 4856 should_error_from_cudnn = "cudnn" in op and ( 4857 "TORCH_CUDNN_V8_API_DISABLED" in os.environ 4858 and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"]) 4859 or torch.cuda.get_device_capability() < (8, 0) 4860 ) 4861 should_error_from_not_implemented = should_error_from_cudnn 4862 if not skip_test: 4863 if should_error_from_not_implemented: 4864 with self.assertRaises( 4865 RuntimeError, 4866 msg=str(op) + " should not be supported for bfloat16!", 4867 ): 4868 self._run_autocast_outofplace( 4869 op, args, torch.bfloat16, device="cuda" 4870 ) 4871 else: 4872 if torch.cuda.is_bf16_supported(): 4873 self._run_autocast_outofplace( 4874 op, args, torch.bfloat16, device="cuda" 4875 ) 4876 else: 4877 with self.assertRaisesRegex( 4878 RuntimeError, "Device does not support bfloat16" 4879 ): 4880 self._run_autocast_outofplace( 4881 op, args, torch.bfloat16, device="cuda" 4882 ) 4883 4884 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4885 def test_autocast_torch_fp32(self): 4886 for op_with_args in self.autocast_lists.torch_fp32: 4887 op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 4888 self._run_autocast_outofplace( 4889 op, 4890 args, 4891 torch.float32, 4892 device="cuda", 4893 add_kwargs=maybe_kwargs, 4894 amp_dtype=torch.float16, 4895 ) 4896 4897 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4898 def test_autocast_torch_need_autocast_promote(self): 4899 for op, args in self.autocast_lists.torch_need_autocast_promote: 4900 self._run_autocast_outofplace( 4901 op, args, torch.float32, device="cuda", amp_dtype=torch.float16 4902 ) 4903 4904 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4905 def test_autocast_torch_expect_builtin_promote(self): 4906 for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: 4907 self._run_autocast_outofplace( 4908 op, 4909 args, 4910 torch.float32, 4911 device="cuda", 4912 out_type=out_type, 4913 amp_dtype=torch.float16, 4914 ) 4915 4916 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4917 def test_autocast_nn_fp16(self): 4918 with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4919 for op, args in self.autocast_lists.nn_fp16: 4920 self._run_autocast_outofplace( 4921 op, 4922 args, 4923 torch.float16, 4924 device="cuda", 4925 module=torch._C._nn, 4926 amp_dtype=torch.float16, 4927 ) 4928 4929 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4930 def test_autocast_nn_bf16(self): 4931 with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4932 for op, args in self.autocast_lists.nn_fp16: 4933 if torch.cuda.is_bf16_supported(): 4934 self._run_autocast_outofplace( 4935 op, args, torch.bfloat16, device="cuda", module=torch._C._nn 4936 ) 4937 else: 4938 with self.assertRaisesRegex( 4939 RuntimeError, "Device does not support bfloat16" 4940 ): 4941 self._run_autocast_outofplace( 4942 op, args, torch.bfloat16, device="cuda", module=torch._C._nn 4943 ) 4944 4945 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4946 def test_autocast_nn_fp32(self): 4947 for op, args in self.autocast_lists.nn_fp32: 4948 self._run_autocast_outofplace( 4949 op, 4950 args, 4951 torch.float32, 4952 device="cuda", 4953 module=torch._C._nn, 4954 amp_dtype=torch.float16, 4955 ) 4956 4957 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4958 def test_autocast_linalg_fp16(self): 4959 with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4960 for op, args in self.autocast_lists.linalg_fp16: 4961 self._run_autocast_outofplace( 4962 op, 4963 args, 4964 torch.float16, 4965 device="cuda", 4966 module=torch._C._linalg, 4967 amp_dtype=torch.float16, 4968 ) 4969 4970 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4971 def test_autocast_methods_fp16(self): 4972 with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4973 for op, args in self.autocast_lists.methods_fp16: 4974 self._run_autocast_outofplace( 4975 op, 4976 args, 4977 torch.float16, 4978 device="cuda", 4979 module=None, 4980 amp_dtype=torch.float16, 4981 ) 4982 4983 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4984 def test_autocast_methods_fp32(self): 4985 for op, args in self.autocast_lists.methods_fp32: 4986 self._run_autocast_outofplace( 4987 op, 4988 args, 4989 torch.float32, 4990 device="cuda", 4991 module=None, 4992 amp_dtype=torch.float16, 4993 ) 4994 4995 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4996 def test_autocast_methods_expect_builtin_promote(self): 4997 for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote: 4998 self._run_autocast_outofplace( 4999 op, 5000 args, 5001 torch.float32, 5002 device="cuda", 5003 module=None, 5004 out_type=out_type, 5005 amp_dtype=torch.float16, 5006 ) 5007 5008 def test_autocast_banned(self): 5009 with torch.autocast("cuda"): 5010 for op, args, module in self.autocast_lists.banned: 5011 with self.assertRaises(RuntimeError): 5012 getattr(module, op)(*args) 5013 5014 def test_autocast_ignored_types(self): 5015 with torch.autocast("cuda"): 5016 for ignore_type in (torch.double, torch.int32): 5017 a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") 5018 b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") 5019 c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0") 5020 5021 # Tests if CastPolicy::fp16 ops ignore double and int 5022 # Currently, no ops belonging to this policy support integer inputs. 5023 if ignore_type is torch.double: 5024 with self.assertRaises(RuntimeError): 5025 torch.mm(a_ignore, c_16) 5026 with torch.autocast("cuda", enabled=False): 5027 type_no_autocast = torch.mm(a_ignore, b_ignore).dtype 5028 self.assertTrue( 5029 torch.mm(a_ignore, b_ignore).dtype is type_no_autocast 5030 ) 5031 5032 # Tests if CastPolicy::fp32 ops ignore double and int 5033 with torch.autocast("cuda", enabled=False): 5034 type_no_autocast = torch.pow(a_ignore, 2.0).dtype 5035 self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast) 5036 5037 # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int 5038 with torch.autocast("cuda", enabled=False): 5039 type_no_autocast = torch.sum(a_ignore).dtype 5040 self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast) 5041 5042 # Tests if CastPolicy::fp32_append_dtype ops ignore double and int 5043 # Currently, no ops belonging to this policy support integer inputs. 5044 if ignore_type is torch.double: 5045 with torch.autocast("cuda", enabled=False): 5046 type_no_autocast = torch.norm(a_ignore).dtype 5047 self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast) 5048 5049 def test_autocast_custom_enabled(self): 5050 class MyMM(torch.autograd.Function): 5051 @staticmethod 5052 @torch.amp.custom_fwd(device_type="cuda") 5053 def forward(ctx, a, b): 5054 self.assertTrue(a.dtype is torch.float32) 5055 self.assertTrue(b.dtype is torch.float32) 5056 self.assertTrue(torch.is_autocast_enabled()) 5057 ctx.save_for_backward(a, b) 5058 return a.mm(b) 5059 5060 @staticmethod 5061 @torch.amp.custom_bwd(device_type="cuda") 5062 def backward(ctx, grad): 5063 self.assertTrue(torch.is_autocast_enabled()) 5064 a, b = ctx.saved_tensors 5065 a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad) 5066 self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype) 5067 return a_grad, b_grad 5068 5069 mymm = MyMM.apply 5070 5071 x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) 5072 y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) 5073 5074 dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,) 5075 for dtype in dtypes: 5076 with torch.cuda.amp.autocast(dtype=dtype): 5077 output = mymm(x, y) 5078 self.assertTrue(output.dtype is dtype) 5079 loss = output.sum() 5080 loss.backward() 5081 5082 def test_autocast_custom_cast_inputs(self): 5083 class MyMM(torch.autograd.Function): 5084 @staticmethod 5085 @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) 5086 def forward(ctx, a, container, expect_type): 5087 b = container[1][0] 5088 self.assertTrue(a.dtype is expect_type) 5089 self.assertTrue(b.dtype is expect_type) 5090 self.assertFalse(torch.is_autocast_enabled()) 5091 ctx.save_for_backward(a, b) 5092 return a.mm(b) 5093 5094 @staticmethod 5095 @torch.amp.custom_bwd(device_type="cuda") 5096 def backward(ctx, grad): 5097 self.assertFalse(torch.is_autocast_enabled()) 5098 a, b = ctx.saved_tensors 5099 return grad.mm(b.t()), None, None 5100 5101 mymm = MyMM.apply 5102 5103 x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True) 5104 # Puts one input tensor in a nested container. y's contained Tensor won't receive a gradient, 5105 # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments. 5106 # Sets requires_grad=False explicitly so we don't lie about expecting a gradient. 5107 y = ( 5108 0, 5109 { 5110 0: torch.randn( 5111 (8, 8), device="cuda", dtype=torch.float16, requires_grad=False 5112 ) 5113 }, 5114 ) 5115 5116 with torch.autocast("cuda"): 5117 output = mymm(x, y, torch.float32) 5118 self.assertTrue(output.dtype is torch.float32) 5119 loss = output.sum() 5120 loss.backward() 5121 5122 # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region. 5123 output = mymm(x, y, torch.float16) 5124 self.assertTrue(output.dtype is torch.float16) 5125 loss = output.sum() 5126 loss.backward() 5127 5128 def test_autocast_custom_deprecated_warning(self): 5129 with warnings.catch_warnings(record=True) as w: 5130 5131 class MyMM(torch.autograd.Function): 5132 @staticmethod 5133 @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 5134 def forward(ctx, x, y): 5135 ctx.save_for_backward(x, y) 5136 self.assertFalse(torch.is_autocast_enabled()) 5137 return x + y 5138 5139 @staticmethod 5140 @torch.cuda.amp.custom_bwd 5141 def backward(ctx, grad): 5142 _, _ = ctx.saved_tensors 5143 self.assertFalse(torch.is_autocast_enabled()) 5144 return grad, grad 5145 5146 self.assertRegex( 5147 str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated." 5148 ) 5149 self.assertRegex( 5150 str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated." 5151 ) 5152 5153 mymm = MyMM.apply 5154 x = torch.randn(3, 3, requires_grad=True) 5155 y = torch.randn(3, 3, requires_grad=True) 5156 with torch.amp.autocast("cuda"): 5157 output = mymm(x, y) 5158 loss = output.sum() 5159 loss.backward() 5160 5161 def test_autocast_cat_jit(self): 5162 # Reported at https://github.com/pytorch/pytorch/issues/38958 5163 5164 class Model(torch.nn.Module): 5165 def forward(self): 5166 a = torch.randn(1) 5167 b = torch.randn(1) 5168 c = torch.cat((a, b), 0) 5169 d = torch.stack([c, c], 0) 5170 return d 5171 5172 # The JIT here doesn't really matter, we just need to call 5173 # cat via the boxed API 5174 model = Model() 5175 model_jit_script = torch.jit.script(model) 5176 5177 with torch.autocast("cuda", enabled=True): 5178 model() 5179 model_jit_script() 5180 5181 # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened) 5182 # so they get a dedicated test. 5183 # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100). 5184 @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 5185 def test_autocast_rnn(self): 5186 with torch.backends.cudnn.flags(enabled=True, deterministic=True): 5187 # seq, batch, features, hidden size 5188 clses = ("RNN", "GRU", "LSTM") 5189 T, B, F, H = 3, 4, 5, 6 5190 dtypes = (torch.float16, torch.float32) 5191 input_layouts = ("seq_first", "batch_first", "packed") 5192 5193 for ( 5194 cls, 5195 num_layers, 5196 bias, 5197 input_layout, 5198 bidirectional, 5199 try_nonpreflattened_weights, 5200 input_dtype, 5201 hidden_dtype, 5202 weight_dtype, 5203 ) in product( 5204 clses, 5205 (1, 2), 5206 (True, False), 5207 input_layouts, 5208 (True, False), 5209 (True, False), 5210 dtypes, 5211 dtypes, 5212 dtypes, 5213 ): 5214 if input_layout == "seq_first": 5215 batch_first = False 5216 x = torch.randn((T, B, F), device="cuda", dtype=input_dtype) 5217 elif input_layout == "batch_first": 5218 batch_first = True 5219 x = torch.randn((B, T, F), device="cuda", dtype=input_dtype) 5220 elif input_layout == "packed": 5221 batch_first = False 5222 x = torch.nn.utils.rnn.pack_padded_sequence( 5223 torch.randn((T, B, F), device="cuda", dtype=input_dtype), 5224 lengths=(3, 2, 1, 3), 5225 enforce_sorted=False, 5226 ) 5227 5228 rnn = ( 5229 getattr(torch.nn, cls)( 5230 F, 5231 H, 5232 num_layers=num_layers, 5233 bidirectional=bidirectional, 5234 bias=bias, 5235 batch_first=batch_first, 5236 ) 5237 .cuda() 5238 .to(dtype=weight_dtype) 5239 ) 5240 5241 if try_nonpreflattened_weights: 5242 for p in rnn.parameters(): 5243 with torch.no_grad(): 5244 p.set_(p.clone()) 5245 5246 h = torch.randn( 5247 (num_layers * (2 if bidirectional else 1), B, H), 5248 device="cuda", 5249 dtype=hidden_dtype, 5250 ) 5251 if cls == "LSTM": 5252 c = torch.randn( 5253 (num_layers * (2 if bidirectional else 1), B, H), 5254 device="cuda", 5255 dtype=hidden_dtype, 5256 ) 5257 h = (h, c) 5258 5259 with torch.autocast("cuda"): 5260 out, h_out = rnn(x, h) 5261 out = out.data if input_layout == "packed" else out 5262 self.assertEqual(out.dtype, torch.float16) 5263 # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed. This check can't guarantee 5264 # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has 5265 # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed. 5266 self.assertEqual( 5267 out.grad_fn.name(), 5268 "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0", 5269 ) 5270 out.sum().backward() 5271 grads = [p.grad.clone() for p in rnn.parameters()] 5272 5273 rnn.zero_grad() 5274 5275 if cls == "LSTM": 5276 out_control, h_out_control = rnn.to(dtype=torch.float16)( 5277 x.half(), (h[0].half(), h[1].half()) 5278 ) 5279 else: 5280 out_control, h_out_control = rnn.to(dtype=torch.float16)( 5281 x.half(), h.half() 5282 ) 5283 out_control = ( 5284 out_control.data if input_layout == "packed" else out_control 5285 ) 5286 out_control.sum().backward() 5287 grads_control = [p.grad.clone() for p in rnn.parameters()] 5288 5289 # Compares with default tolerances, even for FP16 execution. Barring nondeterminism, 5290 # autocast and control results should be bitwise identical. 5291 self.assertEqual(out, out_control) 5292 5293 if cls == "LSTM": 5294 self.assertTrue( 5295 h_out[0].dtype is torch.float16 5296 and h_out[1].dtype is torch.float16 5297 ) 5298 self.assertEqual(h_out[0], h_out_control[0]) 5299 self.assertEqual(h_out[1], h_out_control[1]) 5300 else: 5301 self.assertEqual(h_out.dtype, torch.float16) 5302 self.assertEqual(h_out, h_out_control) 5303 for grad, grad_control in zip(grads, grads_control): 5304 self.assertEqual(grad.half(), grad_control) 5305 5306 def test_autocast_cache_leak(self): 5307 # Reported at https://github.com/pytorch/pytorch/issues/48049 5308 # Test is used to check, if autocast recaches the same parameters 5309 # when executed in a `torch.no_grad()` block. 5310 5311 linear = torch.nn.Linear(10, 10).to("cuda") 5312 data = torch.randn(1, 10, device="cuda") 5313 5314 with torch.autocast("cuda"): 5315 with torch.no_grad(): 5316 out = linear(data) 5317 first_iter_mem = torch.cuda.memory_allocated() 5318 for _ in range(3): 5319 out = linear(data) 5320 self.assertTrue(first_iter_mem == torch.cuda.memory_allocated()) 5321 5322 def test_autocast_checkpointing(self): 5323 model = torch.nn.Sequential( 5324 torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8) 5325 ).cuda() 5326 input = torch.rand( 5327 (8, 8), device="cuda", dtype=torch.float16, requires_grad=True 5328 ) 5329 for reentrant in (True, False): 5330 with torch.autocast("cuda"): 5331 output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant) 5332 self.assertTrue(output.requires_grad) 5333 self.assertTrue(output.dtype is torch.float16) 5334 output.sum().backward() 5335 5336 def test_cuda_autocast_deprecated_warning(self): 5337 with self.assertWarnsRegex( 5338 FutureWarning, 5339 r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.", 5340 ): 5341 with torch.cuda.amp.autocast(): 5342 _ = torch.ones(10) 5343 5344 5345instantiate_parametrized_tests(TestCuda) 5346instantiate_parametrized_tests(TestCudaMallocAsync) 5347instantiate_device_type_tests(TestCudaOptims, globals()) 5348 5349if __name__ == "__main__": 5350 run_tests() 5351