1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: cuda"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport ctypes 5*da0073e9SAndroid Build Coastguard Workerimport gc 6*da0073e9SAndroid Build Coastguard Workerimport json 7*da0073e9SAndroid Build Coastguard Workerimport os 8*da0073e9SAndroid Build Coastguard Workerimport pickle 9*da0073e9SAndroid Build Coastguard Workerimport random 10*da0073e9SAndroid Build Coastguard Workerimport subprocess 11*da0073e9SAndroid Build Coastguard Workerimport sys 12*da0073e9SAndroid Build Coastguard Workerimport tempfile 13*da0073e9SAndroid Build Coastguard Workerimport threading 14*da0073e9SAndroid Build Coastguard Workerimport unittest 15*da0073e9SAndroid Build Coastguard Workerimport warnings 16*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy 17*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 18*da0073e9SAndroid Build Coastguard Workerfrom random import randint 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerimport psutil 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerimport torch 23*da0073e9SAndroid Build Coastguard Workerimport torch.cuda 24*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan 25*da0073e9SAndroid Build Coastguard Workerfrom torch.cuda._memory_viz import ( 26*da0073e9SAndroid Build Coastguard Worker _profile_to_snapshot, 27*da0073e9SAndroid Build Coastguard Worker profile_plot, 28*da0073e9SAndroid Build Coastguard Worker segment_plot, 29*da0073e9SAndroid Build Coastguard Worker trace_plot, 30*da0073e9SAndroid Build Coastguard Worker) 31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast 32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import ( 33*da0073e9SAndroid Build Coastguard Worker _create_scaling_case, 34*da0073e9SAndroid Build Coastguard Worker _get_torch_cuda_version, 35*da0073e9SAndroid Build Coastguard Worker TEST_CUDNN, 36*da0073e9SAndroid Build Coastguard Worker TEST_MULTIGPU, 37*da0073e9SAndroid Build Coastguard Worker) 38*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 39*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 40*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 41*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, 42*da0073e9SAndroid Build Coastguard Worker) 43*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_optimizers import ( 44*da0073e9SAndroid Build Coastguard Worker _get_optim_inputs_including_global_cliquey_kwargs, 45*da0073e9SAndroid Build Coastguard Worker optim_db, 46*da0073e9SAndroid Build Coastguard Worker optims, 47*da0073e9SAndroid Build Coastguard Worker TensorTracker, 48*da0073e9SAndroid Build Coastguard Worker) 49*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 50*da0073e9SAndroid Build Coastguard Worker EXPANDABLE_SEGMENTS, 51*da0073e9SAndroid Build Coastguard Worker freeze_rng_state, 52*da0073e9SAndroid Build Coastguard Worker gcIfJetson, 53*da0073e9SAndroid Build Coastguard Worker get_cycles_per_ms, 54*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 55*da0073e9SAndroid Build Coastguard Worker IS_ARM64, 56*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 57*da0073e9SAndroid Build Coastguard Worker IS_JETSON, 58*da0073e9SAndroid Build Coastguard Worker IS_LINUX, 59*da0073e9SAndroid Build Coastguard Worker IS_SANDCASTLE, 60*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 61*da0073e9SAndroid Build Coastguard Worker load_tests, 62*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 63*da0073e9SAndroid Build Coastguard Worker parametrize, 64*da0073e9SAndroid Build Coastguard Worker run_tests, 65*da0073e9SAndroid Build Coastguard Worker serialTest, 66*da0073e9SAndroid Build Coastguard Worker skipCUDAMemoryLeakCheckIf, 67*da0073e9SAndroid Build Coastguard Worker skipCUDANonDefaultStreamIf, 68*da0073e9SAndroid Build Coastguard Worker skipIfRocm, 69*da0073e9SAndroid Build Coastguard Worker slowTest, 70*da0073e9SAndroid Build Coastguard Worker subtest, 71*da0073e9SAndroid Build Coastguard Worker TemporaryFileName, 72*da0073e9SAndroid Build Coastguard Worker TEST_CUDA, 73*da0073e9SAndroid Build Coastguard Worker TEST_CUDA_GRAPH, 74*da0073e9SAndroid Build Coastguard Worker TEST_NUMPY, 75*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 76*da0073e9SAndroid Build Coastguard Worker TestCase, 77*da0073e9SAndroid Build Coastguard Worker) 78*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.checkpoint import checkpoint_sequential 79*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.viz._cycles import observe_tensor_cycles 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for 83*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings 84*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Workertry: 87*da0073e9SAndroid Build Coastguard Worker import torchvision.models # noqa: F401 88*da0073e9SAndroid Build Coastguard Worker from torchvision.models import resnet18 # noqa: F401 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = True 91*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 92*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = False 93*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard WorkerTEST_CUDAMALLOCASYNC = TEST_CUDA and ( 96*da0073e9SAndroid Build Coastguard Worker torch.cuda.get_allocator_backend() == "cudaMallocAsync" 97*da0073e9SAndroid Build Coastguard Worker) 98*da0073e9SAndroid Build Coastguard WorkerTEST_LARGE_TENSOR = TEST_CUDA 99*da0073e9SAndroid Build Coastguard WorkerTEST_MEDIUM_TENSOR = TEST_CUDA 100*da0073e9SAndroid Build Coastguard WorkerTEST_BF16 = False 101*da0073e9SAndroid Build Coastguard WorkerTEST_PYNVML = not torch.cuda._HAS_PYNVML 102*da0073e9SAndroid Build Coastguard Workerif TEST_CUDA: 103*da0073e9SAndroid Build Coastguard Worker TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9 104*da0073e9SAndroid Build Coastguard Worker TEST_MEDIUM_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 6e9 105*da0073e9SAndroid Build Coastguard Worker TEST_BF16 = torch.cuda.is_bf16_supported() 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker_cycles_per_ms = None 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 111*da0073e9SAndroid Build Coastguard Worker@torch.testing._internal.common_utils.markDynamoStrictTest 112*da0073e9SAndroid Build Coastguard Workerclass TestCuda(TestCase): 113*da0073e9SAndroid Build Coastguard Worker _do_cuda_memory_leak_check = True 114*da0073e9SAndroid Build Coastguard Worker _do_cuda_non_default_stream = True 115*da0073e9SAndroid Build Coastguard Worker FIFTY_MIL_CYCLES = 50000000 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker def setUp(self): 118*da0073e9SAndroid Build Coastguard Worker super().setUp() 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 121*da0073e9SAndroid Build Coastguard Worker super().tearDown() 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker @property 124*da0073e9SAndroid Build Coastguard Worker def expandable_segments(self): 125*da0073e9SAndroid Build Coastguard Worker return EXPANDABLE_SEGMENTS 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker def test_pinned_memory_with_cudaregister(self): 128*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 129*da0073e9SAndroid Build Coastguard Worker "pinned_use_cuda_host_register:True,pinned_num_register_threads:8" 130*da0073e9SAndroid Build Coastguard Worker ) 131*da0073e9SAndroid Build Coastguard Worker t = torch.ones(20) 132*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t.is_pinned()) 133*da0073e9SAndroid Build Coastguard Worker try: 134*da0073e9SAndroid Build Coastguard Worker pinned_t = torch.ones(1 << 21).pin_memory() 135*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pinned_t.is_pinned()) 136*da0073e9SAndroid Build Coastguard Worker pinned_t = torch.ones(1 << 24).pin_memory() 137*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pinned_t.is_pinned()) 138*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 139*da0073e9SAndroid Build Coastguard Worker # Some GPUs don't support same address space on host and device side 140*da0073e9SAndroid Build Coastguard Worker pass 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker def test_pinned_memory_with_cudaregister_multithread(self): 143*da0073e9SAndroid Build Coastguard Worker num_threads = 4 144*da0073e9SAndroid Build Coastguard Worker threads = [ 145*da0073e9SAndroid Build Coastguard Worker threading.Thread(target=self.test_pinned_memory_with_cudaregister) 146*da0073e9SAndroid Build Coastguard Worker for t in range(num_threads) 147*da0073e9SAndroid Build Coastguard Worker ] 148*da0073e9SAndroid Build Coastguard Worker for thread in threads: 149*da0073e9SAndroid Build Coastguard Worker thread.start() 150*da0073e9SAndroid Build Coastguard Worker for thread in threads: 151*da0073e9SAndroid Build Coastguard Worker thread.join() 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker def test_pinned_memory_empty_cache(self): 154*da0073e9SAndroid Build Coastguard Worker for alloc_settings in (True, False): 155*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 156*da0073e9SAndroid Build Coastguard Worker f"pinned_use_cuda_host_register:{alloc_settings}" 157*da0073e9SAndroid Build Coastguard Worker ) 158*da0073e9SAndroid Build Coastguard Worker try: 159*da0073e9SAndroid Build Coastguard Worker t = torch.ones(1024 * 1024, pin_memory=True) 160*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.is_pinned()) 161*da0073e9SAndroid Build Coastguard Worker del t 162*da0073e9SAndroid Build Coastguard Worker torch._C._host_emptyCache() 163*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 164*da0073e9SAndroid Build Coastguard Worker # Some GPUs don't support same address space on host and device side 165*da0073e9SAndroid Build Coastguard Worker pass 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker def test_cudart_register(self): 168*da0073e9SAndroid Build Coastguard Worker t = torch.ones(20) 169*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t.is_pinned()) 170*da0073e9SAndroid Build Coastguard Worker cudart = torch.cuda.cudart() 171*da0073e9SAndroid Build Coastguard Worker r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) 172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, 0) 173*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.is_pinned()) 174*da0073e9SAndroid Build Coastguard Worker r = cudart.cudaHostUnregister(t.data_ptr()) 175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, 0) 176*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t.is_pinned()) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker def test_memory_allocation(self): 179*da0073e9SAndroid Build Coastguard Worker gc.collect() 180*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 181*da0073e9SAndroid Build Coastguard Worker mem = None 182*da0073e9SAndroid Build Coastguard Worker size = 1 183*da0073e9SAndroid Build Coastguard Worker prev = 0 184*da0073e9SAndroid Build Coastguard Worker try: 185*da0073e9SAndroid Build Coastguard Worker prev = torch.cuda.memory_allocated() 186*da0073e9SAndroid Build Coastguard Worker mem = torch.cuda.caching_allocator_alloc(size) 187*da0073e9SAndroid Build Coastguard Worker self.assertGreater(torch.cuda.memory_allocated(), prev) 188*da0073e9SAndroid Build Coastguard Worker finally: 189*da0073e9SAndroid Build Coastguard Worker if mem is not None: 190*da0073e9SAndroid Build Coastguard Worker torch.cuda.caching_allocator_delete(mem) 191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.memory_allocated(), prev) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker def test_check_error(self): 194*da0073e9SAndroid Build Coastguard Worker # Assert this call doesn't raise. 195*da0073e9SAndroid Build Coastguard Worker torch.cuda.check_error(0) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 198*da0073e9SAndroid Build Coastguard Worker torch.cuda.CudaError, "out of memory|hipErrorOutOfMemory" 199*da0073e9SAndroid Build Coastguard Worker ): 200*da0073e9SAndroid Build Coastguard Worker torch.cuda.check_error(2) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker def test_cuda_get_device_name(self): 203*da0073e9SAndroid Build Coastguard Worker # Testing the behaviour with None as an argument 204*da0073e9SAndroid Build Coastguard Worker current_device = torch.cuda.current_device() 205*da0073e9SAndroid Build Coastguard Worker current_device_name = torch.cuda.get_device_name(current_device) 206*da0073e9SAndroid Build Coastguard Worker device_name_None = torch.cuda.get_device_name(None) 207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(current_device_name, device_name_None) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker # Testing the behaviour for No argument 210*da0073e9SAndroid Build Coastguard Worker device_name_no_argument = torch.cuda.get_device_name() 211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(current_device_name, device_name_no_argument) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker def test_cuda_get_device_capability(self): 214*da0073e9SAndroid Build Coastguard Worker # Testing the behaviour with None as an argument 215*da0073e9SAndroid Build Coastguard Worker current_device = torch.cuda.current_device() 216*da0073e9SAndroid Build Coastguard Worker current_device_capability = torch.cuda.get_device_capability(current_device) 217*da0073e9SAndroid Build Coastguard Worker device_capability_None = torch.cuda.get_device_capability(None) 218*da0073e9SAndroid Build Coastguard Worker self.assertEqual(current_device_capability, device_capability_None) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker # Testing the behaviour for No argument 221*da0073e9SAndroid Build Coastguard Worker device_capability_no_argument = torch.cuda.get_device_capability() 222*da0073e9SAndroid Build Coastguard Worker self.assertEqual(current_device_capability, device_capability_no_argument) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker def test_out_of_memory(self): 225*da0073e9SAndroid Build Coastguard Worker tensor = torch.zeros(1024, device="cuda") 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker oom_regex = ( 228*da0073e9SAndroid Build Coastguard Worker "would exceed allowed memory" 229*da0073e9SAndroid Build Coastguard Worker if TEST_CUDAMALLOCASYNC 230*da0073e9SAndroid Build Coastguard Worker else "Tried to allocate 800000000.00 GiB" 231*da0073e9SAndroid Build Coastguard Worker ) 232*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, oom_regex): 233*da0073e9SAndroid Build Coastguard Worker torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device="cuda") 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 236*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Tried to allocate more than 1EB memory" 237*da0073e9SAndroid Build Coastguard Worker ): 238*da0073e9SAndroid Build Coastguard Worker torch.empty( 239*da0073e9SAndroid Build Coastguard Worker 1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="cuda" 240*da0073e9SAndroid Build Coastguard Worker ) 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker # ensure out of memory error doesn't disturb subsequent kernel 243*da0073e9SAndroid Build Coastguard Worker tensor.fill_(1) 244*da0073e9SAndroid Build Coastguard Worker self.assertTrue((tensor == 1).all()) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 247*da0073e9SAndroid Build Coastguard Worker TEST_CUDAMALLOCASYNC or IS_JETSON, "Segmentation fault (core dumped)" 248*da0073e9SAndroid Build Coastguard Worker ) 249*da0073e9SAndroid Build Coastguard Worker @serialTest() 250*da0073e9SAndroid Build Coastguard Worker def test_out_of_memory_retry(self): 251*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 252*da0073e9SAndroid Build Coastguard Worker total_memory = torch.cuda.get_device_properties(0).total_memory 253*da0073e9SAndroid Build Coastguard Worker oom_regex = ( 254*da0073e9SAndroid Build Coastguard Worker "would exceed allowed memory" 255*da0073e9SAndroid Build Coastguard Worker if TEST_CUDAMALLOCASYNC 256*da0073e9SAndroid Build Coastguard Worker else "Tried to allocate" 257*da0073e9SAndroid Build Coastguard Worker ) 258*da0073e9SAndroid Build Coastguard Worker size = int(total_memory * 0.5) 259*da0073e9SAndroid Build Coastguard Worker a = torch.empty(size, dtype=torch.int8, device="cuda") 260*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, oom_regex): 261*da0073e9SAndroid Build Coastguard Worker b = torch.empty(size, dtype=torch.int8, device="cuda") 262*da0073e9SAndroid Build Coastguard Worker del a 263*da0073e9SAndroid Build Coastguard Worker b = torch.empty(size, dtype=torch.int8, device="cuda") 264*da0073e9SAndroid Build Coastguard Worker del b 265*da0073e9SAndroid Build Coastguard Worker # We used a lot of memory here, clean up so we don't affect other tests too much 266*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 267*da0073e9SAndroid Build Coastguard Worker torch.cuda.reset_peak_memory_stats() 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker @serialTest() 270*da0073e9SAndroid Build Coastguard Worker def test_set_per_process_memory_fraction(self): 271*da0073e9SAndroid Build Coastguard Worker # test invalid fraction value. 272*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "Invalid type"): 273*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_per_process_memory_fraction(1) 274*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Invalid fraction value"): 275*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_per_process_memory_fraction(-0.1) 276*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Invalid fraction value"): 277*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_per_process_memory_fraction(2.0) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker tensor = torch.zeros(1024, device="cuda") 280*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 281*da0073e9SAndroid Build Coastguard Worker total_memory = torch.cuda.get_device_properties(0).total_memory 282*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_per_process_memory_fraction(0.5, 0) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker # test 0.499 allocation is ok. 285*da0073e9SAndroid Build Coastguard Worker application = int(total_memory * 0.499) - torch.cuda.max_memory_reserved() 286*da0073e9SAndroid Build Coastguard Worker tmp_tensor = torch.empty(application, dtype=torch.int8, device="cuda") 287*da0073e9SAndroid Build Coastguard Worker del tmp_tensor 288*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker application = int(total_memory * 0.5) 291*da0073e9SAndroid Build Coastguard Worker # it will get OOM when try to allocate more than half memory. 292*da0073e9SAndroid Build Coastguard Worker oom_regex = ( 293*da0073e9SAndroid Build Coastguard Worker "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory" 294*da0073e9SAndroid Build Coastguard Worker ) 295*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, oom_regex): 296*da0073e9SAndroid Build Coastguard Worker torch.empty(application, dtype=torch.int8, device="cuda") 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker # ensure out of memory error doesn't disturb subsequent kernel 299*da0073e9SAndroid Build Coastguard Worker tensor.fill_(1) 300*da0073e9SAndroid Build Coastguard Worker self.assertTrue((tensor == 1).all()) 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "uuid attribute not yet available") 303*da0073e9SAndroid Build Coastguard Worker def test_uuid(self): 304*da0073e9SAndroid Build Coastguard Worker uuid = torch.cuda.get_device_properties(0).uuid 305*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(str(uuid)), 36) # xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx 306*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(uuid.bytes), 16) 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker def test_copy_non_blocking(self): 309*da0073e9SAndroid Build Coastguard Worker def _test_copy_non_blocking(a, b): 310*da0073e9SAndroid Build Coastguard Worker event = torch.cuda.Event() 311*da0073e9SAndroid Build Coastguard Worker a.copy_(b, non_blocking=True) 312*da0073e9SAndroid Build Coastguard Worker event.record() 313*da0073e9SAndroid Build Coastguard Worker event.synchronize() 314*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker # 10MB copies 317*da0073e9SAndroid Build Coastguard Worker x = torch.ones(10000000, dtype=torch.uint8).cuda() 318*da0073e9SAndroid Build Coastguard Worker y = torch.zeros(10000000, dtype=torch.uint8).pin_memory() 319*da0073e9SAndroid Build Coastguard Worker _test_copy_non_blocking(x, y) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(10000000, dtype=torch.uint8).pin_memory() 322*da0073e9SAndroid Build Coastguard Worker y = torch.ones(10000000, dtype=torch.uint8).cuda() 323*da0073e9SAndroid Build Coastguard Worker _test_copy_non_blocking(x, y) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker # Test the case where the pinned data_ptr is not equal to the storage data_ptr. 326*da0073e9SAndroid Build Coastguard Worker x_base = torch.zeros(10000000, dtype=torch.uint8).pin_memory() 327*da0073e9SAndroid Build Coastguard Worker x = x_base[1:] 328*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_pinned()) 329*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x_base.is_pinned()) 330*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(x_base.data_ptr(), x.data_ptr()) 331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_base.storage().data_ptr(), x.storage().data_ptr()) 332*da0073e9SAndroid Build Coastguard Worker y = torch.ones(10000000 - 1, dtype=torch.uint8).cuda() 333*da0073e9SAndroid Build Coastguard Worker _test_copy_non_blocking(x, y) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker def test_copy_non_blocking_type_conversion(self): 336*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, device="cuda") 337*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(1, device="cpu", pin_memory=True) 338*da0073e9SAndroid Build Coastguard Worker c = torch.empty(1, device="cuda", dtype=torch.long) 339*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(int(100 * get_cycles_per_ms())) 340*da0073e9SAndroid Build Coastguard Worker b.copy_(a, non_blocking=True) 341*da0073e9SAndroid Build Coastguard Worker c.copy_(b, non_blocking=True) 342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, c, exact_dtype=False) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker @serialTest() 345*da0073e9SAndroid Build Coastguard Worker def test_to_non_blocking(self): 346*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.current_stream() 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker def _test_to_non_blocking(a, non_blocking, dst): 349*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 350*da0073e9SAndroid Build Coastguard Worker # Pushes an 0.1 second spin to stream so if the copy is non blocking, 351*da0073e9SAndroid Build Coastguard Worker # stream will almost surely be active when we query(). 352*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(int(100 * get_cycles_per_ms())) 353*da0073e9SAndroid Build Coastguard Worker b = a.to(device=dst, non_blocking=non_blocking) 354*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stream.query(), not non_blocking) 355*da0073e9SAndroid Build Coastguard Worker stream.synchronize() 356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 357*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.is_pinned() == (non_blocking and dst == "cpu")) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker for dst, try_non_blocking in product(("cuda", "cpu"), (True, False)): 360*da0073e9SAndroid Build Coastguard Worker # Creates source on the opposite device from destination. 361*da0073e9SAndroid Build Coastguard Worker src = torch.randn( 362*da0073e9SAndroid Build Coastguard Worker 1000000, 363*da0073e9SAndroid Build Coastguard Worker device="cuda" if dst == "cpu" else "cpu", 364*da0073e9SAndroid Build Coastguard Worker pin_memory=True if dst == "cuda" else False, 365*da0073e9SAndroid Build Coastguard Worker ) 366*da0073e9SAndroid Build Coastguard Worker _test_to_non_blocking(src, try_non_blocking, dst) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker def test_to_cpu_blocking_by_default(self): 369*da0073e9SAndroid Build Coastguard Worker src = torch.randn(1000000, device="cuda") 370*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 371*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(int(100 * get_cycles_per_ms())) 372*da0073e9SAndroid Build Coastguard Worker dst = src.to(device="cpu") 373*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream().query(), True) 374*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src, dst) 375*da0073e9SAndroid Build Coastguard Worker self.assertFalse(dst.is_pinned()) 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker def test_serialization_array_with_storage(self): 378*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5).cuda() 379*da0073e9SAndroid Build Coastguard Worker y = torch.IntTensor(2, 5).fill_(0).cuda() 380*da0073e9SAndroid Build Coastguard Worker q = [x, y, x, y.storage()] 381*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 382*da0073e9SAndroid Build Coastguard Worker torch.save(q, f) 383*da0073e9SAndroid Build Coastguard Worker f.seek(0) 384*da0073e9SAndroid Build Coastguard Worker q_copy = torch.load(f) 385*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q_copy, q, atol=0, rtol=0) 386*da0073e9SAndroid Build Coastguard Worker q_copy[0].fill_(5) 387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0) 388*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor)) 389*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor)) 390*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor)) 391*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage)) 392*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage)) 393*da0073e9SAndroid Build Coastguard Worker q_copy[1].fill_(10) 394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 397*da0073e9SAndroid Build Coastguard Worker TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async" 398*da0073e9SAndroid Build Coastguard Worker ) 399*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 400*da0073e9SAndroid Build Coastguard Worker _get_torch_cuda_version() >= (12, 2), 401*da0073e9SAndroid Build Coastguard Worker "skipped as explicit workspace allocation is removed", 402*da0073e9SAndroid Build Coastguard Worker ) 403*da0073e9SAndroid Build Coastguard Worker def test_cublas_workspace_explicit_allocation(self): 404*da0073e9SAndroid Build Coastguard Worker a = torch.randn(7, 7, device="cuda", requires_grad=False) 405*da0073e9SAndroid Build Coastguard Worker default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024 # :4096:2:16:8 406*da0073e9SAndroid Build Coastguard Worker # different size (32 MiB) expected on Hopper GPU 407*da0073e9SAndroid Build Coastguard Worker if torch.cuda.get_device_capability() == (9, 0): 408*da0073e9SAndroid Build Coastguard Worker default_workspace_size = 4096 * 8 * 1024 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker def check_workspace_size(inp): 411*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_clearCublasWorkspaces() 412*da0073e9SAndroid Build Coastguard Worker start = torch.cuda.memory_stats()["active_bytes.all.allocated"] 413*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 414*da0073e9SAndroid Build Coastguard Worker torch.matmul(inp, inp) 415*da0073e9SAndroid Build Coastguard Worker finish = torch.cuda.memory_stats()["active_bytes.all.allocated"] 416*da0073e9SAndroid Build Coastguard Worker return finish - start 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker # check default 419*da0073e9SAndroid Build Coastguard Worker os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" 420*da0073e9SAndroid Build Coastguard Worker self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288) 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker # check default with bad user config 423*da0073e9SAndroid Build Coastguard Worker os.environ["CUBLAS_WORKSPACE_CONFIG"] = "-1" 424*da0073e9SAndroid Build Coastguard Worker self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288) 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker # check valid config 427*da0073e9SAndroid Build Coastguard Worker os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":128:8:64:16:32:32" 428*da0073e9SAndroid Build Coastguard Worker self.assertTrue(abs(check_workspace_size(a) - (3072 * 1024)) < 524288) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_clearCublasWorkspaces() 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker def test_cublas_allow_tf32_get_set(self): 433*da0073e9SAndroid Build Coastguard Worker skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( 434*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] 435*da0073e9SAndroid Build Coastguard Worker ) 436*da0073e9SAndroid Build Coastguard Worker if skip_tf32_cublas: 437*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.backends.cuda.matmul.allow_tf32) 438*da0073e9SAndroid Build Coastguard Worker return 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker orig = torch.backends.cuda.matmul.allow_tf32 441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch._C._get_cublas_allow_tf32(), orig) 442*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_tf32 = not orig 443*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig) 444*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_tf32 = orig 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker def test_float32_matmul_precision_get_set(self): 447*da0073e9SAndroid Build Coastguard Worker orig = torch.get_float32_matmul_precision() 448*da0073e9SAndroid Build Coastguard Worker skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( 449*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] 450*da0073e9SAndroid Build Coastguard Worker ) 451*da0073e9SAndroid Build Coastguard Worker # this is really just checking that the environment variable is respected during testing 452*da0073e9SAndroid Build Coastguard Worker # and not overwritten by another function that doesn't revert it to the intitial value 453*da0073e9SAndroid Build Coastguard Worker if not skip_tf32_cublas: 454*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.backends.cuda.matmul.allow_tf32) 455*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.get_float32_matmul_precision(), "highest") 456*da0073e9SAndroid Build Coastguard Worker else: 457*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.backends.cuda.matmul.allow_tf32) 458*da0073e9SAndroid Build Coastguard Worker for p in ("medium", "high"): 459*da0073e9SAndroid Build Coastguard Worker torch.set_float32_matmul_precision(p) 460*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.get_float32_matmul_precision(), p) 461*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.backends.cuda.matmul.allow_tf32) 462*da0073e9SAndroid Build Coastguard Worker torch.set_float32_matmul_precision("highest") 463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.get_float32_matmul_precision(), "highest") 464*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.backends.cuda.matmul.allow_tf32) 465*da0073e9SAndroid Build Coastguard Worker torch.set_float32_matmul_precision(orig) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self): 468*da0073e9SAndroid Build Coastguard Worker orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction 469*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 470*da0073e9SAndroid Build Coastguard Worker torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), orig 471*da0073e9SAndroid Build Coastguard Worker ) 472*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = not orig 473*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 474*da0073e9SAndroid Build Coastguard Worker torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig 475*da0073e9SAndroid Build Coastguard Worker ) 476*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker def test_cublas_allow_bf16_reduced_precision_reduction_get_set(self): 479*da0073e9SAndroid Build Coastguard Worker orig = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction 480*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 481*da0073e9SAndroid Build Coastguard Worker torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), orig 482*da0073e9SAndroid Build Coastguard Worker ) 483*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = not orig 484*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 485*da0073e9SAndroid Build Coastguard Worker torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig 486*da0073e9SAndroid Build Coastguard Worker ) 487*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker def test_cudnn_allow_tf32_get_set(self): 490*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags( 491*da0073e9SAndroid Build Coastguard Worker enabled=None, benchmark=None, deterministic=None, allow_tf32=False 492*da0073e9SAndroid Build Coastguard Worker ): 493*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.backends.cudnn.allow_tf32) 494*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags( 495*da0073e9SAndroid Build Coastguard Worker enabled=None, benchmark=None, deterministic=None, allow_tf32=True 496*da0073e9SAndroid Build Coastguard Worker ): 497*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.backends.cudnn.allow_tf32) 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker def test_type_conversions(self): 500*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 501*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.float(), torch.FloatTensor) 502*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.cuda().double(), torch.cuda.DoubleTensor) 503*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.cuda().float(), torch.cuda.FloatTensor) 504*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.cuda().float().cpu(), torch.FloatTensor) 505*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.cuda().float().cpu().int(), torch.IntTensor) 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker y = x.storage() 508*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(y.float(), torch.FloatStorage) 509*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(y.cuda().double(), torch.cuda.DoubleStorage) 510*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(y.cuda().float(), torch.cuda.FloatStorage) 511*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(y.cuda().float().cpu(), torch.FloatStorage) 512*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(y.cuda().float().cpu().int(), torch.IntStorage) 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker @unittest.skip("was disabled due to not enough memory, but actually it always fail") 515*da0073e9SAndroid Build Coastguard Worker def test_arithmetic_large_tensor(self): 516*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2**30, device="cuda") 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker x.fill_(1) 519*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(), 2**30) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker x += 1 522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(), 2**31) 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker x.fill_(1) 525*da0073e9SAndroid Build Coastguard Worker x -= 0.5 526*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(), 2**29) 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker x.fill_(1) 529*da0073e9SAndroid Build Coastguard Worker x *= 2 530*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(), 2**31) 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker x.fill_(1) 533*da0073e9SAndroid Build Coastguard Worker x /= 2 534*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(), 2**29) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker def test_gather_bool(self): 537*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[False, True], [True, True]], device="cuda") 538*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 539*da0073e9SAndroid Build Coastguard Worker torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]], device="cuda")), 540*da0073e9SAndroid Build Coastguard Worker torch.tensor([[False, False], [True, True]], device="cuda"), 541*da0073e9SAndroid Build Coastguard Worker ) 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker def test_torch_manual_seed_seeds_cuda_devices(self): 544*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 545*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(4, 4).float().cuda() 546*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(2) 547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.initial_seed(), 2) 548*da0073e9SAndroid Build Coastguard Worker x.uniform_() 549*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(2) 550*da0073e9SAndroid Build Coastguard Worker y = x.clone().uniform_() 551*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 552*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.initial_seed(), 2) 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker def test_manual_seed(self): 555*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 556*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(4, 4).float().cuda() 557*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(2) 558*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.initial_seed(), 2) 559*da0073e9SAndroid Build Coastguard Worker x.uniform_() 560*da0073e9SAndroid Build Coastguard Worker a = torch.bernoulli(torch.full_like(x, 0.5)) 561*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(2) 562*da0073e9SAndroid Build Coastguard Worker y = x.clone().uniform_() 563*da0073e9SAndroid Build Coastguard Worker b = torch.bernoulli(torch.full_like(x, 0.5)) 564*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 566*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.initial_seed(), 2) 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker def test_specify_improper_device_name(self): 569*da0073e9SAndroid Build Coastguard Worker import os 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker fname = "tempfile.pt" 572*da0073e9SAndroid Build Coastguard Worker try: 573*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Invalid device string"): 574*da0073e9SAndroid Build Coastguard Worker torch.save( 575*da0073e9SAndroid Build Coastguard Worker [torch.nn.Parameter(torch.randn(10, 10))], 576*da0073e9SAndroid Build Coastguard Worker fname, 577*da0073e9SAndroid Build Coastguard Worker _use_new_zipfile_serialization=True, 578*da0073e9SAndroid Build Coastguard Worker ) 579*da0073e9SAndroid Build Coastguard Worker torch.load(fname, "cuda0") 580*da0073e9SAndroid Build Coastguard Worker finally: 581*da0073e9SAndroid Build Coastguard Worker if os.path.exists(fname): 582*da0073e9SAndroid Build Coastguard Worker os.remove(fname) 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker def test_get_device_index(self): 585*da0073e9SAndroid Build Coastguard Worker from torch.cuda._utils import _get_device_index 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Invalid device string"): 588*da0073e9SAndroid Build Coastguard Worker _get_device_index("cuda0", optional=True) 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a cuda device"): 591*da0073e9SAndroid Build Coastguard Worker cpu_device = torch.device("cpu") 592*da0073e9SAndroid Build Coastguard Worker _get_device_index(cpu_device, optional=True) 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker def test_serialization_array_with_empty(self): 595*da0073e9SAndroid Build Coastguard Worker x = [torch.randn(4, 4).cuda(), torch.cuda.FloatTensor()] 596*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 597*da0073e9SAndroid Build Coastguard Worker torch.save(x, f) 598*da0073e9SAndroid Build Coastguard Worker f.seek(0) 599*da0073e9SAndroid Build Coastguard Worker x_copy = torch.load(f) 600*da0073e9SAndroid Build Coastguard Worker for original, copy in zip(x, x_copy): 601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy, original) 602*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(copy), type(original)) 603*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy.get_device(), original.get_device()) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker @skipCUDANonDefaultStreamIf(True) 606*da0073e9SAndroid Build Coastguard Worker def test_streams(self): 607*da0073e9SAndroid Build Coastguard Worker default_stream = torch.cuda.current_stream() 608*da0073e9SAndroid Build Coastguard Worker user_stream = torch.cuda.Stream() 609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), default_stream) 610*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(default_stream, user_stream) 611*da0073e9SAndroid Build Coastguard Worker self.assertEqual(default_stream.cuda_stream, 0) 612*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(user_stream.cuda_stream, 0) 613*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(user_stream): 614*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), user_stream) 615*da0073e9SAndroid Build Coastguard Worker self.assertTrue(user_stream.query()) 616*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.ByteTensor(5).pin_memory() 617*da0073e9SAndroid Build Coastguard Worker tensor2 = tensor1.cuda(non_blocking=True) + 1 618*da0073e9SAndroid Build Coastguard Worker default_stream.synchronize() 619*da0073e9SAndroid Build Coastguard Worker self.assertTrue(default_stream.query()) 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker def test_stream_event_repr(self): 622*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.current_stream() 623*da0073e9SAndroid Build Coastguard Worker self.assertTrue("torch.cuda.Stream" in s.__repr__()) 624*da0073e9SAndroid Build Coastguard Worker e = torch.cuda.Event() 625*da0073e9SAndroid Build Coastguard Worker self.assertTrue("torch.cuda.Event" in e.__repr__()) 626*da0073e9SAndroid Build Coastguard Worker s.record_event(e) 627*da0073e9SAndroid Build Coastguard Worker self.assertTrue("torch.cuda.Event" in e.__repr__()) 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker def test_events(self): 630*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.current_stream() 631*da0073e9SAndroid Build Coastguard Worker event = torch.cuda.Event(enable_timing=True) 632*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event.query()) 633*da0073e9SAndroid Build Coastguard Worker start_event = torch.cuda.Event(enable_timing=True) 634*da0073e9SAndroid Build Coastguard Worker stream.record_event(start_event) 635*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(int(50 * get_cycles_per_ms())) 636*da0073e9SAndroid Build Coastguard Worker stream.record_event(event) 637*da0073e9SAndroid Build Coastguard Worker self.assertFalse(event.query()) 638*da0073e9SAndroid Build Coastguard Worker event.synchronize() 639*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event.query()) 640*da0073e9SAndroid Build Coastguard Worker self.assertGreater(start_event.elapsed_time(event), 0) 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker def test_generic_stream_event(self): 643*da0073e9SAndroid Build Coastguard Worker stream = torch.Stream("cuda") 644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stream.device_index, torch.cuda.current_device()) 645*da0073e9SAndroid Build Coastguard Worker cuda_stream = torch.cuda.Stream( 646*da0073e9SAndroid Build Coastguard Worker stream_id=stream.stream_id, 647*da0073e9SAndroid Build Coastguard Worker device_index=stream.device_index, 648*da0073e9SAndroid Build Coastguard Worker device_type=stream.device_type, 649*da0073e9SAndroid Build Coastguard Worker ) 650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stream.stream_id, cuda_stream.stream_id) 651*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(stream.stream_id, torch.cuda.current_stream().stream_id) 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker event1 = torch.Event("cuda", enable_timing=True) 654*da0073e9SAndroid Build Coastguard Worker event2 = torch.Event("cuda", enable_timing=True) 655*da0073e9SAndroid Build Coastguard Worker self.assertEqual(event1.event_id, 0) 656*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1000) 657*da0073e9SAndroid Build Coastguard Worker b = torch.randn(1000) 658*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(cuda_stream): 659*da0073e9SAndroid Build Coastguard Worker a_cuda = a.to("cuda", non_blocking=True) 660*da0073e9SAndroid Build Coastguard Worker b_cuda = b.to("cuda", non_blocking=True) 661*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stream.stream_id, torch.cuda.current_stream().stream_id) 662*da0073e9SAndroid Build Coastguard Worker event1.record(stream) 663*da0073e9SAndroid Build Coastguard Worker event1.synchronize() 664*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event1.query()) 665*da0073e9SAndroid Build Coastguard Worker c_cuda = a_cuda + b_cuda 666*da0073e9SAndroid Build Coastguard Worker event2.record() 667*da0073e9SAndroid Build Coastguard Worker event2.synchronize() 668*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event2.query()) 669*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(event1.event_id, event2.event_id) 670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_cuda.cpu(), a + b) 671*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event1.elapsed_time(event2) > 0) 672*da0073e9SAndroid Build Coastguard Worker 673*da0073e9SAndroid Build Coastguard Worker def test_record_stream(self): 674*da0073e9SAndroid Build Coastguard Worker cycles_per_ms = get_cycles_per_ms() 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker t = torch.FloatTensor([1, 2, 3, 4]).pin_memory() 677*da0073e9SAndroid Build Coastguard Worker result = torch.cuda.FloatTensor(t.size()) 678*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 679*da0073e9SAndroid Build Coastguard Worker ptr = [None] 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker # Performs the CPU->GPU copy in a background stream 682*da0073e9SAndroid Build Coastguard Worker def perform_copy(): 683*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 684*da0073e9SAndroid Build Coastguard Worker tmp = t.cuda(non_blocking=True) 685*da0073e9SAndroid Build Coastguard Worker ptr[0] = tmp.data_ptr() 686*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(stream) 687*da0073e9SAndroid Build Coastguard Worker tmp.record_stream(torch.cuda.current_stream()) 688*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(int(50 * cycles_per_ms)) # delay the copy 689*da0073e9SAndroid Build Coastguard Worker result.copy_(tmp) 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker perform_copy() 692*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 693*da0073e9SAndroid Build Coastguard Worker tmp2 = torch.cuda.FloatTensor(t.size()) 694*da0073e9SAndroid Build Coastguard Worker tmp2.zero_() 695*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual( 696*da0073e9SAndroid Build Coastguard Worker tmp2.data_ptr(), ptr[0], msg="allocation re-used to soon" 697*da0073e9SAndroid Build Coastguard Worker ) 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.tolist(), [1, 2, 3, 4]) 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker if not TEST_CUDAMALLOCASYNC: 702*da0073e9SAndroid Build Coastguard Worker # In the native allocator, we expect "tmp"'s side-stream-tagged block will be reused 703*da0073e9SAndroid Build Coastguard Worker # in that side stream after result.copy_(tmp) in the main stream finishes. 704*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().synchronize() 705*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 706*da0073e9SAndroid Build Coastguard Worker tmp3 = torch.cuda.FloatTensor(t.size()) 707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tmp3.data_ptr(), ptr[0], msg="allocation not re-used") 708*da0073e9SAndroid Build Coastguard Worker 709*da0073e9SAndroid Build Coastguard Worker def test_record_stream_on_shifted_view(self): 710*da0073e9SAndroid Build Coastguard Worker # See issue #27366 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker # This test detects unexpected block reallocation. For reliable test, 713*da0073e9SAndroid Build Coastguard Worker # the stream to allocate tensors is isolated. The allocator will not 714*da0073e9SAndroid Build Coastguard Worker # reuse free blocks which were allocated from another stream. 715*da0073e9SAndroid Build Coastguard Worker stream_alloc = torch.cuda.Stream() 716*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream_alloc): 717*da0073e9SAndroid Build Coastguard Worker base = torch.cuda.FloatTensor([10, 10]) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker # Record another stream on a shifted view tensor. 720*da0073e9SAndroid Build Coastguard Worker view = base[5:] 721*da0073e9SAndroid Build Coastguard Worker assert view.storage_offset() > 0 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker stream_record = torch.cuda.Stream() 724*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream_record): 725*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(int(50 * get_cycles_per_ms())) 726*da0073e9SAndroid Build Coastguard Worker 727*da0073e9SAndroid Build Coastguard Worker view.record_stream(stream_record) 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker # Delete those tensors to make the block free soon. 730*da0073e9SAndroid Build Coastguard Worker data_ptr = base.data_ptr() 731*da0073e9SAndroid Build Coastguard Worker del base, view 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker # A new tensor should not be allocated to the block above. 734*da0073e9SAndroid Build Coastguard Worker stream_alloc.synchronize() 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream_alloc): 737*da0073e9SAndroid Build Coastguard Worker try_realloc = torch.cuda.FloatTensor([10, 10]) 738*da0073e9SAndroid Build Coastguard Worker 739*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(try_realloc.data_ptr(), data_ptr) 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker def test_noncontiguous_pinned_memory(self): 742*da0073e9SAndroid Build Coastguard Worker # See issue #3266 743*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 10).view((2, 5)) 744*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.t(), x.t().pin_memory()) 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker def test_caching_pinned_memory(self): 747*da0073e9SAndroid Build Coastguard Worker cycles_per_ms = get_cycles_per_ms() 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker # check that allocations are re-used after deletion 750*da0073e9SAndroid Build Coastguard Worker t = torch.FloatTensor([1]).pin_memory() 751*da0073e9SAndroid Build Coastguard Worker ptr = t.data_ptr() 752*da0073e9SAndroid Build Coastguard Worker del t 753*da0073e9SAndroid Build Coastguard Worker t = torch.FloatTensor([1]).pin_memory() 754*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.data_ptr(), ptr, msg="allocation not reused") 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard Worker # check that the allocation is not re-used if it's in-use by a copy 757*da0073e9SAndroid Build Coastguard Worker gpu_tensor = torch.cuda.FloatTensor([0]) 758*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(int(1000 * cycles_per_ms)) # delay the copy by 1s 759*da0073e9SAndroid Build Coastguard Worker gpu_tensor.copy_(t, non_blocking=True) 760*da0073e9SAndroid Build Coastguard Worker del t 761*da0073e9SAndroid Build Coastguard Worker t = torch.FloatTensor([1]).pin_memory() 762*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon") 763*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(gpu_tensor), [1]) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker def test_caching_allocator_record_stream_oom(self): 766*da0073e9SAndroid Build Coastguard Worker """allocations delayed by a record_stream call should still be freed on 767*da0073e9SAndroid Build Coastguard Worker an out-of-memory in cuda_malloc_retry. see issue #19219""" 768*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 771*da0073e9SAndroid Build Coastguard Worker y = torch.zeros(40 * 1024 * 1024, device="cuda") 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker for _ in range(100): 774*da0073e9SAndroid Build Coastguard Worker x = torch.empty(40 * 1024 * 1024, device="cuda") 775*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 776*da0073e9SAndroid Build Coastguard Worker y += x 777*da0073e9SAndroid Build Coastguard Worker # delays re-use of `x` until after all operations in `stream` 778*da0073e9SAndroid Build Coastguard Worker x.record_stream(stream) 779*da0073e9SAndroid Build Coastguard Worker del x 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker # we've made a mess by allocating up to the device capacity. free any 782*da0073e9SAndroid Build Coastguard Worker # cached blocks in case it affects future tests. 783*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker # Tests for historic illegal memory access, see #17040. 786*da0073e9SAndroid Build Coastguard Worker def test_reduction_gpu_memory_accessing(self): 787*da0073e9SAndroid Build Coastguard Worker x = torch.ones(512, 8, dtype=torch.float32, device="cuda") 788*da0073e9SAndroid Build Coastguard Worker torch.sum(x, 0) 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker def test_sum_fp16(self): 791*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(10, device="cuda", dtype=torch.float16) 792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(), 0) 793*da0073e9SAndroid Build Coastguard Worker 794*da0073e9SAndroid Build Coastguard Worker x = torch.ones(65504, device="cuda", dtype=torch.float16) 795*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(), 65504) 796*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(dtype=torch.float32), 65504) 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker x = torch.ones(65536, device="cuda", dtype=torch.float16) 799*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(dtype=torch.float32), 65536) 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(1203611).bernoulli_(0.0005) 802*da0073e9SAndroid Build Coastguard Worker x = a.to(device="cuda", dtype=torch.float16) 803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum().item(), a.sum().item()) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(100, 121, 80).bernoulli_(0.0005) 806*da0073e9SAndroid Build Coastguard Worker x = a.to(device="cuda", dtype=torch.float16) 807*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum((0, 2)).float().cpu(), a.sum((0, 2))) 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker def test_mean_fp16(self): 810*da0073e9SAndroid Build Coastguard Worker x = torch.ones(65536, device="cuda", dtype=torch.float16) 811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.mean(), 1) 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker x = torch.ones(65536, device="cuda", dtype=torch.float16) 814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.mean(dtype=torch.float32), 1) 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker def test_prod_large(self): 817*da0073e9SAndroid Build Coastguard Worker # tests global reduction (should_global_reduce = true) in case of non-zero identity element 818*da0073e9SAndroid Build Coastguard Worker x = torch.ones(240000, device="cuda", dtype=torch.float32) 819*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.prod(), 1) 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker # test for complex types. Note 240k is divisible by 4 822*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.cfloat, torch.cdouble]: 823*da0073e9SAndroid Build Coastguard Worker x = torch.ones(240000, device="cuda", dtype=dtype) * (0 + 1j) 824*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.prod(), 1) 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker def test_multinomial_ext(self): 827*da0073e9SAndroid Build Coastguard Worker # Test two corner cases from older PyTorch (Issue #4858) 828*da0073e9SAndroid Build Coastguard Worker freqs = torch.cuda.FloatTensor( 829*da0073e9SAndroid Build Coastguard Worker [ 830*da0073e9SAndroid Build Coastguard Worker 0.0, 831*da0073e9SAndroid Build Coastguard Worker 0.0, 832*da0073e9SAndroid Build Coastguard Worker 0.0, 833*da0073e9SAndroid Build Coastguard Worker 0.0, 834*da0073e9SAndroid Build Coastguard Worker 0.0, 835*da0073e9SAndroid Build Coastguard Worker 0.0, 836*da0073e9SAndroid Build Coastguard Worker 0.0, 837*da0073e9SAndroid Build Coastguard Worker 0.0, 838*da0073e9SAndroid Build Coastguard Worker 0.0, 839*da0073e9SAndroid Build Coastguard Worker 0.03178183361887932, 840*da0073e9SAndroid Build Coastguard Worker 0.027680952101945877, 841*da0073e9SAndroid Build Coastguard Worker 0.033176131546497345, 842*da0073e9SAndroid Build Coastguard Worker 0.046052902936935425, 843*da0073e9SAndroid Build Coastguard Worker 0.07742464542388916, 844*da0073e9SAndroid Build Coastguard Worker 0.11543981730937958, 845*da0073e9SAndroid Build Coastguard Worker 0.14148041605949402, 846*da0073e9SAndroid Build Coastguard Worker 0.15784293413162231, 847*da0073e9SAndroid Build Coastguard Worker 0.13180233538150787, 848*da0073e9SAndroid Build Coastguard Worker 0.08271478116512299, 849*da0073e9SAndroid Build Coastguard Worker 0.049702685326337814, 850*da0073e9SAndroid Build Coastguard Worker 0.027557924389839172, 851*da0073e9SAndroid Build Coastguard Worker 0.018125897273421288, 852*da0073e9SAndroid Build Coastguard Worker 0.011851548217236996, 853*da0073e9SAndroid Build Coastguard Worker 0.010252203792333603, 854*da0073e9SAndroid Build Coastguard Worker 0.007422595750540495, 855*da0073e9SAndroid Build Coastguard Worker 0.005372154992073774, 856*da0073e9SAndroid Build Coastguard Worker 0.0045109698548913, 857*da0073e9SAndroid Build Coastguard Worker 0.0036087757907807827, 858*da0073e9SAndroid Build Coastguard Worker 0.0035267581697553396, 859*da0073e9SAndroid Build Coastguard Worker 0.0018864056328311563, 860*da0073e9SAndroid Build Coastguard Worker 0.0024605290964245796, 861*da0073e9SAndroid Build Coastguard Worker 0.0022964938543736935, 862*da0073e9SAndroid Build Coastguard Worker 0.0018453967059031129, 863*da0073e9SAndroid Build Coastguard Worker 0.0010662291897460818, 864*da0073e9SAndroid Build Coastguard Worker 0.0009842115687206388, 865*da0073e9SAndroid Build Coastguard Worker 0.00045109697384759784, 866*da0073e9SAndroid Build Coastguard Worker 0.0007791675161570311, 867*da0073e9SAndroid Build Coastguard Worker 0.00020504408166743815, 868*da0073e9SAndroid Build Coastguard Worker 0.00020504408166743815, 869*da0073e9SAndroid Build Coastguard Worker 0.00020504408166743815, 870*da0073e9SAndroid Build Coastguard Worker 0.00012302644609007984, 871*da0073e9SAndroid Build Coastguard Worker 0.0, 872*da0073e9SAndroid Build Coastguard Worker 0.00012302644609007984, 873*da0073e9SAndroid Build Coastguard Worker 4.100881778867915e-05, 874*da0073e9SAndroid Build Coastguard Worker 0.0, 875*da0073e9SAndroid Build Coastguard Worker 0.0, 876*da0073e9SAndroid Build Coastguard Worker 0.0, 877*da0073e9SAndroid Build Coastguard Worker 0.0, 878*da0073e9SAndroid Build Coastguard Worker 0.0, 879*da0073e9SAndroid Build Coastguard Worker 0.0, 880*da0073e9SAndroid Build Coastguard Worker ] 881*da0073e9SAndroid Build Coastguard Worker ) 882*da0073e9SAndroid Build Coastguard Worker 883*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(11042) 884*da0073e9SAndroid Build Coastguard Worker sample = torch.multinomial(freqs, 1000, True) 885*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(freqs[sample].min(), 0) 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker p = torch.zeros(3421, 2, device="cuda", dtype=torch.float) 888*da0073e9SAndroid Build Coastguard Worker p[:, 1] = 1 889*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(5214) 890*da0073e9SAndroid Build Coastguard Worker r = torch.multinomial(p, 1) 891*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(r.min().item(), 0) 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker # test corner case from Issue #13867 894*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(33) 895*da0073e9SAndroid Build Coastguard Worker probs = torch.randn(1000000, device="cuda").clamp(min=0) * 3e-5 896*da0073e9SAndroid Build Coastguard Worker samples = probs.multinomial(1000000, replacement=True) 897*da0073e9SAndroid Build Coastguard Worker self.assertGreater(probs[samples].min().item(), 0) 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker def _spawn_test_multinomial_invalid_probs_cuda(self, probs): 900*da0073e9SAndroid Build Coastguard Worker import subprocess 901*da0073e9SAndroid Build Coastguard Worker 902*da0073e9SAndroid Build Coastguard Worker try: 903*da0073e9SAndroid Build Coastguard Worker p = subprocess.Popen( 904*da0073e9SAndroid Build Coastguard Worker [ 905*da0073e9SAndroid Build Coastguard Worker sys.executable, 906*da0073e9SAndroid Build Coastguard Worker "-c", 907*da0073e9SAndroid Build Coastguard Worker f"""\ 908*da0073e9SAndroid Build Coastguard Workerimport sys 909*da0073e9SAndroid Build Coastguard Workerimport torch 910*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan 911*da0073e9SAndroid Build Coastguard Workertry: 912*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(devices=[0]): 913*da0073e9SAndroid Build Coastguard Worker torch.multinomial(torch.tensor({probs}).to('cuda'), 2, replacement=True) 914*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 915*da0073e9SAndroid Build Coastguard Worker sys.exit(-1) # Should not be reached 916*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError as e: 917*da0073e9SAndroid Build Coastguard Worker sys.exit(-2) 918*da0073e9SAndroid Build Coastguard Worker""", 919*da0073e9SAndroid Build Coastguard Worker ], 920*da0073e9SAndroid Build Coastguard Worker stdout=subprocess.PIPE, 921*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.PIPE, 922*da0073e9SAndroid Build Coastguard Worker universal_newlines=True, 923*da0073e9SAndroid Build Coastguard Worker ) 924*da0073e9SAndroid Build Coastguard Worker out, err = p.communicate(timeout=10) 925*da0073e9SAndroid Build Coastguard Worker p.wait(timeout=10) 926*da0073e9SAndroid Build Coastguard Worker except subprocess.TimeoutExpired as e: 927*da0073e9SAndroid Build Coastguard Worker p.kill() 928*da0073e9SAndroid Build Coastguard Worker out, err = p.communicate() 929*da0073e9SAndroid Build Coastguard Worker expected_messages = [ 930*da0073e9SAndroid Build Coastguard Worker "device-side assert triggered", # CUDA 931*da0073e9SAndroid Build Coastguard Worker "Assertion", # CUDA 932*da0073e9SAndroid Build Coastguard Worker "HSA_STATUS_ERROR_EXCEPTION", # ROCm 933*da0073e9SAndroid Build Coastguard Worker "Device-side assertion", # ROCm 934*da0073e9SAndroid Build Coastguard Worker ] 935*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any(msg in out or msg in err for msg in expected_messages)) 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Worker @slowTest 938*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts") 939*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 940*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 941*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 942*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 943*da0073e9SAndroid Build Coastguard Worker ) 944*da0073e9SAndroid Build Coastguard Worker def test_multinomial_invalid_probs_cuda(self): 945*da0073e9SAndroid Build Coastguard Worker self._spawn_test_multinomial_invalid_probs_cuda([1.0, -1.0, 1.0]) 946*da0073e9SAndroid Build Coastguard Worker self._spawn_test_multinomial_invalid_probs_cuda([1.0, inf, 1.0]) 947*da0073e9SAndroid Build Coastguard Worker self._spawn_test_multinomial_invalid_probs_cuda([1.0, -inf, 1.0]) 948*da0073e9SAndroid Build Coastguard Worker self._spawn_test_multinomial_invalid_probs_cuda([1.0, 1.0, nan]) 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker @staticmethod 951*da0073e9SAndroid Build Coastguard Worker def _mute_init(): 952*da0073e9SAndroid Build Coastguard Worker os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno()) 953*da0073e9SAndroid Build Coastguard Worker 954*da0073e9SAndroid Build Coastguard Worker def _spawn_method(self, method, arg): 955*da0073e9SAndroid Build Coastguard Worker ctx = torch.multiprocessing.get_context("spawn") 956*da0073e9SAndroid Build Coastguard Worker with ctx.Pool(1, initializer=self._mute_init) as pool: 957*da0073e9SAndroid Build Coastguard Worker errors = pool.map(method, [arg]) 958*da0073e9SAndroid Build Coastguard Worker for e in errors: 959*da0073e9SAndroid Build Coastguard Worker if "device-side assert triggered" not in str(e): 960*da0073e9SAndroid Build Coastguard Worker self.fail(e) 961*da0073e9SAndroid Build Coastguard Worker 962*da0073e9SAndroid Build Coastguard Worker @staticmethod 963*da0073e9SAndroid Build Coastguard Worker def _test_index_bounds_cuda(idx): 964*da0073e9SAndroid Build Coastguard Worker x = torch.arange(10, device="cuda") 965*da0073e9SAndroid Build Coastguard Worker try: 966*da0073e9SAndroid Build Coastguard Worker y = x[torch.tensor([idx])] 967*da0073e9SAndroid Build Coastguard Worker return f"x[torch.tensor([{idx})]={y}" 968*da0073e9SAndroid Build Coastguard Worker except RuntimeError as err: 969*da0073e9SAndroid Build Coastguard Worker return err 970*da0073e9SAndroid Build Coastguard Worker 971*da0073e9SAndroid Build Coastguard Worker @slowTest 972*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 973*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 974*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 975*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 976*da0073e9SAndroid Build Coastguard Worker ) 977*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 978*da0073e9SAndroid Build Coastguard Worker def test_index_out_of_bounds_exception_cuda(self): 979*da0073e9SAndroid Build Coastguard Worker test_method = TestCuda._test_index_bounds_cuda 980*da0073e9SAndroid Build Coastguard Worker # Test in-bound access works fine 981*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 982*da0073e9SAndroid Build Coastguard Worker test_method(1), "x[torch.tensor([1)]=tensor([1], device='cuda:0')" 983*da0073e9SAndroid Build Coastguard Worker ) 984*da0073e9SAndroid Build Coastguard Worker # Test that indexing out of bounds causes assert 985*da0073e9SAndroid Build Coastguard Worker self._spawn_method(test_method, 11) 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker @slowTest 988*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") 989*da0073e9SAndroid Build Coastguard Worker @serialTest() 990*da0073e9SAndroid Build Coastguard Worker def test_huge_index(self): 991*da0073e9SAndroid Build Coastguard Worker src = torch.empty(15000000, 45, device="cuda", dtype=torch.long).random_( 992*da0073e9SAndroid Build Coastguard Worker 0, 2**22 993*da0073e9SAndroid Build Coastguard Worker ) 994*da0073e9SAndroid Build Coastguard Worker idx = torch.randperm(src.shape[0], device="cuda") 995*da0073e9SAndroid Build Coastguard Worker res = src[idx] 996*da0073e9SAndroid Build Coastguard Worker res_cpu = src.cpu()[idx.cpu()] 997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.cpu(), res_cpu) 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker def test_randint_randomness_for_large_range(self) -> None: 1000*da0073e9SAndroid Build Coastguard Worker # For large ranges, randint generation is slightly different. This lead to a subtle bug where some Philox 1001*da0073e9SAndroid Build Coastguard Worker # offsets were not calculated correctly, resulting in reused random states. 1002*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/125224 1003*da0073e9SAndroid Build Coastguard Worker size = 1_000_000 1004*da0073e9SAndroid Build Coastguard Worker high = 6_000_000_000 # Keep this above 2**32 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker def run(dev: torch.device) -> int: 1007*da0073e9SAndroid Build Coastguard Worker # Measure how many unique numbers are generated in 2 consecutive calls to randint. If random states are 1008*da0073e9SAndroid Build Coastguard Worker # reused, this will yield fewer unique numbers. 1009*da0073e9SAndroid Build Coastguard Worker gen = torch.Generator(device=dev) 1010*da0073e9SAndroid Build Coastguard Worker gen.manual_seed(0) 1011*da0073e9SAndroid Build Coastguard Worker t1 = torch.randint( 1012*da0073e9SAndroid Build Coastguard Worker 0, high, [size], device=dev, generator=gen, dtype=torch.int64 1013*da0073e9SAndroid Build Coastguard Worker ) 1014*da0073e9SAndroid Build Coastguard Worker t2 = torch.randint( 1015*da0073e9SAndroid Build Coastguard Worker 0, high, [size], device=dev, generator=gen, dtype=torch.int64 1016*da0073e9SAndroid Build Coastguard Worker ) 1017*da0073e9SAndroid Build Coastguard Worker return torch.stack([t1, t2]).unique().shape[0] 1018*da0073e9SAndroid Build Coastguard Worker 1019*da0073e9SAndroid Build Coastguard Worker # Use CPU as reference. The results should not deviate too much. 1020*da0073e9SAndroid Build Coastguard Worker assert abs(run(torch.device("cuda")) - run(torch.device("cpu"))) < 10_000 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float32, torch.double]) 1023*da0073e9SAndroid Build Coastguard Worker def test_random_no_reused_random_states(self, dtype: torch.dtype) -> None: 1024*da0073e9SAndroid Build Coastguard Worker # Test if random states do not overlap between consecutive rand/randn calls. 1025*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/125224 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker def run(func, dev: torch.device, dtype: torch.dtype) -> int: 1028*da0073e9SAndroid Build Coastguard Worker # Measure how many unique numbers are generated in 2 consecutive calls. If random states are 1029*da0073e9SAndroid Build Coastguard Worker # reused, this will yield fewer unique numbers. 1030*da0073e9SAndroid Build Coastguard Worker size = 1000000 1031*da0073e9SAndroid Build Coastguard Worker gen = torch.Generator(device=dev) 1032*da0073e9SAndroid Build Coastguard Worker gen.manual_seed(0) 1033*da0073e9SAndroid Build Coastguard Worker t1 = func((size,), device=dev, generator=gen, dtype=dtype) 1034*da0073e9SAndroid Build Coastguard Worker t2 = func((size,), device=dev, generator=gen, dtype=dtype) 1035*da0073e9SAndroid Build Coastguard Worker return torch.stack([t1, t2]).unique().shape[0] 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker # Use CPU as reference. The results should not deviate too much. 1038*da0073e9SAndroid Build Coastguard Worker for func in [torch.rand, torch.randn]: 1039*da0073e9SAndroid Build Coastguard Worker deviation = abs( 1040*da0073e9SAndroid Build Coastguard Worker run(func, torch.device("cuda"), dtype) 1041*da0073e9SAndroid Build Coastguard Worker - run(func, torch.device("cpu"), dtype) 1042*da0073e9SAndroid Build Coastguard Worker ) 1043*da0073e9SAndroid Build Coastguard Worker assert deviation < 50_000, deviation 1044*da0073e9SAndroid Build Coastguard Worker 1045*da0073e9SAndroid Build Coastguard Worker def test_min_max_inits(self): 1046*da0073e9SAndroid Build Coastguard Worker # Testing if THC_reduceAll received the correct index initialization. 1047*da0073e9SAndroid Build Coastguard Worker # This affects the result of THC_reduceAll operations at extreme values 1048*da0073e9SAndroid Build Coastguard Worker x = torch.cuda.ByteTensor([0]) 1049*da0073e9SAndroid Build Coastguard Worker y = torch.cuda.ByteTensor([255]) 1050*da0073e9SAndroid Build Coastguard Worker expected = torch.cuda.LongTensor([0])[0] 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Worker _, v = x.max(dim=0) 1053*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v, expected) 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard Worker _, v = y.min(dim=0) 1056*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v, expected) 1057*da0073e9SAndroid Build Coastguard Worker 1058*da0073e9SAndroid Build Coastguard Worker def test_nvtx(self): 1059*da0073e9SAndroid Build Coastguard Worker # Just making sure we can see the symbols 1060*da0073e9SAndroid Build Coastguard Worker torch.cuda.nvtx.range_push("foo") 1061*da0073e9SAndroid Build Coastguard Worker torch.cuda.nvtx.mark("bar") 1062*da0073e9SAndroid Build Coastguard Worker torch.cuda.nvtx.range_pop() 1063*da0073e9SAndroid Build Coastguard Worker range_handle = torch.cuda.nvtx.range_start("range_start") 1064*da0073e9SAndroid Build Coastguard Worker torch.cuda.nvtx.range_end(range_handle) 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker def test_bincount_ext(self): 1067*da0073e9SAndroid Build Coastguard Worker # ensure CUDA code coverage 1068*da0073e9SAndroid Build Coastguard Worker input_size = (100000,) 1069*da0073e9SAndroid Build Coastguard Worker w = torch.randn(input_size, dtype=torch.double, device="cuda") 1070*da0073e9SAndroid Build Coastguard Worker w_cpu = w.cpu() 1071*da0073e9SAndroid Build Coastguard Worker # test shared memory impl 1072*da0073e9SAndroid Build Coastguard Worker t = torch.randint(50, input_size, dtype=torch.int8, device="cuda") 1073*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.cpu().bincount(), t.bincount()) 1074*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) 1075*da0073e9SAndroid Build Coastguard Worker # test global memory impl 1076*da0073e9SAndroid Build Coastguard Worker # see `CUDAHistogramMemoryType` in SummaryOps.cu 1077*da0073e9SAndroid Build Coastguard Worker # 50000 * sizeof(int64_t) == 390 KiB, which should exceed smem of any known GPU 1078*da0073e9SAndroid Build Coastguard Worker t = torch.randint(50000, input_size, dtype=torch.int64, device="cuda") 1079*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.cpu().bincount(), t.bincount()) 1080*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) 1081*da0073e9SAndroid Build Coastguard Worker 1082*da0073e9SAndroid Build Coastguard Worker t = torch.zeros([10], dtype=torch.int32, device="cuda") 1083*da0073e9SAndroid Build Coastguard Worker # 35488 * 65536 as int32 would cause overflow to negative value 1084*da0073e9SAndroid Build Coastguard Worker # giving negative bin offset 1085*da0073e9SAndroid Build Coastguard Worker t[0] = 35488 1086*da0073e9SAndroid Build Coastguard Worker counted = t.bincount(minlength=65536) 1087*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sum(counted), 10) 1088*da0073e9SAndroid Build Coastguard Worker 1089*da0073e9SAndroid Build Coastguard Worker def test_tiny_half_norm_(self): 1090*da0073e9SAndroid Build Coastguard Worker a = torch.arange(25).cuda().float() 1091*da0073e9SAndroid Build Coastguard Worker a /= 100000000 1092*da0073e9SAndroid Build Coastguard Worker b = a.half() 1093*da0073e9SAndroid Build Coastguard Worker self.assertGreater(b.norm().item(), 0) 1094*da0073e9SAndroid Build Coastguard Worker 1095*da0073e9SAndroid Build Coastguard Worker def test_norm_type_conversion(self): 1096*da0073e9SAndroid Build Coastguard Worker a = torch.ones(65536).cuda().half() 1097*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536) 1098*da0073e9SAndroid Build Coastguard Worker 1099*da0073e9SAndroid Build Coastguard Worker def test_cuda_memory_leak_detection_propagates_errors(self): 1100*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1101*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"The size of tensor a \(3\) must match" 1102*da0073e9SAndroid Build Coastguard Worker ): 1103*da0073e9SAndroid Build Coastguard Worker with self.assertLeaksNoCudaTensors(): 1104*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 1, device="cuda") 1105*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 1, device="cuda") 1106*da0073e9SAndroid Build Coastguard Worker z = x + y 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MEDIUM_TENSOR, "not enough memory") 1109*da0073e9SAndroid Build Coastguard Worker @serialTest() 1110*da0073e9SAndroid Build Coastguard Worker def test_cuda_kernel_loop_overflow(self): 1111*da0073e9SAndroid Build Coastguard Worker # Issue #24309: In extreme cases, the loop variable could overflow and continue 1112*da0073e9SAndroid Build Coastguard Worker # the kernel loop with a negative index, causing a RuntimeError (invalid write): 1113*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1, 1, 2**30 + 1, dtype=torch.float16, device="cuda") 1114*da0073e9SAndroid Build Coastguard Worker expected = x[0, 0, 0, 2**30] 1115*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.avg_pool2d(x, kernel_size=1) 1116*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y[0, 0, 0, 2**30], expected) 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") 1120*da0073e9SAndroid Build Coastguard Worker @gcIfJetson 1121*da0073e9SAndroid Build Coastguard Worker @serialTest() 1122*da0073e9SAndroid Build Coastguard Worker def test_cuda_kernel_loop_overflow_large(self): 1123*da0073e9SAndroid Build Coastguard Worker # Make sure input.numel() > INT_MAX is handled: 1124*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1, 1, 2**31, dtype=torch.float16, device="cuda") 1125*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "integer out of range"): 1126*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.avg_pool2d(x, kernel_size=1) 1127*da0073e9SAndroid Build Coastguard Worker 1128*da0073e9SAndroid Build Coastguard Worker # Issue #24309: In extreme cases, the loop variable could overflow and continue 1129*da0073e9SAndroid Build Coastguard Worker # the kernel loop with a negative index, causing a RuntimeError (invalid write): 1130*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1, 1, 2**31 - 1, dtype=torch.float16, device="cuda") 1131*da0073e9SAndroid Build Coastguard Worker expected = x[0, 0, 0, 2**31 - 2] 1132*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.avg_pool2d(x, kernel_size=1) 1133*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y[0, 0, 0, 2**31 - 2], expected) 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker # this might create a reference cycle on self... 1137*da0073e9SAndroid Build Coastguard Worker def _make_multiply_in_stream(self): 1138*da0073e9SAndroid Build Coastguard Worker class MultiplyInStream(torch.autograd.Function): 1139*da0073e9SAndroid Build Coastguard Worker @staticmethod 1140*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, val): 1141*da0073e9SAndroid Build Coastguard Worker ctx.val = val 1142*da0073e9SAndroid Build Coastguard Worker ctx.stream = torch.cuda.current_stream() 1143*da0073e9SAndroid Build Coastguard Worker return x * val 1144*da0073e9SAndroid Build Coastguard Worker 1145*da0073e9SAndroid Build Coastguard Worker @staticmethod 1146*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 1147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), ctx.stream) 1148*da0073e9SAndroid Build Coastguard Worker # delays the operation in the background stream 1149*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(1000 * 5000) 1150*da0073e9SAndroid Build Coastguard Worker return grad * ctx.val, None 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker return MultiplyInStream 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker @skipCUDANonDefaultStreamIf(True) 1155*da0073e9SAndroid Build Coastguard Worker def test_streaming_backwards_sync(self): 1156*da0073e9SAndroid Build Coastguard Worker default_stream = torch.cuda.current_stream() 1157*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 1158*da0073e9SAndroid Build Coastguard Worker 1159*da0073e9SAndroid Build Coastguard Worker MultiplyInStream = self._make_multiply_in_stream() 1160*da0073e9SAndroid Build Coastguard Worker 1161*da0073e9SAndroid Build Coastguard Worker # Tests using grads outside the backward() stream context 1162*da0073e9SAndroid Build Coastguard Worker # See "Stream semantics of backward passes" on https://pytorch.org/docs/stable/notes/cuda.html 1163*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, device="cuda", requires_grad=True) 1164*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 1165*da0073e9SAndroid Build Coastguard Worker stream.wait_stream(default_stream) 1166*da0073e9SAndroid Build Coastguard Worker output = MultiplyInStream.apply(x, 2) 1167*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 1168*da0073e9SAndroid Build Coastguard Worker # sync needed 1169*da0073e9SAndroid Build Coastguard Worker default_stream.wait_stream(stream) 1170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones_like(x) * 2) 1171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), default_stream) 1172*da0073e9SAndroid Build Coastguard Worker 1173*da0073e9SAndroid Build Coastguard Worker # Tests that using grads in the same stream context as backward() 1174*da0073e9SAndroid Build Coastguard Worker # is safe regardless what streams bwd ops ran on 1175*da0073e9SAndroid Build Coastguard Worker bwd_ambient_stream = torch.cuda.Stream() 1176*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, device="cuda", requires_grad=True) 1177*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 1178*da0073e9SAndroid Build Coastguard Worker stream.wait_stream(default_stream) 1179*da0073e9SAndroid Build Coastguard Worker output = MultiplyInStream.apply(x, 3) 1180*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(bwd_ambient_stream): 1181*da0073e9SAndroid Build Coastguard Worker bwd_ambient_stream.wait_stream(stream) 1182*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 1183*da0073e9SAndroid Build Coastguard Worker # x was first used on "stream" so its AccumulateGrad leaf should run on "stream". 1184*da0073e9SAndroid Build Coastguard Worker # The end of backward() should have synced "bwd_ambient_stream" with "stream" 1185*da0073e9SAndroid Build Coastguard Worker # so it should be safe to use x.grad here without any syncs. 1186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones_like(x) * 3) 1187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), bwd_ambient_stream) 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Worker # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190 1190*da0073e9SAndroid Build Coastguard Worker @skipIfRocm(msg="flakey on ROCm https://github.com/pytorch/pytorch/issues/53190") 1191*da0073e9SAndroid Build Coastguard Worker def test_streaming_backwards_multiple_streams(self): 1192*da0073e9SAndroid Build Coastguard Worker MultiplyInStream = self._make_multiply_in_stream() 1193*da0073e9SAndroid Build Coastguard Worker 1194*da0073e9SAndroid Build Coastguard Worker class StreamModel(torch.nn.Module): 1195*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1196*da0073e9SAndroid Build Coastguard Worker super().__init__() 1197*da0073e9SAndroid Build Coastguard Worker self.event = torch.cuda.Event() 1198*da0073e9SAndroid Build Coastguard Worker self.stream0 = torch.cuda.Stream() 1199*da0073e9SAndroid Build Coastguard Worker self.stream1 = torch.cuda.Stream() 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker def forward(self, x, x_first_use_on_ambient): 1202*da0073e9SAndroid Build Coastguard Worker if x_first_use_on_ambient: 1203*da0073e9SAndroid Build Coastguard Worker x0 = x.clone() 1204*da0073e9SAndroid Build Coastguard Worker self.stream0.wait_stream(torch.cuda.current_stream()) 1205*da0073e9SAndroid Build Coastguard Worker self.stream1.wait_stream(torch.cuda.current_stream()) 1206*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(self.stream0): 1207*da0073e9SAndroid Build Coastguard Worker if not x_first_use_on_ambient: 1208*da0073e9SAndroid Build Coastguard Worker x0 = x.clone() 1209*da0073e9SAndroid Build Coastguard Worker y0 = MultiplyInStream.apply(x0, 2) 1210*da0073e9SAndroid Build Coastguard Worker self.event.record(stream=torch.cuda.current_stream()) 1211*da0073e9SAndroid Build Coastguard Worker 1212*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(self.stream1): 1213*da0073e9SAndroid Build Coastguard Worker y1 = MultiplyInStream.apply(x, 3) 1214*da0073e9SAndroid Build Coastguard Worker self.stream1.wait_event(self.event) 1215*da0073e9SAndroid Build Coastguard Worker return y0 + y1 1216*da0073e9SAndroid Build Coastguard Worker 1217*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Worker for x_first_use_on_ambient in (True, False): 1220*da0073e9SAndroid Build Coastguard Worker # the out_of_place=False, iters=1 case stresses if proper syncs are inserted 1221*da0073e9SAndroid Build Coastguard Worker # when grads are initially None and stolen by backward ops. 1222*da0073e9SAndroid Build Coastguard Worker for out_of_place, iters in ((True, 1), (False, 1), (False, 5)): 1223*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 1224*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, device="cuda", requires_grad=True) 1225*da0073e9SAndroid Build Coastguard Worker model = StreamModel().cuda() 1226*da0073e9SAndroid Build Coastguard Worker x.register_hook( 1227*da0073e9SAndroid Build Coastguard Worker lambda grad: self.assertEqual( 1228*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream(), 1229*da0073e9SAndroid Build Coastguard Worker stream if x_first_use_on_ambient else model.stream0, 1230*da0073e9SAndroid Build Coastguard Worker ) 1231*da0073e9SAndroid Build Coastguard Worker ) 1232*da0073e9SAndroid Build Coastguard Worker for p in model.parameters(): 1233*da0073e9SAndroid Build Coastguard Worker self.assertTrue(p.grad is None) 1234*da0073e9SAndroid Build Coastguard Worker for i in range(iters): 1235*da0073e9SAndroid Build Coastguard Worker loss = model(x, x_first_use_on_ambient).sum() 1236*da0073e9SAndroid Build Coastguard Worker if out_of_place: 1237*da0073e9SAndroid Build Coastguard Worker x_grad = torch.autograd.grad((loss,), (x,))[0] 1238*da0073e9SAndroid Build Coastguard Worker else: 1239*da0073e9SAndroid Build Coastguard Worker loss.backward() 1240*da0073e9SAndroid Build Coastguard Worker # See "Stream semantics of backward passes" on https://pytorch.org/docs/stable/notes/cuda.html 1241*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(stream) 1242*da0073e9SAndroid Build Coastguard Worker 1243*da0073e9SAndroid Build Coastguard Worker if out_of_place: 1244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_grad, torch.ones_like(x) * 5 * iters) 1245*da0073e9SAndroid Build Coastguard Worker else: 1246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones_like(x) * 5 * iters) 1247*da0073e9SAndroid Build Coastguard Worker 1248*da0073e9SAndroid Build Coastguard Worker def test_streaming_backwards_sync_graph_root(self): 1249*da0073e9SAndroid Build Coastguard Worker # This function tests if bwd ops running on a side stream properly sync with the GraphRoot. 1250*da0073e9SAndroid Build Coastguard Worker # The potential bug it targets is a race condition. The test uses multiple trials and 1251*da0073e9SAndroid Build Coastguard Worker # torch.cuda._sleep such that if the race condition exists, the test will almost certainly fail, 1252*da0073e9SAndroid Build Coastguard Worker # but there's a chance it may spuriously pass. Passing does not guarantee the backend is bug-free, 1253*da0073e9SAndroid Build Coastguard Worker # but failure does guarantee there is a bug. 1254*da0073e9SAndroid Build Coastguard Worker fwd_bwd_op_stream = torch.cuda.Stream() 1255*da0073e9SAndroid Build Coastguard Worker bwd_ambient_stream = torch.cuda.Stream() 1256*da0073e9SAndroid Build Coastguard Worker # We need these streams to be different otherwise the test is meaningless. 1257*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fwd_bwd_op_stream != bwd_ambient_stream) 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker size = int(1e3) 1260*da0073e9SAndroid Build Coastguard Worker 1261*da0073e9SAndroid Build Coastguard Worker a = torch.full((size,), 2.0, device="cuda", requires_grad=True) 1262*da0073e9SAndroid Build Coastguard Worker b = torch.full((size,), 3.0, device="cuda", requires_grad=True) 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker # I don't think we need any manual record_streams below. 1265*da0073e9SAndroid Build Coastguard Worker # a and b remain in scope for the entire test. 1266*da0073e9SAndroid Build Coastguard Worker # c and grad remain in scope for each iteration, and there's a full sync between iterations. 1267*da0073e9SAndroid Build Coastguard Worker for trial in range(5): 1268*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1269*da0073e9SAndroid Build Coastguard Worker a.grad = b.grad = None 1270*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(fwd_bwd_op_stream): 1271*da0073e9SAndroid Build Coastguard Worker c = a * b 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(bwd_ambient_stream): 1274*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1275*da0073e9SAndroid Build Coastguard Worker # Long-running dummy kernel on bwd_ambient_stream delays filling of grad 1276*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(int(50 * get_cycles_per_ms())) 1277*da0073e9SAndroid Build Coastguard Worker # Fills grad on bwd_ambient_stream 1278*da0073e9SAndroid Build Coastguard Worker grad = torch.full((size,), float(trial + 1), device="cuda") 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Worker # Bwd ops still run on fwd_bwd_ops_stream, so the following will likely fail if 1281*da0073e9SAndroid Build Coastguard Worker # bwd ops don't sync with bwd_ambient_stream before consuming grad. 1282*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(tensors=c, grad_tensors=grad) 1283*da0073e9SAndroid Build Coastguard Worker 1284*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/47028 1285*da0073e9SAndroid Build Coastguard Worker # assertEquals below run on bwd_ambient_stream, so this test may also fail 1286*da0073e9SAndroid Build Coastguard Worker # if backward() fails to sync with bwd_ambient_stream at the end. 1287*da0073e9SAndroid Build Coastguard Worker # Synchronizing here works around the issue until a proper fix can be made. 1288*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1289*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, grad * b) 1291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, grad * a) 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Worker def test_streaming_backwards_callback(self): 1294*da0073e9SAndroid Build Coastguard Worker # Tests if autograd callbacks sync properly with respect to leaf streams and 1295*da0073e9SAndroid Build Coastguard Worker # the user-facing stream surrounding backward(). If it fails, first suspect is 1296*da0073e9SAndroid Build Coastguard Worker # sync logic where "final_callbacks_" are called in torch/csrc/autograd/engine.cpp 1297*da0073e9SAndroid Build Coastguard Worker MultiplyInStream = self._make_multiply_in_stream() 1298*da0073e9SAndroid Build Coastguard Worker 1299*da0073e9SAndroid Build Coastguard Worker size = int(1e3) 1300*da0073e9SAndroid Build Coastguard Worker a = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True) 1301*da0073e9SAndroid Build Coastguard Worker b = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True) 1302*da0073e9SAndroid Build Coastguard Worker 1303*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.Stream() 1304*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream() 1305*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.Stream() 1306*da0073e9SAndroid Build Coastguard Worker 1307*da0073e9SAndroid Build Coastguard Worker stash = [] 1308*da0073e9SAndroid Build Coastguard Worker 1309*da0073e9SAndroid Build Coastguard Worker # sets up a nontrivial structure of leaf streams 1310*da0073e9SAndroid Build Coastguard Worker s0.wait_stream(torch.cuda.current_stream()) 1311*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s0): 1312*da0073e9SAndroid Build Coastguard Worker c = MultiplyInStream.apply(a, 2) 1313*da0073e9SAndroid Build Coastguard Worker 1314*da0073e9SAndroid Build Coastguard Worker s1.wait_stream(torch.cuda.current_stream()) 1315*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 1316*da0073e9SAndroid Build Coastguard Worker d = MultiplyInStream.apply(b, 3) 1317*da0073e9SAndroid Build Coastguard Worker s1.wait_stream(s0) 1318*da0073e9SAndroid Build Coastguard Worker e = c * d 1319*da0073e9SAndroid Build Coastguard Worker 1320*da0073e9SAndroid Build Coastguard Worker def clone_leaf_grads(): 1321*da0073e9SAndroid Build Coastguard Worker stash.append(a.grad.clone()) 1322*da0073e9SAndroid Build Coastguard Worker stash.append(b.grad.clone()) 1323*da0073e9SAndroid Build Coastguard Worker 1324*da0073e9SAndroid Build Coastguard Worker # Use a hook on e to install the callback 1325*da0073e9SAndroid Build Coastguard Worker e.register_hook( 1326*da0073e9SAndroid Build Coastguard Worker lambda grad: torch.autograd.Variable._execution_engine.queue_callback( 1327*da0073e9SAndroid Build Coastguard Worker clone_leaf_grads 1328*da0073e9SAndroid Build Coastguard Worker ) 1329*da0073e9SAndroid Build Coastguard Worker ) 1330*da0073e9SAndroid Build Coastguard Worker 1331*da0073e9SAndroid Build Coastguard Worker s2.wait_stream(s1) 1332*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s2): 1333*da0073e9SAndroid Build Coastguard Worker e.sum().backward() 1334*da0073e9SAndroid Build Coastguard Worker # The autograd engine should sync s2 with all leaf streams then run the callback clone_leaf_grads on s2. 1335*da0073e9SAndroid Build Coastguard Worker # If those things happened properly, checking the values of the cloned grads on s2 should be safe: 1336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stash[0], torch.full_like(a, 6)) 1337*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stash[1], torch.full_like(a, 6)) 1338*da0073e9SAndroid Build Coastguard Worker 1339*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1340*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 1341*da0073e9SAndroid Build Coastguard Worker "In ROCm, kernel asserts are disabled due to performance overhead", 1342*da0073e9SAndroid Build Coastguard Worker ) 1343*da0073e9SAndroid Build Coastguard Worker def test_fixed_cuda_assert_async(self): 1344*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1345*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Boolean value of Tensor with no values is ambiguous" 1346*da0073e9SAndroid Build Coastguard Worker ): 1347*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor([], device="cuda")) 1348*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1349*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1350*da0073e9SAndroid Build Coastguard Worker "Boolean value of Tensor with more than one value is ambiguous", 1351*da0073e9SAndroid Build Coastguard Worker ): 1352*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor([0, 0], device="cuda")) 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(1, device="cuda")) 1355*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(0.1, device="cuda")) 1356*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(-0.1, device="cuda")) 1357*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(True, device="cuda")) 1358*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(0 + 0.1j, device="cuda")) 1359*da0073e9SAndroid Build Coastguard Worker 1360*da0073e9SAndroid Build Coastguard Worker fail_stmts = [ 1361*da0073e9SAndroid Build Coastguard Worker "torch._assert_async(torch.tensor(0, device='cuda'))", 1362*da0073e9SAndroid Build Coastguard Worker "torch._assert_async(torch.tensor(0.0, device='cuda'))", 1363*da0073e9SAndroid Build Coastguard Worker "torch._assert_async(torch.tensor(False, device='cuda'))", 1364*da0073e9SAndroid Build Coastguard Worker "torch._assert_async(torch.tensor(0 + 0j, device='cuda'))", 1365*da0073e9SAndroid Build Coastguard Worker ] 1366*da0073e9SAndroid Build Coastguard Worker 1367*da0073e9SAndroid Build Coastguard Worker import subprocess 1368*da0073e9SAndroid Build Coastguard Worker 1369*da0073e9SAndroid Build Coastguard Worker for stmt in fail_stmts: 1370*da0073e9SAndroid Build Coastguard Worker with self.subTest(stmt=stmt): 1371*da0073e9SAndroid Build Coastguard Worker r = subprocess.call( 1372*da0073e9SAndroid Build Coastguard Worker [ 1373*da0073e9SAndroid Build Coastguard Worker sys.executable, 1374*da0073e9SAndroid Build Coastguard Worker "-c", 1375*da0073e9SAndroid Build Coastguard Worker f"""\ 1376*da0073e9SAndroid Build Coastguard Workerimport torch 1377*da0073e9SAndroid Build Coastguard Worker 1378*da0073e9SAndroid Build Coastguard Worker{stmt} 1379*da0073e9SAndroid Build Coastguard Workertorch.cuda.synchronize() 1380*da0073e9SAndroid Build Coastguard Worker""", 1381*da0073e9SAndroid Build Coastguard Worker ] 1382*da0073e9SAndroid Build Coastguard Worker ) 1383*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r != 0) 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_CUDAMALLOCASYNC, "FAIL") 1386*da0073e9SAndroid Build Coastguard Worker def test_cublas_multiple_threads_same_device(self): 1387*da0073e9SAndroid Build Coastguard Worker # Note, these parameters should be very carefully tuned 1388*da0073e9SAndroid Build Coastguard Worker # Too small number makes it hard for the racing condition 1389*da0073e9SAndroid Build Coastguard Worker # to happen, while too large number sometimes cause hang 1390*da0073e9SAndroid Build Coastguard Worker size = 1024 1391*da0073e9SAndroid Build Coastguard Worker num_threads = 2 1392*da0073e9SAndroid Build Coastguard Worker trials = 3 1393*da0073e9SAndroid Build Coastguard Worker test_iters = 100 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker weight = torch.ones((size, size), device="cuda") 1396*da0073e9SAndroid Build Coastguard Worker results = {} 1397*da0073e9SAndroid Build Coastguard Worker barrier = threading.Barrier(num_threads) 1398*da0073e9SAndroid Build Coastguard Worker 1399*da0073e9SAndroid Build Coastguard Worker def _worker(t): 1400*da0073e9SAndroid Build Coastguard Worker my_stream = torch.cuda.Stream() 1401*da0073e9SAndroid Build Coastguard Worker # Hard sync so we don't need to worry about creating and using tensors 1402*da0073e9SAndroid Build Coastguard Worker # across streams or the fact that default streams are thread-local. 1403*da0073e9SAndroid Build Coastguard Worker # Those issues are not the target of this test. 1404*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1405*da0073e9SAndroid Build Coastguard Worker # Line up threads to increase likelihood of race conditions. 1406*da0073e9SAndroid Build Coastguard Worker barrier.wait() 1407*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(my_stream): 1408*da0073e9SAndroid Build Coastguard Worker for i in range(test_iters): 1409*da0073e9SAndroid Build Coastguard Worker # If all threads are sharing the same cublas handle, 1410*da0073e9SAndroid Build Coastguard Worker # the following sequence may occur: 1411*da0073e9SAndroid Build Coastguard Worker # thread 0 calls cublasSetStream() 1412*da0073e9SAndroid Build Coastguard Worker # thread 1 calls cublasSetStream() 1413*da0073e9SAndroid Build Coastguard Worker # thread 0 launches its raw gemm, which it thinks is in 1414*da0073e9SAndroid Build Coastguard Worker # its own stream, but is actually in thread 1's stream. 1415*da0073e9SAndroid Build Coastguard Worker # thread 0 enqueues its div_, which IS is its own stream, 1416*da0073e9SAndroid Build Coastguard Worker # but actually now races with its gemm. 1417*da0073e9SAndroid Build Coastguard Worker results[t] = torch.mm(results[t], weight) 1418*da0073e9SAndroid Build Coastguard Worker results[t].div_(float(size)) 1419*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1420*da0073e9SAndroid Build Coastguard Worker 1421*da0073e9SAndroid Build Coastguard Worker for _ in range(trials): 1422*da0073e9SAndroid Build Coastguard Worker for t in range(num_threads): 1423*da0073e9SAndroid Build Coastguard Worker results[t] = torch.ones((size, size), device="cuda") 1424*da0073e9SAndroid Build Coastguard Worker 1425*da0073e9SAndroid Build Coastguard Worker threads = [ 1426*da0073e9SAndroid Build Coastguard Worker threading.Thread(target=_worker, args=(t,)) for t in range(num_threads) 1427*da0073e9SAndroid Build Coastguard Worker ] 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker for thread in threads: 1430*da0073e9SAndroid Build Coastguard Worker thread.start() 1431*da0073e9SAndroid Build Coastguard Worker for thread in threads: 1432*da0073e9SAndroid Build Coastguard Worker thread.join() 1433*da0073e9SAndroid Build Coastguard Worker 1434*da0073e9SAndroid Build Coastguard Worker for t in range(num_threads): 1435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(results[t].sum().item(), size * size) 1436*da0073e9SAndroid Build Coastguard Worker 1437*da0073e9SAndroid Build Coastguard Worker # Test is flaky on Windows (https://github.com/pytorch/pytorch/issues/57401) 1438*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows (see issue 57401)") 1439*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 1440*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 1441*da0073e9SAndroid Build Coastguard Worker def test_cudnn_multiple_threads_same_device(self): 1442*da0073e9SAndroid Build Coastguard Worker # This function is intended to test the lazy creation and reuse of per-thread 1443*da0073e9SAndroid Build Coastguard Worker # cudnn handles on each device in aten/src/ATen/cudnn/Handles.cpp. 1444*da0073e9SAndroid Build Coastguard Worker # Failure here likely indicates something wrong with that logic. 1445*da0073e9SAndroid Build Coastguard Worker weight = torch.ones((1, 1, 2, 2), device="cuda") 1446*da0073e9SAndroid Build Coastguard Worker 1447*da0073e9SAndroid Build Coastguard Worker results = {} 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker num_threads = 2 1450*da0073e9SAndroid Build Coastguard Worker trials = 3 1451*da0073e9SAndroid Build Coastguard Worker test_iters = 1000 1452*da0073e9SAndroid Build Coastguard Worker barrier = threading.Barrier(num_threads) 1453*da0073e9SAndroid Build Coastguard Worker 1454*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=True): 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker def _worker(t): 1457*da0073e9SAndroid Build Coastguard Worker my_stream = torch.cuda.Stream() 1458*da0073e9SAndroid Build Coastguard Worker # Hard sync so we don't need to worry about creating and using tensors 1459*da0073e9SAndroid Build Coastguard Worker # across streams or the fact that default streams are thread-local. 1460*da0073e9SAndroid Build Coastguard Worker # Those issues are not the target of this test. 1461*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1462*da0073e9SAndroid Build Coastguard Worker # Line up threads to increase likelihood of race conditions. 1463*da0073e9SAndroid Build Coastguard Worker barrier.wait() 1464*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(my_stream): 1465*da0073e9SAndroid Build Coastguard Worker for _ in range(test_iters): 1466*da0073e9SAndroid Build Coastguard Worker # If all threads are sharing the same cudnn handle, 1467*da0073e9SAndroid Build Coastguard Worker # the following sequence may occur: 1468*da0073e9SAndroid Build Coastguard Worker # thread 0 calls setCuDNNStreamToCurrent() 1469*da0073e9SAndroid Build Coastguard Worker # thread 1 calls setCuDNNStreamToCurrent() 1470*da0073e9SAndroid Build Coastguard Worker # thread 0 launches its raw convolution, which it thinks is in 1471*da0073e9SAndroid Build Coastguard Worker # its own stream, but is actually in thread 1's stream. 1472*da0073e9SAndroid Build Coastguard Worker # thread 0 enqueues its div_, which IS is its own stream, 1473*da0073e9SAndroid Build Coastguard Worker # but now races with its convolution. 1474*da0073e9SAndroid Build Coastguard Worker results[t] = torch.nn.functional.conv2d( 1475*da0073e9SAndroid Build Coastguard Worker results[t], weight, padding=0 1476*da0073e9SAndroid Build Coastguard Worker ) 1477*da0073e9SAndroid Build Coastguard Worker results[t].div_(4.0) 1478*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker for _ in range(trials): 1481*da0073e9SAndroid Build Coastguard Worker for t in range(num_threads): 1482*da0073e9SAndroid Build Coastguard Worker results[t] = torch.ones((1, 1, 2048, 2048), device="cuda") 1483*da0073e9SAndroid Build Coastguard Worker 1484*da0073e9SAndroid Build Coastguard Worker threads = [ 1485*da0073e9SAndroid Build Coastguard Worker threading.Thread(target=_worker, args=(t,)) 1486*da0073e9SAndroid Build Coastguard Worker for t in range(num_threads) 1487*da0073e9SAndroid Build Coastguard Worker ] 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker for thread in threads: 1490*da0073e9SAndroid Build Coastguard Worker thread.start() 1491*da0073e9SAndroid Build Coastguard Worker for thread in threads: 1492*da0073e9SAndroid Build Coastguard Worker thread.join() 1493*da0073e9SAndroid Build Coastguard Worker 1494*da0073e9SAndroid Build Coastguard Worker for t in range(num_threads): 1495*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1496*da0073e9SAndroid Build Coastguard Worker results[t].sum().item(), 1497*da0073e9SAndroid Build Coastguard Worker (2048 - test_iters) * (2048 - test_iters), 1498*da0073e9SAndroid Build Coastguard Worker ) 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker def test_cusparse_multiple_threads_same_device(self): 1501*da0073e9SAndroid Build Coastguard Worker size = 1024 1502*da0073e9SAndroid Build Coastguard Worker num_threads = 2 1503*da0073e9SAndroid Build Coastguard Worker trials = 3 1504*da0073e9SAndroid Build Coastguard Worker test_iters = 500 1505*da0073e9SAndroid Build Coastguard Worker 1506*da0073e9SAndroid Build Coastguard Worker def ones_sparse(size): 1507*da0073e9SAndroid Build Coastguard Worker a = torch.arange(size, device="cuda") 1508*da0073e9SAndroid Build Coastguard Worker indices = torch.cartesian_prod(a, a).t() 1509*da0073e9SAndroid Build Coastguard Worker values = torch.ones(size * size, device="cuda") 1510*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(indices, values) 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Worker weight = ones_sparse(size) 1513*da0073e9SAndroid Build Coastguard Worker results = {} 1514*da0073e9SAndroid Build Coastguard Worker barrier = threading.Barrier(num_threads) 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker def _worker(t): 1517*da0073e9SAndroid Build Coastguard Worker my_stream = torch.cuda.Stream() 1518*da0073e9SAndroid Build Coastguard Worker # Hard sync so we don't need to worry about creating and using tensors 1519*da0073e9SAndroid Build Coastguard Worker # across streams or the fact that default streams are thread-local. 1520*da0073e9SAndroid Build Coastguard Worker # Those issues are not the target of this test. 1521*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1522*da0073e9SAndroid Build Coastguard Worker # Line up threads to increase likelihood of race conditions. 1523*da0073e9SAndroid Build Coastguard Worker barrier.wait() 1524*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(my_stream): 1525*da0073e9SAndroid Build Coastguard Worker for i in range(test_iters): 1526*da0073e9SAndroid Build Coastguard Worker # If all threads are sharing the same cublas handle, 1527*da0073e9SAndroid Build Coastguard Worker # the following sequence may occur: 1528*da0073e9SAndroid Build Coastguard Worker # thread 0 calls cublasSetStream() 1529*da0073e9SAndroid Build Coastguard Worker # thread 1 calls cublasSetStream() 1530*da0073e9SAndroid Build Coastguard Worker # thread 0 launches its raw gemm, which it thinks is in 1531*da0073e9SAndroid Build Coastguard Worker # its own stream, but is actually in thread 1's stream. 1532*da0073e9SAndroid Build Coastguard Worker # thread 0 enqueues its div_, which IS is its own stream, 1533*da0073e9SAndroid Build Coastguard Worker # but actually now races with its gemm. 1534*da0073e9SAndroid Build Coastguard Worker results[t] = weight.mm(results[t]) 1535*da0073e9SAndroid Build Coastguard Worker results[t].div_(float(size)) 1536*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1537*da0073e9SAndroid Build Coastguard Worker 1538*da0073e9SAndroid Build Coastguard Worker for _ in range(trials): 1539*da0073e9SAndroid Build Coastguard Worker for t in range(num_threads): 1540*da0073e9SAndroid Build Coastguard Worker results[t] = torch.ones((size, size), device="cuda") 1541*da0073e9SAndroid Build Coastguard Worker 1542*da0073e9SAndroid Build Coastguard Worker threads = [ 1543*da0073e9SAndroid Build Coastguard Worker threading.Thread(target=_worker, args=(t,)) for t in range(num_threads) 1544*da0073e9SAndroid Build Coastguard Worker ] 1545*da0073e9SAndroid Build Coastguard Worker 1546*da0073e9SAndroid Build Coastguard Worker for thread in threads: 1547*da0073e9SAndroid Build Coastguard Worker thread.start() 1548*da0073e9SAndroid Build Coastguard Worker for thread in threads: 1549*da0073e9SAndroid Build Coastguard Worker thread.join() 1550*da0073e9SAndroid Build Coastguard Worker 1551*da0073e9SAndroid Build Coastguard Worker for t in range(num_threads): 1552*da0073e9SAndroid Build Coastguard Worker self.assertEqual(results[t].sum().item(), size * size) 1553*da0073e9SAndroid Build Coastguard Worker 1554*da0073e9SAndroid Build Coastguard Worker @slowTest 1555*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") 1556*da0073e9SAndroid Build Coastguard Worker @serialTest() 1557*da0073e9SAndroid Build Coastguard Worker def test_max_large_axis(self): 1558*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(2**32, device="cuda", dtype=torch.int8) 1559*da0073e9SAndroid Build Coastguard Worker x[-1] = 1 1560*da0073e9SAndroid Build Coastguard Worker val, idx = x.max(0) 1561*da0073e9SAndroid Build Coastguard Worker self.assertEqual(val, 1) 1562*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx, x.shape[0] - 1) 1563*da0073e9SAndroid Build Coastguard Worker 1564*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 1565*da0073e9SAndroid Build Coastguard Worker def test_to_numpy(self): 1566*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy()) 1567*da0073e9SAndroid Build Coastguard Worker 1568*da0073e9SAndroid Build Coastguard Worker def test_graph_is_current_stream_capturing(self): 1569*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.cuda.is_current_stream_capturing()) 1570*da0073e9SAndroid Build Coastguard Worker 1571*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA and (not TEST_WITH_ROCM): 1572*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 1573*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 1574*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 1575*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.cuda.is_current_stream_capturing()) 1576*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 1577*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.cuda.is_current_stream_capturing()) 1578*da0073e9SAndroid Build Coastguard Worker g.capture_end() 1579*da0073e9SAndroid Build Coastguard Worker 1580*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1581*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1582*da0073e9SAndroid Build Coastguard Worker ) 1583*da0073e9SAndroid Build Coastguard Worker def test_graph_capture_simple(self): 1584*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 1585*da0073e9SAndroid Build Coastguard Worker 1586*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 1587*da0073e9SAndroid Build Coastguard Worker a = torch.full((1000,), 1, device="cuda") 1588*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 1589*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 1590*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 1591*da0073e9SAndroid Build Coastguard Worker b = a 1592*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 1593*da0073e9SAndroid Build Coastguard Worker b = b + 1 1594*da0073e9SAndroid Build Coastguard Worker g.capture_end() 1595*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s) 1596*da0073e9SAndroid Build Coastguard Worker 1597*da0073e9SAndroid Build Coastguard Worker g.replay() 1598*da0073e9SAndroid Build Coastguard Worker 1599*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.sum().item() == 11000.0) 1600*da0073e9SAndroid Build Coastguard Worker 1601*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1602*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1603*da0073e9SAndroid Build Coastguard Worker ) 1604*da0073e9SAndroid Build Coastguard Worker def test_graphsafe_set_get_rng_state(self): 1605*da0073e9SAndroid Build Coastguard Worker # Define a function to create generator states, with optional graph registration 1606*da0073e9SAndroid Build Coastguard Worker def create_states(generator): 1607*da0073e9SAndroid Build Coastguard Worker """Initializes generator states and registers them with a CUDA graph if provided.""" 1608*da0073e9SAndroid Build Coastguard Worker # Ensure the CUDA generator is initialized 1609*da0073e9SAndroid Build Coastguard Worker torch.rand(1, device="cuda") 1610*da0073e9SAndroid Build Coastguard Worker generator.manual_seed(0) 1611*da0073e9SAndroid Build Coastguard Worker 1612*da0073e9SAndroid Build Coastguard Worker # Save the current state of the generator 1613*da0073e9SAndroid Build Coastguard Worker old_state = generator.graphsafe_get_state() 1614*da0073e9SAndroid Build Coastguard Worker # Create and save a cloned state of the generator 1615*da0073e9SAndroid Build Coastguard Worker new_state = generator.clone_state() 1616*da0073e9SAndroid Build Coastguard Worker # Return the original generator and its two states 1617*da0073e9SAndroid Build Coastguard Worker return generator, old_state, new_state 1618*da0073e9SAndroid Build Coastguard Worker 1619*da0073e9SAndroid Build Coastguard Worker def register_states_to_graph(generator_state, graph): 1620*da0073e9SAndroid Build Coastguard Worker generator, old_state, new_state = generator_state 1621*da0073e9SAndroid Build Coastguard Worker graph.register_generator_state(old_state) 1622*da0073e9SAndroid Build Coastguard Worker graph.register_generator_state(new_state) 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker # Define a function to perform specific RNG actions using the generator's states 1625*da0073e9SAndroid Build Coastguard Worker def perform_random_generation_steps(generator_state): 1626*da0073e9SAndroid Build Coastguard Worker generator, old_state, new_state = generator_state 1627*da0073e9SAndroid Build Coastguard Worker random_values = [] 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Worker # Generate random numbers with the new generator state 1630*da0073e9SAndroid Build Coastguard Worker generator.graphsafe_set_state(new_state) 1631*da0073e9SAndroid Build Coastguard Worker random_values.append(torch.rand(5, device="cuda", generator=generator)) 1632*da0073e9SAndroid Build Coastguard Worker 1633*da0073e9SAndroid Build Coastguard Worker # Generate random numbers twice with the old generator state 1634*da0073e9SAndroid Build Coastguard Worker generator.graphsafe_set_state(old_state) 1635*da0073e9SAndroid Build Coastguard Worker random_values.extend( 1636*da0073e9SAndroid Build Coastguard Worker [torch.rand(5, device="cuda", generator=generator) for _ in range(2)] 1637*da0073e9SAndroid Build Coastguard Worker ) 1638*da0073e9SAndroid Build Coastguard Worker 1639*da0073e9SAndroid Build Coastguard Worker return random_values 1640*da0073e9SAndroid Build Coastguard Worker 1641*da0073e9SAndroid Build Coastguard Worker # Define a function to retrieve the final offsets of the original and new generator states 1642*da0073e9SAndroid Build Coastguard Worker def get_final_offsets_of_states(generator_state): 1643*da0073e9SAndroid Build Coastguard Worker generator, old_state, new_state = generator_state 1644*da0073e9SAndroid Build Coastguard Worker old_state_offset = old_state.get_offset() 1645*da0073e9SAndroid Build Coastguard Worker new_state_offset = new_state.get_offset() 1646*da0073e9SAndroid Build Coastguard Worker return old_state_offset, new_state_offset 1647*da0073e9SAndroid Build Coastguard Worker 1648*da0073e9SAndroid Build Coastguard Worker # Set up and test a new CUDA generator 1649*da0073e9SAndroid Build Coastguard Worker generator = torch.Generator(device="cuda") 1650*da0073e9SAndroid Build Coastguard Worker generator_state = create_states(generator) 1651*da0073e9SAndroid Build Coastguard Worker 1652*da0073e9SAndroid Build Coastguard Worker # Set up and test the default CUDA generator with a CUDA Graph 1653*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 1654*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 1655*da0073e9SAndroid Build Coastguard Worker default_generator = torch.cuda.default_generators[0] 1656*da0073e9SAndroid Build Coastguard Worker default_generator_state = create_states(default_generator) 1657*da0073e9SAndroid Build Coastguard Worker register_states_to_graph(default_generator_state, g) 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker # Perform random number generation within a CUDA graph 1660*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 1661*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 1662*da0073e9SAndroid Build Coastguard Worker graphed_random_values = perform_random_generation_steps( 1663*da0073e9SAndroid Build Coastguard Worker default_generator_state 1664*da0073e9SAndroid Build Coastguard Worker ) 1665*da0073e9SAndroid Build Coastguard Worker g.capture_end() 1666*da0073e9SAndroid Build Coastguard Worker 1667*da0073e9SAndroid Build Coastguard Worker # Synchronize the streams and replay the graph 1668*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s) 1669*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 1670*da0073e9SAndroid Build Coastguard Worker random_values = perform_random_generation_steps(generator_state) 1671*da0073e9SAndroid Build Coastguard Worker g.replay() 1672*da0073e9SAndroid Build Coastguard Worker offset = get_final_offsets_of_states(generator_state) 1673*da0073e9SAndroid Build Coastguard Worker graph_offset = get_final_offsets_of_states(default_generator_state) 1674*da0073e9SAndroid Build Coastguard Worker 1675*da0073e9SAndroid Build Coastguard Worker # Compare the final offsets of states for both generators to ensure consistency 1676*da0073e9SAndroid Build Coastguard Worker self.assertTrue(offset == graph_offset) 1677*da0073e9SAndroid Build Coastguard Worker # Compare the states generated outside and inside the graph 1678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(random_values, graphed_random_values) 1679*da0073e9SAndroid Build Coastguard Worker 1680*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1681*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1682*da0073e9SAndroid Build Coastguard Worker ) 1683*da0073e9SAndroid Build Coastguard Worker def test_memory_stats_of_multiple_generators_and_graphs(self): 1684*da0073e9SAndroid Build Coastguard Worker # Function to clear CUDA cache and collect garbage 1685*da0073e9SAndroid Build Coastguard Worker def clear_cuda_cache(): 1686*da0073e9SAndroid Build Coastguard Worker gc.collect() 1687*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 1688*da0073e9SAndroid Build Coastguard Worker 1689*da0073e9SAndroid Build Coastguard Worker # Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph. 1690*da0073e9SAndroid Build Coastguard Worker def simple_graph_task(graph): 1691*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 1692*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 1693*da0073e9SAndroid Build Coastguard Worker graph.capture_begin() 1694*da0073e9SAndroid Build Coastguard Worker torch.rand(1, device="cuda") 1695*da0073e9SAndroid Build Coastguard Worker graph.capture_end() 1696*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s) 1697*da0073e9SAndroid Build Coastguard Worker graph.replay() # Replays the captured operations 1698*da0073e9SAndroid Build Coastguard Worker 1699*da0073e9SAndroid Build Coastguard Worker def get_memory_stats(): 1700*da0073e9SAndroid Build Coastguard Worker stats = torch.cuda.memory_stats() 1701*da0073e9SAndroid Build Coastguard Worker num_blocks = stats["active.all.current"] 1702*da0073e9SAndroid Build Coastguard Worker total_size = stats["active_bytes.all.current"] 1703*da0073e9SAndroid Build Coastguard Worker return num_blocks, total_size 1704*da0073e9SAndroid Build Coastguard Worker 1705*da0073e9SAndroid Build Coastguard Worker def test(num_graphs, num_generators): 1706*da0073e9SAndroid Build Coastguard Worker baseline = get_memory_stats() 1707*da0073e9SAndroid Build Coastguard Worker baseline_num_blocks, baseline_total_size = baseline 1708*da0073e9SAndroid Build Coastguard Worker 1709*da0073e9SAndroid Build Coastguard Worker # Allocate CUDA graphs 1710*da0073e9SAndroid Build Coastguard Worker graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)] 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker # Allocate and manage generator states 1713*da0073e9SAndroid Build Coastguard Worker default_generator = torch.cuda.default_generators[0] 1714*da0073e9SAndroid Build Coastguard Worker generators = [default_generator.graphsafe_get_state()] 1715*da0073e9SAndroid Build Coastguard Worker 1716*da0073e9SAndroid Build Coastguard Worker # Starts from 1 as one state is already added 1717*da0073e9SAndroid Build Coastguard Worker for _ in range(1, num_generators): 1718*da0073e9SAndroid Build Coastguard Worker generators.append(default_generator.clone_state()) 1719*da0073e9SAndroid Build Coastguard Worker 1720*da0073e9SAndroid Build Coastguard Worker for graph in graphs: 1721*da0073e9SAndroid Build Coastguard Worker for generator_state in generators: 1722*da0073e9SAndroid Build Coastguard Worker graph.register_generator_state(generator_state) 1723*da0073e9SAndroid Build Coastguard Worker simple_graph_task(graph) 1724*da0073e9SAndroid Build Coastguard Worker 1725*da0073e9SAndroid Build Coastguard Worker # Assert conditions after graph tasks 1726*da0073e9SAndroid Build Coastguard Worker num_blocks, total_size = get_memory_stats() 1727*da0073e9SAndroid Build Coastguard Worker # The allocated blocks should only be proportional to the number of generators 1728*da0073e9SAndroid Build Coastguard Worker expected_blocks_diff = 2 * num_generators 1729*da0073e9SAndroid Build Coastguard Worker expected_size_diff = 2 * 512 * num_generators # Each block's size is 512 1730*da0073e9SAndroid Build Coastguard Worker 1731*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1732*da0073e9SAndroid Build Coastguard Worker (num_blocks - baseline_num_blocks) == expected_blocks_diff, 1733*da0073e9SAndroid Build Coastguard Worker "Unexpected number of active blocks.", 1734*da0073e9SAndroid Build Coastguard Worker ) 1735*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1736*da0073e9SAndroid Build Coastguard Worker (total_size - baseline_total_size) == expected_size_diff, 1737*da0073e9SAndroid Build Coastguard Worker "Unexpected total memory size.", 1738*da0073e9SAndroid Build Coastguard Worker ) 1739*da0073e9SAndroid Build Coastguard Worker 1740*da0073e9SAndroid Build Coastguard Worker # Cleanup graphs and clear CUDA cache 1741*da0073e9SAndroid Build Coastguard Worker while graphs: 1742*da0073e9SAndroid Build Coastguard Worker graph = graphs.pop() 1743*da0073e9SAndroid Build Coastguard Worker del graph 1744*da0073e9SAndroid Build Coastguard Worker clear_cuda_cache() 1745*da0073e9SAndroid Build Coastguard Worker 1746*da0073e9SAndroid Build Coastguard Worker # Assert that memory stats return to baseline after cleanup 1747*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1748*da0073e9SAndroid Build Coastguard Worker get_memory_stats() == baseline, 1749*da0073e9SAndroid Build Coastguard Worker "Memory stats do not match baseline after cleanup.", 1750*da0073e9SAndroid Build Coastguard Worker ) 1751*da0073e9SAndroid Build Coastguard Worker 1752*da0073e9SAndroid Build Coastguard Worker # Running the test function with different parameters 1753*da0073e9SAndroid Build Coastguard Worker test(1, 1) 1754*da0073e9SAndroid Build Coastguard Worker test(3, 2) 1755*da0073e9SAndroid Build Coastguard Worker test(10, 20) 1756*da0073e9SAndroid Build Coastguard Worker 1757*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1758*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1759*da0073e9SAndroid Build Coastguard Worker ) 1760*da0073e9SAndroid Build Coastguard Worker def test_graph_capture_reset_recapture(self): 1761*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 1762*da0073e9SAndroid Build Coastguard Worker 1763*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 1764*da0073e9SAndroid Build Coastguard Worker a = torch.full((1000,), 1, device="cuda") 1765*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 1766*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 1767*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 1768*da0073e9SAndroid Build Coastguard Worker b = a 1769*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 1770*da0073e9SAndroid Build Coastguard Worker b = b + 1 1771*da0073e9SAndroid Build Coastguard Worker g.capture_end() 1772*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s) 1773*da0073e9SAndroid Build Coastguard Worker 1774*da0073e9SAndroid Build Coastguard Worker g.replay() 1775*da0073e9SAndroid Build Coastguard Worker 1776*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.sum().item() == 11000.0) 1777*da0073e9SAndroid Build Coastguard Worker 1778*da0073e9SAndroid Build Coastguard Worker g.reset() 1779*da0073e9SAndroid Build Coastguard Worker 1780*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 1781*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 1782*da0073e9SAndroid Build Coastguard Worker b.fill_(2.0) 1783*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 1784*da0073e9SAndroid Build Coastguard Worker b = b + 2 1785*da0073e9SAndroid Build Coastguard Worker g.capture_end() 1786*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s) 1787*da0073e9SAndroid Build Coastguard Worker 1788*da0073e9SAndroid Build Coastguard Worker g.replay() 1789*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.sum().item() == 22000.0) 1790*da0073e9SAndroid Build Coastguard Worker 1791*da0073e9SAndroid Build Coastguard Worker g.reset() 1792*da0073e9SAndroid Build Coastguard Worker del g 1793*da0073e9SAndroid Build Coastguard Worker 1794*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1795*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1796*da0073e9SAndroid Build Coastguard Worker ) 1797*da0073e9SAndroid Build Coastguard Worker def test_graph_debugdump(self): 1798*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 1799*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10240000, device="cuda") 1800*da0073e9SAndroid Build Coastguard Worker y = torch.rand_like(x) 1801*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 1802*da0073e9SAndroid Build Coastguard Worker g.enable_debug_mode() 1803*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.Stream() 1804*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream() 1805*da0073e9SAndroid Build Coastguard Worker s0.wait_stream(torch.cuda.current_stream()) 1806*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s0): 1807*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 1808*da0073e9SAndroid Build Coastguard Worker z = x + y 1809*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 1810*da0073e9SAndroid Build Coastguard Worker s1.wait_stream(s0) 1811*da0073e9SAndroid Build Coastguard Worker w = z + y 1812*da0073e9SAndroid Build Coastguard Worker s0.wait_stream(s1) 1813*da0073e9SAndroid Build Coastguard Worker g.capture_end() 1814*da0073e9SAndroid Build Coastguard Worker s0.synchronize() 1815*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1816*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tempdir: 1817*da0073e9SAndroid Build Coastguard Worker g.debug_dump(os.path.join(tempdir, "out_multi_stream.dot")) 1818*da0073e9SAndroid Build Coastguard Worker 1819*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1820*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1821*da0073e9SAndroid Build Coastguard Worker ) 1822*da0073e9SAndroid Build Coastguard Worker def test_graph_error(self): 1823*da0073e9SAndroid Build Coastguard Worker # We need to run this test in a separate thread as the error we trigger 1824*da0073e9SAndroid Build Coastguard Worker # puts the cuda context in a bad state 1825*da0073e9SAndroid Build Coastguard Worker script = """ 1826*da0073e9SAndroid Build Coastguard Workerimport torch 1827*da0073e9SAndroid Build Coastguard Worker 1828*da0073e9SAndroid Build Coastguard Workerg = torch.cuda.CUDAGraph() 1829*da0073e9SAndroid Build Coastguard Workertry: 1830*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 1831*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError as e: 1832*da0073e9SAndroid Build Coastguard Worker if "CUDA graphs must be captured on a non-default stream." in str(e): 1833*da0073e9SAndroid Build Coastguard Worker exit(0) 1834*da0073e9SAndroid Build Coastguard Worker else: 1835*da0073e9SAndroid Build Coastguard Worker exit(1) 1836*da0073e9SAndroid Build Coastguard Workerexit(2) 1837*da0073e9SAndroid Build Coastguard Worker""" 1838*da0073e9SAndroid Build Coastguard Worker try: 1839*da0073e9SAndroid Build Coastguard Worker a = subprocess.check_output( 1840*da0073e9SAndroid Build Coastguard Worker [sys.executable, "-c", script], 1841*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 1842*da0073e9SAndroid Build Coastguard Worker # On Windows, opening the subprocess with the default CWD makes `import torch` 1843*da0073e9SAndroid Build Coastguard Worker # fail, so just set CWD to this script's directory 1844*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)), 1845*da0073e9SAndroid Build Coastguard Worker ) 1846*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError as e: 1847*da0073e9SAndroid Build Coastguard Worker if e.returncode == 1: 1848*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1849*da0073e9SAndroid Build Coastguard Worker False, 1850*da0073e9SAndroid Build Coastguard Worker "Error raise by starting capture without a stream is not the expected one", 1851*da0073e9SAndroid Build Coastguard Worker ) 1852*da0073e9SAndroid Build Coastguard Worker elif e.returncode == 2: 1853*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1854*da0073e9SAndroid Build Coastguard Worker False, 1855*da0073e9SAndroid Build Coastguard Worker "Error raised by starting capture without a stream was not caught", 1856*da0073e9SAndroid Build Coastguard Worker ) 1857*da0073e9SAndroid Build Coastguard Worker 1858*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1859*da0073e9SAndroid Build Coastguard Worker (not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11, 1860*da0073e9SAndroid Build Coastguard Worker "CUDA >= 11.0 required for graphs", 1861*da0073e9SAndroid Build Coastguard Worker ) 1862*da0073e9SAndroid Build Coastguard Worker def test_graph_warn_if_has_zero_nodes(self): 1863*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as caught: 1864*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 1865*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 1866*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 1867*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 1868*da0073e9SAndroid Build Coastguard Worker g.capture_end() 1869*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1870*da0073e9SAndroid Build Coastguard Worker any("The CUDA Graph is empty" in str(w.message) for w in caught) 1871*da0073e9SAndroid Build Coastguard Worker ) 1872*da0073e9SAndroid Build Coastguard Worker 1873*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1874*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1875*da0073e9SAndroid Build Coastguard Worker ) 1876*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1877*da0073e9SAndroid Build Coastguard Worker IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support" 1878*da0073e9SAndroid Build Coastguard Worker ) 1879*da0073e9SAndroid Build Coastguard Worker def test_graph_capture_oom(self): 1880*da0073e9SAndroid Build Coastguard Worker oom_regex = ( 1881*da0073e9SAndroid Build Coastguard Worker "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory" 1882*da0073e9SAndroid Build Coastguard Worker ) 1883*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, oom_regex): 1884*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(torch.cuda.CUDAGraph()): 1885*da0073e9SAndroid Build Coastguard Worker torch.zeros(2**40, device="cuda") 1886*da0073e9SAndroid Build Coastguard Worker 1887*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1888*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1889*da0073e9SAndroid Build Coastguard Worker ) 1890*da0073e9SAndroid Build Coastguard Worker @serialTest() 1891*da0073e9SAndroid Build Coastguard Worker def test_repeat_graph_capture_cublas_workspace_memory(self): 1892*da0073e9SAndroid Build Coastguard Worker (x, y, z) = 1024, 512, 64 1893*da0073e9SAndroid Build Coastguard Worker a = torch.rand((x, y), device="cuda") 1894*da0073e9SAndroid Build Coastguard Worker b = torch.rand((y, z), device="cuda") 1895*da0073e9SAndroid Build Coastguard Worker 1896*da0073e9SAndroid Build Coastguard Worker # warmup 1897*da0073e9SAndroid Build Coastguard Worker torch.mm(a, b) 1898*da0073e9SAndroid Build Coastguard Worker 1899*da0073e9SAndroid Build Coastguard Worker free_bytes_before, total_bytes = torch.cuda.mem_get_info() 1900*da0073e9SAndroid Build Coastguard Worker used_gb_before = (total_bytes - free_bytes_before) / 1e9 1901*da0073e9SAndroid Build Coastguard Worker 1902*da0073e9SAndroid Build Coastguard Worker for i in range(100): 1903*da0073e9SAndroid Build Coastguard Worker torch_graph = torch.cuda.CUDAGraph() 1904*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(torch_graph): 1905*da0073e9SAndroid Build Coastguard Worker torch.mm(a, b) 1906*da0073e9SAndroid Build Coastguard Worker torch_graph.replay() 1907*da0073e9SAndroid Build Coastguard Worker 1908*da0073e9SAndroid Build Coastguard Worker free_bytes_after, _ = torch.cuda.mem_get_info() 1909*da0073e9SAndroid Build Coastguard Worker used_gb_after = (total_bytes - free_bytes_after) / 1e9 1910*da0073e9SAndroid Build Coastguard Worker 1911*da0073e9SAndroid Build Coastguard Worker self.assertFalse(used_gb_before + 0.1 < used_gb_after) 1912*da0073e9SAndroid Build Coastguard Worker 1913*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1914*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 1915*da0073e9SAndroid Build Coastguard Worker ) 1916*da0073e9SAndroid Build Coastguard Worker def test_graph_rng_functional(self): 1917*da0073e9SAndroid Build Coastguard Worker ops_with_kwargs = ( 1918*da0073e9SAndroid Build Coastguard Worker (torch.nn.functional.dropout, {"p": 0.1}), 1919*da0073e9SAndroid Build Coastguard Worker (torch.nn.functional.rrelu, {"training": True}), 1920*da0073e9SAndroid Build Coastguard Worker ) 1921*da0073e9SAndroid Build Coastguard Worker size = 10000 1922*da0073e9SAndroid Build Coastguard Worker 1923*da0073e9SAndroid Build Coastguard Worker def run(op, kwargs): 1924*da0073e9SAndroid Build Coastguard Worker a = torch.randn((size,), device="cuda", dtype=torch.float) 1925*da0073e9SAndroid Build Coastguard Worker 1926*da0073e9SAndroid Build Coastguard Worker # Control 1927*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(5) 1928*da0073e9SAndroid Build Coastguard Worker eager_out = a 1929*da0073e9SAndroid Build Coastguard Worker for _ in range(6): 1930*da0073e9SAndroid Build Coastguard Worker eager_out = op(eager_out, **kwargs) 1931*da0073e9SAndroid Build Coastguard Worker 1932*da0073e9SAndroid Build Coastguard Worker graph_in = a.clone() 1933*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 1934*da0073e9SAndroid Build Coastguard Worker stream.wait_stream(torch.cuda.current_stream()) 1935*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 1936*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(5) 1937*da0073e9SAndroid Build Coastguard Worker 1938*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 1939*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 1940*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 1941*da0073e9SAndroid Build Coastguard Worker graph_out = graph_in 1942*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 1943*da0073e9SAndroid Build Coastguard Worker graph_out = op(graph_out, **kwargs) 1944*da0073e9SAndroid Build Coastguard Worker g.capture_end() 1945*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(stream) 1946*da0073e9SAndroid Build Coastguard Worker 1947*da0073e9SAndroid Build Coastguard Worker # Runs a graphed->eager->graphed sequence of RNG ops. 1948*da0073e9SAndroid Build Coastguard Worker # replay() plays 2 invocations of the op, so the sequence has 6 1949*da0073e9SAndroid Build Coastguard Worker # invocations total, matching Control. 1950*da0073e9SAndroid Build Coastguard Worker # replay() reads from graph_in and writes to graph_out. 1951*da0073e9SAndroid Build Coastguard Worker g.replay() 1952*da0073e9SAndroid Build Coastguard Worker out = op(graph_out, **kwargs) 1953*da0073e9SAndroid Build Coastguard Worker out = op(out, **kwargs) 1954*da0073e9SAndroid Build Coastguard Worker graph_in.copy_(out) 1955*da0073e9SAndroid Build Coastguard Worker g.replay() 1956*da0073e9SAndroid Build Coastguard Worker 1957*da0073e9SAndroid Build Coastguard Worker # If replay() updated RNG state correctly, graph_out 1958*da0073e9SAndroid Build Coastguard Worker # should now hold data equal to eager_out. 1959*da0073e9SAndroid Build Coastguard Worker try: 1960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_out, graph_out) 1961*da0073e9SAndroid Build Coastguard Worker except Exception as e: 1962*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Failed on ", op) from e 1963*da0073e9SAndroid Build Coastguard Worker 1964*da0073e9SAndroid Build Coastguard Worker # Do the same operations varying seeds 1965*da0073e9SAndroid Build Coastguard Worker seeds = [6, 128, 9999] 1966*da0073e9SAndroid Build Coastguard Worker 1967*da0073e9SAndroid Build Coastguard Worker for seed in seeds: 1968*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(seed) 1969*da0073e9SAndroid Build Coastguard Worker graph_in.copy_(a) 1970*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 1971*da0073e9SAndroid Build Coastguard Worker g.replay() 1972*da0073e9SAndroid Build Coastguard Worker 1973*da0073e9SAndroid Build Coastguard Worker # If the random seed was not updated then the graph would 1974*da0073e9SAndroid Build Coastguard Worker # generate the same output as in previous check. 1975*da0073e9SAndroid Build Coastguard Worker try: 1976*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(eager_out, graph_out) 1977*da0073e9SAndroid Build Coastguard Worker except Exception as e: 1978*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Failed on ", op) from e 1979*da0073e9SAndroid Build Coastguard Worker 1980*da0073e9SAndroid Build Coastguard Worker # Now repeat the same operations in non-graphed mode. 1981*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(seed) 1982*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 1983*da0073e9SAndroid Build Coastguard Worker eager_out.copy_(a) 1984*da0073e9SAndroid Build Coastguard Worker eager_out = op(eager_out, **kwargs) 1985*da0073e9SAndroid Build Coastguard Worker eager_out = op(eager_out, **kwargs) 1986*da0073e9SAndroid Build Coastguard Worker 1987*da0073e9SAndroid Build Coastguard Worker # In the end, graph_out and eager_out must be equal 1988*da0073e9SAndroid Build Coastguard Worker # as they went under the same set of operations. 1989*da0073e9SAndroid Build Coastguard Worker try: 1990*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_out, graph_out) 1991*da0073e9SAndroid Build Coastguard Worker except Exception as e: 1992*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Failed on ", op) from e 1993*da0073e9SAndroid Build Coastguard Worker 1994*da0073e9SAndroid Build Coastguard Worker # We hold references to all tensors used across streams up til this sync, 1995*da0073e9SAndroid Build Coastguard Worker # so no need to call record_stream on those tensors. 1996*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1997*da0073e9SAndroid Build Coastguard Worker 1998*da0073e9SAndroid Build Coastguard Worker for op, kwargs in ops_with_kwargs: 1999*da0073e9SAndroid Build Coastguard Worker run(op, kwargs) 2000*da0073e9SAndroid Build Coastguard Worker 2001*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2002*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2003*da0073e9SAndroid Build Coastguard Worker ) 2004*da0073e9SAndroid Build Coastguard Worker def test_graph_rng_distributions(self): 2005*da0073e9SAndroid Build Coastguard Worker size = 10000 2006*da0073e9SAndroid Build Coastguard Worker input = torch.rand((size,), device="cuda", dtype=torch.float) 2007*da0073e9SAndroid Build Coastguard Worker alloc = torch.empty((size,), device="cuda", dtype=torch.float) 2008*da0073e9SAndroid Build Coastguard Worker 2009*da0073e9SAndroid Build Coastguard Worker # Torch ops to test with sample args (tuple) and kwargs (dict) 2010*da0073e9SAndroid Build Coastguard Worker torch_with_args = ( 2011*da0073e9SAndroid Build Coastguard Worker ("bernoulli", (input.clone(),), {}), 2012*da0073e9SAndroid Build Coastguard Worker # multinomial uses some uncapturable CUDA calls. 2013*da0073e9SAndroid Build Coastguard Worker # TODO: reenable multinomial tests if/when the implementation is capturable. 2014*da0073e9SAndroid Build Coastguard Worker # ("multinomial", (input.clone(), size, True), {}), 2015*da0073e9SAndroid Build Coastguard Worker # ("multinomial", (input.clone(), size // 2, False), {}), 2016*da0073e9SAndroid Build Coastguard Worker # TODO: reenable normal test, where std is a device 2017*da0073e9SAndroid Build Coastguard Worker # tensor, when graph test failures are fixed 2018*da0073e9SAndroid Build Coastguard Worker # ("normal", (input.clone() + 1, input.clone()), {}), 2019*da0073e9SAndroid Build Coastguard Worker ("normal", (input.clone() + 1, 1.0), {}), 2020*da0073e9SAndroid Build Coastguard Worker ("poisson", (input.clone(),), {}), 2021*da0073e9SAndroid Build Coastguard Worker ("rand", (size,), {"device": "cuda", "dtype": torch.float}), 2022*da0073e9SAndroid Build Coastguard Worker ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}), 2023*da0073e9SAndroid Build Coastguard Worker ("randn", (size,), {"device": "cuda", "dtype": torch.float}), 2024*da0073e9SAndroid Build Coastguard Worker ) 2025*da0073e9SAndroid Build Coastguard Worker 2026*da0073e9SAndroid Build Coastguard Worker # Tensor methods to test with sample args (tuple) 2027*da0073e9SAndroid Build Coastguard Worker tensor_with_args = ( 2028*da0073e9SAndroid Build Coastguard Worker ("bernoulli_", (input.clone(),)), 2029*da0073e9SAndroid Build Coastguard Worker ("cauchy_", ()), 2030*da0073e9SAndroid Build Coastguard Worker ("exponential_", ()), 2031*da0073e9SAndroid Build Coastguard Worker ("geometric_", (0.3,)), 2032*da0073e9SAndroid Build Coastguard Worker ("log_normal_", ()), 2033*da0073e9SAndroid Build Coastguard Worker ("normal_", ()), 2034*da0073e9SAndroid Build Coastguard Worker ("random_", ()), 2035*da0073e9SAndroid Build Coastguard Worker ("uniform_", ()), 2036*da0073e9SAndroid Build Coastguard Worker ) 2037*da0073e9SAndroid Build Coastguard Worker 2038*da0073e9SAndroid Build Coastguard Worker def run(module, op, args, kwargs): 2039*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(5) 2040*da0073e9SAndroid Build Coastguard Worker 2041*da0073e9SAndroid Build Coastguard Worker # Each path runs a dummy op to increment the state a bit before creating controls. 2042*da0073e9SAndroid Build Coastguard Worker if module == "torch": 2043*da0073e9SAndroid Build Coastguard Worker dummy = getattr(torch, op)(*args, **kwargs) 2044*da0073e9SAndroid Build Coastguard Worker control1 = getattr(torch, op)(*args, **kwargs) 2045*da0073e9SAndroid Build Coastguard Worker control2 = getattr(torch, op)(*args, **kwargs) 2046*da0073e9SAndroid Build Coastguard Worker else: 2047*da0073e9SAndroid Build Coastguard Worker dummy = alloc.clone() 2048*da0073e9SAndroid Build Coastguard Worker control1 = alloc.clone() 2049*da0073e9SAndroid Build Coastguard Worker control2 = alloc.clone() 2050*da0073e9SAndroid Build Coastguard Worker getattr(dummy, op)(*args) 2051*da0073e9SAndroid Build Coastguard Worker getattr(control1, op)(*args) 2052*da0073e9SAndroid Build Coastguard Worker getattr(control2, op)(*args) 2053*da0073e9SAndroid Build Coastguard Worker 2054*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 2055*da0073e9SAndroid Build Coastguard Worker stream.wait_stream(torch.cuda.current_stream()) 2056*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 2057*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(5) 2058*da0073e9SAndroid Build Coastguard Worker 2059*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 2060*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2061*da0073e9SAndroid Build Coastguard Worker if module == "torch": 2062*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 2063*da0073e9SAndroid Build Coastguard Worker t1 = getattr(torch, op)(*args, **kwargs) 2064*da0073e9SAndroid Build Coastguard Worker t2 = getattr(torch, op)(*args, **kwargs) 2065*da0073e9SAndroid Build Coastguard Worker g.capture_end() 2066*da0073e9SAndroid Build Coastguard Worker else: 2067*da0073e9SAndroid Build Coastguard Worker t1 = alloc.clone() 2068*da0073e9SAndroid Build Coastguard Worker t2 = alloc.clone() 2069*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 2070*da0073e9SAndroid Build Coastguard Worker getattr(t1, op)(*args) 2071*da0073e9SAndroid Build Coastguard Worker getattr(t2, op)(*args) 2072*da0073e9SAndroid Build Coastguard Worker g.capture_end() 2073*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(stream) 2074*da0073e9SAndroid Build Coastguard Worker 2075*da0073e9SAndroid Build Coastguard Worker if not TEST_CUDAMALLOCASYNC: 2076*da0073e9SAndroid Build Coastguard Worker # Makes sure values haven't been populated yet 2077*da0073e9SAndroid Build Coastguard Worker # (in other words, makes sure capture didn't actually run ops). 2078*da0073e9SAndroid Build Coastguard Worker # We can only try this with the native allocator, for which captured 2079*da0073e9SAndroid Build Coastguard Worker # addresses are already backed by cudaMalloced memory. 2080*da0073e9SAndroid Build Coastguard Worker # If we try it with cudaMallocAsync, CUDA won't event consider 2081*da0073e9SAndroid Build Coastguard Worker # the captured addresses allocated until replay(), and if we 2082*da0073e9SAndroid Build Coastguard Worker # access them before replay() we get IMAs. 2083*da0073e9SAndroid Build Coastguard Worker try: 2084*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(control1, t1) 2085*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(control2, t2) 2086*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2087*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Failed on " + module + "." + op) from e 2088*da0073e9SAndroid Build Coastguard Worker 2089*da0073e9SAndroid Build Coastguard Worker # Set a new seed to check if graph would use it 2090*da0073e9SAndroid Build Coastguard Worker for seed in [6, 314, 271]: 2091*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(seed) 2092*da0073e9SAndroid Build Coastguard Worker # Runs a dummy op prelude, as for controls, to make sure replay() 2093*da0073e9SAndroid Build Coastguard Worker # picks up the dummy op's state increment. 2094*da0073e9SAndroid Build Coastguard Worker if module == "torch": 2095*da0073e9SAndroid Build Coastguard Worker dummy = getattr(torch, op)(*args, **kwargs) 2096*da0073e9SAndroid Build Coastguard Worker control1 = getattr(torch, op)(*args, **kwargs) 2097*da0073e9SAndroid Build Coastguard Worker control2 = getattr(torch, op)(*args, **kwargs) 2098*da0073e9SAndroid Build Coastguard Worker else: 2099*da0073e9SAndroid Build Coastguard Worker getattr(dummy, op)(*args) 2100*da0073e9SAndroid Build Coastguard Worker getattr(control1, op)(*args) 2101*da0073e9SAndroid Build Coastguard Worker getattr(control2, op)(*args) 2102*da0073e9SAndroid Build Coastguard Worker 2103*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(seed) 2104*da0073e9SAndroid Build Coastguard Worker if module == "torch": 2105*da0073e9SAndroid Build Coastguard Worker dummy = getattr(torch, op)(*args, **kwargs) 2106*da0073e9SAndroid Build Coastguard Worker else: 2107*da0073e9SAndroid Build Coastguard Worker getattr(dummy, op)(*args) 2108*da0073e9SAndroid Build Coastguard Worker 2109*da0073e9SAndroid Build Coastguard Worker # see above comment on TEST_CUDAMALLOCASYNC 2110*da0073e9SAndroid Build Coastguard Worker if not TEST_CUDAMALLOCASYNC: 2111*da0073e9SAndroid Build Coastguard Worker t1.copy_(alloc) 2112*da0073e9SAndroid Build Coastguard Worker t2.copy_(alloc) 2113*da0073e9SAndroid Build Coastguard Worker 2114*da0073e9SAndroid Build Coastguard Worker # Runs RNG ops that fill t1 and t2. 2115*da0073e9SAndroid Build Coastguard Worker g.replay() 2116*da0073e9SAndroid Build Coastguard Worker 2117*da0073e9SAndroid Build Coastguard Worker try: 2118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(control1, t1) 2119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(control2, t2) 2120*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2121*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Failed on " + module + "." + op) from e 2122*da0073e9SAndroid Build Coastguard Worker 2123*da0073e9SAndroid Build Coastguard Worker # We hold references to all tensors used across streams up til this sync, 2124*da0073e9SAndroid Build Coastguard Worker # so no need to call record_stream on those tensors. 2125*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 2126*da0073e9SAndroid Build Coastguard Worker 2127*da0073e9SAndroid Build Coastguard Worker for op_with_args in torch_with_args: 2128*da0073e9SAndroid Build Coastguard Worker run("torch", *op_with_args) 2129*da0073e9SAndroid Build Coastguard Worker 2130*da0073e9SAndroid Build Coastguard Worker for meth_with_args in tensor_with_args: 2131*da0073e9SAndroid Build Coastguard Worker # Adds an empty dict for kwargs, which none of the Tensor methods use 2132*da0073e9SAndroid Build Coastguard Worker run("Tensor", *(meth_with_args + ({},))) 2133*da0073e9SAndroid Build Coastguard Worker 2134*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2135*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2136*da0073e9SAndroid Build Coastguard Worker ) 2137*da0073e9SAndroid Build Coastguard Worker def test_graph_two_successive(self): 2138*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2139*da0073e9SAndroid Build Coastguard Worker 2140*da0073e9SAndroid Build Coastguard Worker size = 1000 2141*da0073e9SAndroid Build Coastguard Worker kSmallBuffer = 2097152 2142*da0073e9SAndroid Build Coastguard Worker 2143*da0073e9SAndroid Build Coastguard Worker def func_with_temps(t, val): 2144*da0073e9SAndroid Build Coastguard Worker x = t.clone() + val 2145*da0073e9SAndroid Build Coastguard Worker y = t.clone() + val 2146*da0073e9SAndroid Build Coastguard Worker return x + y 2147*da0073e9SAndroid Build Coastguard Worker 2148*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 2149*da0073e9SAndroid Build Coastguard Worker 2150*da0073e9SAndroid Build Coastguard Worker for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): 2151*da0073e9SAndroid Build Coastguard Worker g0 = torch.cuda.CUDAGraph() 2152*da0073e9SAndroid Build Coastguard Worker g1 = torch.cuda.CUDAGraph() 2153*da0073e9SAndroid Build Coastguard Worker 2154*da0073e9SAndroid Build Coastguard Worker a = torch.ones((size,), device="cuda") 2155*da0073e9SAndroid Build Coastguard Worker 2156*da0073e9SAndroid Build Coastguard Worker s.wait_stream(torch.cuda.current_stream()) 2157*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 2158*da0073e9SAndroid Build Coastguard Worker g0_args = ( 2159*da0073e9SAndroid Build Coastguard Worker (torch.cuda.graph_pool_handle(),) 2160*da0073e9SAndroid Build Coastguard Worker if share_mem == "via graph_pool_handle()" 2161*da0073e9SAndroid Build Coastguard Worker else () 2162*da0073e9SAndroid Build Coastguard Worker ) 2163*da0073e9SAndroid Build Coastguard Worker g0.capture_begin(*g0_args) 2164*da0073e9SAndroid Build Coastguard Worker b = a.clone() 2165*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 2166*da0073e9SAndroid Build Coastguard Worker b = func_with_temps(b, 1) 2167*da0073e9SAndroid Build Coastguard Worker g0.capture_end() 2168*da0073e9SAndroid Build Coastguard Worker 2169*da0073e9SAndroid Build Coastguard Worker g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args 2170*da0073e9SAndroid Build Coastguard Worker g1.capture_begin(*g1_args) 2171*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 2172*da0073e9SAndroid Build Coastguard Worker b = func_with_temps(b, 1) 2173*da0073e9SAndroid Build Coastguard Worker g1.capture_end() 2174*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s) 2175*da0073e9SAndroid Build Coastguard Worker 2176*da0073e9SAndroid Build Coastguard Worker # mixes unrelated eager ops with replays 2177*da0073e9SAndroid Build Coastguard Worker c = a.clone() 2178*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 2179*da0073e9SAndroid Build Coastguard Worker c = func_with_temps(c, 3) 2180*da0073e9SAndroid Build Coastguard Worker g0.replay() 2181*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 2182*da0073e9SAndroid Build Coastguard Worker c = func_with_temps(c, 3) 2183*da0073e9SAndroid Build Coastguard Worker g1.replay() 2184*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 2185*da0073e9SAndroid Build Coastguard Worker c = func_with_temps(c, 3) 2186*da0073e9SAndroid Build Coastguard Worker 2187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.sum().item(), size * 3070) 2188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.sum().item(), size * 442) 2189*da0073e9SAndroid Build Coastguard Worker 2190*da0073e9SAndroid Build Coastguard Worker if not TEST_CUDAMALLOCASYNC: 2191*da0073e9SAndroid Build Coastguard Worker # These stat checks are specific to the native allocator. 2192*da0073e9SAndroid Build Coastguard Worker if share_mem != "Don't share": 2193*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2194*da0073e9SAndroid Build Coastguard Worker reserved_no_sharing # noqa: F821 2195*da0073e9SAndroid Build Coastguard Worker - torch.cuda.memory_stats()["reserved_bytes.all.current"], 2196*da0073e9SAndroid Build Coastguard Worker kSmallBuffer, 2197*da0073e9SAndroid Build Coastguard Worker ) 2198*da0073e9SAndroid Build Coastguard Worker else: 2199*da0073e9SAndroid Build Coastguard Worker reserved_no_sharing = torch.cuda.memory_stats()[ 2200*da0073e9SAndroid Build Coastguard Worker "reserved_bytes.all.current" 2201*da0073e9SAndroid Build Coastguard Worker ] 2202*da0073e9SAndroid Build Coastguard Worker 2203*da0073e9SAndroid Build Coastguard Worker del a, b, c, g0, g1 2204*da0073e9SAndroid Build Coastguard Worker # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. 2205*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 2206*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2207*da0073e9SAndroid Build Coastguard Worker 2208*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2209*da0073e9SAndroid Build Coastguard Worker (not TEST_CUDA_GRAPH) 2210*da0073e9SAndroid Build Coastguard Worker or IS_WINDOWS 2211*da0073e9SAndroid Build Coastguard Worker or ( # appears to still be broken on Windows as of 11.4+ 2212*da0073e9SAndroid Build Coastguard Worker torch.version.cuda 2213*da0073e9SAndroid Build Coastguard Worker and int(torch.version.cuda.split(".")[0]) == 11 2214*da0073e9SAndroid Build Coastguard Worker and int(torch.version.cuda.split(".")[1]) < 4 2215*da0073e9SAndroid Build Coastguard Worker ), 2216*da0073e9SAndroid Build Coastguard Worker "Graph bindings disallow concurrent replay for CUDA < 11.4, see " 2217*da0073e9SAndroid Build Coastguard Worker + "https://github.com/pytorch/pytorch/pull/57556", 2218*da0073e9SAndroid Build Coastguard Worker ) 2219*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2220*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2221*da0073e9SAndroid Build Coastguard Worker ) 2222*da0073e9SAndroid Build Coastguard Worker def test_graph_concurrent_replay(self): 2223*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2224*da0073e9SAndroid Build Coastguard Worker 2225*da0073e9SAndroid Build Coastguard Worker size = 1000000 # largeish to help expose race conditions 2226*da0073e9SAndroid Build Coastguard Worker 2227*da0073e9SAndroid Build Coastguard Worker def func_with_temps(t, val): 2228*da0073e9SAndroid Build Coastguard Worker x = t.clone() + val 2229*da0073e9SAndroid Build Coastguard Worker y = t.clone() + val 2230*da0073e9SAndroid Build Coastguard Worker return x + y 2231*da0073e9SAndroid Build Coastguard Worker 2232*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 2233*da0073e9SAndroid Build Coastguard Worker 2234*da0073e9SAndroid Build Coastguard Worker for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): 2235*da0073e9SAndroid Build Coastguard Worker g0 = torch.cuda.CUDAGraph() 2236*da0073e9SAndroid Build Coastguard Worker g1 = torch.cuda.CUDAGraph() 2237*da0073e9SAndroid Build Coastguard Worker 2238*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.Stream() 2239*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream() 2240*da0073e9SAndroid Build Coastguard Worker 2241*da0073e9SAndroid Build Coastguard Worker a = torch.ones((size,), device="cuda") 2242*da0073e9SAndroid Build Coastguard Worker 2243*da0073e9SAndroid Build Coastguard Worker s.wait_stream(torch.cuda.current_stream()) 2244*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 2245*da0073e9SAndroid Build Coastguard Worker g0_args = ( 2246*da0073e9SAndroid Build Coastguard Worker (torch.cuda.graph_pool_handle(),) 2247*da0073e9SAndroid Build Coastguard Worker if share_mem == "via graph_pool_handle()" 2248*da0073e9SAndroid Build Coastguard Worker else () 2249*da0073e9SAndroid Build Coastguard Worker ) 2250*da0073e9SAndroid Build Coastguard Worker g0.capture_begin(*g0_args) 2251*da0073e9SAndroid Build Coastguard Worker b = a.clone() 2252*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 2253*da0073e9SAndroid Build Coastguard Worker b = func_with_temps(b, 1) 2254*da0073e9SAndroid Build Coastguard Worker g0.capture_end() 2255*da0073e9SAndroid Build Coastguard Worker 2256*da0073e9SAndroid Build Coastguard Worker g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args 2257*da0073e9SAndroid Build Coastguard Worker g1.capture_begin(*g1_args) 2258*da0073e9SAndroid Build Coastguard Worker c = a.clone() 2259*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 2260*da0073e9SAndroid Build Coastguard Worker c = func_with_temps(c, 2) 2261*da0073e9SAndroid Build Coastguard Worker g1.capture_end() 2262*da0073e9SAndroid Build Coastguard Worker 2263*da0073e9SAndroid Build Coastguard Worker # To reproduce data corruption, I need g0 and g1's kernels to run concurrently. 2264*da0073e9SAndroid Build Coastguard Worker # But replay() (especially cudaGraphLaunch) can incur significant CPU overhead. 2265*da0073e9SAndroid Build Coastguard Worker # The following pattern helps align device-side execution of g0 and g1's kernels. 2266*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 2267*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s0): 2268*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(1000000) 2269*da0073e9SAndroid Build Coastguard Worker s1.wait_stream(s0) 2270*da0073e9SAndroid Build Coastguard Worker g0.replay() 2271*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 2272*da0073e9SAndroid Build Coastguard Worker g1.replay() 2273*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s0) 2274*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s1) 2275*da0073e9SAndroid Build Coastguard Worker 2276*da0073e9SAndroid Build Coastguard Worker if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"): 2277*da0073e9SAndroid Build Coastguard Worker # If we used the native allocator and shared mempools, 2278*da0073e9SAndroid Build Coastguard Worker # we expect the concurrent replays corrupted each other. 2279*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(b.sum().item(), size * 94) 2280*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(c.sum().item(), size * 156) 2281*da0073e9SAndroid Build Coastguard Worker else: 2282*da0073e9SAndroid Build Coastguard Worker # If we EITHER 2283*da0073e9SAndroid Build Coastguard Worker # - used the native allocator without sharing mempools, OR 2284*da0073e9SAndroid Build Coastguard Worker # - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe 2285*da0073e9SAndroid Build Coastguard Worker # we don't expect memory corruption. 2286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.sum().item(), size * 94) 2287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.sum().item(), size * 156) 2288*da0073e9SAndroid Build Coastguard Worker 2289*da0073e9SAndroid Build Coastguard Worker del a, b, c, g0, g1 2290*da0073e9SAndroid Build Coastguard Worker # Tensors used across streams (a, b, c) were held until just now, so no need to call record_stream on them. 2291*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 2292*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2293*da0073e9SAndroid Build Coastguard Worker 2294*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2295*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2296*da0073e9SAndroid Build Coastguard Worker ) 2297*da0073e9SAndroid Build Coastguard Worker def test_graph_three_successive(self): 2298*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2299*da0073e9SAndroid Build Coastguard Worker 2300*da0073e9SAndroid Build Coastguard Worker size = 1000 2301*da0073e9SAndroid Build Coastguard Worker 2302*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 2303*da0073e9SAndroid Build Coastguard Worker 2304*da0073e9SAndroid Build Coastguard Worker for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): 2305*da0073e9SAndroid Build Coastguard Worker a = torch.ones((size,), device="cuda") 2306*da0073e9SAndroid Build Coastguard Worker 2307*da0073e9SAndroid Build Coastguard Worker g0 = torch.cuda.CUDAGraph() 2308*da0073e9SAndroid Build Coastguard Worker g1 = torch.cuda.CUDAGraph() 2309*da0073e9SAndroid Build Coastguard Worker g2 = torch.cuda.CUDAGraph() 2310*da0073e9SAndroid Build Coastguard Worker 2311*da0073e9SAndroid Build Coastguard Worker s.wait_stream(torch.cuda.current_stream()) 2312*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 2313*da0073e9SAndroid Build Coastguard Worker g0_args = ( 2314*da0073e9SAndroid Build Coastguard Worker (torch.cuda.graph_pool_handle(),) 2315*da0073e9SAndroid Build Coastguard Worker if share_mem == "via graph_pool_handle()" 2316*da0073e9SAndroid Build Coastguard Worker else () 2317*da0073e9SAndroid Build Coastguard Worker ) 2318*da0073e9SAndroid Build Coastguard Worker g0.capture_begin(*g0_args) 2319*da0073e9SAndroid Build Coastguard Worker b = a.clone() 2320*da0073e9SAndroid Build Coastguard Worker c = b + 1 2321*da0073e9SAndroid Build Coastguard Worker d = b + 2 2322*da0073e9SAndroid Build Coastguard Worker g0.capture_end() 2323*da0073e9SAndroid Build Coastguard Worker 2324*da0073e9SAndroid Build Coastguard Worker args = (g0.pool(),) if share_mem == "via pool()" else g0_args 2325*da0073e9SAndroid Build Coastguard Worker 2326*da0073e9SAndroid Build Coastguard Worker g1.capture_begin(*args) 2327*da0073e9SAndroid Build Coastguard Worker e = c + 3 2328*da0073e9SAndroid Build Coastguard Worker del c 2329*da0073e9SAndroid Build Coastguard Worker g1.capture_end() 2330*da0073e9SAndroid Build Coastguard Worker 2331*da0073e9SAndroid Build Coastguard Worker g2.capture_begin(*args) 2332*da0073e9SAndroid Build Coastguard Worker f = d + 4 2333*da0073e9SAndroid Build Coastguard Worker g2.capture_end() 2334*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s) 2335*da0073e9SAndroid Build Coastguard Worker 2336*da0073e9SAndroid Build Coastguard Worker # Tests that replaying in capture order is valid 2337*da0073e9SAndroid Build Coastguard Worker g0.replay() 2338*da0073e9SAndroid Build Coastguard Worker g1.replay() 2339*da0073e9SAndroid Build Coastguard Worker g2.replay() 2340*da0073e9SAndroid Build Coastguard Worker 2341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.sum().item(), size * 5) 2342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.sum().item(), size * 7) 2343*da0073e9SAndroid Build Coastguard Worker 2344*da0073e9SAndroid Build Coastguard Worker # Tests that replaying as g0, g2, g1 is only valid if they don't share a pool 2345*da0073e9SAndroid Build Coastguard Worker g0.replay() 2346*da0073e9SAndroid Build Coastguard Worker g2.replay() 2347*da0073e9SAndroid Build Coastguard Worker g1.replay() 2348*da0073e9SAndroid Build Coastguard Worker 2349*da0073e9SAndroid Build Coastguard Worker expect_corruption = (not TEST_CUDAMALLOCASYNC) and ( 2350*da0073e9SAndroid Build Coastguard Worker share_mem != "Don't share" 2351*da0073e9SAndroid Build Coastguard Worker ) 2352*da0073e9SAndroid Build Coastguard Worker # If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f. 2353*da0073e9SAndroid Build Coastguard Worker # We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3". 2354*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2355*da0073e9SAndroid Build Coastguard Worker e.sum().item(), size * (7 + 3) if expect_corruption else size * 5 2356*da0073e9SAndroid Build Coastguard Worker ) 2357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.sum().item(), size * 7) 2358*da0073e9SAndroid Build Coastguard Worker 2359*da0073e9SAndroid Build Coastguard Worker del a, b, d, e, f, g0, g1, g2 2360*da0073e9SAndroid Build Coastguard Worker # Tensors used across streams (a, e, f) were held until just now, so no need to call record_stream on them. 2361*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 2362*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2363*da0073e9SAndroid Build Coastguard Worker 2364*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2365*da0073e9SAndroid Build Coastguard Worker (not TEST_CUDA_GRAPH) or TEST_CUDAMALLOCASYNC, 2366*da0073e9SAndroid Build Coastguard Worker "CUDA >= 11.0 or ROCM >= 5.3 required for graphs", 2367*da0073e9SAndroid Build Coastguard Worker ) 2368*da0073e9SAndroid Build Coastguard Worker def test_graph_memory_stats_and_use_result_after_destroy_graph(self): 2369*da0073e9SAndroid Build Coastguard Worker kSmallSize = 1048576 2370*da0073e9SAndroid Build Coastguard Worker kSmallBuffer = 2097152 2371*da0073e9SAndroid Build Coastguard Worker kLargeBuffer = 20971520 2372*da0073e9SAndroid Build Coastguard Worker kMinLargeAlloc = 10485760 2373*da0073e9SAndroid Build Coastguard Worker kRoundLarge = 2097152 2374*da0073e9SAndroid Build Coastguard Worker 2375*da0073e9SAndroid Build Coastguard Worker elem = 4 2376*da0073e9SAndroid Build Coastguard Worker 2377*da0073e9SAndroid Build Coastguard Worker # this was annoying to write but stresses the expectations pretty rigorously 2378*da0073e9SAndroid Build Coastguard Worker cases = ( 2379*da0073e9SAndroid Build Coastguard Worker (512 // elem, 1, kSmallBuffer, kSmallBuffer, "small_pool"), 2380*da0073e9SAndroid Build Coastguard Worker (kSmallSize // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"), 2381*da0073e9SAndroid Build Coastguard Worker ((kSmallSize + 512) // elem, 1, kLargeBuffer, kLargeBuffer, "large_pool"), 2382*da0073e9SAndroid Build Coastguard Worker ( 2383*da0073e9SAndroid Build Coastguard Worker (kMinLargeAlloc - 512) // elem, 2384*da0073e9SAndroid Build Coastguard Worker 2, 2385*da0073e9SAndroid Build Coastguard Worker 2 * kLargeBuffer, 2386*da0073e9SAndroid Build Coastguard Worker kLargeBuffer, 2387*da0073e9SAndroid Build Coastguard Worker "large_pool", 2388*da0073e9SAndroid Build Coastguard Worker ), 2389*da0073e9SAndroid Build Coastguard Worker ( 2390*da0073e9SAndroid Build Coastguard Worker (kMinLargeAlloc + 512) // elem, 2391*da0073e9SAndroid Build Coastguard Worker 3, 2392*da0073e9SAndroid Build Coastguard Worker 3 2393*da0073e9SAndroid Build Coastguard Worker * ( 2394*da0073e9SAndroid Build Coastguard Worker kRoundLarge 2395*da0073e9SAndroid Build Coastguard Worker * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge) 2396*da0073e9SAndroid Build Coastguard Worker ), 2397*da0073e9SAndroid Build Coastguard Worker kRoundLarge * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge), 2398*da0073e9SAndroid Build Coastguard Worker "large_pool", 2399*da0073e9SAndroid Build Coastguard Worker ), 2400*da0073e9SAndroid Build Coastguard Worker ) 2401*da0073e9SAndroid Build Coastguard Worker 2402*da0073e9SAndroid Build Coastguard Worker stats_to_check = ("segment.", "reserved_bytes.", "active.", "active_bytes.") 2403*da0073e9SAndroid Build Coastguard Worker 2404*da0073e9SAndroid Build Coastguard Worker gc.collect() 2405*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2406*da0073e9SAndroid Build Coastguard Worker 2407*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 2408*da0073e9SAndroid Build Coastguard Worker 2409*da0073e9SAndroid Build Coastguard Worker for ( 2410*da0073e9SAndroid Build Coastguard Worker numel, 2411*da0073e9SAndroid Build Coastguard Worker delta_cudaMallocs, 2412*da0073e9SAndroid Build Coastguard Worker delta_cudaMalloc_bytes, 2413*da0073e9SAndroid Build Coastguard Worker delta_cudaMalloc_bytes_post_del_g, 2414*da0073e9SAndroid Build Coastguard Worker pool_string, 2415*da0073e9SAndroid Build Coastguard Worker ) in cases: 2416*da0073e9SAndroid Build Coastguard Worker if pool_string == "small_pool": 2417*da0073e9SAndroid Build Coastguard Worker delta_active_blocks = 3 # one from "b" plus a sneaky two from CUDAGraph's one-element rng seed and offset holders 2418*da0073e9SAndroid Build Coastguard Worker delta_active_bytes = ( 2419*da0073e9SAndroid Build Coastguard Worker numel * elem + 1024 2420*da0073e9SAndroid Build Coastguard Worker ) # + 1024 for CUDAGraph's rng seed and offset holders each 2421*da0073e9SAndroid Build Coastguard Worker else: 2422*da0073e9SAndroid Build Coastguard Worker delta_active_blocks = 1 # We only check the large pool, which isn't affected by rng offset holder 2423*da0073e9SAndroid Build Coastguard Worker delta_active_bytes = numel * elem 2424*da0073e9SAndroid Build Coastguard Worker 2425*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 2426*da0073e9SAndroid Build Coastguard Worker s.wait_stream(torch.cuda.current_stream()) 2427*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 2428*da0073e9SAndroid Build Coastguard Worker # Allocation stat estimates assume input is created on the same stream as capture_begin() 2429*da0073e9SAndroid Build Coastguard Worker # (in other words, the same stream silo as the rng offset holder, which is not allocated from the 2430*da0073e9SAndroid Build Coastguard Worker # capture's private pool). 2431*da0073e9SAndroid Build Coastguard Worker a = torch.ones((numel,), device="cuda") 2432*da0073e9SAndroid Build Coastguard Worker 2433*da0073e9SAndroid Build Coastguard Worker precapture_stats = torch.cuda.memory_stats() 2434*da0073e9SAndroid Build Coastguard Worker 2435*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 2436*da0073e9SAndroid Build Coastguard Worker b = a.clone() 2437*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 2438*da0073e9SAndroid Build Coastguard Worker b = b.clone() + 1 2439*da0073e9SAndroid Build Coastguard Worker g.capture_end() 2440*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s) 2441*da0073e9SAndroid Build Coastguard Worker 2442*da0073e9SAndroid Build Coastguard Worker gc.collect() 2443*da0073e9SAndroid Build Coastguard Worker 2444*da0073e9SAndroid Build Coastguard Worker postcapture_stats = torch.cuda.memory_stats() 2445*da0073e9SAndroid Build Coastguard Worker 2446*da0073e9SAndroid Build Coastguard Worker expecteds = ( 2447*da0073e9SAndroid Build Coastguard Worker delta_cudaMallocs, 2448*da0073e9SAndroid Build Coastguard Worker delta_cudaMalloc_bytes, 2449*da0073e9SAndroid Build Coastguard Worker delta_active_blocks, 2450*da0073e9SAndroid Build Coastguard Worker delta_active_bytes, 2451*da0073e9SAndroid Build Coastguard Worker ) 2452*da0073e9SAndroid Build Coastguard Worker # Double checks replay and stats before and after a call to empty_cache 2453*da0073e9SAndroid Build Coastguard Worker for i in range(2): 2454*da0073e9SAndroid Build Coastguard Worker for stat, expected in zip(stats_to_check, expecteds): 2455*da0073e9SAndroid Build Coastguard Worker stat = stat + pool_string + ".current" 2456*da0073e9SAndroid Build Coastguard Worker current = postcapture_stats[stat] - precapture_stats[stat] 2457*da0073e9SAndroid Build Coastguard Worker 2458*da0073e9SAndroid Build Coastguard Worker # There will only ever be one expandable segment in each of the small and large pools. The way the 2459*da0073e9SAndroid Build Coastguard Worker # bookeeping is done in the allocator means that we never increment the number of segments. 2460*da0073e9SAndroid Build Coastguard Worker if self.expandable_segments and "segment" in stat: 2461*da0073e9SAndroid Build Coastguard Worker expected = 0 2462*da0073e9SAndroid Build Coastguard Worker # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an 2463*da0073e9SAndroid Build Coastguard Worker # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is 2464*da0073e9SAndroid Build Coastguard Worker # smaller than the page size 2465*da0073e9SAndroid Build Coastguard Worker if ( 2466*da0073e9SAndroid Build Coastguard Worker self.expandable_segments 2467*da0073e9SAndroid Build Coastguard Worker and "reserved" in stat 2468*da0073e9SAndroid Build Coastguard Worker and (numel == cases[3][0] or numel == cases[4][0]) 2469*da0073e9SAndroid Build Coastguard Worker ): 2470*da0073e9SAndroid Build Coastguard Worker expected = 2 * kLargeBuffer 2471*da0073e9SAndroid Build Coastguard Worker 2472*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2473*da0073e9SAndroid Build Coastguard Worker current, 2474*da0073e9SAndroid Build Coastguard Worker expected, 2475*da0073e9SAndroid Build Coastguard Worker "Pre to post capture delta of " 2476*da0073e9SAndroid Build Coastguard Worker + stat 2477*da0073e9SAndroid Build Coastguard Worker + f" = {current}, expected = {expected}, numel = {numel}", 2478*da0073e9SAndroid Build Coastguard Worker ) 2479*da0073e9SAndroid Build Coastguard Worker 2480*da0073e9SAndroid Build Coastguard Worker g.replay() 2481*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.sum().item(), 6 * numel) 2482*da0073e9SAndroid Build Coastguard Worker if i == 0: 2483*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2484*da0073e9SAndroid Build Coastguard Worker 2485*da0073e9SAndroid Build Coastguard Worker del g 2486*da0073e9SAndroid Build Coastguard Worker gc.collect() 2487*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2488*da0073e9SAndroid Build Coastguard Worker postdel_stats = torch.cuda.memory_stats() 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker # Uses graph result b after graph has been deleted 2491*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.sum().item(), 6 * numel) 2492*da0073e9SAndroid Build Coastguard Worker 2493*da0073e9SAndroid Build Coastguard Worker # b should be the only live reference remaining from the graph's private pool 2494*da0073e9SAndroid Build Coastguard Worker expecteds = (1, delta_cudaMalloc_bytes_post_del_g, 1, numel * elem) 2495*da0073e9SAndroid Build Coastguard Worker for stat, expected in zip(stats_to_check, expecteds): 2496*da0073e9SAndroid Build Coastguard Worker stat = stat + pool_string + ".current" 2497*da0073e9SAndroid Build Coastguard Worker current = postdel_stats[stat] - precapture_stats[stat] 2498*da0073e9SAndroid Build Coastguard Worker 2499*da0073e9SAndroid Build Coastguard Worker # There will only ever be one expandable segment in each of the small and large pools. The way the 2500*da0073e9SAndroid Build Coastguard Worker # bookeeping is done in the allocator means that we never increment the number of segments. 2501*da0073e9SAndroid Build Coastguard Worker if self.expandable_segments and "segment" in stat: 2502*da0073e9SAndroid Build Coastguard Worker expected = 0 2503*da0073e9SAndroid Build Coastguard Worker # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an 2504*da0073e9SAndroid Build Coastguard Worker # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is 2505*da0073e9SAndroid Build Coastguard Worker # smaller than the page size 2506*da0073e9SAndroid Build Coastguard Worker if ( 2507*da0073e9SAndroid Build Coastguard Worker self.expandable_segments 2508*da0073e9SAndroid Build Coastguard Worker and "reserved" in stat 2509*da0073e9SAndroid Build Coastguard Worker and numel == cases[3][0] 2510*da0073e9SAndroid Build Coastguard Worker ): 2511*da0073e9SAndroid Build Coastguard Worker expected = 2 * kLargeBuffer 2512*da0073e9SAndroid Build Coastguard Worker if ( 2513*da0073e9SAndroid Build Coastguard Worker self.expandable_segments 2514*da0073e9SAndroid Build Coastguard Worker and "reserved" in stat 2515*da0073e9SAndroid Build Coastguard Worker and numel == cases[4][0] 2516*da0073e9SAndroid Build Coastguard Worker ): 2517*da0073e9SAndroid Build Coastguard Worker expected = kLargeBuffer 2518*da0073e9SAndroid Build Coastguard Worker 2519*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2520*da0073e9SAndroid Build Coastguard Worker current, 2521*da0073e9SAndroid Build Coastguard Worker expected, 2522*da0073e9SAndroid Build Coastguard Worker "Pre capture to post graph delete delta of " 2523*da0073e9SAndroid Build Coastguard Worker + stat 2524*da0073e9SAndroid Build Coastguard Worker + f" = {current}, expected = {expected}, numel = {numel}", 2525*da0073e9SAndroid Build Coastguard Worker ) 2526*da0073e9SAndroid Build Coastguard Worker 2527*da0073e9SAndroid Build Coastguard Worker # del a, b before the next case is essential, otherwise overwriting a and b in the next case 2528*da0073e9SAndroid Build Coastguard Worker # can throw off its allocation/deallocation counts. 2529*da0073e9SAndroid Build Coastguard Worker del a, b 2530*da0073e9SAndroid Build Coastguard Worker # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. 2531*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 2532*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2533*da0073e9SAndroid Build Coastguard Worker 2534*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2535*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2536*da0073e9SAndroid Build Coastguard Worker ) 2537*da0073e9SAndroid Build Coastguard Worker def test_graph_record_stream(self): 2538*da0073e9SAndroid Build Coastguard Worker # Makes sure graph capture defers attempting to reclaim allocations used across streams. See 2539*da0073e9SAndroid Build Coastguard Worker # "Q. Why skip process_events if a capture might be underway?" in c10/cuda/CUDACachingAllocator.cpp 2540*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2541*da0073e9SAndroid Build Coastguard Worker 2542*da0073e9SAndroid Build Coastguard Worker potential_problem = torch.zeros((3,), device="cuda") 2543*da0073e9SAndroid Build Coastguard Worker a = torch.zeros((3,), device="cuda") 2544*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.Stream() 2545*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream() 2546*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.Stream() 2547*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 2548*da0073e9SAndroid Build Coastguard Worker 2549*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 2550*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s0): 2551*da0073e9SAndroid Build Coastguard Worker potential_problem.record_stream(s0) 2552*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(TestCuda.FIFTY_MIL_CYCLES) 2553*da0073e9SAndroid Build Coastguard Worker potential_problem.fill_(1.0) 2554*da0073e9SAndroid Build Coastguard Worker del potential_problem 2555*da0073e9SAndroid Build Coastguard Worker 2556*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 2557*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 2558*da0073e9SAndroid Build Coastguard Worker # potential_problem's allocation should still be outstanding. if DeviceCachingAllocator::malloc 2559*da0073e9SAndroid Build Coastguard Worker # mistakenly calls process_events, it will trigger cudaEventQueries on potential_problem's end-of-life 2560*da0073e9SAndroid Build Coastguard Worker # event, which will cause the capture to error. 2561*da0073e9SAndroid Build Coastguard Worker b = a.clone() 2562*da0073e9SAndroid Build Coastguard Worker 2563*da0073e9SAndroid Build Coastguard Worker # Let's also see what happens if we record_stream on a tensor during capture. 2564*da0073e9SAndroid Build Coastguard Worker s2.wait_stream(s1) 2565*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s2): 2566*da0073e9SAndroid Build Coastguard Worker b.fill_(1.0) 2567*da0073e9SAndroid Build Coastguard Worker b.record_stream(s2) # dummy record_stream 2568*da0073e9SAndroid Build Coastguard Worker del b 2569*da0073e9SAndroid Build Coastguard Worker s1.wait_stream(s2) 2570*da0073e9SAndroid Build Coastguard Worker g.capture_end() 2571*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 2572*da0073e9SAndroid Build Coastguard Worker 2573*da0073e9SAndroid Build Coastguard Worker # dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event. 2574*da0073e9SAndroid Build Coastguard Worker c = torch.zeros((3,), device="cuda") 2575*da0073e9SAndroid Build Coastguard Worker 2576*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 2577*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2578*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2579*da0073e9SAndroid Build Coastguard Worker ) 2580*da0073e9SAndroid Build Coastguard Worker # If this test is the first in the process to try cudnn rnns with dropout, it'll initialize 2581*da0073e9SAndroid Build Coastguard Worker # DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior 2582*da0073e9SAndroid Build Coastguard Worker # as a memory leak unless we skip the leak check. 2583*da0073e9SAndroid Build Coastguard Worker @skipCUDAMemoryLeakCheckIf(True) 2584*da0073e9SAndroid Build Coastguard Worker @serialTest() 2585*da0073e9SAndroid Build Coastguard Worker def test_graph_cudnn_dropout(self): 2586*da0073e9SAndroid Build Coastguard Worker # Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp. 2587*da0073e9SAndroid Build Coastguard Worker # In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should 2588*da0073e9SAndroid Build Coastguard Worker # avoid syncing noncapturing streams with captured events or vice versa. 2589*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 2590*da0073e9SAndroid Build Coastguard Worker 2591*da0073e9SAndroid Build Coastguard Worker model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda() 2592*da0073e9SAndroid Build Coastguard Worker x = torch.ones(100, 192, 512, device="cuda") 2593*da0073e9SAndroid Build Coastguard Worker 2594*da0073e9SAndroid Build Coastguard Worker y = model(x) 2595*da0073e9SAndroid Build Coastguard Worker 2596*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 2597*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 2598*da0073e9SAndroid Build Coastguard Worker s.wait_stream(torch.cuda.current_stream()) 2599*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 2600*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 2601*da0073e9SAndroid Build Coastguard Worker y = model(x) 2602*da0073e9SAndroid Build Coastguard Worker g.capture_end() 2603*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s) 2604*da0073e9SAndroid Build Coastguard Worker 2605*da0073e9SAndroid Build Coastguard Worker g.replay() 2606*da0073e9SAndroid Build Coastguard Worker 2607*da0073e9SAndroid Build Coastguard Worker y = model(x) 2608*da0073e9SAndroid Build Coastguard Worker 2609*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2610*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2611*da0073e9SAndroid Build Coastguard Worker ) 2612*da0073e9SAndroid Build Coastguard Worker @parametrize( 2613*da0073e9SAndroid Build Coastguard Worker "with_amp,cache_enabled,allow_unused_input", 2614*da0073e9SAndroid Build Coastguard Worker [ 2615*da0073e9SAndroid Build Coastguard Worker subtest((False, False, True), decorators=[skipIfRocm]), 2616*da0073e9SAndroid Build Coastguard Worker subtest((True, False, True), decorators=[skipIfRocm]), 2617*da0073e9SAndroid Build Coastguard Worker subtest((True, True, True), decorators=[unittest.expectedFailure]), 2618*da0073e9SAndroid Build Coastguard Worker subtest((False, False, False), decorators=[unittest.expectedFailure]), 2619*da0073e9SAndroid Build Coastguard Worker ], 2620*da0073e9SAndroid Build Coastguard Worker name_fn=lambda x, y, z: "{}{}{}".format( 2621*da0073e9SAndroid Build Coastguard Worker {True: "with_amp", False: "without_amp"}[x], 2622*da0073e9SAndroid Build Coastguard Worker {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", 2623*da0073e9SAndroid Build Coastguard Worker {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], 2624*da0073e9SAndroid Build Coastguard Worker ), 2625*da0073e9SAndroid Build Coastguard Worker ) 2626*da0073e9SAndroid Build Coastguard Worker @serialTest() 2627*da0073e9SAndroid Build Coastguard Worker def test_graph_make_graphed_callables( 2628*da0073e9SAndroid Build Coastguard Worker self, with_amp, cache_enabled, allow_unused_input 2629*da0073e9SAndroid Build Coastguard Worker ): 2630*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(5) 2631*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(5) 2632*da0073e9SAndroid Build Coastguard Worker 2633*da0073e9SAndroid Build Coastguard Worker N, D_in, H, D_out = 640, 4096, 2048, 1024 2634*da0073e9SAndroid Build Coastguard Worker 2635*da0073e9SAndroid Build Coastguard Worker class MLP1(torch.nn.Module): 2636*da0073e9SAndroid Build Coastguard Worker def __init__(self, D_in: int, H: int, D_out: int): 2637*da0073e9SAndroid Build Coastguard Worker super().__init__() 2638*da0073e9SAndroid Build Coastguard Worker self.net_1 = torch.nn.Sequential( 2639*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) 2640*da0073e9SAndroid Build Coastguard Worker ).cuda() 2641*da0073e9SAndroid Build Coastguard Worker self.net_2 = torch.nn.Sequential( 2642*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) 2643*da0073e9SAndroid Build Coastguard Worker ).cuda() 2644*da0073e9SAndroid Build Coastguard Worker 2645*da0073e9SAndroid Build Coastguard Worker def forward(self, input_dict: dict): 2646*da0073e9SAndroid Build Coastguard Worker x = input_dict["x"] 2647*da0073e9SAndroid Build Coastguard Worker return self.net_2(self.net_1(x)) 2648*da0073e9SAndroid Build Coastguard Worker 2649*da0073e9SAndroid Build Coastguard Worker class MLP2(torch.nn.Module): 2650*da0073e9SAndroid Build Coastguard Worker def __init__(self, D_in: int, H: int, D_out: int): 2651*da0073e9SAndroid Build Coastguard Worker super().__init__() 2652*da0073e9SAndroid Build Coastguard Worker self.net_1 = torch.nn.Sequential( 2653*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) 2654*da0073e9SAndroid Build Coastguard Worker ).cuda() 2655*da0073e9SAndroid Build Coastguard Worker self.net_2 = torch.nn.Sequential( 2656*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) 2657*da0073e9SAndroid Build Coastguard Worker ).cuda() 2658*da0073e9SAndroid Build Coastguard Worker 2659*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2660*da0073e9SAndroid Build Coastguard Worker return self.net_2(self.net_1(x)) 2661*da0073e9SAndroid Build Coastguard Worker 2662*da0073e9SAndroid Build Coastguard Worker class ParameterlessModule(torch.nn.Module): 2663*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2664*da0073e9SAndroid Build Coastguard Worker idx = ( 2665*da0073e9SAndroid Build Coastguard Worker torch.arange(x.size(0), device=x.device) 2666*da0073e9SAndroid Build Coastguard Worker .view(-1, 1) 2667*da0073e9SAndroid Build Coastguard Worker .repeat(1, x.size(1)) 2668*da0073e9SAndroid Build Coastguard Worker ) 2669*da0073e9SAndroid Build Coastguard Worker return {"output": torch.gather(x, 0, idx)} 2670*da0073e9SAndroid Build Coastguard Worker 2671*da0073e9SAndroid Build Coastguard Worker models = [] 2672*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 2673*da0073e9SAndroid Build Coastguard Worker model_section1 = MLP1(D_in, H, H).cuda() 2674*da0073e9SAndroid Build Coastguard Worker model_section2 = MLP2(H, H, D_out).cuda() 2675*da0073e9SAndroid Build Coastguard Worker model_section3 = ParameterlessModule().cuda() 2676*da0073e9SAndroid Build Coastguard Worker models.append( 2677*da0073e9SAndroid Build Coastguard Worker torch.nn.Sequential(model_section1, model_section2, model_section3) 2678*da0073e9SAndroid Build Coastguard Worker ) 2679*da0073e9SAndroid Build Coastguard Worker 2680*da0073e9SAndroid Build Coastguard Worker model_graphed = models[0] 2681*da0073e9SAndroid Build Coastguard Worker model_control = models[1] 2682*da0073e9SAndroid Build Coastguard Worker 2683*da0073e9SAndroid Build Coastguard Worker model_graphed.load_state_dict(model_control.state_dict()) 2684*da0073e9SAndroid Build Coastguard Worker 2685*da0073e9SAndroid Build Coastguard Worker opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1) 2686*da0073e9SAndroid Build Coastguard Worker opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1) 2687*da0073e9SAndroid Build Coastguard Worker 2688*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, D_in, device="cuda") 2689*da0073e9SAndroid Build Coastguard Worker h = torch.randn(N, H, device="cuda", requires_grad=True) 2690*da0073e9SAndroid Build Coastguard Worker h2 = torch.randn(N, D_out, device="cuda", requires_grad=True) 2691*da0073e9SAndroid Build Coastguard Worker unused_input = torch.randn(N, H, device="cuda", requires_grad=True) 2692*da0073e9SAndroid Build Coastguard Worker y_pred = torch.randn(N, D_out, device="cuda", requires_grad=True) 2693*da0073e9SAndroid Build Coastguard Worker y = torch.randn(N, D_out, device="cuda") 2694*da0073e9SAndroid Build Coastguard Worker 2695*da0073e9SAndroid Build Coastguard Worker loss_fn_control = torch.nn.functional.mse_loss 2696*da0073e9SAndroid Build Coastguard Worker relu_control = torch.nn.functional.relu 2697*da0073e9SAndroid Build Coastguard Worker 2698*da0073e9SAndroid Build Coastguard Worker # This is a good stress test. It graphs four callables: two Modules and two python functions. 2699*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast( 2700*da0073e9SAndroid Build Coastguard Worker device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled 2701*da0073e9SAndroid Build Coastguard Worker ): 2702*da0073e9SAndroid Build Coastguard Worker ( 2703*da0073e9SAndroid Build Coastguard Worker model_graphed[0], 2704*da0073e9SAndroid Build Coastguard Worker model_graphed[1], 2705*da0073e9SAndroid Build Coastguard Worker model_graphed[2], 2706*da0073e9SAndroid Build Coastguard Worker relu_graphed, 2707*da0073e9SAndroid Build Coastguard Worker loss_fn_graphed, 2708*da0073e9SAndroid Build Coastguard Worker ) = torch.cuda.make_graphed_callables( 2709*da0073e9SAndroid Build Coastguard Worker ( 2710*da0073e9SAndroid Build Coastguard Worker model_graphed[0], 2711*da0073e9SAndroid Build Coastguard Worker model_graphed[1], 2712*da0073e9SAndroid Build Coastguard Worker model_graphed[2], 2713*da0073e9SAndroid Build Coastguard Worker relu_control, 2714*da0073e9SAndroid Build Coastguard Worker loss_fn_control, 2715*da0073e9SAndroid Build Coastguard Worker ), 2716*da0073e9SAndroid Build Coastguard Worker ( 2717*da0073e9SAndroid Build Coastguard Worker ({"x": x, "unused_input": unused_input},), 2718*da0073e9SAndroid Build Coastguard Worker (h,), 2719*da0073e9SAndroid Build Coastguard Worker (h2,), 2720*da0073e9SAndroid Build Coastguard Worker (y_pred,), 2721*da0073e9SAndroid Build Coastguard Worker (y_pred, y), 2722*da0073e9SAndroid Build Coastguard Worker ), 2723*da0073e9SAndroid Build Coastguard Worker allow_unused_input=allow_unused_input, 2724*da0073e9SAndroid Build Coastguard Worker ) 2725*da0073e9SAndroid Build Coastguard Worker 2726*da0073e9SAndroid Build Coastguard Worker real_inputs = [torch.rand_like(x) for _ in range(10)] 2727*da0073e9SAndroid Build Coastguard Worker real_targets = [torch.rand_like(y) for _ in range(10)] 2728*da0073e9SAndroid Build Coastguard Worker 2729*da0073e9SAndroid Build Coastguard Worker for m, opt, relu, loss_fn in zip( 2730*da0073e9SAndroid Build Coastguard Worker (model_graphed, model_control), 2731*da0073e9SAndroid Build Coastguard Worker (opt_graphed, opt_control), 2732*da0073e9SAndroid Build Coastguard Worker (relu_graphed, relu_control), 2733*da0073e9SAndroid Build Coastguard Worker (loss_fn_graphed, loss_fn_control), 2734*da0073e9SAndroid Build Coastguard Worker ): 2735*da0073e9SAndroid Build Coastguard Worker # Resets RNC states before iterations for graphed and ungraphed models, 2736*da0073e9SAndroid Build Coastguard Worker # so dropout math should be bitwise identical for both. 2737*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(5) 2738*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(5) 2739*da0073e9SAndroid Build Coastguard Worker for data, target in zip(real_inputs, real_targets): 2740*da0073e9SAndroid Build Coastguard Worker opt.zero_grad(set_to_none=True) 2741*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast( 2742*da0073e9SAndroid Build Coastguard Worker device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled 2743*da0073e9SAndroid Build Coastguard Worker ): 2744*da0073e9SAndroid Build Coastguard Worker y_pred = m({"x": data, "unused_input": unused_input})["output"] 2745*da0073e9SAndroid Build Coastguard Worker y_pred = relu(y_pred) 2746*da0073e9SAndroid Build Coastguard Worker loss = loss_fn(y_pred, target) 2747*da0073e9SAndroid Build Coastguard Worker loss.backward() 2748*da0073e9SAndroid Build Coastguard Worker opt.step() 2749*da0073e9SAndroid Build Coastguard Worker 2750*da0073e9SAndroid Build Coastguard Worker for p, pc in zip(model_graphed.parameters(), model_control.parameters()): 2751*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p, pc) 2752*da0073e9SAndroid Build Coastguard Worker 2753*da0073e9SAndroid Build Coastguard Worker # We graphed the models in training mode. Eval should still run ungraphed. 2754*da0073e9SAndroid Build Coastguard Worker model_graphed.eval() 2755*da0073e9SAndroid Build Coastguard Worker model_control.eval() 2756*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2757*da0073e9SAndroid Build Coastguard Worker model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) 2758*da0073e9SAndroid Build Coastguard Worker ) 2759*da0073e9SAndroid Build Coastguard Worker 2760*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2761*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2762*da0073e9SAndroid Build Coastguard Worker ) 2763*da0073e9SAndroid Build Coastguard Worker @parametrize( 2764*da0073e9SAndroid Build Coastguard Worker "with_amp,cache_enabled,allow_unused_input", 2765*da0073e9SAndroid Build Coastguard Worker [ 2766*da0073e9SAndroid Build Coastguard Worker subtest((False, False, True), decorators=[skipIfRocm]), 2767*da0073e9SAndroid Build Coastguard Worker subtest((True, False, True), decorators=[skipIfRocm]), 2768*da0073e9SAndroid Build Coastguard Worker subtest((True, True, True), decorators=[unittest.expectedFailure]), 2769*da0073e9SAndroid Build Coastguard Worker subtest((False, False, False), decorators=[skipIfRocm]), 2770*da0073e9SAndroid Build Coastguard Worker ], 2771*da0073e9SAndroid Build Coastguard Worker name_fn=lambda x, y, z: "{}{}{}".format( 2772*da0073e9SAndroid Build Coastguard Worker {True: "with_amp", False: "without_amp"}[x], 2773*da0073e9SAndroid Build Coastguard Worker {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", 2774*da0073e9SAndroid Build Coastguard Worker {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], 2775*da0073e9SAndroid Build Coastguard Worker ), 2776*da0073e9SAndroid Build Coastguard Worker ) 2777*da0073e9SAndroid Build Coastguard Worker @serialTest() 2778*da0073e9SAndroid Build Coastguard Worker def test_graph_make_graphed_callables_parameterless_nograd_module( 2779*da0073e9SAndroid Build Coastguard Worker self, with_amp, cache_enabled, allow_unused_input 2780*da0073e9SAndroid Build Coastguard Worker ): 2781*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(5) 2782*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(5) 2783*da0073e9SAndroid Build Coastguard Worker 2784*da0073e9SAndroid Build Coastguard Worker N, D_in, H, D_out = 640, 4096, 2048, 1024 2785*da0073e9SAndroid Build Coastguard Worker 2786*da0073e9SAndroid Build Coastguard Worker class ParameterlessModule(torch.nn.Module): 2787*da0073e9SAndroid Build Coastguard Worker def forward(self, input_dict: dict): 2788*da0073e9SAndroid Build Coastguard Worker x = input_dict["x"] 2789*da0073e9SAndroid Build Coastguard Worker idx = ( 2790*da0073e9SAndroid Build Coastguard Worker torch.arange(x.size(0), device=x.device) 2791*da0073e9SAndroid Build Coastguard Worker .view(-1, 1) 2792*da0073e9SAndroid Build Coastguard Worker .repeat(1, x.size(1)) 2793*da0073e9SAndroid Build Coastguard Worker ) 2794*da0073e9SAndroid Build Coastguard Worker return {"output": torch.gather(x, 0, idx)} 2795*da0073e9SAndroid Build Coastguard Worker 2796*da0073e9SAndroid Build Coastguard Worker models = [] 2797*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 2798*da0073e9SAndroid Build Coastguard Worker model_section1 = ParameterlessModule().cuda() 2799*da0073e9SAndroid Build Coastguard Worker models.append(torch.nn.Sequential(model_section1)) 2800*da0073e9SAndroid Build Coastguard Worker 2801*da0073e9SAndroid Build Coastguard Worker model_graphed = models[0] 2802*da0073e9SAndroid Build Coastguard Worker model_control = models[1] 2803*da0073e9SAndroid Build Coastguard Worker 2804*da0073e9SAndroid Build Coastguard Worker model_graphed.load_state_dict(model_control.state_dict()) 2805*da0073e9SAndroid Build Coastguard Worker 2806*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, D_in, device="cuda", requires_grad=False) 2807*da0073e9SAndroid Build Coastguard Worker unused_input = torch.randn(N, H, device="cuda", requires_grad=False) 2808*da0073e9SAndroid Build Coastguard Worker y_pred = torch.randn(N, D_in, device="cuda", requires_grad=False) 2809*da0073e9SAndroid Build Coastguard Worker y = torch.randn(N, D_in, device="cuda") 2810*da0073e9SAndroid Build Coastguard Worker 2811*da0073e9SAndroid Build Coastguard Worker # This is a good stress test. It graphs four callables: two Modules and two python functions. 2812*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast( 2813*da0073e9SAndroid Build Coastguard Worker device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled 2814*da0073e9SAndroid Build Coastguard Worker ): 2815*da0073e9SAndroid Build Coastguard Worker model_graphed[0] = torch.cuda.make_graphed_callables( 2816*da0073e9SAndroid Build Coastguard Worker model_graphed[0], 2817*da0073e9SAndroid Build Coastguard Worker ({"x": x, "unused_input": unused_input},), 2818*da0073e9SAndroid Build Coastguard Worker allow_unused_input=allow_unused_input, 2819*da0073e9SAndroid Build Coastguard Worker ) 2820*da0073e9SAndroid Build Coastguard Worker 2821*da0073e9SAndroid Build Coastguard Worker real_inputs = [torch.rand_like(x, requires_grad=True) for _ in range(10)] 2822*da0073e9SAndroid Build Coastguard Worker real_targets = [torch.rand_like(y) for _ in range(10)] 2823*da0073e9SAndroid Build Coastguard Worker 2824*da0073e9SAndroid Build Coastguard Worker for m in (model_graphed, model_control): 2825*da0073e9SAndroid Build Coastguard Worker # Resets RNC states before iterations for graphed and ungraphed models, 2826*da0073e9SAndroid Build Coastguard Worker # so dropout math should be bitwise identical for both. 2827*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(5) 2828*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(5) 2829*da0073e9SAndroid Build Coastguard Worker for data, _ in zip(real_inputs, real_targets): 2830*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast( 2831*da0073e9SAndroid Build Coastguard Worker device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled 2832*da0073e9SAndroid Build Coastguard Worker ): 2833*da0073e9SAndroid Build Coastguard Worker out = m({"x": data, "unused_input": unused_input})["output"] 2834*da0073e9SAndroid Build Coastguard Worker 2835*da0073e9SAndroid Build Coastguard Worker # We graphed the models in training mode. Eval should still run ungraphed. 2836*da0073e9SAndroid Build Coastguard Worker model_graphed.eval() 2837*da0073e9SAndroid Build Coastguard Worker model_control.eval() 2838*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2839*da0073e9SAndroid Build Coastguard Worker model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) 2840*da0073e9SAndroid Build Coastguard Worker ) 2841*da0073e9SAndroid Build Coastguard Worker 2842*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2843*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2844*da0073e9SAndroid Build Coastguard Worker ) 2845*da0073e9SAndroid Build Coastguard Worker def test_graph_make_graphed_callables_same_pool(self): 2846*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(5) 2847*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(5) 2848*da0073e9SAndroid Build Coastguard Worker models = [] 2849*da0073e9SAndroid Build Coastguard Worker num_models = 3 2850*da0073e9SAndroid Build Coastguard Worker for _ in range(num_models): 2851*da0073e9SAndroid Build Coastguard Worker models.append( 2852*da0073e9SAndroid Build Coastguard Worker torch.nn.Sequential( 2853*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(32, 128), 2854*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 2855*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(128, 128), 2856*da0073e9SAndroid Build Coastguard Worker ).cuda() 2857*da0073e9SAndroid Build Coastguard Worker ) 2858*da0073e9SAndroid Build Coastguard Worker # we will reuse the same pool for all graph captures 2859*da0073e9SAndroid Build Coastguard Worker mempool = torch.cuda.graph_pool_handle() 2860*da0073e9SAndroid Build Coastguard Worker graphed_models = [] 2861*da0073e9SAndroid Build Coastguard Worker for model in models: 2862*da0073e9SAndroid Build Coastguard Worker x = torch.randn([64, 32], device="cuda") 2863*da0073e9SAndroid Build Coastguard Worker graphed_model = deepcopy(model) 2864*da0073e9SAndroid Build Coastguard Worker graphed_model = torch.cuda.make_graphed_callables( 2865*da0073e9SAndroid Build Coastguard Worker graphed_model, (x,), pool=mempool 2866*da0073e9SAndroid Build Coastguard Worker ) 2867*da0073e9SAndroid Build Coastguard Worker graphed_models.append(graphed_model) 2868*da0073e9SAndroid Build Coastguard Worker 2869*da0073e9SAndroid Build Coastguard Worker for model, graphed_model in zip(models, graphed_models): 2870*da0073e9SAndroid Build Coastguard Worker x = torch.randn([64, 32], device="cuda") 2871*da0073e9SAndroid Build Coastguard Worker y = model(x) 2872*da0073e9SAndroid Build Coastguard Worker yg = graphed_model(x) 2873*da0073e9SAndroid Build Coastguard Worker l = y.norm() 2874*da0073e9SAndroid Build Coastguard Worker lg = yg.norm() 2875*da0073e9SAndroid Build Coastguard Worker l.backward() 2876*da0073e9SAndroid Build Coastguard Worker lg.backward() 2877*da0073e9SAndroid Build Coastguard Worker 2878*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, yg) 2879*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l, lg) 2880*da0073e9SAndroid Build Coastguard Worker for p, pg in zip(model.parameters(), graphed_model.parameters()): 2881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p, pg) 2882*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p.grad, pg.grad) 2883*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(p.data_ptr(), pg.data_ptr()) 2884*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(p.grad.data_ptr(), pg.grad.data_ptr()) 2885*da0073e9SAndroid Build Coastguard Worker 2886*da0073e9SAndroid Build Coastguard Worker def _test_graphed_optimizer( 2887*da0073e9SAndroid Build Coastguard Worker self, steps_warmup, steps_train, optimizer_ctor, kwargs 2888*da0073e9SAndroid Build Coastguard Worker ): 2889*da0073e9SAndroid Build Coastguard Worker for actually_do_graphs in (True, False): 2890*da0073e9SAndroid Build Coastguard Worker params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + [ 2891*da0073e9SAndroid Build Coastguard Worker torch.randn((), device="cuda") 2892*da0073e9SAndroid Build Coastguard Worker ] 2893*da0073e9SAndroid Build Coastguard Worker params_control = [p.clone().requires_grad_() for p in params] 2894*da0073e9SAndroid Build Coastguard Worker params_graphed = [p.clone().requires_grad_() for p in params] 2895*da0073e9SAndroid Build Coastguard Worker 2896*da0073e9SAndroid Build Coastguard Worker grads = [ 2897*da0073e9SAndroid Build Coastguard Worker [torch.randn_like(p) for p in params] 2898*da0073e9SAndroid Build Coastguard Worker for _ in range(steps_warmup + steps_train) 2899*da0073e9SAndroid Build Coastguard Worker ] 2900*da0073e9SAndroid Build Coastguard Worker 2901*da0073e9SAndroid Build Coastguard Worker # Control (capturable=False) 2902*da0073e9SAndroid Build Coastguard Worker 2903*da0073e9SAndroid Build Coastguard Worker opt = optimizer_ctor(params_control, capturable=False, **kwargs) 2904*da0073e9SAndroid Build Coastguard Worker 2905*da0073e9SAndroid Build Coastguard Worker for i in range(steps_warmup + steps_train): 2906*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_control): 2907*da0073e9SAndroid Build Coastguard Worker p.grad = grads[i][j] 2908*da0073e9SAndroid Build Coastguard Worker opt.step() 2909*da0073e9SAndroid Build Coastguard Worker 2910*da0073e9SAndroid Build Coastguard Worker # capturable=True 2911*da0073e9SAndroid Build Coastguard Worker 2912*da0073e9SAndroid Build Coastguard Worker opt = optimizer_ctor(params_graphed, capturable=True, **kwargs) 2913*da0073e9SAndroid Build Coastguard Worker 2914*da0073e9SAndroid Build Coastguard Worker for i in range(steps_warmup): 2915*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_graphed): 2916*da0073e9SAndroid Build Coastguard Worker p.grad = grads[i][j] 2917*da0073e9SAndroid Build Coastguard Worker opt.step() 2918*da0073e9SAndroid Build Coastguard Worker 2919*da0073e9SAndroid Build Coastguard Worker if actually_do_graphs: 2920*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 2921*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(g): 2922*da0073e9SAndroid Build Coastguard Worker opt.step() 2923*da0073e9SAndroid Build Coastguard Worker 2924*da0073e9SAndroid Build Coastguard Worker for i in range(steps_train): 2925*da0073e9SAndroid Build Coastguard Worker if actually_do_graphs: 2926*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_graphed): 2927*da0073e9SAndroid Build Coastguard Worker p.grad.copy_(grads[i + steps_warmup][j]) 2928*da0073e9SAndroid Build Coastguard Worker g.replay() 2929*da0073e9SAndroid Build Coastguard Worker else: 2930*da0073e9SAndroid Build Coastguard Worker # Passing capturable=True to the constructor and running without graphs should still be 2931*da0073e9SAndroid Build Coastguard Worker # numerically correct, even if it's not ideal for performance. 2932*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_graphed): 2933*da0073e9SAndroid Build Coastguard Worker p.grad = grads[i + steps_warmup][j] 2934*da0073e9SAndroid Build Coastguard Worker opt.step() 2935*da0073e9SAndroid Build Coastguard Worker 2936*da0073e9SAndroid Build Coastguard Worker for p_control, p_graphed in zip(params_control, params_graphed): 2937*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p_control, p_graphed) 2938*da0073e9SAndroid Build Coastguard Worker 2939*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2940*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 2941*da0073e9SAndroid Build Coastguard Worker ) 2942*da0073e9SAndroid Build Coastguard Worker def test_graph_optims_with_explicitly_capturable_param_groups(self): 2943*da0073e9SAndroid Build Coastguard Worker # mimicking `_test_graphed_optimizer` maladroitly to pass two param_groups to optimizer.__init__ 2944*da0073e9SAndroid Build Coastguard Worker n_warmup, n_replay = 3, 2 2945*da0073e9SAndroid Build Coastguard Worker for optimizer, second_param_group_capturable in product( 2946*da0073e9SAndroid Build Coastguard Worker ( 2947*da0073e9SAndroid Build Coastguard Worker torch.optim.Adam, 2948*da0073e9SAndroid Build Coastguard Worker torch.optim.AdamW, 2949*da0073e9SAndroid Build Coastguard Worker torch.optim.ASGD, 2950*da0073e9SAndroid Build Coastguard Worker torch.optim.Adamax, 2951*da0073e9SAndroid Build Coastguard Worker torch.optim.NAdam, 2952*da0073e9SAndroid Build Coastguard Worker torch.optim.RAdam, 2953*da0073e9SAndroid Build Coastguard Worker torch.optim.Adadelta, 2954*da0073e9SAndroid Build Coastguard Worker torch.optim.RMSprop, 2955*da0073e9SAndroid Build Coastguard Worker torch.optim.Rprop, 2956*da0073e9SAndroid Build Coastguard Worker ), 2957*da0073e9SAndroid Build Coastguard Worker (True, False), 2958*da0073e9SAndroid Build Coastguard Worker ): 2959*da0073e9SAndroid Build Coastguard Worker ref_p1, param1 = ( 2960*da0073e9SAndroid Build Coastguard Worker torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) 2961*da0073e9SAndroid Build Coastguard Worker ) 2962*da0073e9SAndroid Build Coastguard Worker ref_p2, param2 = ( 2963*da0073e9SAndroid Build Coastguard Worker torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) 2964*da0073e9SAndroid Build Coastguard Worker ) 2965*da0073e9SAndroid Build Coastguard Worker grads1, grads2 = ( 2966*da0073e9SAndroid Build Coastguard Worker [torch.randn_like(param1) for _ in range(n_warmup + n_replay)] 2967*da0073e9SAndroid Build Coastguard Worker for _ in range(2) 2968*da0073e9SAndroid Build Coastguard Worker ) 2969*da0073e9SAndroid Build Coastguard Worker ref_grads1, ref_grads2 = ( 2970*da0073e9SAndroid Build Coastguard Worker [t.clone() for t in tensors] for tensors in (grads1, grads2) 2971*da0073e9SAndroid Build Coastguard Worker ) 2972*da0073e9SAndroid Build Coastguard Worker params = [ 2973*da0073e9SAndroid Build Coastguard Worker {"params": [param1], "capturable": True}, 2974*da0073e9SAndroid Build Coastguard Worker {"params": [param2], "capturable": second_param_group_capturable}, 2975*da0073e9SAndroid Build Coastguard Worker ] 2976*da0073e9SAndroid Build Coastguard Worker opt = optimizer(params) 2977*da0073e9SAndroid Build Coastguard Worker opt_ = optimizer( 2978*da0073e9SAndroid Build Coastguard Worker [ 2979*da0073e9SAndroid Build Coastguard Worker {"params": [ref_p1], "capturable": False}, 2980*da0073e9SAndroid Build Coastguard Worker {"params": [ref_p2], "capturable": False}, 2981*da0073e9SAndroid Build Coastguard Worker ] 2982*da0073e9SAndroid Build Coastguard Worker ) 2983*da0073e9SAndroid Build Coastguard Worker 2984*da0073e9SAndroid Build Coastguard Worker for i in range(n_warmup + n_replay): 2985*da0073e9SAndroid Build Coastguard Worker ref_p1.grad = ref_grads1[i] 2986*da0073e9SAndroid Build Coastguard Worker ref_p2.grad = ref_grads2[i] 2987*da0073e9SAndroid Build Coastguard Worker opt_.step() 2988*da0073e9SAndroid Build Coastguard Worker 2989*da0073e9SAndroid Build Coastguard Worker for i in range(n_warmup): 2990*da0073e9SAndroid Build Coastguard Worker param1.grad = grads1[i] 2991*da0073e9SAndroid Build Coastguard Worker param2.grad = grads2[i] 2992*da0073e9SAndroid Build Coastguard Worker opt.step() 2993*da0073e9SAndroid Build Coastguard Worker 2994*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 2995*da0073e9SAndroid Build Coastguard Worker if not second_param_group_capturable: 2996*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Attempting CUDA graph"): 2997*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(g): 2998*da0073e9SAndroid Build Coastguard Worker opt.step() 2999*da0073e9SAndroid Build Coastguard Worker else: 3000*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(g): 3001*da0073e9SAndroid Build Coastguard Worker opt.step() 3002*da0073e9SAndroid Build Coastguard Worker 3003*da0073e9SAndroid Build Coastguard Worker for i in range(n_replay): 3004*da0073e9SAndroid Build Coastguard Worker param1.grad.copy_(grads1[n_warmup + i]) 3005*da0073e9SAndroid Build Coastguard Worker param2.grad.copy_(grads2[n_warmup + i]) 3006*da0073e9SAndroid Build Coastguard Worker g.replay() 3007*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_p1, param1) 3008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_p2, param2) 3009*da0073e9SAndroid Build Coastguard Worker 3010*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3011*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 3012*da0073e9SAndroid Build Coastguard Worker ) 3013*da0073e9SAndroid Build Coastguard Worker def test_cuda_graph_error_options(self): 3014*da0073e9SAndroid Build Coastguard Worker def fn(): 3015*da0073e9SAndroid Build Coastguard Worker x = torch.zeros([2000], device="cuda") 3016*da0073e9SAndroid Build Coastguard Worker y = x + x + x 3017*da0073e9SAndroid Build Coastguard Worker return y 3018*da0073e9SAndroid Build Coastguard Worker 3019*da0073e9SAndroid Build Coastguard Worker mem = None 3020*da0073e9SAndroid Build Coastguard Worker 3021*da0073e9SAndroid Build Coastguard Worker def raw_malloc(): 3022*da0073e9SAndroid Build Coastguard Worker global mem 3023*da0073e9SAndroid Build Coastguard Worker mem = None 3024*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 3025*da0073e9SAndroid Build Coastguard Worker try: 3026*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 3027*da0073e9SAndroid Build Coastguard Worker mem = torch.cuda.caching_allocator_alloc(1024) 3028*da0073e9SAndroid Build Coastguard Worker except BaseException: 3029*da0073e9SAndroid Build Coastguard Worker if mem is None: 3030*da0073e9SAndroid Build Coastguard Worker return 3031*da0073e9SAndroid Build Coastguard Worker try: 3032*da0073e9SAndroid Build Coastguard Worker torch.cuda.caching_allocator_delete(mem) 3033*da0073e9SAndroid Build Coastguard Worker mem = None 3034*da0073e9SAndroid Build Coastguard Worker return None 3035*da0073e9SAndroid Build Coastguard Worker except BaseException: 3036*da0073e9SAndroid Build Coastguard Worker pass 3037*da0073e9SAndroid Build Coastguard Worker 3038*da0073e9SAndroid Build Coastguard Worker def throws_on_cuda_event(capture_error_mode): 3039*da0073e9SAndroid Build Coastguard Worker graph = torch.cuda.CUDAGraph() 3040*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 3041*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 3042*da0073e9SAndroid Build Coastguard Worker stream.wait_stream(torch.cuda.current_stream()) 3043*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 3044*da0073e9SAndroid Build Coastguard Worker fn() 3045*da0073e9SAndroid Build Coastguard Worker stream.synchronize() 3046*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(stream) 3047*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 3048*da0073e9SAndroid Build Coastguard Worker try: 3049*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph( 3050*da0073e9SAndroid Build Coastguard Worker graph, stream=stream, capture_error_mode=capture_error_mode 3051*da0073e9SAndroid Build Coastguard Worker ): 3052*da0073e9SAndroid Build Coastguard Worker out = fn() 3053*da0073e9SAndroid Build Coastguard Worker thread = threading.Thread(target=raw_malloc) 3054*da0073e9SAndroid Build Coastguard Worker thread.start() 3055*da0073e9SAndroid Build Coastguard Worker thread.join() 3056*da0073e9SAndroid Build Coastguard Worker except Exception: 3057*da0073e9SAndroid Build Coastguard Worker if mem is not None: 3058*da0073e9SAndroid Build Coastguard Worker torch.cuda.caching_allocator_delete(mem) 3059*da0073e9SAndroid Build Coastguard Worker return True 3060*da0073e9SAndroid Build Coastguard Worker 3061*da0073e9SAndroid Build Coastguard Worker return False 3062*da0073e9SAndroid Build Coastguard Worker 3063*da0073e9SAndroid Build Coastguard Worker self.assertFalse(throws_on_cuda_event("thread_local")) 3064*da0073e9SAndroid Build Coastguard Worker self.assertFalse(throws_on_cuda_event("relaxed")) 3065*da0073e9SAndroid Build Coastguard Worker 3066*da0073e9SAndroid Build Coastguard Worker # Exception would Corrupt Process and make other tests fail 3067*da0073e9SAndroid Build Coastguard Worker # self.assertTrue(throws_on_cuda_event("global")) 3068*da0073e9SAndroid Build Coastguard Worker 3069*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3070*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 3071*da0073e9SAndroid Build Coastguard Worker ) 3072*da0073e9SAndroid Build Coastguard Worker def test_cuda_graph_allocator_propagates_stream(self): 3073*da0073e9SAndroid Build Coastguard Worker segments = torch.cuda.memory_snapshot() 3074*da0073e9SAndroid Build Coastguard Worker existing_pools = {s["segment_pool_id"] for s in segments} 3075*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10240000, device="cuda") 3076*da0073e9SAndroid Build Coastguard Worker y = torch.rand_like(x) 3077*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 3078*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.Stream() 3079*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream() 3080*da0073e9SAndroid Build Coastguard Worker s0.wait_stream(torch.cuda.current_stream()) 3081*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s0): 3082*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 3083*da0073e9SAndroid Build Coastguard Worker z = x + y 3084*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 3085*da0073e9SAndroid Build Coastguard Worker s1.wait_stream(s0) 3086*da0073e9SAndroid Build Coastguard Worker w = z + y 3087*da0073e9SAndroid Build Coastguard Worker s0.wait_stream(s1) 3088*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s0): 3089*da0073e9SAndroid Build Coastguard Worker g.capture_end() 3090*da0073e9SAndroid Build Coastguard Worker segments = torch.cuda.memory_snapshot() 3091*da0073e9SAndroid Build Coastguard Worker x = [ 3092*da0073e9SAndroid Build Coastguard Worker s["segment_pool_id"] 3093*da0073e9SAndroid Build Coastguard Worker for s in segments 3094*da0073e9SAndroid Build Coastguard Worker if s["segment_pool_id"] not in existing_pools 3095*da0073e9SAndroid Build Coastguard Worker ] 3096*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(x), 2) 3097*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[0], x[1]) 3098*da0073e9SAndroid Build Coastguard Worker 3099*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_gather_stats(self): 3100*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 3, 3, 3, device="cuda") 3101*da0073e9SAndroid Build Coastguard Worker mean, invstd = torch.batch_norm_gather_stats( 3102*da0073e9SAndroid Build Coastguard Worker input, 3103*da0073e9SAndroid Build Coastguard Worker mean=torch.ones(2, 3, device="cuda"), 3104*da0073e9SAndroid Build Coastguard Worker invstd=torch.ones(2, 3, device="cuda"), 3105*da0073e9SAndroid Build Coastguard Worker running_mean=None, 3106*da0073e9SAndroid Build Coastguard Worker running_var=None, 3107*da0073e9SAndroid Build Coastguard Worker momentum=0.1, 3108*da0073e9SAndroid Build Coastguard Worker eps=1e-5, 3109*da0073e9SAndroid Build Coastguard Worker count=2, 3110*da0073e9SAndroid Build Coastguard Worker ) 3111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mean, torch.ones(3, device="cuda")) 3112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(invstd, torch.ones(3, device="cuda")) 3113*da0073e9SAndroid Build Coastguard Worker 3114*da0073e9SAndroid Build Coastguard Worker def test_matmul_memory_use(self): 3115*da0073e9SAndroid Build Coastguard Worker def get_max_used(): 3116*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 3117*da0073e9SAndroid Build Coastguard Worker val = torch.cuda.max_memory_allocated() 3118*da0073e9SAndroid Build Coastguard Worker torch.cuda.reset_peak_memory_stats() 3119*da0073e9SAndroid Build Coastguard Worker return val 3120*da0073e9SAndroid Build Coastguard Worker 3121*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1, 32, 32, device="cuda") 3122*da0073e9SAndroid Build Coastguard Worker b = torch.rand(24, 32, 1, device="cuda") 3123*da0073e9SAndroid Build Coastguard Worker 3124*da0073e9SAndroid Build Coastguard Worker get_max_used() 3125*da0073e9SAndroid Build Coastguard Worker 3126*da0073e9SAndroid Build Coastguard Worker torch.matmul(a, b) 3127*da0073e9SAndroid Build Coastguard Worker 3128*da0073e9SAndroid Build Coastguard Worker matmul_mem = get_max_used() 3129*da0073e9SAndroid Build Coastguard Worker 3130*da0073e9SAndroid Build Coastguard Worker a = a.expand(24, 32, 32) 3131*da0073e9SAndroid Build Coastguard Worker torch.matmul(a, b) 3132*da0073e9SAndroid Build Coastguard Worker 3133*da0073e9SAndroid Build Coastguard Worker matmul_expand_mem = get_max_used() 3134*da0073e9SAndroid Build Coastguard Worker 3135*da0073e9SAndroid Build Coastguard Worker torch.bmm(a, b) 3136*da0073e9SAndroid Build Coastguard Worker 3137*da0073e9SAndroid Build Coastguard Worker bmm_mem = get_max_used() 3138*da0073e9SAndroid Build Coastguard Worker 3139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matmul_expand_mem, matmul_mem) 3140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bmm_mem, matmul_mem) 3141*da0073e9SAndroid Build Coastguard Worker 3142*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_WITH_ROCM, "ROCm-only test") 3143*da0073e9SAndroid Build Coastguard Worker def test_rocm_backward_pass_guard(self): 3144*da0073e9SAndroid Build Coastguard Worker # The test exercises a ROCm-specific feature. 3145*da0073e9SAndroid Build Coastguard Worker 3146*da0073e9SAndroid Build Coastguard Worker class MyFunction(torch.autograd.Function): 3147*da0073e9SAndroid Build Coastguard Worker @staticmethod 3148*da0073e9SAndroid Build Coastguard Worker def forward(ctx, tensor, constant): 3149*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._rocm_is_backward_pass()) 3150*da0073e9SAndroid Build Coastguard Worker ctx.constant = constant 3151*da0073e9SAndroid Build Coastguard Worker return tensor * constant 3152*da0073e9SAndroid Build Coastguard Worker 3153*da0073e9SAndroid Build Coastguard Worker @staticmethod 3154*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 3155*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._rocm_is_backward_pass()) 3156*da0073e9SAndroid Build Coastguard Worker return grad_output * ctx.constant, None 3157*da0073e9SAndroid Build Coastguard Worker 3158*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 3159*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3160*da0073e9SAndroid Build Coastguard Worker super().__init__() 3161*da0073e9SAndroid Build Coastguard Worker self.a = torch.nn.Parameter(torch.randn(())) 3162*da0073e9SAndroid Build Coastguard Worker 3163*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3164*da0073e9SAndroid Build Coastguard Worker return MyFunction.apply(x, self.a) 3165*da0073e9SAndroid Build Coastguard Worker 3166*da0073e9SAndroid Build Coastguard Worker model = MyModule() 3167*da0073e9SAndroid Build Coastguard Worker criterion = torch.nn.MSELoss(reduction="sum") 3168*da0073e9SAndroid Build Coastguard Worker optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) 3169*da0073e9SAndroid Build Coastguard Worker 3170*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 3171*da0073e9SAndroid Build Coastguard Worker result = model(x) 3172*da0073e9SAndroid Build Coastguard Worker loss = criterion(result, x) 3173*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 3174*da0073e9SAndroid Build Coastguard Worker loss.backward() 3175*da0073e9SAndroid Build Coastguard Worker optimizer.step() 3176*da0073e9SAndroid Build Coastguard Worker 3177*da0073e9SAndroid Build Coastguard Worker def test_matmul_device_mismatch(self): 3178*da0073e9SAndroid Build Coastguard Worker cpu = torch.rand((10, 10)) 3179*da0073e9SAndroid Build Coastguard Worker cuda = cpu.cuda() 3180*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3181*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device" 3182*da0073e9SAndroid Build Coastguard Worker ): 3183*da0073e9SAndroid Build Coastguard Worker cpu @ cuda 3184*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3185*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device" 3186*da0073e9SAndroid Build Coastguard Worker ): 3187*da0073e9SAndroid Build Coastguard Worker cuda @ cpu 3188*da0073e9SAndroid Build Coastguard Worker 3189*da0073e9SAndroid Build Coastguard Worker for s, m1, m2 in product((cpu, cuda), repeat=3): 3190*da0073e9SAndroid Build Coastguard Worker if s.device == m1.device == m2.device: 3191*da0073e9SAndroid Build Coastguard Worker torch.addmm(s, m1, m2) 3192*da0073e9SAndroid Build Coastguard Worker else: 3193*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3194*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device" 3195*da0073e9SAndroid Build Coastguard Worker ): 3196*da0073e9SAndroid Build Coastguard Worker torch.addmm(s, m1, m2) 3197*da0073e9SAndroid Build Coastguard Worker 3198*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_MULTIGPU, "Testing on one GPU is sufficient") 3199*da0073e9SAndroid Build Coastguard Worker def test_lazy_init(self): 3200*da0073e9SAndroid Build Coastguard Worker """Validate that no CUDA calls are made during `import torch` call""" 3201*da0073e9SAndroid Build Coastguard Worker 3202*da0073e9SAndroid Build Coastguard Worker def check_output(script: str) -> str: 3203*da0073e9SAndroid Build Coastguard Worker return ( 3204*da0073e9SAndroid Build Coastguard Worker subprocess.check_output([sys.executable, "-c", script]) 3205*da0073e9SAndroid Build Coastguard Worker .decode("ascii") 3206*da0073e9SAndroid Build Coastguard Worker .strip() 3207*da0073e9SAndroid Build Coastguard Worker ) 3208*da0073e9SAndroid Build Coastguard Worker 3209*da0073e9SAndroid Build Coastguard Worker VISIBLE_DEVICES = ( 3210*da0073e9SAndroid Build Coastguard Worker "HIP_VISIBLE_DEVICES" if TEST_WITH_ROCM else "CUDA_VISIBLE_DEVICES" 3211*da0073e9SAndroid Build Coastguard Worker ) 3212*da0073e9SAndroid Build Coastguard Worker test_script = f"import os; import torch;os.environ['{VISIBLE_DEVICES}']='32';print(torch.cuda.device_count())" 3213*da0073e9SAndroid Build Coastguard Worker rc = check_output(test_script) 3214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rc, "0") 3215*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_ROCM: 3216*da0073e9SAndroid Build Coastguard Worker # Check that `cuInit` was not called during the import 3217*da0073e9SAndroid Build Coastguard Worker # By using ctypes and calling cuDeviceCountGet() and expect CUDA_ERROR_NOT_INITIALIZED == 3 3218*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/116276 for more details 3219*da0073e9SAndroid Build Coastguard Worker libcuda_name = "libcuda.so.1" if not IS_WINDOWS else "nvcuda.dll" 3220*da0073e9SAndroid Build Coastguard Worker cuda_driver_api_call = ( 3221*da0073e9SAndroid Build Coastguard Worker f"ctypes.CDLL('{libcuda_name}').cuDeviceGetCount(ctypes.byref(x))" 3222*da0073e9SAndroid Build Coastguard Worker ) 3223*da0073e9SAndroid Build Coastguard Worker rc = check_output( 3224*da0073e9SAndroid Build Coastguard Worker f"import torch; import ctypes;x=ctypes.c_int(-1);print({cuda_driver_api_call})" 3225*da0073e9SAndroid Build Coastguard Worker ) 3226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rc, "3") 3227*da0073e9SAndroid Build Coastguard Worker 3228*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing") 3229*da0073e9SAndroid Build Coastguard Worker def test_hip_device_count(self): 3230*da0073e9SAndroid Build Coastguard Worker """Validate device_count works with both CUDA/HIP visible devices""" 3231*da0073e9SAndroid Build Coastguard Worker test_script = """\ 3232*da0073e9SAndroid Build Coastguard Workerimport torch 3233*da0073e9SAndroid Build Coastguard Workerimport os 3234*da0073e9SAndroid Build Coastguard Workerprint(f"{torch.cuda.device_count()}") 3235*da0073e9SAndroid Build Coastguard Worker""" 3236*da0073e9SAndroid Build Coastguard Worker custom_envs = [ 3237*da0073e9SAndroid Build Coastguard Worker {"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None}, 3238*da0073e9SAndroid Build Coastguard Worker {"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"}, 3239*da0073e9SAndroid Build Coastguard Worker {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"}, 3240*da0073e9SAndroid Build Coastguard Worker ] 3241*da0073e9SAndroid Build Coastguard Worker 3242*da0073e9SAndroid Build Coastguard Worker for env_config in custom_envs: 3243*da0073e9SAndroid Build Coastguard Worker env = os.environ.copy() 3244*da0073e9SAndroid Build Coastguard Worker for key, value in env_config.items(): 3245*da0073e9SAndroid Build Coastguard Worker if value is None: 3246*da0073e9SAndroid Build Coastguard Worker env.pop(key, None) 3247*da0073e9SAndroid Build Coastguard Worker else: 3248*da0073e9SAndroid Build Coastguard Worker env[key] = value 3249*da0073e9SAndroid Build Coastguard Worker r = ( 3250*da0073e9SAndroid Build Coastguard Worker subprocess.check_output([sys.executable, "-c", test_script], env=env) 3251*da0073e9SAndroid Build Coastguard Worker .decode("ascii") 3252*da0073e9SAndroid Build Coastguard Worker .strip() 3253*da0073e9SAndroid Build Coastguard Worker ) 3254*da0073e9SAndroid Build Coastguard Worker self.assertEqual("1", r) 3255*da0073e9SAndroid Build Coastguard Worker 3256*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "requires multiple devices") 3257*da0073e9SAndroid Build Coastguard Worker def test_device_count_not_cached_pre_init(self): 3258*da0073e9SAndroid Build Coastguard Worker visible_devices = ( 3259*da0073e9SAndroid Build Coastguard Worker "HIP_VISIBLE_DEVICES" if torch.version.hip else "CUDA_VISIBLE_DEVICES" 3260*da0073e9SAndroid Build Coastguard Worker ) 3261*da0073e9SAndroid Build Coastguard Worker test_script = f"""\ 3262*da0073e9SAndroid Build Coastguard Workerimport torch 3263*da0073e9SAndroid Build Coastguard Workerimport os 3264*da0073e9SAndroid Build Coastguard Workerr1 = torch.cuda.device_count() 3265*da0073e9SAndroid Build Coastguard Workeros.environ['{visible_devices}'] = '0' 3266*da0073e9SAndroid Build Coastguard Workerr2 = torch.cuda.device_count() 3267*da0073e9SAndroid Build Coastguard Workertorch.empty(10, device='cuda') 3268*da0073e9SAndroid Build Coastguard Workerprint(f"{{r1}}, {{r2}}") 3269*da0073e9SAndroid Build Coastguard Worker""" 3270*da0073e9SAndroid Build Coastguard Worker 3271*da0073e9SAndroid Build Coastguard Worker r = ( 3272*da0073e9SAndroid Build Coastguard Worker subprocess.check_output([sys.executable, "-c", test_script]) 3273*da0073e9SAndroid Build Coastguard Worker .decode("ascii") 3274*da0073e9SAndroid Build Coastguard Worker .strip() 3275*da0073e9SAndroid Build Coastguard Worker ) 3276*da0073e9SAndroid Build Coastguard Worker 3277*da0073e9SAndroid Build Coastguard Worker x = torch.cuda.device_count() 3278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f"{x}, 1", r) 3279*da0073e9SAndroid Build Coastguard Worker 3280*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Disabling as USE_CUFILE=0 by default in builds") 3281*da0073e9SAndroid Build Coastguard Worker def test_gds_fails_in_ci(self): 3282*da0073e9SAndroid Build Coastguard Worker if IS_WINDOWS or TEST_WITH_ROCM: 3283*da0073e9SAndroid Build Coastguard Worker error_msg = "is not supported on this platform" 3284*da0073e9SAndroid Build Coastguard Worker else: 3285*da0073e9SAndroid Build Coastguard Worker error_msg = "cuFileHandleRegister failed" 3286*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as f: 3287*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_msg): 3288*da0073e9SAndroid Build Coastguard Worker file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR) 3289*da0073e9SAndroid Build Coastguard Worker 3290*da0073e9SAndroid Build Coastguard Worker 3291*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 3292*da0073e9SAndroid Build Coastguard Worker@torch.testing._internal.common_utils.markDynamoStrictTest 3293*da0073e9SAndroid Build Coastguard Workerclass TestCudaMallocAsync(TestCase): 3294*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3295*da0073e9SAndroid Build Coastguard Worker TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3296*da0073e9SAndroid Build Coastguard Worker ) 3297*da0073e9SAndroid Build Coastguard Worker def test_memory_snapshot(self): 3298*da0073e9SAndroid Build Coastguard Worker try: 3299*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3300*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history("state", stacks="python") 3301*da0073e9SAndroid Build Coastguard Worker # make x the second block in a segment 3302*da0073e9SAndroid Build Coastguard Worker torch.rand(2 * 311, 411, device="cuda") 3303*da0073e9SAndroid Build Coastguard Worker unused = torch.rand(310, 410, device="cuda") 3304*da0073e9SAndroid Build Coastguard Worker x = torch.rand(311, 411, device="cuda") 3305*da0073e9SAndroid Build Coastguard Worker 3306*da0073e9SAndroid Build Coastguard Worker # create a bunch of tensors that all will tile into the 3307*da0073e9SAndroid Build Coastguard Worker # same segment to exercise the history merging code 3308*da0073e9SAndroid Build Coastguard Worker # 512B is the minimum block size, 3309*da0073e9SAndroid Build Coastguard Worker # so we allocate all the tensors to this size to make sure 3310*da0073e9SAndroid Build Coastguard Worker # they tile evenly 3311*da0073e9SAndroid Build Coastguard Worker tensors = [torch.rand(128, device="cuda") for _ in range(1000)] 3312*da0073e9SAndroid Build Coastguard Worker while tensors: 3313*da0073e9SAndroid Build Coastguard Worker del tensors[randint(0, len(tensors) - 1)] 3314*da0073e9SAndroid Build Coastguard Worker 3315*da0073e9SAndroid Build Coastguard Worker # exercise the history trimming code 3316*da0073e9SAndroid Build Coastguard Worker torch.rand(128 * 5, device="cuda") 3317*da0073e9SAndroid Build Coastguard Worker 3318*da0073e9SAndroid Build Coastguard Worker ss = torch.cuda.memory._snapshot() 3319*da0073e9SAndroid Build Coastguard Worker found_it = False 3320*da0073e9SAndroid Build Coastguard Worker for seg in ss["segments"]: 3321*da0073e9SAndroid Build Coastguard Worker self.assertTrue("frames" in seg) 3322*da0073e9SAndroid Build Coastguard Worker for b in seg["blocks"]: 3323*da0073e9SAndroid Build Coastguard Worker if b["requested_size"] == 311 * 411 * 4: 3324*da0073e9SAndroid Build Coastguard Worker self.assertTrue("test_cuda" in b["frames"][0]["filename"]) 3325*da0073e9SAndroid Build Coastguard Worker found_it = True 3326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.untyped_storage().data_ptr(), b["address"]) 3327*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_it) 3328*da0073e9SAndroid Build Coastguard Worker 3329*da0073e9SAndroid Build Coastguard Worker if not IS_WINDOWS: 3330*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 3331*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._save_segment_usage(f.name) 3332*da0073e9SAndroid Build Coastguard Worker with open(f.name) as f2: 3333*da0073e9SAndroid Build Coastguard Worker self.assertTrue("test_cuda.py" in f2.read()) 3334*da0073e9SAndroid Build Coastguard Worker del unused 3335*da0073e9SAndroid Build Coastguard Worker del x 3336*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 3337*da0073e9SAndroid Build Coastguard Worker ss = torch.cuda.memory._snapshot() 3338*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 3339*da0073e9SAndroid Build Coastguard Worker ss["device_traces"][0][-1]["action"] 3340*da0073e9SAndroid Build Coastguard Worker in ("segment_free", "segment_unmap") 3341*da0073e9SAndroid Build Coastguard Worker ) 3342*da0073e9SAndroid Build Coastguard Worker 3343*da0073e9SAndroid Build Coastguard Worker finally: 3344*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(None) 3345*da0073e9SAndroid Build Coastguard Worker 3346*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding") 3347*da0073e9SAndroid Build Coastguard Worker def test_direct_traceback(self): 3348*da0073e9SAndroid Build Coastguard Worker from torch._C._profiler import gather_traceback, symbolize_tracebacks 3349*da0073e9SAndroid Build Coastguard Worker 3350*da0073e9SAndroid Build Coastguard Worker c = gather_traceback(True, True, True) 3351*da0073e9SAndroid Build Coastguard Worker (r,) = symbolize_tracebacks([c]) 3352*da0073e9SAndroid Build Coastguard Worker r = str(r) 3353*da0073e9SAndroid Build Coastguard Worker self.assertTrue("test_cuda.py" in r) 3354*da0073e9SAndroid Build Coastguard Worker self.assertTrue("unwind" in r) 3355*da0073e9SAndroid Build Coastguard Worker 3356*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3357*da0073e9SAndroid Build Coastguard Worker TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3358*da0073e9SAndroid Build Coastguard Worker ) 3359*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3360*da0073e9SAndroid Build Coastguard Worker def test_memory_snapshot_with_cpp(self): 3361*da0073e9SAndroid Build Coastguard Worker try: 3362*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3363*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history("state", stacks="all") 3364*da0073e9SAndroid Build Coastguard Worker x = torch.rand(311, 411, device="cuda") 3365*da0073e9SAndroid Build Coastguard Worker 3366*da0073e9SAndroid Build Coastguard Worker ss = torch.cuda.memory._snapshot()["segments"] 3367*da0073e9SAndroid Build Coastguard Worker found_it = False 3368*da0073e9SAndroid Build Coastguard Worker for seg in ss: 3369*da0073e9SAndroid Build Coastguard Worker for b in seg["blocks"]: 3370*da0073e9SAndroid Build Coastguard Worker if b["requested_size"] == 311 * 411 * 4: 3371*da0073e9SAndroid Build Coastguard Worker self.assertTrue("::rand" in str(b["frames"])) 3372*da0073e9SAndroid Build Coastguard Worker found_it = True 3373*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_it) 3374*da0073e9SAndroid Build Coastguard Worker 3375*da0073e9SAndroid Build Coastguard Worker finally: 3376*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(None) 3377*da0073e9SAndroid Build Coastguard Worker 3378*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 3379*da0073e9SAndroid Build Coastguard Worker def test_memory_profiler_viz(self): 3380*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile( 3381*da0073e9SAndroid Build Coastguard Worker with_stack=True, profile_memory=True, record_shapes=True 3382*da0073e9SAndroid Build Coastguard Worker ) as prof: 3383*da0073e9SAndroid Build Coastguard Worker x = torch.rand(128, 128, device="cuda") 3384*da0073e9SAndroid Build Coastguard Worker x * x + x * x 3385*da0073e9SAndroid Build Coastguard Worker plot = profile_plot(prof) 3386*da0073e9SAndroid Build Coastguard Worker plot = json.dumps(_profile_to_snapshot(prof)) 3387*da0073e9SAndroid Build Coastguard Worker self.assertTrue("test_cuda.py" in plot) 3388*da0073e9SAndroid Build Coastguard Worker self.assertTrue("test_memory_profiler_viz" in plot) 3389*da0073e9SAndroid Build Coastguard Worker self.assertTrue("category" in plot) 3390*da0073e9SAndroid Build Coastguard Worker 3391*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3392*da0073e9SAndroid Build Coastguard Worker TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3393*da0073e9SAndroid Build Coastguard Worker ) 3394*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3395*da0073e9SAndroid Build Coastguard Worker def test_cycles(self): 3396*da0073e9SAndroid Build Coastguard Worker fired = False 3397*da0073e9SAndroid Build Coastguard Worker 3398*da0073e9SAndroid Build Coastguard Worker def observer(html): 3399*da0073e9SAndroid Build Coastguard Worker nonlocal fired 3400*da0073e9SAndroid Build Coastguard Worker fired = True 3401*da0073e9SAndroid Build Coastguard Worker self.assertTrue("torch.Tensor" in html) 3402*da0073e9SAndroid Build Coastguard Worker self.assertTrue("test_cuda" in html) 3403*da0073e9SAndroid Build Coastguard Worker self.assertTrue("cell_contents" in html) 3404*da0073e9SAndroid Build Coastguard Worker 3405*da0073e9SAndroid Build Coastguard Worker disarm = observe_tensor_cycles(observer) 3406*da0073e9SAndroid Build Coastguard Worker 3407*da0073e9SAndroid Build Coastguard Worker def noop(): 3408*da0073e9SAndroid Build Coastguard Worker pass 3409*da0073e9SAndroid Build Coastguard Worker 3410*da0073e9SAndroid Build Coastguard Worker try: 3411*da0073e9SAndroid Build Coastguard Worker 3412*da0073e9SAndroid Build Coastguard Worker def create(): 3413*da0073e9SAndroid Build Coastguard Worker x = torch.empty(3, 4, device="cuda") 3414*da0073e9SAndroid Build Coastguard Worker 3415*da0073e9SAndroid Build Coastguard Worker def foo(p): 3416*da0073e9SAndroid Build Coastguard Worker if p: 3417*da0073e9SAndroid Build Coastguard Worker return foo(not p) 3418*da0073e9SAndroid Build Coastguard Worker else: 3419*da0073e9SAndroid Build Coastguard Worker return x 3420*da0073e9SAndroid Build Coastguard Worker 3421*da0073e9SAndroid Build Coastguard Worker return foo 3422*da0073e9SAndroid Build Coastguard Worker 3423*da0073e9SAndroid Build Coastguard Worker create() 3424*da0073e9SAndroid Build Coastguard Worker gc.collect() 3425*da0073e9SAndroid Build Coastguard Worker # the callback has to run outside of the collect 3426*da0073e9SAndroid Build Coastguard Worker # call so it doesn't actual fire until the next 3427*da0073e9SAndroid Build Coastguard Worker # method call after a gc.collect 3428*da0073e9SAndroid Build Coastguard Worker noop() 3429*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fired) 3430*da0073e9SAndroid Build Coastguard Worker finally: 3431*da0073e9SAndroid Build Coastguard Worker disarm() 3432*da0073e9SAndroid Build Coastguard Worker 3433*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3434*da0073e9SAndroid Build Coastguard Worker TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3435*da0073e9SAndroid Build Coastguard Worker ) 3436*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3437*da0073e9SAndroid Build Coastguard Worker def test_memory_plots(self): 3438*da0073e9SAndroid Build Coastguard Worker for context, stacks in ( 3439*da0073e9SAndroid Build Coastguard Worker ("all", "all" if IS_LINUX else "python"), 3440*da0073e9SAndroid Build Coastguard Worker ("all", "python"), 3441*da0073e9SAndroid Build Coastguard Worker (None, "python"), 3442*da0073e9SAndroid Build Coastguard Worker ): 3443*da0073e9SAndroid Build Coastguard Worker try: 3444*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3445*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history( 3446*da0073e9SAndroid Build Coastguard Worker "all", context=context, stacks=stacks 3447*da0073e9SAndroid Build Coastguard Worker ) 3448*da0073e9SAndroid Build Coastguard Worker 3449*da0073e9SAndroid Build Coastguard Worker def run(): 3450*da0073e9SAndroid Build Coastguard Worker x = torch.rand(128, 128, device="cuda") 3451*da0073e9SAndroid Build Coastguard Worker x * x + x * x 3452*da0073e9SAndroid Build Coastguard Worker 3453*da0073e9SAndroid Build Coastguard Worker run() 3454*da0073e9SAndroid Build Coastguard Worker cpp = stacks == "all" 3455*da0073e9SAndroid Build Coastguard Worker record_context = context is not None 3456*da0073e9SAndroid Build Coastguard Worker ss = torch.cuda.memory._snapshot() 3457*da0073e9SAndroid Build Coastguard Worker 3458*da0073e9SAndroid Build Coastguard Worker tplot = trace_plot(ss) 3459*da0073e9SAndroid Build Coastguard Worker splot = segment_plot(ss) 3460*da0073e9SAndroid Build Coastguard Worker text = json.dumps(ss) 3461*da0073e9SAndroid Build Coastguard Worker 3462*da0073e9SAndroid Build Coastguard Worker self.assertTrue(record_context == ("test_memory_plots" in text)) 3463*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cpp == ("::rand" in text)) 3464*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(128 * 128 * 4) in text) 3465*da0073e9SAndroid Build Coastguard Worker 3466*da0073e9SAndroid Build Coastguard Worker finally: 3467*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(None) 3468*da0073e9SAndroid Build Coastguard Worker 3469*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3470*da0073e9SAndroid Build Coastguard Worker TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3471*da0073e9SAndroid Build Coastguard Worker ) 3472*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3473*da0073e9SAndroid Build Coastguard Worker def test_memory_plots_free_stack(self): 3474*da0073e9SAndroid Build Coastguard Worker for context in ["alloc", "all", "state"]: 3475*da0073e9SAndroid Build Coastguard Worker try: 3476*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3477*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(context=context) 3478*da0073e9SAndroid Build Coastguard Worker x = None 3479*da0073e9SAndroid Build Coastguard Worker 3480*da0073e9SAndroid Build Coastguard Worker def thealloc(): 3481*da0073e9SAndroid Build Coastguard Worker nonlocal x 3482*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4, device="cuda") 3483*da0073e9SAndroid Build Coastguard Worker 3484*da0073e9SAndroid Build Coastguard Worker def thefree(): 3485*da0073e9SAndroid Build Coastguard Worker nonlocal x 3486*da0073e9SAndroid Build Coastguard Worker del x 3487*da0073e9SAndroid Build Coastguard Worker 3488*da0073e9SAndroid Build Coastguard Worker thealloc() 3489*da0073e9SAndroid Build Coastguard Worker thefree() 3490*da0073e9SAndroid Build Coastguard Worker ss = json.dumps(torch.cuda.memory._snapshot()) 3491*da0073e9SAndroid Build Coastguard Worker self.assertTrue(("thefree" in ss) == (context == "all")) 3492*da0073e9SAndroid Build Coastguard Worker self.assertTrue(("thealloc" in ss) == (context != "state")) 3493*da0073e9SAndroid Build Coastguard Worker finally: 3494*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(None) 3495*da0073e9SAndroid Build Coastguard Worker 3496*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3497*da0073e9SAndroid Build Coastguard Worker TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3498*da0073e9SAndroid Build Coastguard Worker ) 3499*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3500*da0073e9SAndroid Build Coastguard Worker def test_memory_plots_history_context(self): 3501*da0073e9SAndroid Build Coastguard Worker try: 3502*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3503*da0073e9SAndroid Build Coastguard Worker x = None 3504*da0073e9SAndroid Build Coastguard Worker 3505*da0073e9SAndroid Build Coastguard Worker def should_capture1(): 3506*da0073e9SAndroid Build Coastguard Worker nonlocal x 3507*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4, 4, device="cuda") 3508*da0073e9SAndroid Build Coastguard Worker 3509*da0073e9SAndroid Build Coastguard Worker def should_not_capture(): 3510*da0073e9SAndroid Build Coastguard Worker nonlocal x 3511*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4, device="cuda") 3512*da0073e9SAndroid Build Coastguard Worker 3513*da0073e9SAndroid Build Coastguard Worker def should_capture2(): 3514*da0073e9SAndroid Build Coastguard Worker nonlocal x 3515*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4, 4, device="cuda") 3516*da0073e9SAndroid Build Coastguard Worker 3517*da0073e9SAndroid Build Coastguard Worker # Recording with context and python call stacks should capture the call stack. 3518*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(context="all", stacks="python") 3519*da0073e9SAndroid Build Coastguard Worker should_capture1() 3520*da0073e9SAndroid Build Coastguard Worker # Recording with context=None should not capture the call stack. 3521*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(context=None) 3522*da0073e9SAndroid Build Coastguard Worker should_not_capture() 3523*da0073e9SAndroid Build Coastguard Worker # Recording with context and python call stacks should capture the call stack. 3524*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(context="all", stacks="python") 3525*da0073e9SAndroid Build Coastguard Worker should_capture2() 3526*da0073e9SAndroid Build Coastguard Worker 3527*da0073e9SAndroid Build Coastguard Worker ss = json.dumps(torch.cuda.memory._snapshot()) 3528*da0073e9SAndroid Build Coastguard Worker self.assertTrue("should_capture1" in ss) 3529*da0073e9SAndroid Build Coastguard Worker self.assertTrue("should_not_capture" not in ss) 3530*da0073e9SAndroid Build Coastguard Worker self.assertTrue("should_capture2" in ss) 3531*da0073e9SAndroid Build Coastguard Worker finally: 3532*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(None) 3533*da0073e9SAndroid Build Coastguard Worker 3534*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3535*da0073e9SAndroid Build Coastguard Worker TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3536*da0073e9SAndroid Build Coastguard Worker ) 3537*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") 3538*da0073e9SAndroid Build Coastguard Worker def test_memory_plots_free_segment_stack(self): 3539*da0073e9SAndroid Build Coastguard Worker for context in ["alloc", "all", "state"]: 3540*da0073e9SAndroid Build Coastguard Worker try: 3541*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3542*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(context=context) 3543*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4, device="cuda") 3544*da0073e9SAndroid Build Coastguard Worker del x 3545*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3546*da0073e9SAndroid Build Coastguard Worker 3547*da0073e9SAndroid Build Coastguard Worker ss = json.dumps(torch.cuda.memory._snapshot()) 3548*da0073e9SAndroid Build Coastguard Worker self.assertTrue(("empty_cache" in ss) == (context == "all")) 3549*da0073e9SAndroid Build Coastguard Worker finally: 3550*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(None) 3551*da0073e9SAndroid Build Coastguard Worker 3552*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3553*da0073e9SAndroid Build Coastguard Worker TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3554*da0073e9SAndroid Build Coastguard Worker ) 3555*da0073e9SAndroid Build Coastguard Worker def test_memory_snapshot_script(self): 3556*da0073e9SAndroid Build Coastguard Worker try: 3557*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3558*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history("state", stacks="python") 3559*da0073e9SAndroid Build Coastguard Worker 3560*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3561*da0073e9SAndroid Build Coastguard Worker def foo(): 3562*da0073e9SAndroid Build Coastguard Worker return torch.rand(311, 411, device="cuda") 3563*da0073e9SAndroid Build Coastguard Worker 3564*da0073e9SAndroid Build Coastguard Worker x = foo() 3565*da0073e9SAndroid Build Coastguard Worker 3566*da0073e9SAndroid Build Coastguard Worker ss = torch.cuda.memory._snapshot()["segments"] 3567*da0073e9SAndroid Build Coastguard Worker found_it = False 3568*da0073e9SAndroid Build Coastguard Worker for seg in ss: 3569*da0073e9SAndroid Build Coastguard Worker for b in seg["blocks"]: 3570*da0073e9SAndroid Build Coastguard Worker if b["requested_size"] == 311 * 411 * 4: 3571*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b["frames"][0]["name"] == "foo") 3572*da0073e9SAndroid Build Coastguard Worker found_it = True 3573*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_it) 3574*da0073e9SAndroid Build Coastguard Worker 3575*da0073e9SAndroid Build Coastguard Worker finally: 3576*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(None) 3577*da0073e9SAndroid Build Coastguard Worker 3578*da0073e9SAndroid Build Coastguard Worker def test_max_split_expandable(self): 3579*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3580*da0073e9SAndroid Build Coastguard Worker mb = 1024 * 1024 3581*da0073e9SAndroid Build Coastguard Worker _, all_memory = torch.cuda.memory.mem_get_info() 3582*da0073e9SAndroid Build Coastguard Worker total_allowed = 120 * mb 3583*da0073e9SAndroid Build Coastguard Worker fraction_allowed = total_allowed / all_memory 3584*da0073e9SAndroid Build Coastguard Worker assert int(fraction_allowed * all_memory) == total_allowed 3585*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed) 3586*da0073e9SAndroid Build Coastguard Worker 3587*da0073e9SAndroid Build Coastguard Worker def alloc(n): 3588*da0073e9SAndroid Build Coastguard Worker return torch.ones(n * mb, dtype=torch.int8, device="cuda") 3589*da0073e9SAndroid Build Coastguard Worker 3590*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 3591*da0073e9SAndroid Build Coastguard Worker "expandable_segments:False,max_split_size_mb:40" 3592*da0073e9SAndroid Build Coastguard Worker ) 3593*da0073e9SAndroid Build Coastguard Worker a = alloc(40) 3594*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 3595*da0073e9SAndroid Build Coastguard Worker "expandable_segments:True,max_split_size_mb:40" 3596*da0073e9SAndroid Build Coastguard Worker ) 3597*da0073e9SAndroid Build Coastguard Worker b = alloc(40) 3598*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 3599*da0073e9SAndroid Build Coastguard Worker "expandable_segments:False,max_split_size_mb:40" 3600*da0073e9SAndroid Build Coastguard Worker ) 3601*da0073e9SAndroid Build Coastguard Worker c = alloc(40) 3602*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch.OutOfMemoryError): 3603*da0073e9SAndroid Build Coastguard Worker alloc(40) 3604*da0073e9SAndroid Build Coastguard Worker del a, b, c 3605*da0073e9SAndroid Build Coastguard Worker # force release_cached_blocks to run with some expandable segments in the free list 3606*da0073e9SAndroid Build Coastguard Worker alloc(120) 3607*da0073e9SAndroid Build Coastguard Worker 3608*da0073e9SAndroid Build Coastguard Worker def test_garbage_collect_expandable(self): 3609*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3610*da0073e9SAndroid Build Coastguard Worker mb = 1024 * 1024 3611*da0073e9SAndroid Build Coastguard Worker _, all_memory = torch.cuda.memory.mem_get_info() 3612*da0073e9SAndroid Build Coastguard Worker total_allowed = 120 * mb 3613*da0073e9SAndroid Build Coastguard Worker fraction_allowed = total_allowed / all_memory 3614*da0073e9SAndroid Build Coastguard Worker assert int(fraction_allowed * all_memory) == total_allowed 3615*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed) 3616*da0073e9SAndroid Build Coastguard Worker 3617*da0073e9SAndroid Build Coastguard Worker def alloc(n): 3618*da0073e9SAndroid Build Coastguard Worker return torch.ones(n * mb, dtype=torch.int8, device="cuda") 3619*da0073e9SAndroid Build Coastguard Worker 3620*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 3621*da0073e9SAndroid Build Coastguard Worker "expandable_segments:False,garbage_collection_threshold:0.5" 3622*da0073e9SAndroid Build Coastguard Worker ) 3623*da0073e9SAndroid Build Coastguard Worker a = alloc(40) 3624*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 3625*da0073e9SAndroid Build Coastguard Worker "expandable_segments:True,garbage_collection_threshold:0.5" 3626*da0073e9SAndroid Build Coastguard Worker ) 3627*da0073e9SAndroid Build Coastguard Worker b = alloc(40) 3628*da0073e9SAndroid Build Coastguard Worker del a, b 3629*da0073e9SAndroid Build Coastguard Worker # causes GC to run. The expandable segment block will be split 3630*da0073e9SAndroid Build Coastguard Worker # so GC would not attempt to free it anyway, but this at least makes sure 3631*da0073e9SAndroid Build Coastguard Worker # expandable_segment blocks can be in the free list when this is called. 3632*da0073e9SAndroid Build Coastguard Worker alloc(80) 3633*da0073e9SAndroid Build Coastguard Worker 3634*da0073e9SAndroid Build Coastguard Worker def test_allocator_settings(self): 3635*da0073e9SAndroid Build Coastguard Worker def power2_div(size, div_factor): 3636*da0073e9SAndroid Build Coastguard Worker pow2 = 1 3637*da0073e9SAndroid Build Coastguard Worker while pow2 < size: 3638*da0073e9SAndroid Build Coastguard Worker pow2 = pow2 * 2 3639*da0073e9SAndroid Build Coastguard Worker if pow2 == size: 3640*da0073e9SAndroid Build Coastguard Worker return pow2 3641*da0073e9SAndroid Build Coastguard Worker step = pow2 / 2 / div_factor 3642*da0073e9SAndroid Build Coastguard Worker ret = pow2 / 2 3643*da0073e9SAndroid Build Coastguard Worker while ret < size: 3644*da0073e9SAndroid Build Coastguard Worker ret = ret + step 3645*da0073e9SAndroid Build Coastguard Worker return ret 3646*da0073e9SAndroid Build Coastguard Worker 3647*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3648*da0073e9SAndroid Build Coastguard Worker key_allocated = ( 3649*da0073e9SAndroid Build Coastguard Worker "active_bytes.all.allocated" 3650*da0073e9SAndroid Build Coastguard Worker if not TEST_CUDAMALLOCASYNC 3651*da0073e9SAndroid Build Coastguard Worker else "allocated_bytes.all.current" 3652*da0073e9SAndroid Build Coastguard Worker ) 3653*da0073e9SAndroid Build Coastguard Worker key_requested = "requested_bytes.all.allocated" 3654*da0073e9SAndroid Build Coastguard Worker 3655*da0073e9SAndroid Build Coastguard Worker nelems = 21 * 1024 * 1024 3656*da0073e9SAndroid Build Coastguard Worker nbytes = 4 * nelems # floats are 4 bytes 3657*da0073e9SAndroid Build Coastguard Worker 3658*da0073e9SAndroid Build Coastguard Worker nelems_big = 100 * 1024 * 1024 3659*da0073e9SAndroid Build Coastguard Worker nbytes_big = 4 * nelems_big # floats are 4 bytes 3660*da0073e9SAndroid Build Coastguard Worker 3661*da0073e9SAndroid Build Coastguard Worker start_mem = torch.cuda.memory_stats()[key_allocated] 3662*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings("") 3663*da0073e9SAndroid Build Coastguard Worker x = torch.rand(nelems, device="cuda") 3664*da0073e9SAndroid Build Coastguard Worker 3665*da0073e9SAndroid Build Coastguard Worker # test roundup_power2_divisions single value syntax 3666*da0073e9SAndroid Build Coastguard Worker reg_mem = torch.cuda.memory_stats()[key_allocated] 3667*da0073e9SAndroid Build Coastguard Worker start_requested = torch.cuda.memory_stats()[key_requested] 3668*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings("roundup_power2_divisions:4") 3669*da0073e9SAndroid Build Coastguard Worker y = torch.rand(nelems, device="cuda") 3670*da0073e9SAndroid Build Coastguard Worker 3671*da0073e9SAndroid Build Coastguard Worker pow2_div4_mem = torch.cuda.memory_stats()[key_allocated] 3672*da0073e9SAndroid Build Coastguard Worker current_requested = torch.cuda.memory_stats()[key_requested] 3673*da0073e9SAndroid Build Coastguard Worker 3674*da0073e9SAndroid Build Coastguard Worker self.assertTrue(reg_mem - start_mem == nbytes) 3675*da0073e9SAndroid Build Coastguard Worker if not TEST_CUDAMALLOCASYNC: 3676*da0073e9SAndroid Build Coastguard Worker # not supported with the cudaMallocAsync backend 3677*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4)) 3678*da0073e9SAndroid Build Coastguard Worker self.assertTrue(current_requested - start_requested == nbytes) 3679*da0073e9SAndroid Build Coastguard Worker 3680*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5") 3681*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 3682*da0073e9SAndroid Build Coastguard Worker "garbage_collection_threshold:0.5,max_split_size_mb:40" 3683*da0073e9SAndroid Build Coastguard Worker ) 3684*da0073e9SAndroid Build Coastguard Worker 3685*da0073e9SAndroid Build Coastguard Worker # should have reset the power2 divisions now 3686*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3687*da0073e9SAndroid Build Coastguard Worker start_mem = torch.cuda.memory_stats()[key_allocated] 3688*da0073e9SAndroid Build Coastguard Worker z = torch.rand(nelems, device="cuda") 3689*da0073e9SAndroid Build Coastguard Worker reg_mem = torch.cuda.memory_stats()[key_allocated] 3690*da0073e9SAndroid Build Coastguard Worker self.assertTrue(reg_mem - start_mem == nbytes) 3691*da0073e9SAndroid Build Coastguard Worker 3692*da0073e9SAndroid Build Coastguard Worker # roundup_power2_divisions knob array syntax 3693*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3694*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 3695*da0073e9SAndroid Build Coastguard Worker "garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,128:2,256:2,512:2,1024:1,>:1]" 3696*da0073e9SAndroid Build Coastguard Worker ) 3697*da0073e9SAndroid Build Coastguard Worker start_mem = torch.cuda.memory_stats()[key_allocated] 3698*da0073e9SAndroid Build Coastguard Worker w = torch.rand(nelems, device="cuda") 3699*da0073e9SAndroid Build Coastguard Worker 3700*da0073e9SAndroid Build Coastguard Worker pow2_div8_mem = torch.cuda.memory_stats()[key_allocated] 3701*da0073e9SAndroid Build Coastguard Worker if not TEST_CUDAMALLOCASYNC: 3702*da0073e9SAndroid Build Coastguard Worker # not supported with the cudaMallocAsync backend 3703*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pow2_div8_mem - start_mem == power2_div(nbytes, 8)) 3704*da0073e9SAndroid Build Coastguard Worker 3705*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3706*da0073e9SAndroid Build Coastguard Worker start_mem = torch.cuda.memory_stats()[key_allocated] 3707*da0073e9SAndroid Build Coastguard Worker v = torch.rand(nelems_big, device="cuda") 3708*da0073e9SAndroid Build Coastguard Worker 3709*da0073e9SAndroid Build Coastguard Worker pow2_div2_mem = torch.cuda.memory_stats()[key_allocated] 3710*da0073e9SAndroid Build Coastguard Worker if not TEST_CUDAMALLOCASYNC: 3711*da0073e9SAndroid Build Coastguard Worker # not supported with the cudaMallocAsync backend 3712*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pow2_div2_mem - start_mem == power2_div(nbytes_big, 2)) 3713*da0073e9SAndroid Build Coastguard Worker 3714*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3715*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:True") 3716*da0073e9SAndroid Build Coastguard Worker start_mem = torch.cuda.memory_stats()[key_allocated] 3717*da0073e9SAndroid Build Coastguard Worker w = torch.rand(nelems, device="cuda") 3718*da0073e9SAndroid Build Coastguard Worker reg_mem = torch.cuda.memory_stats()[key_allocated] 3719*da0073e9SAndroid Build Coastguard Worker self.assertTrue(reg_mem - start_mem == nbytes) 3720*da0073e9SAndroid Build Coastguard Worker 3721*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3722*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings("foo:1,bar:2") 3723*da0073e9SAndroid Build Coastguard Worker 3724*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3725*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 3726*da0073e9SAndroid Build Coastguard Worker "garbage_collection_threshold:1.2" 3727*da0073e9SAndroid Build Coastguard Worker ) 3728*da0073e9SAndroid Build Coastguard Worker 3729*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3730*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings("max_split_size_mb:2") 3731*da0073e9SAndroid Build Coastguard Worker 3732*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3733*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:none") 3734*da0073e9SAndroid Build Coastguard Worker 3735*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3736*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 3737*da0073e9SAndroid Build Coastguard Worker "pinned_use_cuda_host_register:none" 3738*da0073e9SAndroid Build Coastguard Worker ) 3739*da0073e9SAndroid Build Coastguard Worker 3740*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3741*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 3742*da0073e9SAndroid Build Coastguard Worker "pinned_num_register_threads:none" 3743*da0073e9SAndroid Build Coastguard Worker ) 3744*da0073e9SAndroid Build Coastguard Worker 3745*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3746*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings( 3747*da0073e9SAndroid Build Coastguard Worker "pinned_num_register_threads:1024" 3748*da0073e9SAndroid Build Coastguard Worker ) 3749*da0073e9SAndroid Build Coastguard Worker 3750*da0073e9SAndroid Build Coastguard Worker @parametrize("max_split_size_mb_setting", [False, True]) 3751*da0073e9SAndroid Build Coastguard Worker def test_raises_oom(self, max_split_size_mb_setting): 3752*da0073e9SAndroid Build Coastguard Worker if max_split_size_mb_setting: 3753*da0073e9SAndroid Build Coastguard Worker # CudaCachingAllocator does early return when searching available blocks 3754*da0073e9SAndroid Build Coastguard Worker # if max_split_size_mb is not set 3755*da0073e9SAndroid Build Coastguard Worker # Setting this triggers more parts of the code 3756*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._set_allocator_settings("max_split_size_mb:1024") 3757*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory.empty_cache() 3758*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch.cuda.OutOfMemoryError): 3759*da0073e9SAndroid Build Coastguard Worker torch.empty(1024 * 1024 * 1024 * 1024, device="cuda") 3760*da0073e9SAndroid Build Coastguard Worker 3761*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3762*da0073e9SAndroid Build Coastguard Worker not (IS_LINUX and os.uname().machine == "x86_64"), "cpp traces only on linux" 3763*da0073e9SAndroid Build Coastguard Worker ) 3764*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3765*da0073e9SAndroid Build Coastguard Worker TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" 3766*da0073e9SAndroid Build Coastguard Worker ) 3767*da0073e9SAndroid Build Coastguard Worker def test_cpp_memory_snapshot_pickle(self): 3768*da0073e9SAndroid Build Coastguard Worker from torch.utils.cpp_extension import load_inline 3769*da0073e9SAndroid Build Coastguard Worker 3770*da0073e9SAndroid Build Coastguard Worker source = """ 3771*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/cuda/memory_snapshot.h> 3772*da0073e9SAndroid Build Coastguard Worker py::object do_snapshot() { 3773*da0073e9SAndroid Build Coastguard Worker std::string data = torch::cuda::_memory_snapshot_pickled(); 3774*da0073e9SAndroid Build Coastguard Worker return py::bytes(data); 3775*da0073e9SAndroid Build Coastguard Worker } 3776*da0073e9SAndroid Build Coastguard Worker void record(bool e, bool ctx) { 3777*da0073e9SAndroid Build Coastguard Worker torch::cuda::_record_memory_history(e, ctx, 10, ctx, ctx); 3778*da0073e9SAndroid Build Coastguard Worker } 3779*da0073e9SAndroid Build Coastguard Worker """ 3780*da0073e9SAndroid Build Coastguard Worker m = load_inline( 3781*da0073e9SAndroid Build Coastguard Worker name="snapshot", cpp_sources=[source], functions=["do_snapshot", "record"] 3782*da0073e9SAndroid Build Coastguard Worker ) 3783*da0073e9SAndroid Build Coastguard Worker for ctx in (False, True): 3784*da0073e9SAndroid Build Coastguard Worker try: 3785*da0073e9SAndroid Build Coastguard Worker m.record(True, ctx) 3786*da0073e9SAndroid Build Coastguard Worker 3787*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3788*da0073e9SAndroid Build Coastguard Worker def the_script_fn(): 3789*da0073e9SAndroid Build Coastguard Worker return torch.rand(311, 411, device="cuda") 3790*da0073e9SAndroid Build Coastguard Worker 3791*da0073e9SAndroid Build Coastguard Worker def run(): 3792*da0073e9SAndroid Build Coastguard Worker t = the_script_fn() 3793*da0073e9SAndroid Build Coastguard Worker return pickle.loads(m.do_snapshot()) 3794*da0073e9SAndroid Build Coastguard Worker 3795*da0073e9SAndroid Build Coastguard Worker mem = run() 3796*da0073e9SAndroid Build Coastguard Worker found = False 3797*da0073e9SAndroid Build Coastguard Worker for s in mem["segments"]: 3798*da0073e9SAndroid Build Coastguard Worker for b in s["blocks"]: 3799*da0073e9SAndroid Build Coastguard Worker if b["state"] == "active_allocated": 3800*da0073e9SAndroid Build Coastguard Worker if b["requested_size"] == 311 * 411 * 4: 3801*da0073e9SAndroid Build Coastguard Worker if ctx: 3802*da0073e9SAndroid Build Coastguard Worker frame_text = str(b["frames"]) 3803*da0073e9SAndroid Build Coastguard Worker # C++ frame 3804*da0073e9SAndroid Build Coastguard Worker self.assertTrue("::rand" in frame_text) 3805*da0073e9SAndroid Build Coastguard Worker # script frame 3806*da0073e9SAndroid Build Coastguard Worker self.assertTrue("the_script_fn" in frame_text) 3807*da0073e9SAndroid Build Coastguard Worker # python frame 3808*da0073e9SAndroid Build Coastguard Worker self.assertTrue("case.py" in frame_text) 3809*da0073e9SAndroid Build Coastguard Worker found = True 3810*da0073e9SAndroid Build Coastguard Worker last_action = mem["device_traces"][0][-1] 3811*da0073e9SAndroid Build Coastguard Worker self.assertTrue(last_action["action"] == "alloc") 3812*da0073e9SAndroid Build Coastguard Worker self.assertTrue(last_action["size"] == 311 * 411 * 4) 3813*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found) 3814*da0073e9SAndroid Build Coastguard Worker finally: 3815*da0073e9SAndroid Build Coastguard Worker m.record(False, False) 3816*da0073e9SAndroid Build Coastguard Worker 3817*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled") 3818*da0073e9SAndroid Build Coastguard Worker def test_notifies_oom(self): 3819*da0073e9SAndroid Build Coastguard Worker x = False 3820*da0073e9SAndroid Build Coastguard Worker 3821*da0073e9SAndroid Build Coastguard Worker def cb(device, alloc, device_alloc, device_free): 3822*da0073e9SAndroid Build Coastguard Worker nonlocal x 3823*da0073e9SAndroid Build Coastguard Worker x = True 3824*da0073e9SAndroid Build Coastguard Worker 3825*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_attach_out_of_memory_observer(cb) 3826*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch.cuda.OutOfMemoryError): 3827*da0073e9SAndroid Build Coastguard Worker torch.empty(1024 * 1024 * 1024 * 1024, device="cuda") 3828*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x) 3829*da0073e9SAndroid Build Coastguard Worker 3830*da0073e9SAndroid Build Coastguard Worker def test_allocator_fuzz(self): 3831*da0073e9SAndroid Build Coastguard Worker # fuzz 3832*da0073e9SAndroid Build Coastguard Worker state = random.getstate() 3833*da0073e9SAndroid Build Coastguard Worker random.seed(123) 3834*da0073e9SAndroid Build Coastguard Worker N = 10000 3835*da0073e9SAndroid Build Coastguard Worker try: 3836*da0073e9SAndroid Build Coastguard Worker mem = [] 3837*da0073e9SAndroid Build Coastguard Worker total = 0 3838*da0073e9SAndroid Build Coastguard Worker c = 0 3839*da0073e9SAndroid Build Coastguard Worker 3840*da0073e9SAndroid Build Coastguard Worker def alloc(): 3841*da0073e9SAndroid Build Coastguard Worker nonlocal total, c 3842*da0073e9SAndroid Build Coastguard Worker b = random.randrange(2 * 1024 * 1024 // 4, 20 * 1024 * 1024 // 4) 3843*da0073e9SAndroid Build Coastguard Worker mem.append((c, torch.full((b,), c, dtype=torch.int32, device="cuda"))) 3844*da0073e9SAndroid Build Coastguard Worker c += 1 3845*da0073e9SAndroid Build Coastguard Worker total += b 3846*da0073e9SAndroid Build Coastguard Worker 3847*da0073e9SAndroid Build Coastguard Worker def free(): 3848*da0073e9SAndroid Build Coastguard Worker nonlocal total 3849*da0073e9SAndroid Build Coastguard Worker idx = random.randrange(0, len(mem)) 3850*da0073e9SAndroid Build Coastguard Worker v, x = mem.pop(idx) 3851*da0073e9SAndroid Build Coastguard Worker assert torch.all(v == x) 3852*da0073e9SAndroid Build Coastguard Worker total -= x.numel() 3853*da0073e9SAndroid Build Coastguard Worker 3854*da0073e9SAndroid Build Coastguard Worker choices = [alloc, free, torch.cuda.memory.empty_cache] 3855*da0073e9SAndroid Build Coastguard Worker for i in range(N): 3856*da0073e9SAndroid Build Coastguard Worker while total >= 1024 * 1024 * 1024 / (4 * 10): 3857*da0073e9SAndroid Build Coastguard Worker free() 3858*da0073e9SAndroid Build Coastguard Worker (action,) = random.choices(choices, weights=[1, 1 if mem else 0, 0.1]) 3859*da0073e9SAndroid Build Coastguard Worker action() 3860*da0073e9SAndroid Build Coastguard Worker finally: 3861*da0073e9SAndroid Build Coastguard Worker random.setstate(state) 3862*da0073e9SAndroid Build Coastguard Worker 3863*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_PYNVML, "pynvml is not available") 3864*da0073e9SAndroid Build Coastguard Worker def test_nvml_get_handler(self): 3865*da0073e9SAndroid Build Coastguard Worker if not torch.version.hip: 3866*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.cuda._get_pynvml_handler() is not None) 3867*da0073e9SAndroid Build Coastguard Worker else: 3868*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.cuda._get_amdsmi_handler() is not None) 3869*da0073e9SAndroid Build Coastguard Worker 3870*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_PYNVML, "pynvml is not available") 3871*da0073e9SAndroid Build Coastguard Worker def test_temperature(self): 3872*da0073e9SAndroid Build Coastguard Worker self.assertTrue(0 <= torch.cuda.temperature() <= 150) 3873*da0073e9SAndroid Build Coastguard Worker 3874*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_PYNVML, "pynvml is not available") 3875*da0073e9SAndroid Build Coastguard Worker def test_power_draw(self): 3876*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.cuda.power_draw() >= 0) 3877*da0073e9SAndroid Build Coastguard Worker 3878*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_PYNVML, "pynvml is not available") 3879*da0073e9SAndroid Build Coastguard Worker def test_clock_speed(self): 3880*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.cuda.clock_rate() >= 0) 3881*da0073e9SAndroid Build Coastguard Worker 3882*da0073e9SAndroid Build Coastguard Worker 3883*da0073e9SAndroid Build Coastguard WorkerMIN_BLOCK_SIZE = 512 3884*da0073e9SAndroid Build Coastguard WorkerSMALL_SIZE = 1048576 3885*da0073e9SAndroid Build Coastguard WorkerSMALL_BUFFER = 2097152 3886*da0073e9SAndroid Build Coastguard WorkerLARGE_BUFFER = 20971520 3887*da0073e9SAndroid Build Coastguard Worker 3888*da0073e9SAndroid Build Coastguard Worker 3889*da0073e9SAndroid Build Coastguard Workerdef get_cudagraph_segments(pool_id): 3890*da0073e9SAndroid Build Coastguard Worker segments = torch.cuda.memory_snapshot() 3891*da0073e9SAndroid Build Coastguard Worker return [segment for segment in segments if segment["segment_pool_id"] == pool_id] 3892*da0073e9SAndroid Build Coastguard Worker 3893*da0073e9SAndroid Build Coastguard Worker 3894*da0073e9SAndroid Build Coastguard Workerdef get_all_cudagraph_segments(): 3895*da0073e9SAndroid Build Coastguard Worker segments = torch.cuda.memory_snapshot() 3896*da0073e9SAndroid Build Coastguard Worker return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)] 3897*da0073e9SAndroid Build Coastguard Worker 3898*da0073e9SAndroid Build Coastguard Worker 3899*da0073e9SAndroid Build Coastguard Workerdef cudagraphify(fn, inputs, pool=None): 3900*da0073e9SAndroid Build Coastguard Worker if not TEST_CUDA_GRAPH: 3901*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("cuda graph test is skipped") 3902*da0073e9SAndroid Build Coastguard Worker 3903*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 3904*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 3905*da0073e9SAndroid Build Coastguard Worker stream.wait_stream(torch.cuda.current_stream()) 3906*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 3907*da0073e9SAndroid Build Coastguard Worker fn(*inputs) 3908*da0073e9SAndroid Build Coastguard Worker stream.synchronize() 3909*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(stream) 3910*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 3911*da0073e9SAndroid Build Coastguard Worker 3912*da0073e9SAndroid Build Coastguard Worker graph = torch.cuda.CUDAGraph() 3913*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(graph, stream=stream, pool=pool): 3914*da0073e9SAndroid Build Coastguard Worker static_outputs = fn(*inputs) 3915*da0073e9SAndroid Build Coastguard Worker 3916*da0073e9SAndroid Build Coastguard Worker return graph, static_outputs 3917*da0073e9SAndroid Build Coastguard Worker 3918*da0073e9SAndroid Build Coastguard Worker 3919*da0073e9SAndroid Build Coastguard Workerdef int8_cuda(size): 3920*da0073e9SAndroid Build Coastguard Worker return torch.ones([size], device="cuda", dtype=torch.uint8) 3921*da0073e9SAndroid Build Coastguard Worker 3922*da0073e9SAndroid Build Coastguard Worker 3923*da0073e9SAndroid Build Coastguard Workerdef live_blocks(pool_id): 3924*da0073e9SAndroid Build Coastguard Worker blocks = 0 3925*da0073e9SAndroid Build Coastguard Worker seg = get_cudagraph_segments(pool_id) 3926*da0073e9SAndroid Build Coastguard Worker for segment in get_cudagraph_segments(pool_id): 3927*da0073e9SAndroid Build Coastguard Worker for block in segment["blocks"]: 3928*da0073e9SAndroid Build Coastguard Worker blocks += block["state"] == "active_allocated" 3929*da0073e9SAndroid Build Coastguard Worker return blocks 3930*da0073e9SAndroid Build Coastguard Worker 3931*da0073e9SAndroid Build Coastguard Worker 3932*da0073e9SAndroid Build Coastguard Workerdef tensor_metadata(x): 3933*da0073e9SAndroid Build Coastguard Worker return { 3934*da0073e9SAndroid Build Coastguard Worker "nbytes": x.untyped_storage().nbytes(), 3935*da0073e9SAndroid Build Coastguard Worker "data_ptr": x.untyped_storage().data_ptr(), 3936*da0073e9SAndroid Build Coastguard Worker "size": x.shape, 3937*da0073e9SAndroid Build Coastguard Worker "stride": x.stride(), 3938*da0073e9SAndroid Build Coastguard Worker "dtype": x.dtype, 3939*da0073e9SAndroid Build Coastguard Worker "device": x.device, 3940*da0073e9SAndroid Build Coastguard Worker "storage_offset": x.storage_offset(), 3941*da0073e9SAndroid Build Coastguard Worker } 3942*da0073e9SAndroid Build Coastguard Worker 3943*da0073e9SAndroid Build Coastguard Worker 3944*da0073e9SAndroid Build Coastguard Workerdef reconstruct_from_tensor_metadata(metadata): 3945*da0073e9SAndroid Build Coastguard Worker s = torch._C._construct_storage_from_data_pointer( 3946*da0073e9SAndroid Build Coastguard Worker metadata["data_ptr"], metadata["device"], metadata["nbytes"] 3947*da0073e9SAndroid Build Coastguard Worker ) 3948*da0073e9SAndroid Build Coastguard Worker t = torch.empty([0], device=metadata["device"], dtype=metadata["dtype"]) 3949*da0073e9SAndroid Build Coastguard Worker t.set_( 3950*da0073e9SAndroid Build Coastguard Worker source=s, 3951*da0073e9SAndroid Build Coastguard Worker storage_offset=metadata["storage_offset"], 3952*da0073e9SAndroid Build Coastguard Worker size=metadata["size"], 3953*da0073e9SAndroid Build Coastguard Worker stride=metadata["stride"], 3954*da0073e9SAndroid Build Coastguard Worker ) 3955*da0073e9SAndroid Build Coastguard Worker return t 3956*da0073e9SAndroid Build Coastguard Worker 3957*da0073e9SAndroid Build Coastguard Worker 3958*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA or TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI") 3959*da0073e9SAndroid Build Coastguard Worker@torch.testing._internal.common_utils.markDynamoStrictTest 3960*da0073e9SAndroid Build Coastguard Workerclass TestBlockStateAbsorption(TestCase): 3961*da0073e9SAndroid Build Coastguard Worker @property 3962*da0073e9SAndroid Build Coastguard Worker def expandable_segments(self): 3963*da0073e9SAndroid Build Coastguard Worker return EXPANDABLE_SEGMENTS 3964*da0073e9SAndroid Build Coastguard Worker 3965*da0073e9SAndroid Build Coastguard Worker def checkCheckpointedBlock(self, before_block, after_block): 3966*da0073e9SAndroid Build Coastguard Worker for field in ("size", "state"): 3967*da0073e9SAndroid Build Coastguard Worker self.assertEqual(before_block[field], after_block[field]) 3968*da0073e9SAndroid Build Coastguard Worker 3969*da0073e9SAndroid Build Coastguard Worker def checkCheckpointedState(self, before_segments, after_segments): 3970*da0073e9SAndroid Build Coastguard Worker # after may contain additional segments, but all of the segments in before 3971*da0073e9SAndroid Build Coastguard Worker # should be exactly equivalent to after 3972*da0073e9SAndroid Build Coastguard Worker after_ptr_to_segment = { 3973*da0073e9SAndroid Build Coastguard Worker segment["address"]: segment for segment in after_segments 3974*da0073e9SAndroid Build Coastguard Worker } 3975*da0073e9SAndroid Build Coastguard Worker 3976*da0073e9SAndroid Build Coastguard Worker for before_segment in before_segments: 3977*da0073e9SAndroid Build Coastguard Worker self.assertTrue(before_segment["address"] in after_ptr_to_segment) 3978*da0073e9SAndroid Build Coastguard Worker after_segment = after_ptr_to_segment[before_segment["address"]] 3979*da0073e9SAndroid Build Coastguard Worker 3980*da0073e9SAndroid Build Coastguard Worker for field in ( 3981*da0073e9SAndroid Build Coastguard Worker "device", 3982*da0073e9SAndroid Build Coastguard Worker "total_size", 3983*da0073e9SAndroid Build Coastguard Worker "allocated_size", 3984*da0073e9SAndroid Build Coastguard Worker "active_size", 3985*da0073e9SAndroid Build Coastguard Worker "segment_type", 3986*da0073e9SAndroid Build Coastguard Worker "segment_pool_id", 3987*da0073e9SAndroid Build Coastguard Worker ): 3988*da0073e9SAndroid Build Coastguard Worker self.assertEqual(before_segment[field], after_segment[field]) 3989*da0073e9SAndroid Build Coastguard Worker 3990*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3991*da0073e9SAndroid Build Coastguard Worker len(before_segment["blocks"]), len(after_segment["blocks"]) 3992*da0073e9SAndroid Build Coastguard Worker ) 3993*da0073e9SAndroid Build Coastguard Worker for before_block, after_block in zip( 3994*da0073e9SAndroid Build Coastguard Worker before_segment["blocks"], after_segment["blocks"] 3995*da0073e9SAndroid Build Coastguard Worker ): 3996*da0073e9SAndroid Build Coastguard Worker self.checkCheckpointedBlock(before_block, after_block) 3997*da0073e9SAndroid Build Coastguard Worker 3998*da0073e9SAndroid Build Coastguard Worker @staticmethod 3999*da0073e9SAndroid Build Coastguard Worker def setCheckpointPoolState( 4000*da0073e9SAndroid Build Coastguard Worker device, state, stale_storages_ptr, storages_deleters=None 4001*da0073e9SAndroid Build Coastguard Worker ): 4002*da0073e9SAndroid Build Coastguard Worker stale_storages_ptr = [t.untyped_storage()._cdata for t in stale_storages_ptr] 4003*da0073e9SAndroid Build Coastguard Worker storages_deleters = ( 4004*da0073e9SAndroid Build Coastguard Worker [] 4005*da0073e9SAndroid Build Coastguard Worker if not storages_deleters 4006*da0073e9SAndroid Build Coastguard Worker else [t.untyped_storage()._cdata for t in storages_deleters] 4007*da0073e9SAndroid Build Coastguard Worker ) 4008*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_setCheckpointPoolState( 4009*da0073e9SAndroid Build Coastguard Worker device, state, stale_storages_ptr, storages_deleters 4010*da0073e9SAndroid Build Coastguard Worker ) 4011*da0073e9SAndroid Build Coastguard Worker 4012*da0073e9SAndroid Build Coastguard Worker def checkFunction(self, fn, inputs, pool=None): 4013*da0073e9SAndroid Build Coastguard Worker graph, outputs = cudagraphify(fn, inputs, pool=pool) 4014*da0073e9SAndroid Build Coastguard Worker 4015*da0073e9SAndroid Build Coastguard Worker pool_id = graph.pool() 4016*da0073e9SAndroid Build Coastguard Worker device = outputs[0].device.index 4017*da0073e9SAndroid Build Coastguard Worker 4018*da0073e9SAndroid Build Coastguard Worker segments_before_checkpoint = get_cudagraph_segments(pool_id) 4019*da0073e9SAndroid Build Coastguard Worker 4020*da0073e9SAndroid Build Coastguard Worker state = torch._C._cuda_getCheckpointState(device, pool_id) 4021*da0073e9SAndroid Build Coastguard Worker self.setCheckpointPoolState(device, state, [], []) 4022*da0073e9SAndroid Build Coastguard Worker 4023*da0073e9SAndroid Build Coastguard Worker self.checkCheckpointedState( 4024*da0073e9SAndroid Build Coastguard Worker segments_before_checkpoint, get_cudagraph_segments(pool_id) 4025*da0073e9SAndroid Build Coastguard Worker ) 4026*da0073e9SAndroid Build Coastguard Worker 4027*da0073e9SAndroid Build Coastguard Worker def setUp(self): 4028*da0073e9SAndroid Build Coastguard Worker super().setUp() 4029*da0073e9SAndroid Build Coastguard Worker self.segment_length = len(get_all_cudagraph_segments()) 4030*da0073e9SAndroid Build Coastguard Worker 4031*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 4032*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 4033*da0073e9SAndroid Build Coastguard Worker gc.collect() 4034*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 4035*da0073e9SAndroid Build Coastguard Worker 4036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(get_all_cudagraph_segments()), self.segment_length) 4037*da0073e9SAndroid Build Coastguard Worker 4038*da0073e9SAndroid Build Coastguard Worker super().tearDown() 4039*da0073e9SAndroid Build Coastguard Worker 4040*da0073e9SAndroid Build Coastguard Worker def test_simple(self): 4041*da0073e9SAndroid Build Coastguard Worker def foo(): 4042*da0073e9SAndroid Build Coastguard Worker x = torch.zeros([SMALL_SIZE * 8], device="cuda", dtype=torch.uint8) 4043*da0073e9SAndroid Build Coastguard Worker x = x + x 4044*da0073e9SAndroid Build Coastguard Worker x1 = int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) 4045*da0073e9SAndroid Build Coastguard Worker y = int8_cuda(SMALL_SIZE) + x1 4046*da0073e9SAndroid Build Coastguard Worker z = int8_cuda(SMALL_SIZE) 4047*da0073e9SAndroid Build Coastguard Worker return x, y, z 4048*da0073e9SAndroid Build Coastguard Worker 4049*da0073e9SAndroid Build Coastguard Worker self.checkFunction(foo, []) 4050*da0073e9SAndroid Build Coastguard Worker 4051*da0073e9SAndroid Build Coastguard Worker def test_allocated_in_middle_of_segment(self): 4052*da0073e9SAndroid Build Coastguard Worker def foo(): 4053*da0073e9SAndroid Build Coastguard Worker small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] 4054*da0073e9SAndroid Build Coastguard Worker return small_buffers[5].add_(2) 4055*da0073e9SAndroid Build Coastguard Worker 4056*da0073e9SAndroid Build Coastguard Worker self.checkFunction(foo, []) 4057*da0073e9SAndroid Build Coastguard Worker 4058*da0073e9SAndroid Build Coastguard Worker def test_multiple_middle_allocations(self): 4059*da0073e9SAndroid Build Coastguard Worker def foo(): 4060*da0073e9SAndroid Build Coastguard Worker small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] 4061*da0073e9SAndroid Build Coastguard Worker return small_buffers[5], small_buffers[8] 4062*da0073e9SAndroid Build Coastguard Worker 4063*da0073e9SAndroid Build Coastguard Worker self.checkFunction(foo, []) 4064*da0073e9SAndroid Build Coastguard Worker 4065*da0073e9SAndroid Build Coastguard Worker def test_middle_allocations_contiguous(self): 4066*da0073e9SAndroid Build Coastguard Worker def foo(): 4067*da0073e9SAndroid Build Coastguard Worker small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] 4068*da0073e9SAndroid Build Coastguard Worker return small_buffers[5], small_buffers[6] 4069*da0073e9SAndroid Build Coastguard Worker 4070*da0073e9SAndroid Build Coastguard Worker self.checkFunction(foo, []) 4071*da0073e9SAndroid Build Coastguard Worker 4072*da0073e9SAndroid Build Coastguard Worker def test_additional_free_following_checkpoint(self): 4073*da0073e9SAndroid Build Coastguard Worker def foo(): 4074*da0073e9SAndroid Build Coastguard Worker return (int8_cuda(MIN_BLOCK_SIZE),) 4075*da0073e9SAndroid Build Coastguard Worker 4076*da0073e9SAndroid Build Coastguard Worker def foo2(): 4077*da0073e9SAndroid Build Coastguard Worker return (int8_cuda(MIN_BLOCK_SIZE),) 4078*da0073e9SAndroid Build Coastguard Worker 4079*da0073e9SAndroid Build Coastguard Worker graph, outputs = cudagraphify(foo, []) 4080*da0073e9SAndroid Build Coastguard Worker pool_id = graph.pool() 4081*da0073e9SAndroid Build Coastguard Worker 4082*da0073e9SAndroid Build Coastguard Worker segments_before_checkpoint = get_cudagraph_segments(pool_id) 4083*da0073e9SAndroid Build Coastguard Worker 4084*da0073e9SAndroid Build Coastguard Worker state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) 4085*da0073e9SAndroid Build Coastguard Worker 4086*da0073e9SAndroid Build Coastguard Worker graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool()) 4087*da0073e9SAndroid Build Coastguard Worker 4088*da0073e9SAndroid Build Coastguard Worker self.setCheckpointPoolState(outputs[0].device.index, state, outputs2, []) 4089*da0073e9SAndroid Build Coastguard Worker 4090*da0073e9SAndroid Build Coastguard Worker del outputs2 4091*da0073e9SAndroid Build Coastguard Worker 4092*da0073e9SAndroid Build Coastguard Worker self.checkCheckpointedState( 4093*da0073e9SAndroid Build Coastguard Worker segments_before_checkpoint, get_cudagraph_segments(pool_id) 4094*da0073e9SAndroid Build Coastguard Worker ) 4095*da0073e9SAndroid Build Coastguard Worker 4096*da0073e9SAndroid Build Coastguard Worker # TODO: re-enable 4097*da0073e9SAndroid Build Coastguard Worker # def test_additional_free_error(self): 4098*da0073e9SAndroid Build Coastguard Worker # def foo(): 4099*da0073e9SAndroid Build Coastguard Worker # return int8_cuda(MIN_BLOCK_SIZE), 4100*da0073e9SAndroid Build Coastguard Worker 4101*da0073e9SAndroid Build Coastguard Worker # def foo2(): 4102*da0073e9SAndroid Build Coastguard Worker # return int8_cuda(MIN_BLOCK_SIZE), 4103*da0073e9SAndroid Build Coastguard Worker 4104*da0073e9SAndroid Build Coastguard Worker # graph, outputs = cudagraphify(foo, []) 4105*da0073e9SAndroid Build Coastguard Worker # pool_id = graph.pool() 4106*da0073e9SAndroid Build Coastguard Worker 4107*da0073e9SAndroid Build Coastguard Worker # segments_before_checkpoint = get_cudagraph_segments(pool_id) 4108*da0073e9SAndroid Build Coastguard Worker 4109*da0073e9SAndroid Build Coastguard Worker # state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) 4110*da0073e9SAndroid Build Coastguard Worker 4111*da0073e9SAndroid Build Coastguard Worker # graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool()) 4112*da0073e9SAndroid Build Coastguard Worker # with self.assertRaisesRegex(Exception, "being manually freed must be passed"): 4113*da0073e9SAndroid Build Coastguard Worker # self.setCheckpointPoolState(outputs[0].device.index, state, [], []) 4114*da0073e9SAndroid Build Coastguard Worker 4115*da0073e9SAndroid Build Coastguard Worker def test_tensor_dies_after_checkpoint(self): 4116*da0073e9SAndroid Build Coastguard Worker def foo(): 4117*da0073e9SAndroid Build Coastguard Worker return int8_cuda(MIN_BLOCK_SIZE), int8_cuda(MIN_BLOCK_SIZE) 4118*da0073e9SAndroid Build Coastguard Worker 4119*da0073e9SAndroid Build Coastguard Worker graph, outputs = cudagraphify(foo, []) 4120*da0073e9SAndroid Build Coastguard Worker pool_id = graph.pool() 4121*da0073e9SAndroid Build Coastguard Worker device = outputs[0].device.index 4122*da0073e9SAndroid Build Coastguard Worker 4123*da0073e9SAndroid Build Coastguard Worker segments_before_checkpoint = get_cudagraph_segments(pool_id) 4124*da0073e9SAndroid Build Coastguard Worker state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) 4125*da0073e9SAndroid Build Coastguard Worker 4126*da0073e9SAndroid Build Coastguard Worker output_data_ptrs = [output.data_ptr() for output in outputs] 4127*da0073e9SAndroid Build Coastguard Worker 4128*da0073e9SAndroid Build Coastguard Worker del outputs 4129*da0073e9SAndroid Build Coastguard Worker 4130*da0073e9SAndroid Build Coastguard Worker self.setCheckpointPoolState(device, state, [], []) 4131*da0073e9SAndroid Build Coastguard Worker 4132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(live_blocks(pool_id), 2) 4133*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[0]) 4134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(live_blocks(pool_id), 1) 4135*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[1]) 4136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(live_blocks(pool_id), 0) 4137*da0073e9SAndroid Build Coastguard Worker 4138*da0073e9SAndroid Build Coastguard Worker def test_assigning_back_deleter_fns_to_tensor(self): 4139*da0073e9SAndroid Build Coastguard Worker def foo(x): 4140*da0073e9SAndroid Build Coastguard Worker return ( 4141*da0073e9SAndroid Build Coastguard Worker int8_cuda(SMALL_BUFFER) + x, 4142*da0073e9SAndroid Build Coastguard Worker int8_cuda(SMALL_BUFFER) + x, 4143*da0073e9SAndroid Build Coastguard Worker int8_cuda(LARGE_BUFFER) + x, 4144*da0073e9SAndroid Build Coastguard Worker ) 4145*da0073e9SAndroid Build Coastguard Worker 4146*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([1], device="cuda") 4147*da0073e9SAndroid Build Coastguard Worker graph, outputs = cudagraphify(foo, [inp]) 4148*da0073e9SAndroid Build Coastguard Worker pool_id = graph.pool() 4149*da0073e9SAndroid Build Coastguard Worker graph.replay() 4150*da0073e9SAndroid Build Coastguard Worker 4151*da0073e9SAndroid Build Coastguard Worker device = outputs[0].device.index 4152*da0073e9SAndroid Build Coastguard Worker 4153*da0073e9SAndroid Build Coastguard Worker for i in range(len(outputs)): 4154*da0073e9SAndroid Build Coastguard Worker self.assertTrue(outputs[i].mean(dtype=torch.float) == 2) 4155*da0073e9SAndroid Build Coastguard Worker 4156*da0073e9SAndroid Build Coastguard Worker state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) 4157*da0073e9SAndroid Build Coastguard Worker 4158*da0073e9SAndroid Build Coastguard Worker output_ptrs = [output.untyped_storage().data_ptr() for output in outputs] 4159*da0073e9SAndroid Build Coastguard Worker ten_metadata = [tensor_metadata(t) for t in outputs] 4160*da0073e9SAndroid Build Coastguard Worker 4161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(live_blocks(pool_id), 3) 4162*da0073e9SAndroid Build Coastguard Worker 4163*da0073e9SAndroid Build Coastguard Worker del outputs 4164*da0073e9SAndroid Build Coastguard Worker 4165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(live_blocks(pool_id), 0) 4166*da0073e9SAndroid Build Coastguard Worker 4167*da0073e9SAndroid Build Coastguard Worker reconstructed_tensors = [ 4168*da0073e9SAndroid Build Coastguard Worker reconstruct_from_tensor_metadata(metadata) for metadata in ten_metadata 4169*da0073e9SAndroid Build Coastguard Worker ] 4170*da0073e9SAndroid Build Coastguard Worker 4171*da0073e9SAndroid Build Coastguard Worker for i in range(len(reconstructed_tensors)): 4172*da0073e9SAndroid Build Coastguard Worker self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 2) 4173*da0073e9SAndroid Build Coastguard Worker 4174*da0073e9SAndroid Build Coastguard Worker inp.add_(1) 4175*da0073e9SAndroid Build Coastguard Worker graph.replay() 4176*da0073e9SAndroid Build Coastguard Worker 4177*da0073e9SAndroid Build Coastguard Worker for i in range(len(reconstructed_tensors)): 4178*da0073e9SAndroid Build Coastguard Worker self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 3) 4179*da0073e9SAndroid Build Coastguard Worker 4180*da0073e9SAndroid Build Coastguard Worker self.setCheckpointPoolState( 4181*da0073e9SAndroid Build Coastguard Worker device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]] 4182*da0073e9SAndroid Build Coastguard Worker ) 4183*da0073e9SAndroid Build Coastguard Worker 4184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(live_blocks(pool_id), 3) 4185*da0073e9SAndroid Build Coastguard Worker 4186*da0073e9SAndroid Build Coastguard Worker reconstructed_tensors[0] = None 4187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(live_blocks(pool_id), 2) 4188*da0073e9SAndroid Build Coastguard Worker 4189*da0073e9SAndroid Build Coastguard Worker reconstructed_tensors[1] = None 4190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(live_blocks(pool_id), 1) 4191*da0073e9SAndroid Build Coastguard Worker 4192*da0073e9SAndroid Build Coastguard Worker # should not change, we did not pass it in to swap data ptrs 4193*da0073e9SAndroid Build Coastguard Worker reconstructed_tensors[2] = None 4194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(live_blocks(pool_id), 1) 4195*da0073e9SAndroid Build Coastguard Worker 4196*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_cudaCachingAllocator_raw_delete(output_ptrs[2]) 4197*da0073e9SAndroid Build Coastguard Worker 4198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(live_blocks(pool_id), 0) 4199*da0073e9SAndroid Build Coastguard Worker 4200*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 4201*da0073e9SAndroid Build Coastguard Worker def test_resnet(self): 4202*da0073e9SAndroid Build Coastguard Worker import torchvision 4203*da0073e9SAndroid Build Coastguard Worker 4204*da0073e9SAndroid Build Coastguard Worker m = torchvision.models.resnet50() 4205*da0073e9SAndroid Build Coastguard Worker m.eval() 4206*da0073e9SAndroid Build Coastguard Worker m = m.cuda() 4207*da0073e9SAndroid Build Coastguard Worker 4208*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([1, 3, 255, 255], device="cuda") 4209*da0073e9SAndroid Build Coastguard Worker self.checkFunction(m, [inp]) 4210*da0073e9SAndroid Build Coastguard Worker 4211*da0073e9SAndroid Build Coastguard Worker def test_check_pool_live_allocations(self): 4212*da0073e9SAndroid Build Coastguard Worker def foo(): 4213*da0073e9SAndroid Build Coastguard Worker return torch.ones([4], device="cuda") 4214*da0073e9SAndroid Build Coastguard Worker 4215*da0073e9SAndroid Build Coastguard Worker pool = torch.cuda.graph_pool_handle() 4216*da0073e9SAndroid Build Coastguard Worker graph, outputs = cudagraphify(foo, [], pool=pool) 4217*da0073e9SAndroid Build Coastguard Worker 4218*da0073e9SAndroid Build Coastguard Worker index = outputs[0].device.index 4219*da0073e9SAndroid Build Coastguard Worker 4220*da0073e9SAndroid Build Coastguard Worker def check(live_dps): 4221*da0073e9SAndroid Build Coastguard Worker return torch._C._cuda_checkPoolLiveAllocations(index, pool, live_dps) 4222*da0073e9SAndroid Build Coastguard Worker 4223*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check({outputs[0].data_ptr()})) 4224*da0073e9SAndroid Build Coastguard Worker 4225*da0073e9SAndroid Build Coastguard Worker self.assertFalse(check({outputs[0].data_ptr(), 0})) 4226*da0073e9SAndroid Build Coastguard Worker self.assertFalse(check(set())) 4227*da0073e9SAndroid Build Coastguard Worker 4228*da0073e9SAndroid Build Coastguard Worker del outputs 4229*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check(set())) 4230*da0073e9SAndroid Build Coastguard Worker 4231*da0073e9SAndroid Build Coastguard Worker def test_allocate_in_thread_to_pool(self): 4232*da0073e9SAndroid Build Coastguard Worker def foo(): 4233*da0073e9SAndroid Build Coastguard Worker return torch.rand([4], device="cuda") 4234*da0073e9SAndroid Build Coastguard Worker 4235*da0073e9SAndroid Build Coastguard Worker pool = torch.cuda.graph_pool_handle() 4236*da0073e9SAndroid Build Coastguard Worker graph, outputs = cudagraphify(foo, [], pool=pool) 4237*da0073e9SAndroid Build Coastguard Worker device = outputs[0].device.index 4238*da0073e9SAndroid Build Coastguard Worker del outputs 4239*da0073e9SAndroid Build Coastguard Worker 4240*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 4241*da0073e9SAndroid Build Coastguard Worker def _use_cuda_memory_pool_manager(device, mem_pool): 4242*da0073e9SAndroid Build Coastguard Worker """ 4243*da0073e9SAndroid Build Coastguard Worker Context manager to use cuda graph pool for new allocations. If you use this manager 4244*da0073e9SAndroid Build Coastguard Worker all cudagraph tensors in use should be reflected in the allocator or they will be overwritten. 4245*da0073e9SAndroid Build Coastguard Worker existing_graph should already have been used in a capture, and the mem_pool must already exist. 4246*da0073e9SAndroid Build Coastguard Worker """ 4247*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 4248*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 4249*da0073e9SAndroid Build Coastguard Worker stream.wait_stream(torch.cuda.current_stream()) 4250*da0073e9SAndroid Build Coastguard Worker stream_context = torch.cuda.stream(stream) 4251*da0073e9SAndroid Build Coastguard Worker stream_context.__enter__() 4252*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool) 4253*da0073e9SAndroid Build Coastguard Worker try: 4254*da0073e9SAndroid Build Coastguard Worker yield 4255*da0073e9SAndroid Build Coastguard Worker finally: 4256*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool) 4257*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_releasePool(device, mem_pool) 4258*da0073e9SAndroid Build Coastguard Worker stream_context.__exit__(None, None, None) 4259*da0073e9SAndroid Build Coastguard Worker 4260*da0073e9SAndroid Build Coastguard Worker segments = get_cudagraph_segments(pool) 4261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(get_cudagraph_segments(pool)), 1) 4262*da0073e9SAndroid Build Coastguard Worker 4263*da0073e9SAndroid Build Coastguard Worker def use_pool(): 4264*da0073e9SAndroid Build Coastguard Worker def alloc_three(): 4265*da0073e9SAndroid Build Coastguard Worker a = int8_cuda(LARGE_BUFFER) 4266*da0073e9SAndroid Build Coastguard Worker b = int8_cuda(LARGE_BUFFER) 4267*da0073e9SAndroid Build Coastguard Worker c = a + b 4268*da0073e9SAndroid Build Coastguard Worker 4269*da0073e9SAndroid Build Coastguard Worker with _use_cuda_memory_pool_manager(device, pool): 4270*da0073e9SAndroid Build Coastguard Worker # three allocations 4271*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 4272*da0073e9SAndroid Build Coastguard Worker alloc_three() 4273*da0073e9SAndroid Build Coastguard Worker 4274*da0073e9SAndroid Build Coastguard Worker # three more allocations not in pool 4275*da0073e9SAndroid Build Coastguard Worker alloc_three() 4276*da0073e9SAndroid Build Coastguard Worker 4277*da0073e9SAndroid Build Coastguard Worker def no_pool(): 4278*da0073e9SAndroid Build Coastguard Worker # two allocations 4279*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 4280*da0073e9SAndroid Build Coastguard Worker a = int8_cuda(LARGE_BUFFER) 4281*da0073e9SAndroid Build Coastguard Worker b = int8_cuda(LARGE_BUFFER) 4282*da0073e9SAndroid Build Coastguard Worker del a, b 4283*da0073e9SAndroid Build Coastguard Worker 4284*da0073e9SAndroid Build Coastguard Worker graph_thread = threading.Thread(target=use_pool) 4285*da0073e9SAndroid Build Coastguard Worker no_graph_thread = threading.Thread(target=no_pool) 4286*da0073e9SAndroid Build Coastguard Worker graph_thread.start() 4287*da0073e9SAndroid Build Coastguard Worker no_graph_thread.start() 4288*da0073e9SAndroid Build Coastguard Worker 4289*da0073e9SAndroid Build Coastguard Worker graph_thread.join() 4290*da0073e9SAndroid Build Coastguard Worker no_graph_thread.join() 4291*da0073e9SAndroid Build Coastguard Worker 4292*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4293*da0073e9SAndroid Build Coastguard Worker len(get_cudagraph_segments(pool)), 2 if self.expandable_segments else 4 4294*da0073e9SAndroid Build Coastguard Worker ) 4295*da0073e9SAndroid Build Coastguard Worker 4296*da0073e9SAndroid Build Coastguard Worker del graph 4297*da0073e9SAndroid Build Coastguard Worker 4298*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 4299*da0073e9SAndroid Build Coastguard Worker gc.collect() 4300*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 4301*da0073e9SAndroid Build Coastguard Worker 4302*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(get_cudagraph_segments(pool)), 0) 4303*da0073e9SAndroid Build Coastguard Worker 4304*da0073e9SAndroid Build Coastguard Worker def test_no_triton_on_import(self): 4305*da0073e9SAndroid Build Coastguard Worker """Test that Trition is not imported on first GPU use""" 4306*da0073e9SAndroid Build Coastguard Worker script = "import sys; import torch; torch.rand(2, device='cuda'); print('triton' in sys.modules)" 4307*da0073e9SAndroid Build Coastguard Worker 4308*da0073e9SAndroid Build Coastguard Worker rc = ( 4309*da0073e9SAndroid Build Coastguard Worker subprocess.check_output( 4310*da0073e9SAndroid Build Coastguard Worker [sys.executable, "-c", script], 4311*da0073e9SAndroid Build Coastguard Worker # On Windows, opening the subprocess with the default CWD makes `import torch` 4312*da0073e9SAndroid Build Coastguard Worker # fail, so just set CWD to this script's directory 4313*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)), 4314*da0073e9SAndroid Build Coastguard Worker ) 4315*da0073e9SAndroid Build Coastguard Worker .strip() 4316*da0073e9SAndroid Build Coastguard Worker .decode("ascii") 4317*da0073e9SAndroid Build Coastguard Worker ) 4318*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rc, "False", "Triton was imported when importing torch!") 4319*da0073e9SAndroid Build Coastguard Worker 4320*da0073e9SAndroid Build Coastguard Worker 4321*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 4322*da0073e9SAndroid Build Coastguard Workerclass TestMemPool(TestCase): 4323*da0073e9SAndroid Build Coastguard Worker def test_mempool_id(self): 4324*da0073e9SAndroid Build Coastguard Worker pool1 = torch.cuda.graph_pool_handle() 4325*da0073e9SAndroid Build Coastguard Worker pool2 = torch.cuda.MemPool().id 4326*da0073e9SAndroid Build Coastguard Worker 4327*da0073e9SAndroid Build Coastguard Worker # first value of id in a user created pool is always zero 4328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pool1[0] == 0, pool2[0] == 0) 4329*da0073e9SAndroid Build Coastguard Worker 4330*da0073e9SAndroid Build Coastguard Worker # each call to torch.cuda.graph_pool_handle() or torch.cuda.MemPool() 4331*da0073e9SAndroid Build Coastguard Worker # increments the id 4332*da0073e9SAndroid Build Coastguard Worker self.assertTrue(abs(pool2[1] - pool1[1]) > 0) 4333*da0073e9SAndroid Build Coastguard Worker 4334*da0073e9SAndroid Build Coastguard Worker def test_mempool_with_allocator(self): 4335*da0073e9SAndroid Build Coastguard Worker pool = torch.cuda.MemPool() 4336*da0073e9SAndroid Build Coastguard Worker 4337*da0073e9SAndroid Build Coastguard Worker # MemPool doesn't have an allocator by default 4338*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pool.allocator, None) 4339*da0073e9SAndroid Build Coastguard Worker 4340*da0073e9SAndroid Build Coastguard Worker from torch.utils.cpp_extension import load_inline 4341*da0073e9SAndroid Build Coastguard Worker 4342*da0073e9SAndroid Build Coastguard Worker dummy_allocator_source = """ 4343*da0073e9SAndroid Build Coastguard Worker #include <torch/extension.h> 4344*da0073e9SAndroid Build Coastguard Worker #include <ATen/cuda/Exceptions.h> 4345*da0073e9SAndroid Build Coastguard Worker #include <cuda_runtime_api.h> 4346*da0073e9SAndroid Build Coastguard Worker 4347*da0073e9SAndroid Build Coastguard Worker extern "C" { 4348*da0073e9SAndroid Build Coastguard Worker C10_EXPORT int called_dummy_alloc = 0; 4349*da0073e9SAndroid Build Coastguard Worker C10_EXPORT int called_dummy_free = 0; 4350*da0073e9SAndroid Build Coastguard Worker 4351*da0073e9SAndroid Build Coastguard Worker // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865 4352*da0073e9SAndroid Build Coastguard Worker C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { 4353*da0073e9SAndroid Build Coastguard Worker called_dummy_alloc = 123; 4354*da0073e9SAndroid Build Coastguard Worker void* ptr; 4355*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaMallocManaged(&ptr, size)); 4356*da0073e9SAndroid Build Coastguard Worker return ptr; 4357*da0073e9SAndroid Build Coastguard Worker } 4358*da0073e9SAndroid Build Coastguard Worker 4359*da0073e9SAndroid Build Coastguard Worker C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) { 4360*da0073e9SAndroid Build Coastguard Worker called_dummy_free = 321; 4361*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaFree(ptr)); 4362*da0073e9SAndroid Build Coastguard Worker } 4363*da0073e9SAndroid Build Coastguard Worker } 4364*da0073e9SAndroid Build Coastguard Worker """ 4365*da0073e9SAndroid Build Coastguard Worker dummy_allocator_libname = "dummy_allocator" 4366*da0073e9SAndroid Build Coastguard Worker dummy_allocator = load_inline( 4367*da0073e9SAndroid Build Coastguard Worker name=dummy_allocator_libname, 4368*da0073e9SAndroid Build Coastguard Worker cpp_sources=dummy_allocator_source, 4369*da0073e9SAndroid Build Coastguard Worker is_python_module=False, 4370*da0073e9SAndroid Build Coastguard Worker keep_intermediates=False, 4371*da0073e9SAndroid Build Coastguard Worker verbose=True, 4372*da0073e9SAndroid Build Coastguard Worker with_cuda=True, 4373*da0073e9SAndroid Build Coastguard Worker ) 4374*da0073e9SAndroid Build Coastguard Worker allocator = torch.cuda.memory.CUDAPluggableAllocator( 4375*da0073e9SAndroid Build Coastguard Worker dummy_allocator, 4376*da0073e9SAndroid Build Coastguard Worker "dummy_alloc", 4377*da0073e9SAndroid Build Coastguard Worker "dummy_free", 4378*da0073e9SAndroid Build Coastguard Worker ) 4379*da0073e9SAndroid Build Coastguard Worker pool = torch.cuda.MemPool(allocator.allocator()) 4380*da0073e9SAndroid Build Coastguard Worker 4381*da0073e9SAndroid Build Coastguard Worker # pool should point to the same allocator as the one passed into it 4382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(allocator.allocator(), pool.allocator) 4383*da0073e9SAndroid Build Coastguard Worker 4384*da0073e9SAndroid Build Coastguard Worker # no allocations happened yet, so called_dummy_alloc should be 0 4385*da0073e9SAndroid Build Coastguard Worker alloc_lib = ctypes.CDLL(dummy_allocator) 4386*da0073e9SAndroid Build Coastguard Worker called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc") 4387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called_dummy_alloc.value, 0) 4388*da0073e9SAndroid Build Coastguard Worker 4389*da0073e9SAndroid Build Coastguard Worker with torch.cuda.use_mem_pool(pool): 4390*da0073e9SAndroid Build Coastguard Worker out = torch.randn(1, device="cuda") 4391*da0073e9SAndroid Build Coastguard Worker 4392*da0073e9SAndroid Build Coastguard Worker # called_dummy_alloc should be 123 if dummy_alloc was used to allocate 4393*da0073e9SAndroid Build Coastguard Worker # out tensor 4394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called_dummy_alloc.value, 123) 4395*da0073e9SAndroid Build Coastguard Worker 4396*da0073e9SAndroid Build Coastguard Worker def test_mempool_context(self): 4397*da0073e9SAndroid Build Coastguard Worker active_pool = torch.cuda.MemPoolContext.active_pool() 4398*da0073e9SAndroid Build Coastguard Worker 4399*da0073e9SAndroid Build Coastguard Worker # there is no active pool if none was made active 4400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(active_pool, None) 4401*da0073e9SAndroid Build Coastguard Worker 4402*da0073e9SAndroid Build Coastguard Worker pool = torch.cuda.MemPool() 4403*da0073e9SAndroid Build Coastguard Worker ctx = torch.cuda.MemPoolContext(pool) 4404*da0073e9SAndroid Build Coastguard Worker active_pool = torch.cuda.MemPoolContext.active_pool() 4405*da0073e9SAndroid Build Coastguard Worker 4406*da0073e9SAndroid Build Coastguard Worker # pool was made active 4407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(active_pool, pool) 4408*da0073e9SAndroid Build Coastguard Worker 4409*da0073e9SAndroid Build Coastguard Worker del ctx 4410*da0073e9SAndroid Build Coastguard Worker active_pool = torch.cuda.MemPoolContext.active_pool() 4411*da0073e9SAndroid Build Coastguard Worker 4412*da0073e9SAndroid Build Coastguard Worker # ctx was deleted, so active pool is the previous one 4413*da0073e9SAndroid Build Coastguard Worker self.assertEqual(active_pool, None) 4414*da0073e9SAndroid Build Coastguard Worker 4415*da0073e9SAndroid Build Coastguard Worker def test_mempool_multithread(self): 4416*da0073e9SAndroid Build Coastguard Worker pool_ids = [] 4417*da0073e9SAndroid Build Coastguard Worker active_pool_ids = [] 4418*da0073e9SAndroid Build Coastguard Worker 4419*da0073e9SAndroid Build Coastguard Worker def create_mempool_and_make_active(): 4420*da0073e9SAndroid Build Coastguard Worker pool = torch.cuda.MemPool() 4421*da0073e9SAndroid Build Coastguard Worker pool_ids.extend([pool.id]) 4422*da0073e9SAndroid Build Coastguard Worker 4423*da0073e9SAndroid Build Coastguard Worker ctx = torch.cuda.MemPoolContext(pool) 4424*da0073e9SAndroid Build Coastguard Worker active_pool = torch.cuda.MemPoolContext.active_pool() 4425*da0073e9SAndroid Build Coastguard Worker active_pool_ids.extend([active_pool.id]) 4426*da0073e9SAndroid Build Coastguard Worker del ctx 4427*da0073e9SAndroid Build Coastguard Worker 4428*da0073e9SAndroid Build Coastguard Worker num_threads = 4 4429*da0073e9SAndroid Build Coastguard Worker threads = [ 4430*da0073e9SAndroid Build Coastguard Worker threading.Thread(target=create_mempool_and_make_active) 4431*da0073e9SAndroid Build Coastguard Worker for t in range(num_threads) 4432*da0073e9SAndroid Build Coastguard Worker ] 4433*da0073e9SAndroid Build Coastguard Worker for thread in threads: 4434*da0073e9SAndroid Build Coastguard Worker thread.start() 4435*da0073e9SAndroid Build Coastguard Worker for thread in threads: 4436*da0073e9SAndroid Build Coastguard Worker thread.join() 4437*da0073e9SAndroid Build Coastguard Worker 4438*da0073e9SAndroid Build Coastguard Worker # each thread should create a unique mempool, since 4439*da0073e9SAndroid Build Coastguard Worker # mempool id creation is atomic 4440*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(set(pool_ids)), 4) 4441*da0073e9SAndroid Build Coastguard Worker 4442*da0073e9SAndroid Build Coastguard Worker # each thread should have different active mempool, since 4443*da0073e9SAndroid Build Coastguard Worker # the pointer to the mempool is thread local 4444*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(set(active_pool_ids)), 4) 4445*da0073e9SAndroid Build Coastguard Worker 4446*da0073e9SAndroid Build Coastguard Worker 4447*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 4448*da0073e9SAndroid Build Coastguard Worker@torch.testing._internal.common_utils.markDynamoStrictTest 4449*da0073e9SAndroid Build Coastguard Workerclass TestCudaOptims(TestCase): 4450*da0073e9SAndroid Build Coastguard Worker # These tests will be instantiate with instantiate_device_type_tests 4451*da0073e9SAndroid Build Coastguard Worker # to apply the new OptimizerInfo structure. 4452*da0073e9SAndroid Build Coastguard Worker 4453*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4454*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 4455*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs" 4456*da0073e9SAndroid Build Coastguard Worker ) 4457*da0073e9SAndroid Build Coastguard Worker @optims( 4458*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if optim.has_capturable_arg], 4459*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32], 4460*da0073e9SAndroid Build Coastguard Worker ) 4461*da0073e9SAndroid Build Coastguard Worker def test_graph_optims(self, device, dtype, optim_info): 4462*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 4463*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 4464*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable",) 4465*da0073e9SAndroid Build Coastguard Worker ) 4466*da0073e9SAndroid Build Coastguard Worker 4467*da0073e9SAndroid Build Coastguard Worker steps_warmup = 3 4468*da0073e9SAndroid Build Coastguard Worker steps_train = 2 4469*da0073e9SAndroid Build Coastguard Worker 4470*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 4471*da0073e9SAndroid Build Coastguard Worker kwargs = optim_input.kwargs 4472*da0073e9SAndroid Build Coastguard Worker 4473*da0073e9SAndroid Build Coastguard Worker # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam 4474*da0073e9SAndroid Build Coastguard Worker # and torch.optim.adamw 4475*da0073e9SAndroid Build Coastguard Worker kwargs["lr"] = 0.1 4476*da0073e9SAndroid Build Coastguard Worker 4477*da0073e9SAndroid Build Coastguard Worker for actually_do_graphs in (True, False): 4478*da0073e9SAndroid Build Coastguard Worker params = [ 4479*da0073e9SAndroid Build Coastguard Worker torch.randn((i + 5, i + 5), device=device) for i in range(2) 4480*da0073e9SAndroid Build Coastguard Worker ] + [torch.randn((), device=device)] 4481*da0073e9SAndroid Build Coastguard Worker params_control = [p.clone().requires_grad_() for p in params] 4482*da0073e9SAndroid Build Coastguard Worker params_graphed = [p.clone().requires_grad_() for p in params] 4483*da0073e9SAndroid Build Coastguard Worker 4484*da0073e9SAndroid Build Coastguard Worker grads = [ 4485*da0073e9SAndroid Build Coastguard Worker [torch.randn_like(p) for p in params] 4486*da0073e9SAndroid Build Coastguard Worker for _ in range(steps_warmup + steps_train) 4487*da0073e9SAndroid Build Coastguard Worker ] 4488*da0073e9SAndroid Build Coastguard Worker 4489*da0073e9SAndroid Build Coastguard Worker # Control (capturable=False) 4490*da0073e9SAndroid Build Coastguard Worker kwargs["capturable"] = False 4491*da0073e9SAndroid Build Coastguard Worker 4492*da0073e9SAndroid Build Coastguard Worker opt = optim_cls(params_control, **kwargs) 4493*da0073e9SAndroid Build Coastguard Worker for i in range(steps_warmup + steps_train): 4494*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_control): 4495*da0073e9SAndroid Build Coastguard Worker p.grad = grads[i][j] 4496*da0073e9SAndroid Build Coastguard Worker opt.step() 4497*da0073e9SAndroid Build Coastguard Worker 4498*da0073e9SAndroid Build Coastguard Worker # capturable=True 4499*da0073e9SAndroid Build Coastguard Worker kwargs["capturable"] = True 4500*da0073e9SAndroid Build Coastguard Worker opt = optim_cls(params_graphed, **kwargs) 4501*da0073e9SAndroid Build Coastguard Worker 4502*da0073e9SAndroid Build Coastguard Worker for i in range(steps_warmup): 4503*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_graphed): 4504*da0073e9SAndroid Build Coastguard Worker p.grad = grads[i][j] 4505*da0073e9SAndroid Build Coastguard Worker opt.step() 4506*da0073e9SAndroid Build Coastguard Worker 4507*da0073e9SAndroid Build Coastguard Worker if actually_do_graphs: 4508*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 4509*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(g): 4510*da0073e9SAndroid Build Coastguard Worker opt.step() 4511*da0073e9SAndroid Build Coastguard Worker 4512*da0073e9SAndroid Build Coastguard Worker for i in range(steps_train): 4513*da0073e9SAndroid Build Coastguard Worker if actually_do_graphs: 4514*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_graphed): 4515*da0073e9SAndroid Build Coastguard Worker p.grad.copy_(grads[i + steps_warmup][j]) 4516*da0073e9SAndroid Build Coastguard Worker g.replay() 4517*da0073e9SAndroid Build Coastguard Worker else: 4518*da0073e9SAndroid Build Coastguard Worker # Passing capturable=True to the constructor and running without graphs should still be 4519*da0073e9SAndroid Build Coastguard Worker # numerically correct, even if it's not ideal for performance. 4520*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_graphed): 4521*da0073e9SAndroid Build Coastguard Worker p.grad = grads[i + steps_warmup][j] 4522*da0073e9SAndroid Build Coastguard Worker opt.step() 4523*da0073e9SAndroid Build Coastguard Worker 4524*da0073e9SAndroid Build Coastguard Worker for p_control, p_graphed in zip(params_control, params_graphed): 4525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p_control, p_graphed) 4526*da0073e9SAndroid Build Coastguard Worker 4527*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4528*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 4529*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 4530*da0073e9SAndroid Build Coastguard Worker ) 4531*da0073e9SAndroid Build Coastguard Worker @optims( 4532*da0073e9SAndroid Build Coastguard Worker [ 4533*da0073e9SAndroid Build Coastguard Worker optim 4534*da0073e9SAndroid Build Coastguard Worker for optim in optim_db 4535*da0073e9SAndroid Build Coastguard Worker if "fused" in optim.supported_impls and "cuda" in optim.supports_fused_on 4536*da0073e9SAndroid Build Coastguard Worker ], 4537*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32], 4538*da0073e9SAndroid Build Coastguard Worker ) 4539*da0073e9SAndroid Build Coastguard Worker def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info): 4540*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 4541*da0073e9SAndroid Build Coastguard Worker 4542*da0073e9SAndroid Build Coastguard Worker steps_warmup = 3 4543*da0073e9SAndroid Build Coastguard Worker steps_train = 2 4544*da0073e9SAndroid Build Coastguard Worker 4545*da0073e9SAndroid Build Coastguard Worker optim_inputs = optim_info.optim_inputs_func(device=device) 4546*da0073e9SAndroid Build Coastguard Worker 4547*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_inputs: 4548*da0073e9SAndroid Build Coastguard Worker kwargs = optim_input.kwargs 4549*da0073e9SAndroid Build Coastguard Worker kwargs["fused"] = True 4550*da0073e9SAndroid Build Coastguard Worker 4551*da0073e9SAndroid Build Coastguard Worker for actually_do_graphs in ( 4552*da0073e9SAndroid Build Coastguard Worker (True, False) if optim_info.has_capturable_arg else (True,) 4553*da0073e9SAndroid Build Coastguard Worker ): 4554*da0073e9SAndroid Build Coastguard Worker params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)] 4555*da0073e9SAndroid Build Coastguard Worker params_control = [p.clone().requires_grad_() for p in params] 4556*da0073e9SAndroid Build Coastguard Worker params_graphed = [p.clone().requires_grad_() for p in params] 4557*da0073e9SAndroid Build Coastguard Worker 4558*da0073e9SAndroid Build Coastguard Worker # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. 4559*da0073e9SAndroid Build Coastguard Worker grads = [ 4560*da0073e9SAndroid Build Coastguard Worker [torch.randn_like(p) for p in params] 4561*da0073e9SAndroid Build Coastguard Worker for _ in range(steps_warmup + steps_train) 4562*da0073e9SAndroid Build Coastguard Worker ] 4563*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 4564*da0073e9SAndroid Build Coastguard Worker grads_control = [[g.clone() for g in gs] for gs in grads] 4565*da0073e9SAndroid Build Coastguard Worker grads_graphed = [[g.clone() for g in gs] for gs in grads] 4566*da0073e9SAndroid Build Coastguard Worker 4567*da0073e9SAndroid Build Coastguard Worker # Gradient Scaler 4568*da0073e9SAndroid Build Coastguard Worker scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) 4569*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 4570*da0073e9SAndroid Build Coastguard Worker scaler_for_control._lazy_init_scale_growth_tracker(device) 4571*da0073e9SAndroid Build Coastguard Worker 4572*da0073e9SAndroid Build Coastguard Worker scaler_for_graphed = torch.cuda.amp.GradScaler() 4573*da0073e9SAndroid Build Coastguard Worker scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) 4574*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 4575*da0073e9SAndroid Build Coastguard Worker scaler_for_graphed._lazy_init_scale_growth_tracker(device) 4576*da0073e9SAndroid Build Coastguard Worker 4577*da0073e9SAndroid Build Coastguard Worker # Control (capturable=False) 4578*da0073e9SAndroid Build Coastguard Worker if optim_info.has_capturable_arg: 4579*da0073e9SAndroid Build Coastguard Worker kwargs["capturable"] = False 4580*da0073e9SAndroid Build Coastguard Worker opt = optim_cls(params_control, **kwargs) 4581*da0073e9SAndroid Build Coastguard Worker 4582*da0073e9SAndroid Build Coastguard Worker for i in range(steps_warmup + steps_train): 4583*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_control): 4584*da0073e9SAndroid Build Coastguard Worker p.grad = grads_control[i][j] 4585*da0073e9SAndroid Build Coastguard Worker scaler_for_control.step(opt) 4586*da0073e9SAndroid Build Coastguard Worker scaler_for_control.update() 4587*da0073e9SAndroid Build Coastguard Worker 4588*da0073e9SAndroid Build Coastguard Worker # capturable=True 4589*da0073e9SAndroid Build Coastguard Worker if optim_info.has_capturable_arg: 4590*da0073e9SAndroid Build Coastguard Worker kwargs["capturable"] = True 4591*da0073e9SAndroid Build Coastguard Worker opt = optim_cls(params_graphed, **kwargs) 4592*da0073e9SAndroid Build Coastguard Worker 4593*da0073e9SAndroid Build Coastguard Worker for i in range(steps_warmup): 4594*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_graphed): 4595*da0073e9SAndroid Build Coastguard Worker p.grad = grads_graphed[i][j] 4596*da0073e9SAndroid Build Coastguard Worker scaler_for_graphed.step(opt) 4597*da0073e9SAndroid Build Coastguard Worker scaler_for_graphed.update() 4598*da0073e9SAndroid Build Coastguard Worker 4599*da0073e9SAndroid Build Coastguard Worker if actually_do_graphs: 4600*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 4601*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(g): 4602*da0073e9SAndroid Build Coastguard Worker scaler_for_graphed.step(opt) 4603*da0073e9SAndroid Build Coastguard Worker scaler_for_graphed.update() 4604*da0073e9SAndroid Build Coastguard Worker 4605*da0073e9SAndroid Build Coastguard Worker for i in range(steps_train): 4606*da0073e9SAndroid Build Coastguard Worker if actually_do_graphs: 4607*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_graphed): 4608*da0073e9SAndroid Build Coastguard Worker p.grad.copy_(grads_graphed[i + steps_warmup][j]) 4609*da0073e9SAndroid Build Coastguard Worker g.replay() 4610*da0073e9SAndroid Build Coastguard Worker else: 4611*da0073e9SAndroid Build Coastguard Worker # Passing capturable=True to the constructor and running without graphs should still be 4612*da0073e9SAndroid Build Coastguard Worker # numerically correct, even if it's not ideal for performance. 4613*da0073e9SAndroid Build Coastguard Worker for j, p in enumerate(params_graphed): 4614*da0073e9SAndroid Build Coastguard Worker p.grad = grads_graphed[i + steps_warmup][j] 4615*da0073e9SAndroid Build Coastguard Worker scaler_for_graphed.step(opt) 4616*da0073e9SAndroid Build Coastguard Worker scaler_for_graphed.update() 4617*da0073e9SAndroid Build Coastguard Worker 4618*da0073e9SAndroid Build Coastguard Worker for p_control, p_graphed in zip(params_control, params_graphed): 4619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p_control, p_graphed) 4620*da0073e9SAndroid Build Coastguard Worker 4621*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4622*da0073e9SAndroid Build Coastguard Worker @optims( 4623*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if "fused" in optim.supported_impls], 4624*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32], 4625*da0073e9SAndroid Build Coastguard Worker ) 4626*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info): 4627*da0073e9SAndroid Build Coastguard Worker device = device.split(":")[0] 4628*da0073e9SAndroid Build Coastguard Worker if device not in optim_info.supports_fused_on: 4629*da0073e9SAndroid Build Coastguard Worker self.skipTest( 4630*da0073e9SAndroid Build Coastguard Worker f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" 4631*da0073e9SAndroid Build Coastguard Worker ) 4632*da0073e9SAndroid Build Coastguard Worker optim_inputs = optim_info.optim_inputs_func(device=device) 4633*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 4634*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_inputs: 4635*da0073e9SAndroid Build Coastguard Worker for _separate_unscale in (True, False): 4636*da0073e9SAndroid Build Coastguard Worker kwargs = optim_input.kwargs 4637*da0073e9SAndroid Build Coastguard Worker kwargs["fused"] = True 4638*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(20) 4639*da0073e9SAndroid Build Coastguard Worker ( 4640*da0073e9SAndroid Build Coastguard Worker mod_control, 4641*da0073e9SAndroid Build Coastguard Worker mod_scaling, 4642*da0073e9SAndroid Build Coastguard Worker opt_control, 4643*da0073e9SAndroid Build Coastguard Worker opt_scaling, 4644*da0073e9SAndroid Build Coastguard Worker data, 4645*da0073e9SAndroid Build Coastguard Worker loss_fn, 4646*da0073e9SAndroid Build Coastguard Worker _, 4647*da0073e9SAndroid Build Coastguard Worker ) = _create_scaling_case( 4648*da0073e9SAndroid Build Coastguard Worker optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, device=device 4649*da0073e9SAndroid Build Coastguard Worker ) 4650*da0073e9SAndroid Build Coastguard Worker optimizer_kwargs = deepcopy(kwargs) 4651*da0073e9SAndroid Build Coastguard Worker optimizer_kwargs["fused"] = False 4652*da0073e9SAndroid Build Coastguard Worker if "lr" not in kwargs: 4653*da0073e9SAndroid Build Coastguard Worker # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr 4654*da0073e9SAndroid Build Coastguard Worker optimizer_kwargs["lr"] = 1.0 4655*da0073e9SAndroid Build Coastguard Worker opt_control = optim_cls(mod_control.parameters(), **optimizer_kwargs) 4656*da0073e9SAndroid Build Coastguard Worker scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0) 4657*da0073e9SAndroid Build Coastguard Worker scaler_control = torch.amp.GradScaler(device, init_scale=128.0) 4658*da0073e9SAndroid Build Coastguard Worker tracker = TensorTracker() 4659*da0073e9SAndroid Build Coastguard Worker for input, target in data: 4660*da0073e9SAndroid Build Coastguard Worker opt_control.zero_grad() 4661*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type=device, dtype=torch.half): 4662*da0073e9SAndroid Build Coastguard Worker output_control = mod_control(input) 4663*da0073e9SAndroid Build Coastguard Worker loss_control = loss_fn(output_control, target) 4664*da0073e9SAndroid Build Coastguard Worker scaler_control.scale(loss_control).backward() 4665*da0073e9SAndroid Build Coastguard Worker scaler_control.step(opt_control) 4666*da0073e9SAndroid Build Coastguard Worker scaler_control.update() 4667*da0073e9SAndroid Build Coastguard Worker 4668*da0073e9SAndroid Build Coastguard Worker opt_scaling.zero_grad() 4669*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type=device, dtype=torch.half): 4670*da0073e9SAndroid Build Coastguard Worker output_scaling = mod_scaling(input) 4671*da0073e9SAndroid Build Coastguard Worker loss_scaling = loss_fn(output_scaling, target) 4672*da0073e9SAndroid Build Coastguard Worker scaler_scaling.scale(loss_scaling).backward() 4673*da0073e9SAndroid Build Coastguard Worker if _separate_unscale: 4674*da0073e9SAndroid Build Coastguard Worker scaler_scaling.unscale_(opt_scaling) 4675*da0073e9SAndroid Build Coastguard Worker scaler_scaling.step(opt_scaling) 4676*da0073e9SAndroid Build Coastguard Worker scaler_scaling.update() 4677*da0073e9SAndroid Build Coastguard Worker 4678*da0073e9SAndroid Build Coastguard Worker tracker.add(loss_control) 4679*da0073e9SAndroid Build Coastguard Worker tracker.pop_check_set(loss_scaling, self) 4680*da0073e9SAndroid Build Coastguard Worker for param_control, param_scaling in zip( 4681*da0073e9SAndroid Build Coastguard Worker mod_control.parameters(), mod_scaling.parameters() 4682*da0073e9SAndroid Build Coastguard Worker ): 4683*da0073e9SAndroid Build Coastguard Worker tracker.add(param_control.grad) 4684*da0073e9SAndroid Build Coastguard Worker tracker.pop_check_set(param_scaling.grad, self) 4685*da0073e9SAndroid Build Coastguard Worker tracker.add(param_control) 4686*da0073e9SAndroid Build Coastguard Worker tracker.pop_check_set(param_scaling, self) 4687*da0073e9SAndroid Build Coastguard Worker 4688*da0073e9SAndroid Build Coastguard Worker state_control, state_scaling = ( 4689*da0073e9SAndroid Build Coastguard Worker opt_control.state[param_control], 4690*da0073e9SAndroid Build Coastguard Worker opt_scaling.state[param_scaling], 4691*da0073e9SAndroid Build Coastguard Worker ) 4692*da0073e9SAndroid Build Coastguard Worker 4693*da0073e9SAndroid Build Coastguard Worker for k in state_control: 4694*da0073e9SAndroid Build Coastguard Worker actual = state_scaling[k] 4695*da0073e9SAndroid Build Coastguard Worker if k == "step": 4696*da0073e9SAndroid Build Coastguard Worker actual = actual.squeeze() 4697*da0073e9SAndroid Build Coastguard Worker tracker.add(state_control[k]) 4698*da0073e9SAndroid Build Coastguard Worker tracker.pop_check_set(actual, self) 4699*da0073e9SAndroid Build Coastguard Worker 4700*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4701*da0073e9SAndroid Build Coastguard Worker @parametrize("in_place_unscale", [False, True]) 4702*da0073e9SAndroid Build Coastguard Worker @optims( 4703*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if "cuda" in optim.supports_fused_on], 4704*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32], 4705*da0073e9SAndroid Build Coastguard Worker ) 4706*da0073e9SAndroid Build Coastguard Worker def test_grad_scaler_with_preset_grad_scale( 4707*da0073e9SAndroid Build Coastguard Worker self, device, dtype, optim_info, in_place_unscale 4708*da0073e9SAndroid Build Coastguard Worker ): 4709*da0073e9SAndroid Build Coastguard Worker weight = torch.ones((5, 5), device="cuda", requires_grad=True) 4710*da0073e9SAndroid Build Coastguard Worker weight.grad = torch.full_like(weight, fill_value=15) 4711*da0073e9SAndroid Build Coastguard Worker opt = optim_info.optim_cls([weight], lr=0.1, fused=True) 4712*da0073e9SAndroid Build Coastguard Worker scaler = torch.amp.GradScaler(init_scale=5) 4713*da0073e9SAndroid Build Coastguard Worker 4714*da0073e9SAndroid Build Coastguard Worker # simulate scaling a loss 4715*da0073e9SAndroid Build Coastguard Worker scaler.scale(torch.ones(5)) 4716*da0073e9SAndroid Build Coastguard Worker 4717*da0073e9SAndroid Build Coastguard Worker if in_place_unscale: 4718*da0073e9SAndroid Build Coastguard Worker scaler.unscale_(opt) 4719*da0073e9SAndroid Build Coastguard Worker # the gradient should have been divided in-place 4720*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight.grad, torch.full_like(weight, fill_value=3)) 4721*da0073e9SAndroid Build Coastguard Worker 4722*da0073e9SAndroid Build Coastguard Worker # the user sets a `grad_scale` value which should be fused with the optimizer step 4723*da0073e9SAndroid Build Coastguard Worker opt.grad_scale = torch.Tensor([3]).cuda() 4724*da0073e9SAndroid Build Coastguard Worker scaler.step(opt) 4725*da0073e9SAndroid Build Coastguard Worker 4726*da0073e9SAndroid Build Coastguard Worker # check that the user's grad_scale was respected (i.e. the gradient was divided by 5 * 3) 4727*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight.grad, torch.full_like(weight, fill_value=1)) 4728*da0073e9SAndroid Build Coastguard Worker 4729*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4730*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 4731*da0073e9SAndroid Build Coastguard Worker not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" 4732*da0073e9SAndroid Build Coastguard Worker ) 4733*da0073e9SAndroid Build Coastguard Worker @parametrize("foreach, fused", [(False, False), (True, False), (False, True)]) 4734*da0073e9SAndroid Build Coastguard Worker @optims( 4735*da0073e9SAndroid Build Coastguard Worker [ 4736*da0073e9SAndroid Build Coastguard Worker optim 4737*da0073e9SAndroid Build Coastguard Worker for optim in optim_db 4738*da0073e9SAndroid Build Coastguard Worker if "foreach" in optim.supported_impls and "cuda" in optim.supports_fused_on 4739*da0073e9SAndroid Build Coastguard Worker ], 4740*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32], 4741*da0073e9SAndroid Build Coastguard Worker ) 4742*da0073e9SAndroid Build Coastguard Worker def test_graph_grad_scaling(self, device, dtype, optim_info, foreach, fused): 4743*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 4744*da0073e9SAndroid Build Coastguard Worker 4745*da0073e9SAndroid Build Coastguard Worker scaler = torch.amp.GradScaler(device="cuda", init_scale=4.0) 4746*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 4747*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 4748*da0073e9SAndroid Build Coastguard Worker 4749*da0073e9SAndroid Build Coastguard Worker weight = torch.ones((100,), device="cuda", requires_grad=True) 4750*da0073e9SAndroid Build Coastguard Worker opt = optim_info.optim_cls([weight], lr=0.1, foreach=foreach, fused=fused) 4751*da0073e9SAndroid Build Coastguard Worker static_input = torch.ones_like(weight) 4752*da0073e9SAndroid Build Coastguard Worker static_grad = torch.ones_like(weight) 4753*da0073e9SAndroid Build Coastguard Worker 4754*da0073e9SAndroid Build Coastguard Worker # warmup 4755*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 4756*da0073e9SAndroid Build Coastguard Worker s.wait_stream(torch.cuda.current_stream()) 4757*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 4758*da0073e9SAndroid Build Coastguard Worker loss = (weight.half() * static_input).sum() 4759*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss).backward() 4760*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s) 4761*da0073e9SAndroid Build Coastguard Worker 4762*da0073e9SAndroid Build Coastguard Worker opt.zero_grad(set_to_none=True) 4763*da0073e9SAndroid Build Coastguard Worker 4764*da0073e9SAndroid Build Coastguard Worker # capture 4765*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 4766*da0073e9SAndroid Build Coastguard Worker g.capture_begin() 4767*da0073e9SAndroid Build Coastguard Worker loss = (weight.half() * static_input).sum() 4768*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss).backward() 4769*da0073e9SAndroid Build Coastguard Worker g.capture_end() 4770*da0073e9SAndroid Build Coastguard Worker 4771*da0073e9SAndroid Build Coastguard Worker input_vals = [5, 20000, 5, 40000] 4772*da0073e9SAndroid Build Coastguard Worker # If the scale gets updated properly, these are the scale, growth tracker, 4773*da0073e9SAndroid Build Coastguard Worker # and grad values we expect. 4774*da0073e9SAndroid Build Coastguard Worker expected_scales = [4, 2, 2, 1] 4775*da0073e9SAndroid Build Coastguard Worker expected_growth_trackers = [1, 0, 1, 0] 4776*da0073e9SAndroid Build Coastguard Worker expected_grad_vals = [5 * 4, float("inf"), 5 * 2, float("inf")] 4777*da0073e9SAndroid Build Coastguard Worker 4778*da0073e9SAndroid Build Coastguard Worker for data, scale, growth_tracker, grad_val in zip( 4779*da0073e9SAndroid Build Coastguard Worker input_vals, expected_scales, expected_growth_trackers, expected_grad_vals 4780*da0073e9SAndroid Build Coastguard Worker ): 4781*da0073e9SAndroid Build Coastguard Worker static_input.fill_(data) 4782*da0073e9SAndroid Build Coastguard Worker g.replay() 4783*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight.grad, torch.full_like(weight.grad, grad_val)) 4784*da0073e9SAndroid Build Coastguard Worker scaler.step(opt) 4785*da0073e9SAndroid Build Coastguard Worker scaler.update() 4786*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scaler._scale, scale) 4787*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scaler._growth_tracker, growth_tracker) 4788*da0073e9SAndroid Build Coastguard Worker 4789*da0073e9SAndroid Build Coastguard Worker 4790*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 4791*da0073e9SAndroid Build Coastguard Workerclass TestGDS(TestCase): 4792*da0073e9SAndroid Build Coastguard Worker def _get_tmp_dir_fs_type(self): 4793*da0073e9SAndroid Build Coastguard Worker my_path = os.path.realpath("/tmp") 4794*da0073e9SAndroid Build Coastguard Worker root_type = "" 4795*da0073e9SAndroid Build Coastguard Worker for part in psutil.disk_partitions(): 4796*da0073e9SAndroid Build Coastguard Worker if part.mountpoint == "/": 4797*da0073e9SAndroid Build Coastguard Worker root_type = part.fstype 4798*da0073e9SAndroid Build Coastguard Worker continue 4799*da0073e9SAndroid Build Coastguard Worker if part.mountpoint == my_path: 4800*da0073e9SAndroid Build Coastguard Worker return part.fstype 4801*da0073e9SAndroid Build Coastguard Worker return root_type 4802*da0073e9SAndroid Build Coastguard Worker 4803*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Disabling as USE_CUFILE=0 by default in builds") 4804*da0073e9SAndroid Build Coastguard Worker def test_gds_read_write_tensors(self): 4805*da0073e9SAndroid Build Coastguard Worker if self._get_tmp_dir_fs_type() not in ("ext4", "xfs"): 4806*da0073e9SAndroid Build Coastguard Worker self.skipTest("GPUDirect Storage requires ext4/xfs for local filesystem") 4807*da0073e9SAndroid Build Coastguard Worker src1 = torch.randn(1024, device="cuda") 4808*da0073e9SAndroid Build Coastguard Worker src2 = torch.randn(2, 1024, device="cuda") 4809*da0073e9SAndroid Build Coastguard Worker torch.cuda.gds._gds_register_buffer(src1.untyped_storage()) 4810*da0073e9SAndroid Build Coastguard Worker torch.cuda.gds._gds_register_buffer(src2.untyped_storage()) 4811*da0073e9SAndroid Build Coastguard Worker dest1 = torch.empty(1024, device="cuda") 4812*da0073e9SAndroid Build Coastguard Worker dest2 = torch.empty(2, 1024, device="cuda") 4813*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as f: 4814*da0073e9SAndroid Build Coastguard Worker file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR) 4815*da0073e9SAndroid Build Coastguard Worker file.save_storage(src1.untyped_storage(), offset=0) 4816*da0073e9SAndroid Build Coastguard Worker file.save_storage(src2.untyped_storage(), offset=src1.nbytes) 4817*da0073e9SAndroid Build Coastguard Worker file.load_storage(dest1.untyped_storage(), offset=0) 4818*da0073e9SAndroid Build Coastguard Worker file.load_storage(dest2.untyped_storage(), offset=src1.nbytes) 4819*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src1, dest1) 4820*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src2, dest2) 4821*da0073e9SAndroid Build Coastguard Worker torch.cuda.gds._gds_deregister_buffer(src1.untyped_storage()) 4822*da0073e9SAndroid Build Coastguard Worker torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage()) 4823*da0073e9SAndroid Build Coastguard Worker 4824*da0073e9SAndroid Build Coastguard Worker 4825*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") 4826*da0073e9SAndroid Build Coastguard Workerclass TestCudaAutocast(TestAutocast): 4827*da0073e9SAndroid Build Coastguard Worker def setUp(self): 4828*da0073e9SAndroid Build Coastguard Worker super().setUp() 4829*da0073e9SAndroid Build Coastguard Worker self.autocast_lists = AutocastTestLists(torch.device("cuda:0")) 4830*da0073e9SAndroid Build Coastguard Worker 4831*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 4832*da0073e9SAndroid Build Coastguard Worker del self.autocast_lists 4833*da0073e9SAndroid Build Coastguard Worker super().tearDown() 4834*da0073e9SAndroid Build Coastguard Worker 4835*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4836*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_fp16(self): 4837*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4838*da0073e9SAndroid Build Coastguard Worker for op_with_args in self.autocast_lists.torch_fp16: 4839*da0073e9SAndroid Build Coastguard Worker skip_test = False 4840*da0073e9SAndroid Build Coastguard Worker op, args = op_with_args[0], op_with_args[1] 4841*da0073e9SAndroid Build Coastguard Worker if len(op_with_args) == 3: 4842*da0073e9SAndroid Build Coastguard Worker skip_test = op_with_args[2] # TEST_WITH_ROCM 4843*da0073e9SAndroid Build Coastguard Worker if not skip_test: 4844*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4845*da0073e9SAndroid Build Coastguard Worker op, args, torch.float16, device="cuda", amp_dtype=torch.float16 4846*da0073e9SAndroid Build Coastguard Worker ) 4847*da0073e9SAndroid Build Coastguard Worker 4848*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4849*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_bf16(self): 4850*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4851*da0073e9SAndroid Build Coastguard Worker for op_with_args in self.autocast_lists.torch_fp16: 4852*da0073e9SAndroid Build Coastguard Worker skip_test = False 4853*da0073e9SAndroid Build Coastguard Worker op, args = op_with_args[0], op_with_args[1] 4854*da0073e9SAndroid Build Coastguard Worker if len(op_with_args) == 3: 4855*da0073e9SAndroid Build Coastguard Worker skip_test = op_with_args[2] # TEST_WITH_ROCM 4856*da0073e9SAndroid Build Coastguard Worker should_error_from_cudnn = "cudnn" in op and ( 4857*da0073e9SAndroid Build Coastguard Worker "TORCH_CUDNN_V8_API_DISABLED" in os.environ 4858*da0073e9SAndroid Build Coastguard Worker and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"]) 4859*da0073e9SAndroid Build Coastguard Worker or torch.cuda.get_device_capability() < (8, 0) 4860*da0073e9SAndroid Build Coastguard Worker ) 4861*da0073e9SAndroid Build Coastguard Worker should_error_from_not_implemented = should_error_from_cudnn 4862*da0073e9SAndroid Build Coastguard Worker if not skip_test: 4863*da0073e9SAndroid Build Coastguard Worker if should_error_from_not_implemented: 4864*da0073e9SAndroid Build Coastguard Worker with self.assertRaises( 4865*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4866*da0073e9SAndroid Build Coastguard Worker msg=str(op) + " should not be supported for bfloat16!", 4867*da0073e9SAndroid Build Coastguard Worker ): 4868*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4869*da0073e9SAndroid Build Coastguard Worker op, args, torch.bfloat16, device="cuda" 4870*da0073e9SAndroid Build Coastguard Worker ) 4871*da0073e9SAndroid Build Coastguard Worker else: 4872*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_bf16_supported(): 4873*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4874*da0073e9SAndroid Build Coastguard Worker op, args, torch.bfloat16, device="cuda" 4875*da0073e9SAndroid Build Coastguard Worker ) 4876*da0073e9SAndroid Build Coastguard Worker else: 4877*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4878*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Device does not support bfloat16" 4879*da0073e9SAndroid Build Coastguard Worker ): 4880*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4881*da0073e9SAndroid Build Coastguard Worker op, args, torch.bfloat16, device="cuda" 4882*da0073e9SAndroid Build Coastguard Worker ) 4883*da0073e9SAndroid Build Coastguard Worker 4884*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4885*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_fp32(self): 4886*da0073e9SAndroid Build Coastguard Worker for op_with_args in self.autocast_lists.torch_fp32: 4887*da0073e9SAndroid Build Coastguard Worker op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) 4888*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4889*da0073e9SAndroid Build Coastguard Worker op, 4890*da0073e9SAndroid Build Coastguard Worker args, 4891*da0073e9SAndroid Build Coastguard Worker torch.float32, 4892*da0073e9SAndroid Build Coastguard Worker device="cuda", 4893*da0073e9SAndroid Build Coastguard Worker add_kwargs=maybe_kwargs, 4894*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 4895*da0073e9SAndroid Build Coastguard Worker ) 4896*da0073e9SAndroid Build Coastguard Worker 4897*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4898*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_need_autocast_promote(self): 4899*da0073e9SAndroid Build Coastguard Worker for op, args in self.autocast_lists.torch_need_autocast_promote: 4900*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4901*da0073e9SAndroid Build Coastguard Worker op, args, torch.float32, device="cuda", amp_dtype=torch.float16 4902*da0073e9SAndroid Build Coastguard Worker ) 4903*da0073e9SAndroid Build Coastguard Worker 4904*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4905*da0073e9SAndroid Build Coastguard Worker def test_autocast_torch_expect_builtin_promote(self): 4906*da0073e9SAndroid Build Coastguard Worker for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: 4907*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4908*da0073e9SAndroid Build Coastguard Worker op, 4909*da0073e9SAndroid Build Coastguard Worker args, 4910*da0073e9SAndroid Build Coastguard Worker torch.float32, 4911*da0073e9SAndroid Build Coastguard Worker device="cuda", 4912*da0073e9SAndroid Build Coastguard Worker out_type=out_type, 4913*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 4914*da0073e9SAndroid Build Coastguard Worker ) 4915*da0073e9SAndroid Build Coastguard Worker 4916*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4917*da0073e9SAndroid Build Coastguard Worker def test_autocast_nn_fp16(self): 4918*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4919*da0073e9SAndroid Build Coastguard Worker for op, args in self.autocast_lists.nn_fp16: 4920*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4921*da0073e9SAndroid Build Coastguard Worker op, 4922*da0073e9SAndroid Build Coastguard Worker args, 4923*da0073e9SAndroid Build Coastguard Worker torch.float16, 4924*da0073e9SAndroid Build Coastguard Worker device="cuda", 4925*da0073e9SAndroid Build Coastguard Worker module=torch._C._nn, 4926*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 4927*da0073e9SAndroid Build Coastguard Worker ) 4928*da0073e9SAndroid Build Coastguard Worker 4929*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4930*da0073e9SAndroid Build Coastguard Worker def test_autocast_nn_bf16(self): 4931*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4932*da0073e9SAndroid Build Coastguard Worker for op, args in self.autocast_lists.nn_fp16: 4933*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_bf16_supported(): 4934*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4935*da0073e9SAndroid Build Coastguard Worker op, args, torch.bfloat16, device="cuda", module=torch._C._nn 4936*da0073e9SAndroid Build Coastguard Worker ) 4937*da0073e9SAndroid Build Coastguard Worker else: 4938*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4939*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Device does not support bfloat16" 4940*da0073e9SAndroid Build Coastguard Worker ): 4941*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4942*da0073e9SAndroid Build Coastguard Worker op, args, torch.bfloat16, device="cuda", module=torch._C._nn 4943*da0073e9SAndroid Build Coastguard Worker ) 4944*da0073e9SAndroid Build Coastguard Worker 4945*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4946*da0073e9SAndroid Build Coastguard Worker def test_autocast_nn_fp32(self): 4947*da0073e9SAndroid Build Coastguard Worker for op, args in self.autocast_lists.nn_fp32: 4948*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4949*da0073e9SAndroid Build Coastguard Worker op, 4950*da0073e9SAndroid Build Coastguard Worker args, 4951*da0073e9SAndroid Build Coastguard Worker torch.float32, 4952*da0073e9SAndroid Build Coastguard Worker device="cuda", 4953*da0073e9SAndroid Build Coastguard Worker module=torch._C._nn, 4954*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 4955*da0073e9SAndroid Build Coastguard Worker ) 4956*da0073e9SAndroid Build Coastguard Worker 4957*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4958*da0073e9SAndroid Build Coastguard Worker def test_autocast_linalg_fp16(self): 4959*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4960*da0073e9SAndroid Build Coastguard Worker for op, args in self.autocast_lists.linalg_fp16: 4961*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4962*da0073e9SAndroid Build Coastguard Worker op, 4963*da0073e9SAndroid Build Coastguard Worker args, 4964*da0073e9SAndroid Build Coastguard Worker torch.float16, 4965*da0073e9SAndroid Build Coastguard Worker device="cuda", 4966*da0073e9SAndroid Build Coastguard Worker module=torch._C._linalg, 4967*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 4968*da0073e9SAndroid Build Coastguard Worker ) 4969*da0073e9SAndroid Build Coastguard Worker 4970*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4971*da0073e9SAndroid Build Coastguard Worker def test_autocast_methods_fp16(self): 4972*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=True, deterministic=True): 4973*da0073e9SAndroid Build Coastguard Worker for op, args in self.autocast_lists.methods_fp16: 4974*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4975*da0073e9SAndroid Build Coastguard Worker op, 4976*da0073e9SAndroid Build Coastguard Worker args, 4977*da0073e9SAndroid Build Coastguard Worker torch.float16, 4978*da0073e9SAndroid Build Coastguard Worker device="cuda", 4979*da0073e9SAndroid Build Coastguard Worker module=None, 4980*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 4981*da0073e9SAndroid Build Coastguard Worker ) 4982*da0073e9SAndroid Build Coastguard Worker 4983*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4984*da0073e9SAndroid Build Coastguard Worker def test_autocast_methods_fp32(self): 4985*da0073e9SAndroid Build Coastguard Worker for op, args in self.autocast_lists.methods_fp32: 4986*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4987*da0073e9SAndroid Build Coastguard Worker op, 4988*da0073e9SAndroid Build Coastguard Worker args, 4989*da0073e9SAndroid Build Coastguard Worker torch.float32, 4990*da0073e9SAndroid Build Coastguard Worker device="cuda", 4991*da0073e9SAndroid Build Coastguard Worker module=None, 4992*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 4993*da0073e9SAndroid Build Coastguard Worker ) 4994*da0073e9SAndroid Build Coastguard Worker 4995*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 4996*da0073e9SAndroid Build Coastguard Worker def test_autocast_methods_expect_builtin_promote(self): 4997*da0073e9SAndroid Build Coastguard Worker for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote: 4998*da0073e9SAndroid Build Coastguard Worker self._run_autocast_outofplace( 4999*da0073e9SAndroid Build Coastguard Worker op, 5000*da0073e9SAndroid Build Coastguard Worker args, 5001*da0073e9SAndroid Build Coastguard Worker torch.float32, 5002*da0073e9SAndroid Build Coastguard Worker device="cuda", 5003*da0073e9SAndroid Build Coastguard Worker module=None, 5004*da0073e9SAndroid Build Coastguard Worker out_type=out_type, 5005*da0073e9SAndroid Build Coastguard Worker amp_dtype=torch.float16, 5006*da0073e9SAndroid Build Coastguard Worker ) 5007*da0073e9SAndroid Build Coastguard Worker 5008*da0073e9SAndroid Build Coastguard Worker def test_autocast_banned(self): 5009*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda"): 5010*da0073e9SAndroid Build Coastguard Worker for op, args, module in self.autocast_lists.banned: 5011*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 5012*da0073e9SAndroid Build Coastguard Worker getattr(module, op)(*args) 5013*da0073e9SAndroid Build Coastguard Worker 5014*da0073e9SAndroid Build Coastguard Worker def test_autocast_ignored_types(self): 5015*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda"): 5016*da0073e9SAndroid Build Coastguard Worker for ignore_type in (torch.double, torch.int32): 5017*da0073e9SAndroid Build Coastguard Worker a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") 5018*da0073e9SAndroid Build Coastguard Worker b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") 5019*da0073e9SAndroid Build Coastguard Worker c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0") 5020*da0073e9SAndroid Build Coastguard Worker 5021*da0073e9SAndroid Build Coastguard Worker # Tests if CastPolicy::fp16 ops ignore double and int 5022*da0073e9SAndroid Build Coastguard Worker # Currently, no ops belonging to this policy support integer inputs. 5023*da0073e9SAndroid Build Coastguard Worker if ignore_type is torch.double: 5024*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 5025*da0073e9SAndroid Build Coastguard Worker torch.mm(a_ignore, c_16) 5026*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda", enabled=False): 5027*da0073e9SAndroid Build Coastguard Worker type_no_autocast = torch.mm(a_ignore, b_ignore).dtype 5028*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5029*da0073e9SAndroid Build Coastguard Worker torch.mm(a_ignore, b_ignore).dtype is type_no_autocast 5030*da0073e9SAndroid Build Coastguard Worker ) 5031*da0073e9SAndroid Build Coastguard Worker 5032*da0073e9SAndroid Build Coastguard Worker # Tests if CastPolicy::fp32 ops ignore double and int 5033*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda", enabled=False): 5034*da0073e9SAndroid Build Coastguard Worker type_no_autocast = torch.pow(a_ignore, 2.0).dtype 5035*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast) 5036*da0073e9SAndroid Build Coastguard Worker 5037*da0073e9SAndroid Build Coastguard Worker # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int 5038*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda", enabled=False): 5039*da0073e9SAndroid Build Coastguard Worker type_no_autocast = torch.sum(a_ignore).dtype 5040*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast) 5041*da0073e9SAndroid Build Coastguard Worker 5042*da0073e9SAndroid Build Coastguard Worker # Tests if CastPolicy::fp32_append_dtype ops ignore double and int 5043*da0073e9SAndroid Build Coastguard Worker # Currently, no ops belonging to this policy support integer inputs. 5044*da0073e9SAndroid Build Coastguard Worker if ignore_type is torch.double: 5045*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda", enabled=False): 5046*da0073e9SAndroid Build Coastguard Worker type_no_autocast = torch.norm(a_ignore).dtype 5047*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast) 5048*da0073e9SAndroid Build Coastguard Worker 5049*da0073e9SAndroid Build Coastguard Worker def test_autocast_custom_enabled(self): 5050*da0073e9SAndroid Build Coastguard Worker class MyMM(torch.autograd.Function): 5051*da0073e9SAndroid Build Coastguard Worker @staticmethod 5052*da0073e9SAndroid Build Coastguard Worker @torch.amp.custom_fwd(device_type="cuda") 5053*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, b): 5054*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.dtype is torch.float32) 5055*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.dtype is torch.float32) 5056*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_autocast_enabled()) 5057*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(a, b) 5058*da0073e9SAndroid Build Coastguard Worker return a.mm(b) 5059*da0073e9SAndroid Build Coastguard Worker 5060*da0073e9SAndroid Build Coastguard Worker @staticmethod 5061*da0073e9SAndroid Build Coastguard Worker @torch.amp.custom_bwd(device_type="cuda") 5062*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 5063*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_autocast_enabled()) 5064*da0073e9SAndroid Build Coastguard Worker a, b = ctx.saved_tensors 5065*da0073e9SAndroid Build Coastguard Worker a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad) 5066*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype) 5067*da0073e9SAndroid Build Coastguard Worker return a_grad, b_grad 5068*da0073e9SAndroid Build Coastguard Worker 5069*da0073e9SAndroid Build Coastguard Worker mymm = MyMM.apply 5070*da0073e9SAndroid Build Coastguard Worker 5071*da0073e9SAndroid Build Coastguard Worker x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) 5072*da0073e9SAndroid Build Coastguard Worker y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) 5073*da0073e9SAndroid Build Coastguard Worker 5074*da0073e9SAndroid Build Coastguard Worker dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,) 5075*da0073e9SAndroid Build Coastguard Worker for dtype in dtypes: 5076*da0073e9SAndroid Build Coastguard Worker with torch.cuda.amp.autocast(dtype=dtype): 5077*da0073e9SAndroid Build Coastguard Worker output = mymm(x, y) 5078*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.dtype is dtype) 5079*da0073e9SAndroid Build Coastguard Worker loss = output.sum() 5080*da0073e9SAndroid Build Coastguard Worker loss.backward() 5081*da0073e9SAndroid Build Coastguard Worker 5082*da0073e9SAndroid Build Coastguard Worker def test_autocast_custom_cast_inputs(self): 5083*da0073e9SAndroid Build Coastguard Worker class MyMM(torch.autograd.Function): 5084*da0073e9SAndroid Build Coastguard Worker @staticmethod 5085*da0073e9SAndroid Build Coastguard Worker @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) 5086*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, container, expect_type): 5087*da0073e9SAndroid Build Coastguard Worker b = container[1][0] 5088*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.dtype is expect_type) 5089*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.dtype is expect_type) 5090*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_autocast_enabled()) 5091*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(a, b) 5092*da0073e9SAndroid Build Coastguard Worker return a.mm(b) 5093*da0073e9SAndroid Build Coastguard Worker 5094*da0073e9SAndroid Build Coastguard Worker @staticmethod 5095*da0073e9SAndroid Build Coastguard Worker @torch.amp.custom_bwd(device_type="cuda") 5096*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 5097*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_autocast_enabled()) 5098*da0073e9SAndroid Build Coastguard Worker a, b = ctx.saved_tensors 5099*da0073e9SAndroid Build Coastguard Worker return grad.mm(b.t()), None, None 5100*da0073e9SAndroid Build Coastguard Worker 5101*da0073e9SAndroid Build Coastguard Worker mymm = MyMM.apply 5102*da0073e9SAndroid Build Coastguard Worker 5103*da0073e9SAndroid Build Coastguard Worker x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True) 5104*da0073e9SAndroid Build Coastguard Worker # Puts one input tensor in a nested container. y's contained Tensor won't receive a gradient, 5105*da0073e9SAndroid Build Coastguard Worker # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments. 5106*da0073e9SAndroid Build Coastguard Worker # Sets requires_grad=False explicitly so we don't lie about expecting a gradient. 5107*da0073e9SAndroid Build Coastguard Worker y = ( 5108*da0073e9SAndroid Build Coastguard Worker 0, 5109*da0073e9SAndroid Build Coastguard Worker { 5110*da0073e9SAndroid Build Coastguard Worker 0: torch.randn( 5111*da0073e9SAndroid Build Coastguard Worker (8, 8), device="cuda", dtype=torch.float16, requires_grad=False 5112*da0073e9SAndroid Build Coastguard Worker ) 5113*da0073e9SAndroid Build Coastguard Worker }, 5114*da0073e9SAndroid Build Coastguard Worker ) 5115*da0073e9SAndroid Build Coastguard Worker 5116*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda"): 5117*da0073e9SAndroid Build Coastguard Worker output = mymm(x, y, torch.float32) 5118*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.dtype is torch.float32) 5119*da0073e9SAndroid Build Coastguard Worker loss = output.sum() 5120*da0073e9SAndroid Build Coastguard Worker loss.backward() 5121*da0073e9SAndroid Build Coastguard Worker 5122*da0073e9SAndroid Build Coastguard Worker # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region. 5123*da0073e9SAndroid Build Coastguard Worker output = mymm(x, y, torch.float16) 5124*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.dtype is torch.float16) 5125*da0073e9SAndroid Build Coastguard Worker loss = output.sum() 5126*da0073e9SAndroid Build Coastguard Worker loss.backward() 5127*da0073e9SAndroid Build Coastguard Worker 5128*da0073e9SAndroid Build Coastguard Worker def test_autocast_custom_deprecated_warning(self): 5129*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5130*da0073e9SAndroid Build Coastguard Worker 5131*da0073e9SAndroid Build Coastguard Worker class MyMM(torch.autograd.Function): 5132*da0073e9SAndroid Build Coastguard Worker @staticmethod 5133*da0073e9SAndroid Build Coastguard Worker @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 5134*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y): 5135*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x, y) 5136*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_autocast_enabled()) 5137*da0073e9SAndroid Build Coastguard Worker return x + y 5138*da0073e9SAndroid Build Coastguard Worker 5139*da0073e9SAndroid Build Coastguard Worker @staticmethod 5140*da0073e9SAndroid Build Coastguard Worker @torch.cuda.amp.custom_bwd 5141*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 5142*da0073e9SAndroid Build Coastguard Worker _, _ = ctx.saved_tensors 5143*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_autocast_enabled()) 5144*da0073e9SAndroid Build Coastguard Worker return grad, grad 5145*da0073e9SAndroid Build Coastguard Worker 5146*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 5147*da0073e9SAndroid Build Coastguard Worker str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated." 5148*da0073e9SAndroid Build Coastguard Worker ) 5149*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 5150*da0073e9SAndroid Build Coastguard Worker str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated." 5151*da0073e9SAndroid Build Coastguard Worker ) 5152*da0073e9SAndroid Build Coastguard Worker 5153*da0073e9SAndroid Build Coastguard Worker mymm = MyMM.apply 5154*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, requires_grad=True) 5155*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 3, requires_grad=True) 5156*da0073e9SAndroid Build Coastguard Worker with torch.amp.autocast("cuda"): 5157*da0073e9SAndroid Build Coastguard Worker output = mymm(x, y) 5158*da0073e9SAndroid Build Coastguard Worker loss = output.sum() 5159*da0073e9SAndroid Build Coastguard Worker loss.backward() 5160*da0073e9SAndroid Build Coastguard Worker 5161*da0073e9SAndroid Build Coastguard Worker def test_autocast_cat_jit(self): 5162*da0073e9SAndroid Build Coastguard Worker # Reported at https://github.com/pytorch/pytorch/issues/38958 5163*da0073e9SAndroid Build Coastguard Worker 5164*da0073e9SAndroid Build Coastguard Worker class Model(torch.nn.Module): 5165*da0073e9SAndroid Build Coastguard Worker def forward(self): 5166*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1) 5167*da0073e9SAndroid Build Coastguard Worker b = torch.randn(1) 5168*da0073e9SAndroid Build Coastguard Worker c = torch.cat((a, b), 0) 5169*da0073e9SAndroid Build Coastguard Worker d = torch.stack([c, c], 0) 5170*da0073e9SAndroid Build Coastguard Worker return d 5171*da0073e9SAndroid Build Coastguard Worker 5172*da0073e9SAndroid Build Coastguard Worker # The JIT here doesn't really matter, we just need to call 5173*da0073e9SAndroid Build Coastguard Worker # cat via the boxed API 5174*da0073e9SAndroid Build Coastguard Worker model = Model() 5175*da0073e9SAndroid Build Coastguard Worker model_jit_script = torch.jit.script(model) 5176*da0073e9SAndroid Build Coastguard Worker 5177*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda", enabled=True): 5178*da0073e9SAndroid Build Coastguard Worker model() 5179*da0073e9SAndroid Build Coastguard Worker model_jit_script() 5180*da0073e9SAndroid Build Coastguard Worker 5181*da0073e9SAndroid Build Coastguard Worker # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened) 5182*da0073e9SAndroid Build Coastguard Worker # so they get a dedicated test. 5183*da0073e9SAndroid Build Coastguard Worker # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100). 5184*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") 5185*da0073e9SAndroid Build Coastguard Worker def test_autocast_rnn(self): 5186*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=True, deterministic=True): 5187*da0073e9SAndroid Build Coastguard Worker # seq, batch, features, hidden size 5188*da0073e9SAndroid Build Coastguard Worker clses = ("RNN", "GRU", "LSTM") 5189*da0073e9SAndroid Build Coastguard Worker T, B, F, H = 3, 4, 5, 6 5190*da0073e9SAndroid Build Coastguard Worker dtypes = (torch.float16, torch.float32) 5191*da0073e9SAndroid Build Coastguard Worker input_layouts = ("seq_first", "batch_first", "packed") 5192*da0073e9SAndroid Build Coastguard Worker 5193*da0073e9SAndroid Build Coastguard Worker for ( 5194*da0073e9SAndroid Build Coastguard Worker cls, 5195*da0073e9SAndroid Build Coastguard Worker num_layers, 5196*da0073e9SAndroid Build Coastguard Worker bias, 5197*da0073e9SAndroid Build Coastguard Worker input_layout, 5198*da0073e9SAndroid Build Coastguard Worker bidirectional, 5199*da0073e9SAndroid Build Coastguard Worker try_nonpreflattened_weights, 5200*da0073e9SAndroid Build Coastguard Worker input_dtype, 5201*da0073e9SAndroid Build Coastguard Worker hidden_dtype, 5202*da0073e9SAndroid Build Coastguard Worker weight_dtype, 5203*da0073e9SAndroid Build Coastguard Worker ) in product( 5204*da0073e9SAndroid Build Coastguard Worker clses, 5205*da0073e9SAndroid Build Coastguard Worker (1, 2), 5206*da0073e9SAndroid Build Coastguard Worker (True, False), 5207*da0073e9SAndroid Build Coastguard Worker input_layouts, 5208*da0073e9SAndroid Build Coastguard Worker (True, False), 5209*da0073e9SAndroid Build Coastguard Worker (True, False), 5210*da0073e9SAndroid Build Coastguard Worker dtypes, 5211*da0073e9SAndroid Build Coastguard Worker dtypes, 5212*da0073e9SAndroid Build Coastguard Worker dtypes, 5213*da0073e9SAndroid Build Coastguard Worker ): 5214*da0073e9SAndroid Build Coastguard Worker if input_layout == "seq_first": 5215*da0073e9SAndroid Build Coastguard Worker batch_first = False 5216*da0073e9SAndroid Build Coastguard Worker x = torch.randn((T, B, F), device="cuda", dtype=input_dtype) 5217*da0073e9SAndroid Build Coastguard Worker elif input_layout == "batch_first": 5218*da0073e9SAndroid Build Coastguard Worker batch_first = True 5219*da0073e9SAndroid Build Coastguard Worker x = torch.randn((B, T, F), device="cuda", dtype=input_dtype) 5220*da0073e9SAndroid Build Coastguard Worker elif input_layout == "packed": 5221*da0073e9SAndroid Build Coastguard Worker batch_first = False 5222*da0073e9SAndroid Build Coastguard Worker x = torch.nn.utils.rnn.pack_padded_sequence( 5223*da0073e9SAndroid Build Coastguard Worker torch.randn((T, B, F), device="cuda", dtype=input_dtype), 5224*da0073e9SAndroid Build Coastguard Worker lengths=(3, 2, 1, 3), 5225*da0073e9SAndroid Build Coastguard Worker enforce_sorted=False, 5226*da0073e9SAndroid Build Coastguard Worker ) 5227*da0073e9SAndroid Build Coastguard Worker 5228*da0073e9SAndroid Build Coastguard Worker rnn = ( 5229*da0073e9SAndroid Build Coastguard Worker getattr(torch.nn, cls)( 5230*da0073e9SAndroid Build Coastguard Worker F, 5231*da0073e9SAndroid Build Coastguard Worker H, 5232*da0073e9SAndroid Build Coastguard Worker num_layers=num_layers, 5233*da0073e9SAndroid Build Coastguard Worker bidirectional=bidirectional, 5234*da0073e9SAndroid Build Coastguard Worker bias=bias, 5235*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first, 5236*da0073e9SAndroid Build Coastguard Worker ) 5237*da0073e9SAndroid Build Coastguard Worker .cuda() 5238*da0073e9SAndroid Build Coastguard Worker .to(dtype=weight_dtype) 5239*da0073e9SAndroid Build Coastguard Worker ) 5240*da0073e9SAndroid Build Coastguard Worker 5241*da0073e9SAndroid Build Coastguard Worker if try_nonpreflattened_weights: 5242*da0073e9SAndroid Build Coastguard Worker for p in rnn.parameters(): 5243*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 5244*da0073e9SAndroid Build Coastguard Worker p.set_(p.clone()) 5245*da0073e9SAndroid Build Coastguard Worker 5246*da0073e9SAndroid Build Coastguard Worker h = torch.randn( 5247*da0073e9SAndroid Build Coastguard Worker (num_layers * (2 if bidirectional else 1), B, H), 5248*da0073e9SAndroid Build Coastguard Worker device="cuda", 5249*da0073e9SAndroid Build Coastguard Worker dtype=hidden_dtype, 5250*da0073e9SAndroid Build Coastguard Worker ) 5251*da0073e9SAndroid Build Coastguard Worker if cls == "LSTM": 5252*da0073e9SAndroid Build Coastguard Worker c = torch.randn( 5253*da0073e9SAndroid Build Coastguard Worker (num_layers * (2 if bidirectional else 1), B, H), 5254*da0073e9SAndroid Build Coastguard Worker device="cuda", 5255*da0073e9SAndroid Build Coastguard Worker dtype=hidden_dtype, 5256*da0073e9SAndroid Build Coastguard Worker ) 5257*da0073e9SAndroid Build Coastguard Worker h = (h, c) 5258*da0073e9SAndroid Build Coastguard Worker 5259*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda"): 5260*da0073e9SAndroid Build Coastguard Worker out, h_out = rnn(x, h) 5261*da0073e9SAndroid Build Coastguard Worker out = out.data if input_layout == "packed" else out 5262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, torch.float16) 5263*da0073e9SAndroid Build Coastguard Worker # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed. This check can't guarantee 5264*da0073e9SAndroid Build Coastguard Worker # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has 5265*da0073e9SAndroid Build Coastguard Worker # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed. 5266*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5267*da0073e9SAndroid Build Coastguard Worker out.grad_fn.name(), 5268*da0073e9SAndroid Build Coastguard Worker "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0", 5269*da0073e9SAndroid Build Coastguard Worker ) 5270*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 5271*da0073e9SAndroid Build Coastguard Worker grads = [p.grad.clone() for p in rnn.parameters()] 5272*da0073e9SAndroid Build Coastguard Worker 5273*da0073e9SAndroid Build Coastguard Worker rnn.zero_grad() 5274*da0073e9SAndroid Build Coastguard Worker 5275*da0073e9SAndroid Build Coastguard Worker if cls == "LSTM": 5276*da0073e9SAndroid Build Coastguard Worker out_control, h_out_control = rnn.to(dtype=torch.float16)( 5277*da0073e9SAndroid Build Coastguard Worker x.half(), (h[0].half(), h[1].half()) 5278*da0073e9SAndroid Build Coastguard Worker ) 5279*da0073e9SAndroid Build Coastguard Worker else: 5280*da0073e9SAndroid Build Coastguard Worker out_control, h_out_control = rnn.to(dtype=torch.float16)( 5281*da0073e9SAndroid Build Coastguard Worker x.half(), h.half() 5282*da0073e9SAndroid Build Coastguard Worker ) 5283*da0073e9SAndroid Build Coastguard Worker out_control = ( 5284*da0073e9SAndroid Build Coastguard Worker out_control.data if input_layout == "packed" else out_control 5285*da0073e9SAndroid Build Coastguard Worker ) 5286*da0073e9SAndroid Build Coastguard Worker out_control.sum().backward() 5287*da0073e9SAndroid Build Coastguard Worker grads_control = [p.grad.clone() for p in rnn.parameters()] 5288*da0073e9SAndroid Build Coastguard Worker 5289*da0073e9SAndroid Build Coastguard Worker # Compares with default tolerances, even for FP16 execution. Barring nondeterminism, 5290*da0073e9SAndroid Build Coastguard Worker # autocast and control results should be bitwise identical. 5291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_control) 5292*da0073e9SAndroid Build Coastguard Worker 5293*da0073e9SAndroid Build Coastguard Worker if cls == "LSTM": 5294*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5295*da0073e9SAndroid Build Coastguard Worker h_out[0].dtype is torch.float16 5296*da0073e9SAndroid Build Coastguard Worker and h_out[1].dtype is torch.float16 5297*da0073e9SAndroid Build Coastguard Worker ) 5298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(h_out[0], h_out_control[0]) 5299*da0073e9SAndroid Build Coastguard Worker self.assertEqual(h_out[1], h_out_control[1]) 5300*da0073e9SAndroid Build Coastguard Worker else: 5301*da0073e9SAndroid Build Coastguard Worker self.assertEqual(h_out.dtype, torch.float16) 5302*da0073e9SAndroid Build Coastguard Worker self.assertEqual(h_out, h_out_control) 5303*da0073e9SAndroid Build Coastguard Worker for grad, grad_control in zip(grads, grads_control): 5304*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad.half(), grad_control) 5305*da0073e9SAndroid Build Coastguard Worker 5306*da0073e9SAndroid Build Coastguard Worker def test_autocast_cache_leak(self): 5307*da0073e9SAndroid Build Coastguard Worker # Reported at https://github.com/pytorch/pytorch/issues/48049 5308*da0073e9SAndroid Build Coastguard Worker # Test is used to check, if autocast recaches the same parameters 5309*da0073e9SAndroid Build Coastguard Worker # when executed in a `torch.no_grad()` block. 5310*da0073e9SAndroid Build Coastguard Worker 5311*da0073e9SAndroid Build Coastguard Worker linear = torch.nn.Linear(10, 10).to("cuda") 5312*da0073e9SAndroid Build Coastguard Worker data = torch.randn(1, 10, device="cuda") 5313*da0073e9SAndroid Build Coastguard Worker 5314*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda"): 5315*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 5316*da0073e9SAndroid Build Coastguard Worker out = linear(data) 5317*da0073e9SAndroid Build Coastguard Worker first_iter_mem = torch.cuda.memory_allocated() 5318*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 5319*da0073e9SAndroid Build Coastguard Worker out = linear(data) 5320*da0073e9SAndroid Build Coastguard Worker self.assertTrue(first_iter_mem == torch.cuda.memory_allocated()) 5321*da0073e9SAndroid Build Coastguard Worker 5322*da0073e9SAndroid Build Coastguard Worker def test_autocast_checkpointing(self): 5323*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential( 5324*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8) 5325*da0073e9SAndroid Build Coastguard Worker ).cuda() 5326*da0073e9SAndroid Build Coastguard Worker input = torch.rand( 5327*da0073e9SAndroid Build Coastguard Worker (8, 8), device="cuda", dtype=torch.float16, requires_grad=True 5328*da0073e9SAndroid Build Coastguard Worker ) 5329*da0073e9SAndroid Build Coastguard Worker for reentrant in (True, False): 5330*da0073e9SAndroid Build Coastguard Worker with torch.autocast("cuda"): 5331*da0073e9SAndroid Build Coastguard Worker output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant) 5332*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.requires_grad) 5333*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.dtype is torch.float16) 5334*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 5335*da0073e9SAndroid Build Coastguard Worker 5336*da0073e9SAndroid Build Coastguard Worker def test_cuda_autocast_deprecated_warning(self): 5337*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 5338*da0073e9SAndroid Build Coastguard Worker FutureWarning, 5339*da0073e9SAndroid Build Coastguard Worker r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.", 5340*da0073e9SAndroid Build Coastguard Worker ): 5341*da0073e9SAndroid Build Coastguard Worker with torch.cuda.amp.autocast(): 5342*da0073e9SAndroid Build Coastguard Worker _ = torch.ones(10) 5343*da0073e9SAndroid Build Coastguard Worker 5344*da0073e9SAndroid Build Coastguard Worker 5345*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCuda) 5346*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCudaMallocAsync) 5347*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestCudaOptims, globals()) 5348*da0073e9SAndroid Build Coastguard Worker 5349*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 5350*da0073e9SAndroid Build Coastguard Worker run_tests() 5351