1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: cuda"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport contextlib 5*da0073e9SAndroid Build Coastguard Workerimport ctypes 6*da0073e9SAndroid Build Coastguard Workerimport gc 7*da0073e9SAndroid Build Coastguard Workerimport io 8*da0073e9SAndroid Build Coastguard Workerimport queue 9*da0073e9SAndroid Build Coastguard Workerimport sys 10*da0073e9SAndroid Build Coastguard Workerimport tempfile 11*da0073e9SAndroid Build Coastguard Workerimport threading 12*da0073e9SAndroid Build Coastguard Workerimport unittest 13*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain, repeat 14*da0073e9SAndroid Build Coastguard Workerfrom typing import NamedTuple, Union 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerimport torch 17*da0073e9SAndroid Build Coastguard Workerimport torch.cuda.comm as comm 18*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel import scatter_gather 19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import ( 20*da0073e9SAndroid Build Coastguard Worker _create_scaling_case, 21*da0073e9SAndroid Build Coastguard Worker _create_scaling_models_optimizers, 22*da0073e9SAndroid Build Coastguard Worker TEST_MULTIGPU, 23*da0073e9SAndroid Build Coastguard Worker) 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 25*da0073e9SAndroid Build Coastguard Worker get_cycles_per_ms, 26*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 27*da0073e9SAndroid Build Coastguard Worker IS_JETSON, 28*da0073e9SAndroid Build Coastguard Worker IS_REMOTE_GPU, 29*da0073e9SAndroid Build Coastguard Worker IS_SANDCASTLE, 30*da0073e9SAndroid Build Coastguard Worker NoTest, 31*da0073e9SAndroid Build Coastguard Worker run_tests, 32*da0073e9SAndroid Build Coastguard Worker serialTest, 33*da0073e9SAndroid Build Coastguard Worker skipCUDANonDefaultStreamIf, 34*da0073e9SAndroid Build Coastguard Worker skipIfRocm, 35*da0073e9SAndroid Build Coastguard Worker TEST_CUDA, 36*da0073e9SAndroid Build Coastguard Worker TestCase, 37*da0073e9SAndroid Build Coastguard Worker) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard WorkerTEST_CUDAMALLOCASYNC = TEST_CUDA and ( 41*da0073e9SAndroid Build Coastguard Worker torch.cuda.get_allocator_backend() == "cudaMallocAsync" 42*da0073e9SAndroid Build Coastguard Worker) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Workerif not TEST_CUDA: 45*da0073e9SAndroid Build Coastguard Worker print("CUDA not available, skipping tests", file=sys.stderr) 46*da0073e9SAndroid Build Coastguard Worker TestCase = NoTest # noqa: F811 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Workerclass TestCudaMultiGPU(TestCase): 50*da0073e9SAndroid Build Coastguard Worker FIFTY_MIL_CYCLES = 50000000 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def _check_memory_stat_consistency(self): 53*da0073e9SAndroid Build Coastguard Worker snapshot = torch.cuda.memory_snapshot() 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker expected_each_device = collections.defaultdict( 56*da0073e9SAndroid Build Coastguard Worker lambda: collections.defaultdict(int) 57*da0073e9SAndroid Build Coastguard Worker ) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker for segment in snapshot: 60*da0073e9SAndroid Build Coastguard Worker expandable = segment["is_expandable"] 61*da0073e9SAndroid Build Coastguard Worker expected = expected_each_device[segment["device"]] 62*da0073e9SAndroid Build Coastguard Worker pool_str = segment["segment_type"] + "_pool" 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker if not expandable: 65*da0073e9SAndroid Build Coastguard Worker expected["segment.all.current"] += 1 66*da0073e9SAndroid Build Coastguard Worker expected["segment." + pool_str + ".current"] += 1 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker expected["allocated_bytes.all.current"] += segment["allocated_size"] 69*da0073e9SAndroid Build Coastguard Worker expected["allocated_bytes." + pool_str + ".current"] += segment[ 70*da0073e9SAndroid Build Coastguard Worker "allocated_size" 71*da0073e9SAndroid Build Coastguard Worker ] 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker expected["reserved_bytes.all.current"] += segment["total_size"] 74*da0073e9SAndroid Build Coastguard Worker expected["reserved_bytes." + pool_str + ".current"] += segment["total_size"] 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker expected["active_bytes.all.current"] += segment["active_size"] 77*da0073e9SAndroid Build Coastguard Worker expected["active_bytes." + pool_str + ".current"] += segment["active_size"] 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker expected["requested_bytes.all.current"] += segment["requested_size"] 80*da0073e9SAndroid Build Coastguard Worker expected["requested_bytes." + pool_str + ".current"] += segment[ 81*da0073e9SAndroid Build Coastguard Worker "requested_size" 82*da0073e9SAndroid Build Coastguard Worker ] 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker sum_requested = 0 85*da0073e9SAndroid Build Coastguard Worker is_split = len(segment["blocks"]) > 1 86*da0073e9SAndroid Build Coastguard Worker for block in segment["blocks"]: 87*da0073e9SAndroid Build Coastguard Worker if block["state"] == "active_allocated": 88*da0073e9SAndroid Build Coastguard Worker expected["allocation.all.current"] += 1 89*da0073e9SAndroid Build Coastguard Worker expected["allocation." + pool_str + ".current"] += 1 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker if block["state"].startswith("active_"): 92*da0073e9SAndroid Build Coastguard Worker sum_requested += block["requested_size"] 93*da0073e9SAndroid Build Coastguard Worker expected["active.all.current"] += 1 94*da0073e9SAndroid Build Coastguard Worker expected["active." + pool_str + ".current"] += 1 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker if block["state"] == "inactive" and is_split and not expandable: 97*da0073e9SAndroid Build Coastguard Worker expected["inactive_split.all.current"] += 1 98*da0073e9SAndroid Build Coastguard Worker expected["inactive_split." + pool_str + ".current"] += 1 99*da0073e9SAndroid Build Coastguard Worker expected["inactive_split_bytes.all.current"] += block["size"] 100*da0073e9SAndroid Build Coastguard Worker expected["inactive_split_bytes." + pool_str + ".current"] += block[ 101*da0073e9SAndroid Build Coastguard Worker "size" 102*da0073e9SAndroid Build Coastguard Worker ] 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum_requested, segment["requested_size"]) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker for device, expected in expected_each_device.items(): 107*da0073e9SAndroid Build Coastguard Worker stats = torch.cuda.memory_stats(device) 108*da0073e9SAndroid Build Coastguard Worker for k, v in expected.items(): 109*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v, stats[k]) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker def test_cuda_synchronize(self): 112*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 113*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize("cuda") 114*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize("cuda:0") 115*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(0) 116*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(torch.device("cuda:0")) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker if TEST_MULTIGPU: 119*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize("cuda:1") 120*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(1) 121*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(torch.device("cuda:1")) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"): 124*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(torch.device("cpu")) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"): 127*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize("cpu") 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker @staticmethod 130*da0073e9SAndroid Build Coastguard Worker def _test_memory_stats_generator(self, device=None, N=35): 131*da0073e9SAndroid Build Coastguard Worker if device is None: 132*da0073e9SAndroid Build Coastguard Worker device = torch.cuda.current_device() 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker m0 = torch.cuda.memory_allocated(device) 135*da0073e9SAndroid Build Coastguard Worker last_m_arr = [torch.cuda.memory_allocated(device)] 136*da0073e9SAndroid Build Coastguard Worker max_m_arr = [torch.cuda.max_memory_allocated(device)] 137*da0073e9SAndroid Build Coastguard Worker last_r_arr = [torch.cuda.memory_reserved(device)] 138*da0073e9SAndroid Build Coastguard Worker max_r_arr = [torch.cuda.max_memory_reserved(device)] 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker def alloc(*size): 141*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(device): 142*da0073e9SAndroid Build Coastguard Worker # NOTE: do **not** use methods that can have additional 143*da0073e9SAndroid Build Coastguard Worker # memory overhead, e.g., inplace random sampling methods. 144*da0073e9SAndroid Build Coastguard Worker # they can leave some memory occupied even after being 145*da0073e9SAndroid Build Coastguard Worker # deallocated, e.g., initialized RNG state, causing some 146*da0073e9SAndroid Build Coastguard Worker # memory checks below to fail. 147*da0073e9SAndroid Build Coastguard Worker return torch.cuda.FloatTensor(*size) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker def assert_change(comp=1, empty_cache=False, reset_peak=False): 150*da0073e9SAndroid Build Coastguard Worker # comp > 0: increased 151*da0073e9SAndroid Build Coastguard Worker # comp = 0: equal 152*da0073e9SAndroid Build Coastguard Worker # comp < 0: decreased 153*da0073e9SAndroid Build Coastguard Worker new_m = torch.cuda.memory_allocated(device) 154*da0073e9SAndroid Build Coastguard Worker new_max_m = torch.cuda.max_memory_allocated(device) 155*da0073e9SAndroid Build Coastguard Worker if comp > 0: 156*da0073e9SAndroid Build Coastguard Worker self.assertGreater(new_m, last_m_arr[0]) 157*da0073e9SAndroid Build Coastguard Worker elif comp < 0: 158*da0073e9SAndroid Build Coastguard Worker self.assertLess(new_m, last_m_arr[0]) 159*da0073e9SAndroid Build Coastguard Worker else: 160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_m, last_m_arr[0]) 161*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(new_m, new_max_m) 162*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(new_max_m, max_m_arr[0]) 163*da0073e9SAndroid Build Coastguard Worker last_m_arr[0] = new_m 164*da0073e9SAndroid Build Coastguard Worker max_m_arr[0] = new_max_m 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker new_r = torch.cuda.memory_reserved(device) 167*da0073e9SAndroid Build Coastguard Worker new_max_r = torch.cuda.max_memory_reserved(device) 168*da0073e9SAndroid Build Coastguard Worker # emptying cache may happen (due to allocation or empty_cache), so 169*da0073e9SAndroid Build Coastguard Worker # we can't assert new_c >= last_c 170*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(new_r, new_max_r) 171*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(new_max_r, max_r_arr[0]) 172*da0073e9SAndroid Build Coastguard Worker last_r_arr[0] = new_r 173*da0073e9SAndroid Build Coastguard Worker max_r_arr[0] = new_max_r 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker stat_key_n_sync = "num_sync_all_streams" 176*da0073e9SAndroid Build Coastguard Worker stat_key_n_alloc = "num_device_alloc" 177*da0073e9SAndroid Build Coastguard Worker stat_key_n_free = "num_device_free" 178*da0073e9SAndroid Build Coastguard Worker if empty_cache: 179*da0073e9SAndroid Build Coastguard Worker num_sync_1 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1) 180*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(num_sync_1, 0) 181*da0073e9SAndroid Build Coastguard Worker num_alloc_1 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1) 182*da0073e9SAndroid Build Coastguard Worker # if current memory usage is greater than zero we must have 183*da0073e9SAndroid Build Coastguard Worker # allocated something 184*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(num_alloc_1, 0 if new_m == 0 else 1) 185*da0073e9SAndroid Build Coastguard Worker num_free_1 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1) 186*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(num_free_1, 0) 187*da0073e9SAndroid Build Coastguard Worker # empty_cache will enforce the call of release_cached_blocks 188*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 189*da0073e9SAndroid Build Coastguard Worker num_sync_2 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1) 190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_sync_1 + 1, num_sync_2) 191*da0073e9SAndroid Build Coastguard Worker num_alloc_2 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1) 192*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(num_alloc_2, num_alloc_1) 193*da0073e9SAndroid Build Coastguard Worker num_free_2 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1) 194*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(num_free_2, num_free_1) 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker new_r = torch.cuda.memory_reserved(device) 197*da0073e9SAndroid Build Coastguard Worker new_max_r = torch.cuda.max_memory_reserved(device) 198*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(new_r, last_r_arr[0]) 199*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(new_r, new_max_r) 200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_max_r, max_r_arr[0]) 201*da0073e9SAndroid Build Coastguard Worker last_r_arr[0] = new_r 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker if reset_peak: 204*da0073e9SAndroid Build Coastguard Worker torch.cuda.reset_peak_memory_stats(device) 205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.memory_allocated(device), last_m_arr[0]) 206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.max_memory_allocated(device), last_m_arr[0]) 207*da0073e9SAndroid Build Coastguard Worker max_m_arr[0] = last_m_arr[0] 208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.memory_reserved(device), last_r_arr[0]) 209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.max_memory_reserved(device), last_r_arr[0]) 210*da0073e9SAndroid Build Coastguard Worker max_r_arr[0] = last_r_arr[0] 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker assert_change(0) 213*da0073e9SAndroid Build Coastguard Worker assert_change(0, reset_peak=True) 214*da0073e9SAndroid Build Coastguard Worker assert_change(0, empty_cache=True) 215*da0073e9SAndroid Build Coastguard Worker assert_change(0, reset_peak=True) 216*da0073e9SAndroid Build Coastguard Worker assert_change(0) 217*da0073e9SAndroid Build Coastguard Worker yield 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker tensors1 = [alloc(1), alloc(10, 20), alloc(200, 300, 2000)] 220*da0073e9SAndroid Build Coastguard Worker m1 = torch.cuda.memory_allocated(device) 221*da0073e9SAndroid Build Coastguard Worker assert_change(1) 222*da0073e9SAndroid Build Coastguard Worker yield 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker tensors2 = [] 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker for i in range(1, int(N / 2) + 1): 227*da0073e9SAndroid Build Coastguard Worker # small ones 228*da0073e9SAndroid Build Coastguard Worker tensors2.append(alloc(i, i * 4)) 229*da0073e9SAndroid Build Coastguard Worker assert_change(1) 230*da0073e9SAndroid Build Coastguard Worker yield 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker for i in range(5, int(N / 2) + 5): 233*da0073e9SAndroid Build Coastguard Worker # large ones 234*da0073e9SAndroid Build Coastguard Worker tensors2.append(alloc(i, i * 7, i * 9, i * 11)) 235*da0073e9SAndroid Build Coastguard Worker assert_change(1, reset_peak=(i % 2 == 0)) 236*da0073e9SAndroid Build Coastguard Worker yield 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker tensors2.append(alloc(0, 0, 0)) 239*da0073e9SAndroid Build Coastguard Worker assert_change(0) 240*da0073e9SAndroid Build Coastguard Worker yield 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker permute = [] 243*da0073e9SAndroid Build Coastguard Worker for i in torch.randperm(len(tensors2)): 244*da0073e9SAndroid Build Coastguard Worker permute.append(tensors2[i]) 245*da0073e9SAndroid Build Coastguard Worker assert_change(0) 246*da0073e9SAndroid Build Coastguard Worker yield 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker del tensors2 249*da0073e9SAndroid Build Coastguard Worker assert_change(0) 250*da0073e9SAndroid Build Coastguard Worker yield 251*da0073e9SAndroid Build Coastguard Worker tensors2 = permute 252*da0073e9SAndroid Build Coastguard Worker assert_change(0) 253*da0073e9SAndroid Build Coastguard Worker yield 254*da0073e9SAndroid Build Coastguard Worker del permute 255*da0073e9SAndroid Build Coastguard Worker assert_change(0, reset_peak=True) 256*da0073e9SAndroid Build Coastguard Worker yield 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker for i in range(int(N / 2)): 259*da0073e9SAndroid Build Coastguard Worker x = tensors2[i].numel() 260*da0073e9SAndroid Build Coastguard Worker del tensors2[i] 261*da0073e9SAndroid Build Coastguard Worker assert_change(-x) # in case that tensors2[i] is empty 262*da0073e9SAndroid Build Coastguard Worker yield 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker for i in range(2, int(2 * N / 3) + 2): 265*da0073e9SAndroid Build Coastguard Worker tensors2.append(alloc(i, i * 3, i * 8)) 266*da0073e9SAndroid Build Coastguard Worker assert_change(1) 267*da0073e9SAndroid Build Coastguard Worker yield 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker del tensors2 270*da0073e9SAndroid Build Coastguard Worker assert_change(-1, reset_peak=True) 271*da0073e9SAndroid Build Coastguard Worker assert_change(0) 272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.memory_allocated(device), m1) 273*da0073e9SAndroid Build Coastguard Worker yield True 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker del tensors1 276*da0073e9SAndroid Build Coastguard Worker assert_change(-1, reset_peak=True) 277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.memory_allocated(device), m0) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker # test empty_cache and reset_peak 280*da0073e9SAndroid Build Coastguard Worker assert_change(0, empty_cache=True) 281*da0073e9SAndroid Build Coastguard Worker assert_change(0, reset_peak=True) 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled") 284*da0073e9SAndroid Build Coastguard Worker @serialTest() 285*da0073e9SAndroid Build Coastguard Worker def test_memory_stats(self): 286*da0073e9SAndroid Build Coastguard Worker gc.collect() 287*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 288*da0073e9SAndroid Build Coastguard Worker for _ in self._test_memory_stats_generator(self): 289*da0073e9SAndroid Build Coastguard Worker self._check_memory_stat_consistency() 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled") 292*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 293*da0073e9SAndroid Build Coastguard Worker def test_memory_stats_multigpu(self): 294*da0073e9SAndroid Build Coastguard Worker # advance a generator with a end flag 295*da0073e9SAndroid Build Coastguard Worker def advance(gen, end): 296*da0073e9SAndroid Build Coastguard Worker if not end: 297*da0073e9SAndroid Build Coastguard Worker try: 298*da0073e9SAndroid Build Coastguard Worker next(gen) 299*da0073e9SAndroid Build Coastguard Worker except StopIteration: 300*da0073e9SAndroid Build Coastguard Worker end = True 301*da0073e9SAndroid Build Coastguard Worker return end 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker # interlace 304*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 305*da0073e9SAndroid Build Coastguard Worker gen0 = self._test_memory_stats_generator(self, device="cuda:0", N=35) 306*da0073e9SAndroid Build Coastguard Worker gen1 = self._test_memory_stats_generator( 307*da0073e9SAndroid Build Coastguard Worker self, device=torch.device("cuda:1"), N=35 308*da0073e9SAndroid Build Coastguard Worker ) 309*da0073e9SAndroid Build Coastguard Worker end0 = end1 = False 310*da0073e9SAndroid Build Coastguard Worker while not (end0 and end1): 311*da0073e9SAndroid Build Coastguard Worker end0 = advance(gen0, end0) 312*da0073e9SAndroid Build Coastguard Worker end1 = advance(gen1, end1) 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker # semi-random order 315*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 316*da0073e9SAndroid Build Coastguard Worker gen0 = self._test_memory_stats_generator(self, device=0, N=35) 317*da0073e9SAndroid Build Coastguard Worker gen1 = self._test_memory_stats_generator( 318*da0073e9SAndroid Build Coastguard Worker self, device=torch.device("cuda:1"), N=35 319*da0073e9SAndroid Build Coastguard Worker ) 320*da0073e9SAndroid Build Coastguard Worker end0 = end1 = False 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker while not (end0 and end1): 323*da0073e9SAndroid Build Coastguard Worker end0 = advance(gen0, end0) 324*da0073e9SAndroid Build Coastguard Worker if not end0: 325*da0073e9SAndroid Build Coastguard Worker gen1_max_times = torch.LongTensor(1).random_(0, 3)[0] 326*da0073e9SAndroid Build Coastguard Worker else: 327*da0073e9SAndroid Build Coastguard Worker gen1_max_times = torch.inf 328*da0073e9SAndroid Build Coastguard Worker t = 0 329*da0073e9SAndroid Build Coastguard Worker while t < gen1_max_times and not end1: 330*da0073e9SAndroid Build Coastguard Worker end1 = advance(gen1, end1) 331*da0073e9SAndroid Build Coastguard Worker t += 1 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 334*da0073e9SAndroid Build Coastguard Worker def test_autogpu(self): 335*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5).cuda() 336*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5).cuda() 337*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.get_device(), 0) 338*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.get_device(), 0) 339*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 340*da0073e9SAndroid Build Coastguard Worker z = torch.randn(5, 5).cuda() 341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.get_device(), 1) 342*da0073e9SAndroid Build Coastguard Worker q = x.add(y) 343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q.get_device(), 0) 344*da0073e9SAndroid Build Coastguard Worker w = torch.randn(5, 5).cuda() 345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.get_device(), 1) 346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.cuda().get_device(), 1) 347*da0073e9SAndroid Build Coastguard Worker z = z.cuda() 348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.get_device(), 0) 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 351*da0073e9SAndroid Build Coastguard Worker def test_new(self): 352*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3).cuda() 353*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new([0, 1, 2]).get_device(), 0) 354*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new([0, 1, 2], device=1).get_device(), 1) 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new([0, 1, 2]).get_device(), 0) 358*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new([0, 1, 2], device=1).get_device(), 1) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 361*da0073e9SAndroid Build Coastguard Worker def test_copy_device(self): 362*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5).cuda() 363*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 364*da0073e9SAndroid Build Coastguard Worker y = x.cuda() 365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.get_device(), 1) 366*da0073e9SAndroid Build Coastguard Worker self.assertIs(y.cuda(), y) 367*da0073e9SAndroid Build Coastguard Worker z = y.cuda(0) 368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.get_device(), 0) 369*da0073e9SAndroid Build Coastguard Worker self.assertIs(z.cuda(0), z) 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 372*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 373*da0073e9SAndroid Build Coastguard Worker y = x.cuda() 374*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.get_device(), 1) 375*da0073e9SAndroid Build Coastguard Worker self.assertIs(y.cuda(), y) 376*da0073e9SAndroid Build Coastguard Worker z = y.cuda(0) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.get_device(), 0) 379*da0073e9SAndroid Build Coastguard Worker self.assertIs(z.cuda(0), z) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker def _test_copy_sync_current_stream(self, x, y): 382*da0073e9SAndroid Build Coastguard Worker x_plus_one = x + 1 383*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.Stream(device=x.device) 384*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream(device=y.device) 385*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.Stream(device=x.device) 386*da0073e9SAndroid Build Coastguard Worker s3 = torch.cuda.Stream(device=y.device) 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker # same dst stream different src streams 389*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s0): 390*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 391*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 392*da0073e9SAndroid Build Coastguard Worker y.copy_(x_plus_one) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s2), torch.cuda.stream(s1): 395*da0073e9SAndroid Build Coastguard Worker y.copy_(x) 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker s1.synchronize() 398*da0073e9SAndroid Build Coastguard Worker # The copy() is synchronized on the current streams of both src and dst. 399*da0073e9SAndroid Build Coastguard Worker # In the above test, the _sleep() op on s0 will not block the copy() on 400*da0073e9SAndroid Build Coastguard Worker # s2, but both copies are synchronized on s1 in the dst device. Hence, 401*da0073e9SAndroid Build Coastguard Worker # x is copied to y after x_plus_one is copied to y. If x and y are on 402*da0073e9SAndroid Build Coastguard Worker # the same device, both copy() ops are synchronized on s1. 403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x) 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker # same src stream different dst streams 406*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 407*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 408*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s0): 409*da0073e9SAndroid Build Coastguard Worker y.copy_(x_plus_one) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s3), torch.cuda.stream(s0): 412*da0073e9SAndroid Build Coastguard Worker y.copy_(x) 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker s0.synchronize() 415*da0073e9SAndroid Build Coastguard Worker # Similarly, both copy() ops are synchronized on s0. 416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x) 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 419*da0073e9SAndroid Build Coastguard Worker def test_copy_streams(self): 420*da0073e9SAndroid Build Coastguard Worker d0 = torch.device("cuda:0") 421*da0073e9SAndroid Build Coastguard Worker x0 = torch.zeros(5, 5, device=d0) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:1") 424*da0073e9SAndroid Build Coastguard Worker x1 = torch.zeros(5, 5, device=d1) 425*da0073e9SAndroid Build Coastguard Worker self._test_copy_sync_current_stream(x0, x1) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker x2 = torch.zeros(5, 5, device=d0) 428*da0073e9SAndroid Build Coastguard Worker self._test_copy_sync_current_stream(x0, x2) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 431*da0073e9SAndroid Build Coastguard Worker def test_cat_autogpu(self): 432*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4).cuda(1) 433*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4).cuda(1) 434*da0073e9SAndroid Build Coastguard Worker z = torch.cat([x, y], 0) 435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.get_device(), x.get_device()) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(torch.cuda.device_count() >= 10, "Loading a cuda:9 tensor") 438*da0073e9SAndroid Build Coastguard Worker def test_load_nonexistent_device(self): 439*da0073e9SAndroid Build Coastguard Worker # Setup: create a serialized file object with a 'cuda:9' restore location 440*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, device="cuda") 441*da0073e9SAndroid Build Coastguard Worker buf = io.BytesIO() 442*da0073e9SAndroid Build Coastguard Worker torch.save(tensor, buf) 443*da0073e9SAndroid Build Coastguard Worker # NB: this might not work in the future if serialization changes 444*da0073e9SAndroid Build Coastguard Worker buf = io.BytesIO(buf.getvalue().replace(b"cuda:0", b"cuda:9")) 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker msg = r"Attempting to deserialize object on CUDA device 9" 447*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 448*da0073e9SAndroid Build Coastguard Worker _ = torch.load(buf) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 451*da0073e9SAndroid Build Coastguard Worker def test_multigpu_serialization_remap(self): 452*da0073e9SAndroid Build Coastguard Worker x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)] 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker def gpu_remap(storage, location): 455*da0073e9SAndroid Build Coastguard Worker if location == "cuda:1": 456*da0073e9SAndroid Build Coastguard Worker return storage.cuda(0) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 459*da0073e9SAndroid Build Coastguard Worker torch.save(x, f) 460*da0073e9SAndroid Build Coastguard Worker f.seek(0) 461*da0073e9SAndroid Build Coastguard Worker x_copy = torch.load(f, map_location=gpu_remap) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker for original, copy in zip(x, x_copy): 464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy, original) 465*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(copy), type(original)) 466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy.get_device(), 0) 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 469*da0073e9SAndroid Build Coastguard Worker def test_multigpu_serialization_remap_dict(self): 470*da0073e9SAndroid Build Coastguard Worker x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)] 471*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 472*da0073e9SAndroid Build Coastguard Worker torch.save(x, f) 473*da0073e9SAndroid Build Coastguard Worker f.seek(0) 474*da0073e9SAndroid Build Coastguard Worker x_copy = torch.load(f, map_location={"cuda:1": "cuda:0"}) 475*da0073e9SAndroid Build Coastguard Worker for original, copy in zip(x, x_copy): 476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy, original) 477*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(copy), type(original)) 478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy.get_device(), 0) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 481*da0073e9SAndroid Build Coastguard Worker def test_multigpu_storage_clone(self): 482*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, device="cuda:1").storage() 483*da0073e9SAndroid Build Coastguard Worker y = x.clone() 484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.get_device(), y.get_device()) 485*da0073e9SAndroid Build Coastguard Worker for t in ["byte", "char", "short", "int", "long", "half", "double"]: 486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getattr(x, t)().get_device(), x.get_device()) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 489*da0073e9SAndroid Build Coastguard Worker def test_cuda_set_device(self): 490*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 491*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 492*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.cuda().get_device(), 1) 493*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device(0) 494*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.cuda().get_device(), 0) 495*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.cuda().get_device(), 1) 497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.cuda().get_device(), 0) 498*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device(1) 499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.cuda().get_device(), 0) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 502*da0073e9SAndroid Build Coastguard Worker def test_current_stream(self): 503*da0073e9SAndroid Build Coastguard Worker d0 = torch.device("cuda:0") 504*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:1") 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream() 507*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.current_stream(device=1) 508*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.current_stream(device=0) 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d0, s0.device) 511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d1, s1.device) 512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d0, s2.device) 513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s0, s2) 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 516*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream() 517*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.current_stream(1) 518*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.current_stream(d0) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d1, s0.device) 521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d1, s1.device) 522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d0, s2.device) 523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s0, s1) 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a cuda device, but got: cpu"): 526*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream(torch.device("cpu")) 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 529*da0073e9SAndroid Build Coastguard Worker @skipCUDANonDefaultStreamIf(True) 530*da0073e9SAndroid Build Coastguard Worker def test_default_stream(self): 531*da0073e9SAndroid Build Coastguard Worker d0 = torch.device("cuda:0") 532*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:1") 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 535*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.default_stream() 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 538*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.default_stream() 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.default_stream(device=0) 541*da0073e9SAndroid Build Coastguard Worker s3 = torch.cuda.default_stream(d1) 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d0, s0.device) 544*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d1, s1.device) 545*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d0, s2.device) 546*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d1, s3.device) 547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s0, s2) 548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1, s3) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 551*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), s0) 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 554*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), s1) 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a cuda device, but got: cpu"): 557*da0073e9SAndroid Build Coastguard Worker torch.cuda.default_stream(torch.device("cpu")) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 560*da0073e9SAndroid Build Coastguard Worker def test_stream_event_device(self): 561*da0073e9SAndroid Build Coastguard Worker d0 = torch.device("cuda:0") 562*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:1") 563*da0073e9SAndroid Build Coastguard Worker e0 = torch.cuda.Event() 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(None, e0.device) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 568*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream() 569*da0073e9SAndroid Build Coastguard Worker s0.record_event(e0) 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 572*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream() 573*da0073e9SAndroid Build Coastguard Worker e1 = s1.record_event() 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s0.device, torch.device("cuda:0")) 576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e0.device, torch.device("cuda:0")) 577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1.device, torch.device("cuda:1")) 578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e1.device, torch.device("cuda:1")) 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 581*da0073e9SAndroid Build Coastguard Worker def test_stream_context(self): 582*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream() 583*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream(device=1) 584*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.Stream(device=0) 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(s1.device): 587*da0073e9SAndroid Build Coastguard Worker prev_stream_on_cuda1 = torch.cuda.current_stream() 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), s0) 590*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, torch.cuda.current_device()) 591*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 592*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), s1) 593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, torch.cuda.current_device()) 594*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s2): 595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), s2) 596*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, torch.cuda.current_device()) 597*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s0): 598*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), s0) 599*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, torch.cuda.current_device()) 600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), s2) 601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, torch.cuda.current_device()) 602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), s1) 603*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, torch.cuda.current_device()) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(s1.device): 606*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prev_stream_on_cuda1, torch.cuda.current_stream()) 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream(), s0) 609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, torch.cuda.current_device()) 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 612*da0073e9SAndroid Build Coastguard Worker def test_streams_multi_gpu(self): 613*da0073e9SAndroid Build Coastguard Worker default_stream = torch.cuda.current_stream() 614*da0073e9SAndroid Build Coastguard Worker self.assertEqual(default_stream.device, torch.device("cuda:0")) 615*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream(device=1) 616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stream.device, torch.device("cuda:1")) 617*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_stream().device, torch.device("cuda:1")) 619*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(torch.cuda.current_stream(), default_stream) 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 622*da0073e9SAndroid Build Coastguard Worker def test_streams_multi_gpu_query(self): 623*da0073e9SAndroid Build Coastguard Worker d0 = torch.device("cuda:0") 624*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:1") 625*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(d0) 626*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(d1) 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 629*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream() 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 632*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.current_stream() 633*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0.query()) 636*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s1.query()) 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 639*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0.query()) 640*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s1.query()) 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 643*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0.query()) 644*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s1.query()) 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker # deliberately using a different device 647*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 648*da0073e9SAndroid Build Coastguard Worker s1.synchronize() 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0.query()) 651*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s1.query()) 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 654*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0.query()) 655*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s1.query()) 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 658*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0.query()) 659*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s1.query()) 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 662*da0073e9SAndroid Build Coastguard Worker def test_streams_multi_gpu_eq(self): 663*da0073e9SAndroid Build Coastguard Worker d0 = torch.device("cuda:0") 664*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:1") 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 667*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream() 668*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.current_stream() 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 671*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.current_stream() 672*da0073e9SAndroid Build Coastguard Worker s3 = torch.cuda.current_stream() 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0 == s0) 675*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0 == s1) 676*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s2 == s2) 677*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s2 == s3) 678*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s0 == s2) 679*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s1 == s3) 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s0.device, s1.device) 682*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s0.cuda_stream, s1.cuda_stream) 683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s2.device, s3.device) 684*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s2.cuda_stream, s3.cuda_stream) 685*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(s0.device, s3.device) 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hash(s0), hash(s1)) 688*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hash(s2), hash(s3)) 689*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(hash(s0), hash(s3)) 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 692*da0073e9SAndroid Build Coastguard Worker def test_streams_priority(self): 693*da0073e9SAndroid Build Coastguard Worker low, high = torch.cuda.Stream.priority_range() 694*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.Stream(device=0, priority=low) 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(low, s0.priority) 697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.device("cuda:0"), s0.device) 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream(device=1, priority=high) 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(high, s1.priority) 702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.device("cuda:1"), s1.device) 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 705*da0073e9SAndroid Build Coastguard Worker def test_tensor_device(self): 706*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.FloatTensor(1).get_device(), 0) 707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.FloatTensor(1, device=1).get_device(), 1) 708*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.FloatTensor(1).get_device(), 1) 710*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.FloatTensor(1, device=0).get_device(), 0) 711*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.FloatTensor(1, device=None).get_device(), 1) 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker @staticmethod 714*da0073e9SAndroid Build Coastguard Worker def _stream_synchronize(self, spin_time_cycles): 715*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.current_stream() 716*da0073e9SAndroid Build Coastguard Worker e_tik = torch.cuda.Event(enable_timing=True) 717*da0073e9SAndroid Build Coastguard Worker e_tok = torch.cuda.Event(enable_timing=True) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker e_tik.record(s) 720*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(spin_time_cycles) 721*da0073e9SAndroid Build Coastguard Worker e_tok.record(s) 722*da0073e9SAndroid Build Coastguard Worker s.synchronize() 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s.query()) 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker # not necessary to check e_tik and e_tok, as elapsed_time would throw 727*da0073e9SAndroid Build Coastguard Worker # exception if otherwise. 728*da0073e9SAndroid Build Coastguard Worker return e_tik.elapsed_time(e_tok) 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker @staticmethod 731*da0073e9SAndroid Build Coastguard Worker def _event_synchronize(self, spin_time_cycles): 732*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.current_stream() 733*da0073e9SAndroid Build Coastguard Worker e_tik = torch.cuda.Event(enable_timing=True) 734*da0073e9SAndroid Build Coastguard Worker e_tok = torch.cuda.Event(enable_timing=True) 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker e_tik.record(s) 737*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(spin_time_cycles) 738*da0073e9SAndroid Build Coastguard Worker s.record_event(e_tok) 739*da0073e9SAndroid Build Coastguard Worker e_tok.synchronize() 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s.query()) 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker # not necessary to check e_tik and e_tok, as elapsed_time would throw 744*da0073e9SAndroid Build Coastguard Worker # exception if otherwise. 745*da0073e9SAndroid Build Coastguard Worker return e_tik.elapsed_time(e_tok) 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker @staticmethod 748*da0073e9SAndroid Build Coastguard Worker def _event_wait(self, spin_time_cycles): 749*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream() 750*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream() 751*da0073e9SAndroid Build Coastguard Worker e_tik = torch.cuda.Event(blocking=True, enable_timing=True) 752*da0073e9SAndroid Build Coastguard Worker e_tok = torch.cuda.Event(blocking=True, enable_timing=True) 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker e_tik.record(s0) 755*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(spin_time_cycles - 10) 756*da0073e9SAndroid Build Coastguard Worker e_sync = torch.cuda.Event(blocking=True) 757*da0073e9SAndroid Build Coastguard Worker e_sync.record() 758*da0073e9SAndroid Build Coastguard Worker e_sync.wait(s1) 759*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 760*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(10) 761*da0073e9SAndroid Build Coastguard Worker s1.synchronize() 762*da0073e9SAndroid Build Coastguard Worker e_tok.record() 763*da0073e9SAndroid Build Coastguard Worker e_tok.synchronize() 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0.query()) 766*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s1.query()) 767*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e_sync.query()) 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker # not necessary to check e_tik and e_tok, as elapsed_time would throw 770*da0073e9SAndroid Build Coastguard Worker # exception if otherwise. 771*da0073e9SAndroid Build Coastguard Worker return e_tik.elapsed_time(e_tok) 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker @staticmethod 774*da0073e9SAndroid Build Coastguard Worker def _test_stream_event_nogil(self, sync_func, p2c, c2p): 775*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device("cuda:1"): 776*da0073e9SAndroid Build Coastguard Worker c2p.put(0) 777*da0073e9SAndroid Build Coastguard Worker p2c.get() 778*da0073e9SAndroid Build Coastguard Worker c2p.put(sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES)) 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190 781*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 782*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 783*da0073e9SAndroid Build Coastguard Worker def test_stream_event_nogil(self): 784*da0073e9SAndroid Build Coastguard Worker for sync_func in [ 785*da0073e9SAndroid Build Coastguard Worker TestCudaMultiGPU._stream_synchronize, 786*da0073e9SAndroid Build Coastguard Worker TestCudaMultiGPU._event_synchronize, 787*da0073e9SAndroid Build Coastguard Worker TestCudaMultiGPU._event_wait, 788*da0073e9SAndroid Build Coastguard Worker ]: 789*da0073e9SAndroid Build Coastguard Worker p2c = queue.Queue() 790*da0073e9SAndroid Build Coastguard Worker c2p = queue.Queue() 791*da0073e9SAndroid Build Coastguard Worker e_tik = torch.cuda.Event(enable_timing=True) 792*da0073e9SAndroid Build Coastguard Worker e_tok = torch.cuda.Event(enable_timing=True) 793*da0073e9SAndroid Build Coastguard Worker 794*da0073e9SAndroid Build Coastguard Worker t = threading.Thread( 795*da0073e9SAndroid Build Coastguard Worker target=TestCudaMultiGPU._test_stream_event_nogil, 796*da0073e9SAndroid Build Coastguard Worker args=(self, sync_func, p2c, c2p), 797*da0073e9SAndroid Build Coastguard Worker ) 798*da0073e9SAndroid Build Coastguard Worker t.daemon = True 799*da0073e9SAndroid Build Coastguard Worker t.start() 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker c2p.get() 802*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device("cuda:0"): 803*da0073e9SAndroid Build Coastguard Worker e_tik.record() 804*da0073e9SAndroid Build Coastguard Worker p2c.put(0) 805*da0073e9SAndroid Build Coastguard Worker parent_time = sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES) 806*da0073e9SAndroid Build Coastguard Worker child_time = c2p.get() 807*da0073e9SAndroid Build Coastguard Worker e_tok.record() 808*da0073e9SAndroid Build Coastguard Worker e_tok.synchronize() 809*da0073e9SAndroid Build Coastguard Worker total_time = e_tik.elapsed_time(e_tok) 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker # Without GIL, synchronizations in parent and child threads can 812*da0073e9SAndroid Build Coastguard Worker # overlap. The total execution time should be a little bit longer 813*da0073e9SAndroid Build Coastguard Worker # than spinning fifty million cycles and much shorter than twice of 814*da0073e9SAndroid Build Coastguard Worker # that. However, testing absolute execution time is not reliable as 815*da0073e9SAndroid Build Coastguard Worker # it may vary on different hardware in different environments. 816*da0073e9SAndroid Build Coastguard Worker # Therefore, this test uses relative comparisons, checking if the 817*da0073e9SAndroid Build Coastguard Worker # sum of parent and child threads execution time is greater than the 818*da0073e9SAndroid Build Coastguard Worker # real execution time by least 40%. 819*da0073e9SAndroid Build Coastguard Worker self.assertGreater(parent_time + child_time, total_time * 1.4) 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker # This test is flaky for ROCm, see issue #62602 822*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 823*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 824*da0073e9SAndroid Build Coastguard Worker def test_events_wait(self): 825*da0073e9SAndroid Build Coastguard Worker d0 = torch.device("cuda:0") 826*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:1") 827*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(d0) 828*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(d1) 829*da0073e9SAndroid Build Coastguard Worker 830*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 831*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream() 832*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 833*da0073e9SAndroid Build Coastguard Worker e0 = torch.cuda.Event() 834*da0073e9SAndroid Build Coastguard Worker s0.record_event(e0) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 837*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.current_stream() 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s0.query()) 840*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s1.query()) 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker s1.wait_event(e0) 843*da0073e9SAndroid Build Coastguard Worker s1.synchronize() 844*da0073e9SAndroid Build Coastguard Worker 845*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e0.query()) 846*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0.query()) 847*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s1.query()) 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 850*da0073e9SAndroid Build Coastguard Worker def test_events_multi_gpu_query(self): 851*da0073e9SAndroid Build Coastguard Worker d0 = torch.device("cuda:0") 852*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:1") 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 855*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream() 856*da0073e9SAndroid Build Coastguard Worker e0 = s0.record_event() 857*da0073e9SAndroid Build Coastguard Worker s0.synchronize() 858*da0073e9SAndroid Build Coastguard Worker 859*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 860*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.current_stream() 861*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 862*da0073e9SAndroid Build Coastguard Worker e1 = s1.record_event() 863*da0073e9SAndroid Build Coastguard Worker 864*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e0.query()) 865*da0073e9SAndroid Build Coastguard Worker self.assertFalse(e1.query()) 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 868*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e0.query()) 869*da0073e9SAndroid Build Coastguard Worker self.assertFalse(e1.query()) 870*da0073e9SAndroid Build Coastguard Worker 871*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 872*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e0.query()) 873*da0073e9SAndroid Build Coastguard Worker self.assertFalse(e1.query()) 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker # deliberately using a different device 876*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 877*da0073e9SAndroid Build Coastguard Worker e1.synchronize() 878*da0073e9SAndroid Build Coastguard Worker 879*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e0.query()) 880*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e1.query()) 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 883*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e0.query()) 884*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e1.query()) 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 887*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e0.query()) 888*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e1.query()) 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 891*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 892*da0073e9SAndroid Build Coastguard Worker def test_events_multi_gpu_elapsed_time(self): 893*da0073e9SAndroid Build Coastguard Worker d0 = torch.device("cuda:0") 894*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:1") 895*da0073e9SAndroid Build Coastguard Worker 896*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 897*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream() 898*da0073e9SAndroid Build Coastguard Worker e0 = torch.cuda.Event(enable_timing=True) 899*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(10) 900*da0073e9SAndroid Build Coastguard Worker s0.record_event(e0) 901*da0073e9SAndroid Build Coastguard Worker 902*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 903*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.current_stream() 904*da0073e9SAndroid Build Coastguard Worker e1 = torch.cuda.Event(enable_timing=True) 905*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 906*da0073e9SAndroid Build Coastguard Worker s1.record_event(e1) 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Worker e0.synchronize() 909*da0073e9SAndroid Build Coastguard Worker e1.synchronize() 910*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 911*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 912*da0073e9SAndroid Build Coastguard Worker self.assertGreater(e0.elapsed_time(e1), 0) 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 915*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 916*da0073e9SAndroid Build Coastguard Worker self.assertGreater(e0.elapsed_time(e1), 0) 917*da0073e9SAndroid Build Coastguard Worker 918*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 919*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream() 920*da0073e9SAndroid Build Coastguard Worker e2 = torch.cuda.Event(enable_timing=True) 921*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 922*da0073e9SAndroid Build Coastguard Worker s0.record_event(e2) 923*da0073e9SAndroid Build Coastguard Worker s0.synchronize() 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker self.assertGreater(e0.elapsed_time(e2), 0) 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker # deliberately calling from a different device 928*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 929*da0073e9SAndroid Build Coastguard Worker self.assertGreater(e0.elapsed_time(e2), 0) 930*da0073e9SAndroid Build Coastguard Worker 931*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 932*da0073e9SAndroid Build Coastguard Worker def _get_external_stream(self, device): 933*da0073e9SAndroid Build Coastguard Worker cudart = torch.cuda.cudart() 934*da0073e9SAndroid Build Coastguard Worker stream = ctypes.c_ulonglong(0) 935*da0073e9SAndroid Build Coastguard Worker stream_p = ctypes.POINTER(ctypes.c_void_p)(stream) 936*da0073e9SAndroid Build Coastguard Worker stream_p_int = ctypes.cast(stream_p, ctypes.c_void_p).value 937*da0073e9SAndroid Build Coastguard Worker with device: 938*da0073e9SAndroid Build Coastguard Worker try: 939*da0073e9SAndroid Build Coastguard Worker out = cudart.cudaStreamCreate(stream_p_int) 940*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, 0) 941*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(stream.value, 0) 942*da0073e9SAndroid Build Coastguard Worker yield stream.value 943*da0073e9SAndroid Build Coastguard Worker finally: 944*da0073e9SAndroid Build Coastguard Worker out = cudart.cudaStreamDestroy(stream.value) 945*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, 0) 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker def test_external_streams(self): 948*da0073e9SAndroid Build Coastguard Worker device = torch.cuda.device(0) 949*da0073e9SAndroid Build Coastguard Worker with self._get_external_stream(device) as stream_v: 950*da0073e9SAndroid Build Coastguard Worker ext_stream = torch.cuda.ExternalStream(stream_v) 951*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stream_v, ext_stream.cuda_stream) 952*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ext_stream.device.index, device.idx) 953*da0073e9SAndroid Build Coastguard Worker 954*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 955*da0073e9SAndroid Build Coastguard Worker def test_external_streams_multi_device(self): 956*da0073e9SAndroid Build Coastguard Worker device = torch.cuda.device(1) 957*da0073e9SAndroid Build Coastguard Worker with self._get_external_stream(device) as stream_v: 958*da0073e9SAndroid Build Coastguard Worker ext_stream = torch.cuda.ExternalStream(stream_v, device=device) 959*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stream_v, ext_stream.cuda_stream) 960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ext_stream.device.index, device.idx) 961*da0073e9SAndroid Build Coastguard Worker 962*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 963*da0073e9SAndroid Build Coastguard Worker def test_caching_pinned_memory_multi_gpu(self): 964*da0073e9SAndroid Build Coastguard Worker # checks that the events preventing pinned memory from being re-used 965*da0073e9SAndroid Build Coastguard Worker # too early are recorded on the correct GPU 966*da0073e9SAndroid Build Coastguard Worker cycles_per_ms = get_cycles_per_ms() 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker t = torch.FloatTensor([1]).pin_memory() 969*da0073e9SAndroid Build Coastguard Worker ptr = t.data_ptr() 970*da0073e9SAndroid Build Coastguard Worker gpu_tensor0 = torch.cuda.FloatTensor([0], device=0) 971*da0073e9SAndroid Build Coastguard Worker gpu_tensor1 = torch.cuda.FloatTensor([0], device=1) 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 974*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(int(1000 * cycles_per_ms)) # delay the copy by 1s 975*da0073e9SAndroid Build Coastguard Worker gpu_tensor1.copy_(t, non_blocking=True) 976*da0073e9SAndroid Build Coastguard Worker 977*da0073e9SAndroid Build Coastguard Worker del t 978*da0073e9SAndroid Build Coastguard Worker t = torch.FloatTensor([2]).pin_memory() 979*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon") 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(0): 982*da0073e9SAndroid Build Coastguard Worker gpu_tensor0.copy_(t, non_blocking=True) 983*da0073e9SAndroid Build Coastguard Worker 984*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gpu_tensor1[0], 1) 985*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gpu_tensor0[0], 2) 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 988*da0073e9SAndroid Build Coastguard Worker def test_get_set_rng_state_all(self): 989*da0073e9SAndroid Build Coastguard Worker states = torch.cuda.get_rng_state_all() 990*da0073e9SAndroid Build Coastguard Worker before0 = torch.cuda.FloatTensor(100, device=0).normal_() 991*da0073e9SAndroid Build Coastguard Worker before1 = torch.cuda.FloatTensor(100, device=1).normal_() 992*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_rng_state_all(states) 993*da0073e9SAndroid Build Coastguard Worker after0 = torch.cuda.FloatTensor(100, device=0).normal_() 994*da0073e9SAndroid Build Coastguard Worker after1 = torch.cuda.FloatTensor(100, device=1).normal_() 995*da0073e9SAndroid Build Coastguard Worker self.assertEqual(before0, after0, atol=0, rtol=0) 996*da0073e9SAndroid Build Coastguard Worker self.assertEqual(before1, after1, atol=0, rtol=0) 997*da0073e9SAndroid Build Coastguard Worker 998*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 999*da0073e9SAndroid Build Coastguard Worker def test_rng_state_offset(self): 1000*da0073e9SAndroid Build Coastguard Worker before = torch.cuda.get_rng_state() 1001*da0073e9SAndroid Build Coastguard Worker torch.cuda._set_rng_state_offset(100) 1002*da0073e9SAndroid Build Coastguard Worker offset = torch.cuda._get_rng_state_offset() 1003*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_rng_state(before) 1004*da0073e9SAndroid Build Coastguard Worker self.assertEqual(offset, 100) 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker # Verifies that mem_get_info works, including when called for a different device 1007*da0073e9SAndroid Build Coastguard Worker def test_mem_get_info(self): 1008*da0073e9SAndroid Build Coastguard Worker def _test(device: Union[str, int, torch.device]): 1009*da0073e9SAndroid Build Coastguard Worker # Prevent PyTorch from reusing the allocated memory 1010*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 1011*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1012*da0073e9SAndroid Build Coastguard Worker before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(device) 1013*da0073e9SAndroid Build Coastguard Worker # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms 1014*da0073e9SAndroid Build Coastguard Worker t = torch.randn(1024 * 1024 * 8, device=device) 1015*da0073e9SAndroid Build Coastguard Worker if IS_JETSON: 1016*da0073e9SAndroid Build Coastguard Worker # w/o syncing, mem_get_info will run before memory allocated has actually increased. 1017*da0073e9SAndroid Build Coastguard Worker # This race condition causes consistent failure 1018*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1019*da0073e9SAndroid Build Coastguard Worker after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(device) 1020*da0073e9SAndroid Build Coastguard Worker 1021*da0073e9SAndroid Build Coastguard Worker self.assertLess(after_free_bytes, before_free_bytes) 1022*da0073e9SAndroid Build Coastguard Worker self.assertEqual(before_available_bytes, after_available_bytes) 1023*da0073e9SAndroid Build Coastguard Worker 1024*da0073e9SAndroid Build Coastguard Worker # Test calls with different device representations 1025*da0073e9SAndroid Build Coastguard Worker _test(0) 1026*da0073e9SAndroid Build Coastguard Worker _test(torch.device("cuda")) 1027*da0073e9SAndroid Build Coastguard Worker _test(torch.device("cuda:0")) 1028*da0073e9SAndroid Build Coastguard Worker _test("cuda") 1029*da0073e9SAndroid Build Coastguard Worker _test("cuda:0") 1030*da0073e9SAndroid Build Coastguard Worker if TEST_MULTIGPU: 1031*da0073e9SAndroid Build Coastguard Worker _test(1) 1032*da0073e9SAndroid Build Coastguard Worker _test(torch.device("cuda:1")) 1033*da0073e9SAndroid Build Coastguard Worker _test("cuda:1") 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard Worker # Test that wrap_with_cuda_memory_check successfully detects leak 1036*da0073e9SAndroid Build Coastguard Worker def test_cuda_memory_leak_detection(self): 1037*da0073e9SAndroid Build Coastguard Worker l = [] 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker @self.wrap_with_cuda_memory_check 1040*da0073e9SAndroid Build Coastguard Worker def no_leak(): 1041*da0073e9SAndroid Build Coastguard Worker pass 1042*da0073e9SAndroid Build Coastguard Worker 1043*da0073e9SAndroid Build Coastguard Worker @self.wrap_with_cuda_memory_check 1044*da0073e9SAndroid Build Coastguard Worker def leak_gpu0(): 1045*da0073e9SAndroid Build Coastguard Worker # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms 1046*da0073e9SAndroid Build Coastguard Worker l.append(torch.randn(1024 * 1024 * 8, device=torch.device("cuda:0"))) 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker no_leak() 1049*da0073e9SAndroid Build Coastguard Worker regex = r"CUDA driver API confirmed .+ on device 0.+" 1050*da0073e9SAndroid Build Coastguard Worker if IS_JETSON: 1051*da0073e9SAndroid Build Coastguard Worker try: 1052*da0073e9SAndroid Build Coastguard Worker leak_gpu0() 1053*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1054*da0073e9SAndroid Build Coastguard Worker import re 1055*da0073e9SAndroid Build Coastguard Worker 1056*da0073e9SAndroid Build Coastguard Worker assert re.match(regex, str(e)), str(e) + "\n does not match: \n" + regex 1057*da0073e9SAndroid Build Coastguard Worker else: 1058*da0073e9SAndroid Build Coastguard Worker # assertRaisesRegex does not pass with Python for Jetson, 1059*da0073e9SAndroid Build Coastguard Worker # even though the RuntimeError matches regex using re.match 1060*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, regex): 1061*da0073e9SAndroid Build Coastguard Worker leak_gpu0() 1062*da0073e9SAndroid Build Coastguard Worker 1063*da0073e9SAndroid Build Coastguard Worker if TEST_MULTIGPU: 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker @self.wrap_with_cuda_memory_check 1066*da0073e9SAndroid Build Coastguard Worker def leak_gpu1(): 1067*da0073e9SAndroid Build Coastguard Worker # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms 1068*da0073e9SAndroid Build Coastguard Worker l.append(torch.randn(1024 * 1024 * 8, device=torch.device("cuda:1"))) 1069*da0073e9SAndroid Build Coastguard Worker 1070*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1071*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"CUDA driver API confirmed .+ on device 1.+" 1072*da0073e9SAndroid Build Coastguard Worker ): 1073*da0073e9SAndroid Build Coastguard Worker leak_gpu1() 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1076*da0073e9SAndroid Build Coastguard Worker def test_streaming_backwards_device_transfer(self): 1077*da0073e9SAndroid Build Coastguard Worker # This function must run with non-default current streams on all devices, otherwise it's meaningless. 1078*da0073e9SAndroid Build Coastguard Worker # The intention is to test that to()'s backward (CopyBackward) interacts properly with the 1079*da0073e9SAndroid Build Coastguard Worker # synchronization logic in torch/csrc/autograd/input_buffer.cpp. 1080*da0073e9SAndroid Build Coastguard Worker dev0 = torch.device("cuda:0") 1081*da0073e9SAndroid Build Coastguard Worker dev1 = torch.device("cuda:1") 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker # Unfortunately I need to make the tensors largeish. 1084*da0073e9SAndroid Build Coastguard Worker # Bigger tensors = longer D2D transfers = more likely to expose races. 1085*da0073e9SAndroid Build Coastguard Worker size = 2**26 1086*da0073e9SAndroid Build Coastguard Worker 1087*da0073e9SAndroid Build Coastguard Worker a = torch.full((size,), 1, device=dev1, dtype=torch.float64, requires_grad=True) 1088*da0073e9SAndroid Build Coastguard Worker b = torch.full((size,), 1, device=dev1, dtype=torch.float64, requires_grad=True) 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker # Here to_backward_recipient = a*b is used only once, so MulBackward's InputBuffer slot only expects 1 input. 1091*da0073e9SAndroid Build Coastguard Worker # This tests the situation where we don't call InputBuffer::accumulate for MulBackward's InputBuffer. 1092*da0073e9SAndroid Build Coastguard Worker to_backward_recipient = a * b 1093*da0073e9SAndroid Build Coastguard Worker s = to_backward_recipient.to(device="cuda:0").sum() 1094*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(device=dev0) 1095*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(device=dev1) 1096*da0073e9SAndroid Build Coastguard Worker s.backward() 1097*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.grad.sum().item() == size) 1098*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.grad.sum().item() == size) 1099*da0073e9SAndroid Build Coastguard Worker 1100*da0073e9SAndroid Build Coastguard Worker # Here to_backward_recipient = a*b is used twice, so MulBackward's InputBuffer slot expects 2 inputs. 1101*da0073e9SAndroid Build Coastguard Worker # This tests the situation where we do call InputBuffer::accumulate for MulBackward's InputBuffer. 1102*da0073e9SAndroid Build Coastguard Worker a.grad = None 1103*da0073e9SAndroid Build Coastguard Worker b.grad = None 1104*da0073e9SAndroid Build Coastguard Worker to_backward_recipient = a * b 1105*da0073e9SAndroid Build Coastguard Worker # Multiply by 2 here so to's backward creates gradient values that are different from the case above, 1106*da0073e9SAndroid Build Coastguard Worker # to mitigate weirdness if the caching allocator happens to reuse memory regions that were populated 1107*da0073e9SAndroid Build Coastguard Worker # with 1s by the case above 1108*da0073e9SAndroid Build Coastguard Worker s0 = to_backward_recipient.to(device="cuda:0").sum() * 2.0 1109*da0073e9SAndroid Build Coastguard Worker s1 = to_backward_recipient.to(device="cuda:0").sum() * 2.0 1110*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(device=dev0) 1111*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(device=dev1) 1112*da0073e9SAndroid Build Coastguard Worker s0.backward(retain_graph=True) 1113*da0073e9SAndroid Build Coastguard Worker s1.backward() 1114*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.grad.sum().item() == 4 * size) 1115*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.grad.sum().item() == 4 * size) 1116*da0073e9SAndroid Build Coastguard Worker 1117*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1118*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE or IS_REMOTE_GPU, "Does not work on Sandcastle") 1119*da0073e9SAndroid Build Coastguard Worker def test_cuda_init_race(self): 1120*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/16559 1121*da0073e9SAndroid Build Coastguard Worker import subprocess 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker subprocess.check_call( 1124*da0073e9SAndroid Build Coastguard Worker [ 1125*da0073e9SAndroid Build Coastguard Worker sys.executable, 1126*da0073e9SAndroid Build Coastguard Worker "-c", 1127*da0073e9SAndroid Build Coastguard Worker """\ 1128*da0073e9SAndroid Build Coastguard Workerimport torch 1129*da0073e9SAndroid Build Coastguard Workerimport threading 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Workerdef worker(rank): 1132*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.]).cuda(rank) 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Workert1 = threading.Thread(target=worker, args=(0,)) 1135*da0073e9SAndroid Build Coastguard Workert2 = threading.Thread(target=worker, args=(1,)) 1136*da0073e9SAndroid Build Coastguard Workert1.start() 1137*da0073e9SAndroid Build Coastguard Workert2.start() 1138*da0073e9SAndroid Build Coastguard Worker""", 1139*da0073e9SAndroid Build Coastguard Worker ] 1140*da0073e9SAndroid Build Coastguard Worker ) 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1143*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_device_as_key(self): 1144*da0073e9SAndroid Build Coastguard Worker # Ensure that different instances of "device" objects that point to the same device 1145*da0073e9SAndroid Build Coastguard Worker # are treated as identical keys by dicts. GradScaler relies on this behavior, and may 1146*da0073e9SAndroid Build Coastguard Worker # error otherwise in a way that's difficult to detect (a silent performance hit). 1147*da0073e9SAndroid Build Coastguard Worker d = {} 1148*da0073e9SAndroid Build Coastguard Worker t = torch.empty((1,), device="cuda:0") 1149*da0073e9SAndroid Build Coastguard Worker dev0a = torch.device("cuda:0") 1150*da0073e9SAndroid Build Coastguard Worker dev0b = torch.device("cuda:0") 1151*da0073e9SAndroid Build Coastguard Worker dev1a = torch.device("cuda:1") 1152*da0073e9SAndroid Build Coastguard Worker dev1b = torch.device("cuda:1") 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hash(dev0a) == hash(dev0b)) 1155*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hash(dev1a) == hash(dev1b)) 1156*da0073e9SAndroid Build Coastguard Worker 1157*da0073e9SAndroid Build Coastguard Worker d[dev0a] = "0a" 1158*da0073e9SAndroid Build Coastguard Worker d[dev0b] = "0b" 1159*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(d) == 1) 1160*da0073e9SAndroid Build Coastguard Worker self.assertTrue(d[dev0a] == "0b") 1161*da0073e9SAndroid Build Coastguard Worker d[t.device] = "t" 1162*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(d) == 1) 1163*da0073e9SAndroid Build Coastguard Worker self.assertTrue(d[dev0a] == "t") 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Worker d[dev1a] = "1a" 1166*da0073e9SAndroid Build Coastguard Worker d[dev1b] = "1b" 1167*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(d) == 2) 1168*da0073e9SAndroid Build Coastguard Worker self.assertTrue(d[dev1a] == "1b") 1169*da0073e9SAndroid Build Coastguard Worker 1170*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1171*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_scale(self): 1172*da0073e9SAndroid Build Coastguard Worker scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0) 1173*da0073e9SAndroid Build Coastguard Worker t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:0") 1174*da0073e9SAndroid Build Coastguard Worker t1 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:1") 1175*da0073e9SAndroid Build Coastguard Worker # Create some nested iterables of tensors on different devices. 1176*da0073e9SAndroid Build Coastguard Worker outputs = ( 1177*da0073e9SAndroid Build Coastguard Worker t1.clone(), 1178*da0073e9SAndroid Build Coastguard Worker (t0.clone(), t1.clone()), 1179*da0073e9SAndroid Build Coastguard Worker [t0.clone(), (t1.clone(), t0.clone())], 1180*da0073e9SAndroid Build Coastguard Worker ) 1181*da0073e9SAndroid Build Coastguard Worker outputs = scaler.scale(outputs) 1182*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1183*da0073e9SAndroid Build Coastguard Worker outputs[0] == 8.0 1184*da0073e9SAndroid Build Coastguard Worker and outputs[1][0] == 8.0 1185*da0073e9SAndroid Build Coastguard Worker and outputs[1][1] == 8.0 1186*da0073e9SAndroid Build Coastguard Worker and outputs[2][0] == 8.0 1187*da0073e9SAndroid Build Coastguard Worker and outputs[2][1][0] == 8.0 1188*da0073e9SAndroid Build Coastguard Worker and outputs[2][1][1] == 8.0 1189*da0073e9SAndroid Build Coastguard Worker ) 1190*da0073e9SAndroid Build Coastguard Worker self.assertTrue(scaler._scale.device == t1.device) 1191*da0073e9SAndroid Build Coastguard Worker 1192*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1193*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_multigpu(self): 1194*da0073e9SAndroid Build Coastguard Worker # Same as above, but runs some of the models on device 1. 1195*da0073e9SAndroid Build Coastguard Worker # GradScaler should transparently handle losses and gradients on multiple devices. 1196*da0073e9SAndroid Build Coastguard Worker # This test could be combined with the test above, but I think it makes sense to treat 1197*da0073e9SAndroid Build Coastguard Worker # multi-GPU operations separately. 1198*da0073e9SAndroid Build Coastguard Worker dev0 = torch.device("cuda:0") 1199*da0073e9SAndroid Build Coastguard Worker dev1 = torch.device("cuda:1") 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker for enabled in True, False: 1202*da0073e9SAndroid Build Coastguard Worker ( 1203*da0073e9SAndroid Build Coastguard Worker mod_control0, 1204*da0073e9SAndroid Build Coastguard Worker mod_scaling0, 1205*da0073e9SAndroid Build Coastguard Worker opt_control0, 1206*da0073e9SAndroid Build Coastguard Worker opt_scaling0, 1207*da0073e9SAndroid Build Coastguard Worker data, 1208*da0073e9SAndroid Build Coastguard Worker loss_fn, 1209*da0073e9SAndroid Build Coastguard Worker skip_iter, 1210*da0073e9SAndroid Build Coastguard Worker ) = _create_scaling_case() 1211*da0073e9SAndroid Build Coastguard Worker ( 1212*da0073e9SAndroid Build Coastguard Worker mod_control1, 1213*da0073e9SAndroid Build Coastguard Worker mod_scaling1, 1214*da0073e9SAndroid Build Coastguard Worker opt_control1, 1215*da0073e9SAndroid Build Coastguard Worker opt_scaling1, 1216*da0073e9SAndroid Build Coastguard Worker ) = _create_scaling_models_optimizers(device=dev1) 1217*da0073e9SAndroid Build Coastguard Worker 1218*da0073e9SAndroid Build Coastguard Worker scaler = torch.amp.GradScaler( 1219*da0073e9SAndroid Build Coastguard Worker device="cuda", 1220*da0073e9SAndroid Build Coastguard Worker init_scale=128.0, 1221*da0073e9SAndroid Build Coastguard Worker growth_factor=2.0, 1222*da0073e9SAndroid Build Coastguard Worker enabled=enabled, 1223*da0073e9SAndroid Build Coastguard Worker growth_interval=1, 1224*da0073e9SAndroid Build Coastguard Worker ) 1225*da0073e9SAndroid Build Coastguard Worker 1226*da0073e9SAndroid Build Coastguard Worker def run(model0, model1, optimizer0, optimizer1, try_scaling_api): 1227*da0073e9SAndroid Build Coastguard Worker for i, (input, target) in enumerate(data): 1228*da0073e9SAndroid Build Coastguard Worker optimizer0.zero_grad() 1229*da0073e9SAndroid Build Coastguard Worker optimizer1.zero_grad() 1230*da0073e9SAndroid Build Coastguard Worker output0 = model0(input) 1231*da0073e9SAndroid Build Coastguard Worker output1 = model1(input.to(dev1)) 1232*da0073e9SAndroid Build Coastguard Worker loss0 = loss_fn(0.3 * output0 + 0.7 * output1.to(dev0), target) 1233*da0073e9SAndroid Build Coastguard Worker loss1 = loss_fn( 1234*da0073e9SAndroid Build Coastguard Worker 0.6 * output0.to(dev1) - 0.4 * output1, target.to(dev1) 1235*da0073e9SAndroid Build Coastguard Worker ) 1236*da0073e9SAndroid Build Coastguard Worker 1237*da0073e9SAndroid Build Coastguard Worker if try_scaling_api: 1238*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss0).backward(retain_graph=True) 1239*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss1).backward() 1240*da0073e9SAndroid Build Coastguard Worker if i == skip_iter and scaler.is_enabled(): 1241*da0073e9SAndroid Build Coastguard Worker model1[1].weight.grad.data.fill_(float("inf")) 1242*da0073e9SAndroid Build Coastguard Worker 1243*da0073e9SAndroid Build Coastguard Worker # As an additional stress test, separately unscale for one of the optimizers. 1244*da0073e9SAndroid Build Coastguard Worker scaler.unscale_(optimizer0) 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker scaler.step(optimizer0) 1247*da0073e9SAndroid Build Coastguard Worker scaler.step(optimizer1) 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker # Make sure the found_infs were collected properly across optimizers and devices. 1250*da0073e9SAndroid Build Coastguard Worker if scaler.is_enabled(): 1251*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1252*da0073e9SAndroid Build Coastguard Worker len(scaler._found_inf_per_device(optimizer0)) == 1 1253*da0073e9SAndroid Build Coastguard Worker ) 1254*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1255*da0073e9SAndroid Build Coastguard Worker len(scaler._found_inf_per_device(optimizer1)) == 1 1256*da0073e9SAndroid Build Coastguard Worker ) 1257*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1258*da0073e9SAndroid Build Coastguard Worker scaler._found_inf_per_device(optimizer0)[dev0].item() 1259*da0073e9SAndroid Build Coastguard Worker == 0.0 1260*da0073e9SAndroid Build Coastguard Worker ) 1261*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1262*da0073e9SAndroid Build Coastguard Worker scaler._found_inf_per_device(optimizer1)[dev1].item() 1263*da0073e9SAndroid Build Coastguard Worker == float(i == skip_iter) 1264*da0073e9SAndroid Build Coastguard Worker ) 1265*da0073e9SAndroid Build Coastguard Worker 1266*da0073e9SAndroid Build Coastguard Worker scaler.update() 1267*da0073e9SAndroid Build Coastguard Worker else: 1268*da0073e9SAndroid Build Coastguard Worker loss0.backward(retain_graph=True) 1269*da0073e9SAndroid Build Coastguard Worker loss1.backward() 1270*da0073e9SAndroid Build Coastguard Worker optimizer0.step() 1271*da0073e9SAndroid Build Coastguard Worker if (not scaler.is_enabled()) or (i != skip_iter): 1272*da0073e9SAndroid Build Coastguard Worker optimizer1.step() 1273*da0073e9SAndroid Build Coastguard Worker 1274*da0073e9SAndroid Build Coastguard Worker run(mod_control0, mod_control1, opt_control0, opt_control1, False) 1275*da0073e9SAndroid Build Coastguard Worker run(mod_scaling0, mod_scaling1, opt_scaling0, opt_scaling1, True) 1276*da0073e9SAndroid Build Coastguard Worker 1277*da0073e9SAndroid Build Coastguard Worker # The loss scale should have been multiplied by the growth factor 3 times and the backoff factor once. 1278*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1279*da0073e9SAndroid Build Coastguard Worker scaler.get_scale() 1280*da0073e9SAndroid Build Coastguard Worker == ( 1281*da0073e9SAndroid Build Coastguard Worker 128.0 1282*da0073e9SAndroid Build Coastguard Worker * scaler.get_growth_factor() ** 3 1283*da0073e9SAndroid Build Coastguard Worker * scaler.get_backoff_factor() ** 1 1284*da0073e9SAndroid Build Coastguard Worker ) 1285*da0073e9SAndroid Build Coastguard Worker if enabled 1286*da0073e9SAndroid Build Coastguard Worker else 1.0 1287*da0073e9SAndroid Build Coastguard Worker ) 1288*da0073e9SAndroid Build Coastguard Worker 1289*da0073e9SAndroid Build Coastguard Worker # Copy mod_control1 and mod_scaling1 back the device 0 for comparison 1290*da0073e9SAndroid Build Coastguard Worker mod_control1.to(dev0) 1291*da0073e9SAndroid Build Coastguard Worker mod_scaling1.to(dev0) 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Worker for c, s in zip( 1294*da0073e9SAndroid Build Coastguard Worker chain(mod_control0.parameters(), mod_control1.parameters()), 1295*da0073e9SAndroid Build Coastguard Worker chain(mod_scaling0.parameters(), mod_scaling1.parameters()), 1296*da0073e9SAndroid Build Coastguard Worker ): 1297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, s, rtol=1e-5, atol=1e-7) 1298*da0073e9SAndroid Build Coastguard Worker 1299*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs") 1300*da0073e9SAndroid Build Coastguard Worker def test_cuda_device_memory_allocated(self): 1301*da0073e9SAndroid Build Coastguard Worker from torch.cuda import memory_allocated 1302*da0073e9SAndroid Build Coastguard Worker 1303*da0073e9SAndroid Build Coastguard Worker device_count = torch.cuda.device_count() 1304*da0073e9SAndroid Build Coastguard Worker current_alloc = [memory_allocated(idx) for idx in range(device_count)] 1305*da0073e9SAndroid Build Coastguard Worker x = torch.ones(10, device="cuda:0") 1306*da0073e9SAndroid Build Coastguard Worker self.assertGreater(memory_allocated(0), current_alloc[0]) 1307*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1308*da0073e9SAndroid Build Coastguard Worker all( 1309*da0073e9SAndroid Build Coastguard Worker memory_allocated(torch.cuda.device(idx)) == current_alloc[idx] 1310*da0073e9SAndroid Build Coastguard Worker for idx in range(1, device_count) 1311*da0073e9SAndroid Build Coastguard Worker ) 1312*da0073e9SAndroid Build Coastguard Worker ) 1313*da0073e9SAndroid Build Coastguard Worker 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Workerclass TestCudaComm(TestCase): 1316*da0073e9SAndroid Build Coastguard Worker def _test_broadcast(self, input): 1317*da0073e9SAndroid Build Coastguard Worker if not TEST_MULTIGPU: 1318*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("only one GPU detected") 1319*da0073e9SAndroid Build Coastguard Worker # test regular 1320*da0073e9SAndroid Build Coastguard Worker results = comm.broadcast(input, (0, 1)) 1321*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(results): 1322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.get_device(), i) 1323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, input) 1324*da0073e9SAndroid Build Coastguard Worker if ( 1325*da0073e9SAndroid Build Coastguard Worker input.is_cuda and input.get_device() == i 1326*da0073e9SAndroid Build Coastguard Worker ): # test not copying on same device 1327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.data_ptr(), input.data_ptr()) 1328*da0073e9SAndroid Build Coastguard Worker # test out= 1329*da0073e9SAndroid Build Coastguard Worker for inplace in [True, False]: 1330*da0073e9SAndroid Build Coastguard Worker if inplace: 1331*da0073e9SAndroid Build Coastguard Worker outputs = [ 1332*da0073e9SAndroid Build Coastguard Worker torch.empty_like(input, device=0), 1333*da0073e9SAndroid Build Coastguard Worker torch.empty_like(input, device=1), 1334*da0073e9SAndroid Build Coastguard Worker ] 1335*da0073e9SAndroid Build Coastguard Worker else: 1336*da0073e9SAndroid Build Coastguard Worker outputs = [input.cuda(0), torch.empty_like(input, device=1)] 1337*da0073e9SAndroid Build Coastguard Worker results = comm.broadcast(input, out=outputs) 1338*da0073e9SAndroid Build Coastguard Worker for r, o in zip(results, outputs): 1339*da0073e9SAndroid Build Coastguard Worker self.assertIs(r, o) 1340*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(results): 1341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.get_device(), i) 1342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, input) 1343*da0073e9SAndroid Build Coastguard Worker # test error msg 1344*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1345*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Exactly one of 'devices' and 'out'" 1346*da0073e9SAndroid Build Coastguard Worker ): 1347*da0073e9SAndroid Build Coastguard Worker comm.broadcast(input, (0, 1), out=outputs) 1348*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1349*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1350*da0073e9SAndroid Build Coastguard Worker r"Expected all output tensors to be CUDA tensors, but output tensor at index 1", 1351*da0073e9SAndroid Build Coastguard Worker ): 1352*da0073e9SAndroid Build Coastguard Worker comm.broadcast(input, out=[input.cuda(0), input.cpu()]) 1353*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1354*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1355*da0073e9SAndroid Build Coastguard Worker r"Expected all output tensors to have same shape as the source .+ at index 1", 1356*da0073e9SAndroid Build Coastguard Worker ): 1357*da0073e9SAndroid Build Coastguard Worker comm.broadcast(input, out=[input.cuda(0), input.cuda(1).unsqueeze(0)]) 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard Worker def test_broadcast_cpu(self): 1360*da0073e9SAndroid Build Coastguard Worker self._test_broadcast(torch.randn(5, 5)) 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker def test_broadcast_gpu(self): 1363*da0073e9SAndroid Build Coastguard Worker self._test_broadcast(torch.randn(5, 5).cuda()) 1364*da0073e9SAndroid Build Coastguard Worker 1365*da0073e9SAndroid Build Coastguard Worker def _test_broadcast_coalesced(self, tensors, buffer_size): 1366*da0073e9SAndroid Build Coastguard Worker b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors] 1367*da0073e9SAndroid Build Coastguard Worker for (_, bt), t in zip(b_tensors, tensors): 1368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bt.get_device(), 1) 1369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bt, t) 1370*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(bt, type(t)) 1371*da0073e9SAndroid Build Coastguard Worker 1372*da0073e9SAndroid Build Coastguard Worker bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=buffer_size) 1373*da0073e9SAndroid Build Coastguard Worker bc_tensors_t = list(zip(*bc_tensors)) 1374*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_tensors, bc_tensors_t) 1375*da0073e9SAndroid Build Coastguard Worker for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t): 1376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bt.get_device(), bct.get_device()) 1377*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(bct, type(bt)) 1378*da0073e9SAndroid Build Coastguard Worker 1379*da0073e9SAndroid Build Coastguard Worker # check that tensors on device[0] are returned as-is 1380*da0073e9SAndroid Build Coastguard Worker for out_tensors in (b_tensors, bc_tensors_t): 1381*da0073e9SAndroid Build Coastguard Worker for inp_t, (out_t, _) in zip(tensors, out_tensors): 1382*da0073e9SAndroid Build Coastguard Worker self.assertIs(inp_t, out_t) 1383*da0073e9SAndroid Build Coastguard Worker 1384*da0073e9SAndroid Build Coastguard Worker # check that the tensors not on device[0] have different version counters 1385*da0073e9SAndroid Build Coastguard Worker # NOTE [ Version Counter in comm.*_coalesced ] 1386*da0073e9SAndroid Build Coastguard Worker versions = [t._version for _, t in bc_tensors_t] 1387*da0073e9SAndroid Build Coastguard Worker for old_version, (_, t) in zip(versions, bc_tensors_t): 1388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t._version, old_version) 1389*da0073e9SAndroid Build Coastguard Worker t.zero_() 1390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t._version, old_version + 1) 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1393*da0073e9SAndroid Build Coastguard Worker # Note: fails sometimes on the CI, passes on dual gfx906 1394*da0073e9SAndroid Build Coastguard Worker def test_broadcast_coalesced(self): 1395*da0073e9SAndroid Build Coastguard Worker numel = 5 1396*da0073e9SAndroid Build Coastguard Worker num_bytes = numel * 8 1397*da0073e9SAndroid Build Coastguard Worker tensors = [ 1398*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((2, 3), 2, 1, False, "cuda", torch.float64)[0], 1399*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1400*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).cuda(), 1401*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((2, 3), 2, 10, False, "cuda", torch.float64)[0], 1402*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((2, 3), 2, 5, False, "cuda", torch.float64)[0], 1403*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((3, 3), 2, 7, False, "cuda", torch.int64)[0], 1404*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((2, 3), 2, 2, False, "cuda", torch.float32)[0], 1405*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1406*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1407*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((2, 7), 2, 3, False, "cuda", torch.int64)[0], 1408*da0073e9SAndroid Build Coastguard Worker torch.randn(numel * 2).int().cuda(), # int is 2x shorter 1409*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).cuda(), 1410*da0073e9SAndroid Build Coastguard Worker ] 1411*da0073e9SAndroid Build Coastguard Worker self._test_broadcast_coalesced(tensors, num_bytes * 5 // 2) 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1414*da0073e9SAndroid Build Coastguard Worker def test_broadcast_coalesced_dense_only(self): 1415*da0073e9SAndroid Build Coastguard Worker numel = 5 1416*da0073e9SAndroid Build Coastguard Worker num_bytes = numel * 8 1417*da0073e9SAndroid Build Coastguard Worker tensors = [ 1418*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1419*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).cuda(), 1420*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1421*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1422*da0073e9SAndroid Build Coastguard Worker torch.randn(numel * 2).int().cuda(), # int is 2x shorter 1423*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).cuda(), 1424*da0073e9SAndroid Build Coastguard Worker ] 1425*da0073e9SAndroid Build Coastguard Worker self._test_broadcast_coalesced(tensors, num_bytes * 5 // 2) 1426*da0073e9SAndroid Build Coastguard Worker 1427*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1428*da0073e9SAndroid Build Coastguard Worker def test_broadcast_coalesced_empty_tensors(self): 1429*da0073e9SAndroid Build Coastguard Worker tensors = [ 1430*da0073e9SAndroid Build Coastguard Worker torch.tensor([]).byte().cuda(), 1431*da0073e9SAndroid Build Coastguard Worker torch.randn(5).cuda(), 1432*da0073e9SAndroid Build Coastguard Worker torch.randn(5).double().cuda(), 1433*da0073e9SAndroid Build Coastguard Worker ] 1434*da0073e9SAndroid Build Coastguard Worker self._test_broadcast_coalesced(tensors, 256) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1437*da0073e9SAndroid Build Coastguard Worker def test_reduce_add(self): 1438*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 1439*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5) 1440*da0073e9SAndroid Build Coastguard Worker x_cuda = x.cuda(0) 1441*da0073e9SAndroid Build Coastguard Worker y_cuda = y.cuda(1) 1442*da0073e9SAndroid Build Coastguard Worker result = comm.reduce_add((x_cuda, y_cuda)) 1443*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.get_device(), 0) 1444*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.cpu(), x + y) 1445*da0073e9SAndroid Build Coastguard Worker 1446*da0073e9SAndroid Build Coastguard Worker def _test_reduce_add_coalesced(self, tensors, buffer_size): 1447*da0073e9SAndroid Build Coastguard Worker dup_tensors = [tensors, [t.cuda(1) for t in tensors]] 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker r_tensors = [comm.reduce_add(t) for t in zip(*dup_tensors)] 1450*da0073e9SAndroid Build Coastguard Worker for r, t in zip(r_tensors, tensors): 1451*da0073e9SAndroid Build Coastguard Worker self.assertEqualTypeString(r, t) 1452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.coalesce() if r.is_sparse else r, t * 2) 1453*da0073e9SAndroid Build Coastguard Worker 1454*da0073e9SAndroid Build Coastguard Worker rc_tensors = comm.reduce_add_coalesced(dup_tensors, buffer_size=buffer_size) 1455*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r_tensors, rc_tensors) 1456*da0073e9SAndroid Build Coastguard Worker for r, rc in zip(r_tensors, rc_tensors): 1457*da0073e9SAndroid Build Coastguard Worker self.assertEqualTypeString(rc, r) 1458*da0073e9SAndroid Build Coastguard Worker 1459*da0073e9SAndroid Build Coastguard Worker # Since we have both cuda:0 and cuda:1 inputs, the outputs must be new. 1460*da0073e9SAndroid Build Coastguard Worker # We can check that they have different version counters. 1461*da0073e9SAndroid Build Coastguard Worker # NOTE [ Version Counter in comm.*_coalesced ] 1462*da0073e9SAndroid Build Coastguard Worker versions = [t._version for t in rc_tensors] 1463*da0073e9SAndroid Build Coastguard Worker for old_version, t in zip(versions, rc_tensors): 1464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t._version, old_version) 1465*da0073e9SAndroid Build Coastguard Worker t.zero_() 1466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t._version, old_version + 1) 1467*da0073e9SAndroid Build Coastguard Worker 1468*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1469*da0073e9SAndroid Build Coastguard Worker def test_reduce_add_coalesced(self): 1470*da0073e9SAndroid Build Coastguard Worker numel = 5 1471*da0073e9SAndroid Build Coastguard Worker num_bytes = numel * 8 1472*da0073e9SAndroid Build Coastguard Worker tensors = [ 1473*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((2, 3), 2, 1, False, "cuda", torch.float64)[0], 1474*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1475*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).cuda(), 1476*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((2, 3), 2, 10, False, "cuda", torch.float64)[0], 1477*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((2, 3), 2, 5, False, "cuda", torch.float64)[0], 1478*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((3, 3), 2, 7, False, "cuda", torch.int64)[0], 1479*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((2, 3), 2, 2, False, "cuda", torch.float32)[0], 1480*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1481*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1482*da0073e9SAndroid Build Coastguard Worker self.genSparseTensor((2, 7), 2, 3, False, "cuda", torch.int64)[0], 1483*da0073e9SAndroid Build Coastguard Worker torch.randn(numel * 2).int().cuda(), # int is 2x shorter 1484*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).cuda(), 1485*da0073e9SAndroid Build Coastguard Worker ] 1486*da0073e9SAndroid Build Coastguard Worker self._test_reduce_add_coalesced(tensors, num_bytes * 5 // 2) 1487*da0073e9SAndroid Build Coastguard Worker 1488*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1489*da0073e9SAndroid Build Coastguard Worker def test_reduce_add_coalesced_dense_only(self): 1490*da0073e9SAndroid Build Coastguard Worker numel = 5 1491*da0073e9SAndroid Build Coastguard Worker num_bytes = numel * 8 1492*da0073e9SAndroid Build Coastguard Worker tensors = [ 1493*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1494*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).cuda(), 1495*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1496*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).long().cuda(), 1497*da0073e9SAndroid Build Coastguard Worker torch.randn(numel * 2).int().cuda(), # int is 2x shorter 1498*da0073e9SAndroid Build Coastguard Worker torch.randn(numel).cuda(), 1499*da0073e9SAndroid Build Coastguard Worker ] 1500*da0073e9SAndroid Build Coastguard Worker self._test_reduce_add_coalesced(tensors, num_bytes * 5 // 2) 1501*da0073e9SAndroid Build Coastguard Worker 1502*da0073e9SAndroid Build Coastguard Worker def _test_scatter(self, input, chunk_sizes=None, dim=0): 1503*da0073e9SAndroid Build Coastguard Worker if not TEST_MULTIGPU: 1504*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("only one GPU detected") 1505*da0073e9SAndroid Build Coastguard Worker if chunk_sizes is None: 1506*da0073e9SAndroid Build Coastguard Worker ref_chunk_sizes = tuple(repeat(input.size(dim) // 2, 2)) 1507*da0073e9SAndroid Build Coastguard Worker else: 1508*da0073e9SAndroid Build Coastguard Worker ref_chunk_sizes = chunk_sizes 1509*da0073e9SAndroid Build Coastguard Worker 1510*da0073e9SAndroid Build Coastguard Worker # test regular 1511*da0073e9SAndroid Build Coastguard Worker result = comm.scatter(input, (0, 1), chunk_sizes, dim) 1512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(result), 2) 1513*da0073e9SAndroid Build Coastguard Worker chunk_start = 0 1514*da0073e9SAndroid Build Coastguard Worker for i, r in enumerate(result): 1515*da0073e9SAndroid Build Coastguard Worker chunk_end = chunk_start + ref_chunk_sizes[i] 1516*da0073e9SAndroid Build Coastguard Worker index = [slice(None, None) for _ in range(input.dim())] 1517*da0073e9SAndroid Build Coastguard Worker index[dim] = slice(chunk_start, chunk_end) 1518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, input[tuple(index)], atol=0, rtol=0) 1519*da0073e9SAndroid Build Coastguard Worker chunk_start = chunk_end 1520*da0073e9SAndroid Build Coastguard Worker if r.device == input.device: 1521*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1522*da0073e9SAndroid Build Coastguard Worker r.data_ptr(), input.data_ptr() 1523*da0073e9SAndroid Build Coastguard Worker ) # for target @ same device, a view should be returned 1524*da0073e9SAndroid Build Coastguard Worker 1525*da0073e9SAndroid Build Coastguard Worker # test out 1526*da0073e9SAndroid Build Coastguard Worker out = [torch.empty_like(t) for t in result] 1527*da0073e9SAndroid Build Coastguard Worker result = comm.scatter(input, dim=dim, out=out) 1528*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(result), 2) 1529*da0073e9SAndroid Build Coastguard Worker chunk_start = 0 1530*da0073e9SAndroid Build Coastguard Worker for i, r in enumerate(result): 1531*da0073e9SAndroid Build Coastguard Worker self.assertIs(r, out[i]) 1532*da0073e9SAndroid Build Coastguard Worker chunk_end = chunk_start + ref_chunk_sizes[i] 1533*da0073e9SAndroid Build Coastguard Worker index = [slice(None, None) for _ in range(input.dim())] 1534*da0073e9SAndroid Build Coastguard Worker index[dim] = slice(chunk_start, chunk_end) 1535*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, input[tuple(index)], atol=0, rtol=0) 1536*da0073e9SAndroid Build Coastguard Worker chunk_start = chunk_end 1537*da0073e9SAndroid Build Coastguard Worker 1538*da0073e9SAndroid Build Coastguard Worker # test error msg 1539*da0073e9SAndroid Build Coastguard Worker if chunk_sizes is not None: 1540*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1541*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Expected devices and chunk_sizes to be of same length" 1542*da0073e9SAndroid Build Coastguard Worker ): 1543*da0073e9SAndroid Build Coastguard Worker comm.scatter( 1544*da0073e9SAndroid Build Coastguard Worker input, 1545*da0073e9SAndroid Build Coastguard Worker [0 for _ in range(len(chunk_sizes) + 1)], 1546*da0073e9SAndroid Build Coastguard Worker dim=dim, 1547*da0073e9SAndroid Build Coastguard Worker chunk_sizes=chunk_sizes, 1548*da0073e9SAndroid Build Coastguard Worker ) 1549*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"'devices' must not be specified"): 1550*da0073e9SAndroid Build Coastguard Worker comm.scatter(input, (0, 1), dim=dim, out=out) 1551*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1552*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Expected at least one device to scatter to" 1553*da0073e9SAndroid Build Coastguard Worker ): 1554*da0073e9SAndroid Build Coastguard Worker comm.scatter(input, (), dim=dim) 1555*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1556*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Expected at least one output tensor to scatter to" 1557*da0073e9SAndroid Build Coastguard Worker ): 1558*da0073e9SAndroid Build Coastguard Worker comm.scatter(input, dim=dim, out=[]) 1559*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1560*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1561*da0073e9SAndroid Build Coastguard Worker r"Expected all output tensors to be CUDA tensors, but output tensor at index 0", 1562*da0073e9SAndroid Build Coastguard Worker ): 1563*da0073e9SAndroid Build Coastguard Worker comm.scatter(input, dim=dim, out=([out[0].cpu()] + out[1:])) 1564*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1565*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Output tensor at index 0 has incorrect shape" 1566*da0073e9SAndroid Build Coastguard Worker ): 1567*da0073e9SAndroid Build Coastguard Worker comm.scatter(input, dim=dim, out=([out[0].unsqueeze(0)] + out[1:])) 1568*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1569*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1570*da0073e9SAndroid Build Coastguard Worker r"Total size for output tensors along scatter dim \d+ does not match", 1571*da0073e9SAndroid Build Coastguard Worker ): 1572*da0073e9SAndroid Build Coastguard Worker index = [slice(None, None) for _ in range(input.dim())] 1573*da0073e9SAndroid Build Coastguard Worker index[dim] = slice(1, None) 1574*da0073e9SAndroid Build Coastguard Worker comm.scatter(input, dim=dim, out=([out[0][tuple(index)]] + out[1:])) 1575*da0073e9SAndroid Build Coastguard Worker 1576*da0073e9SAndroid Build Coastguard Worker def test_scatter_cpu(self): 1577*da0073e9SAndroid Build Coastguard Worker self._test_scatter(torch.randn(4, 4), dim=0) 1578*da0073e9SAndroid Build Coastguard Worker 1579*da0073e9SAndroid Build Coastguard Worker def test_scatter_cpu_dim(self): 1580*da0073e9SAndroid Build Coastguard Worker self._test_scatter(torch.randn(4, 4), dim=1) 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker def test_scatter_cpu_neg_dim(self): 1583*da0073e9SAndroid Build Coastguard Worker self._test_scatter(torch.randn(4, 4), dim=-2) 1584*da0073e9SAndroid Build Coastguard Worker 1585*da0073e9SAndroid Build Coastguard Worker def test_scatter_cpu_sizes(self): 1586*da0073e9SAndroid Build Coastguard Worker self._test_scatter(torch.randn(6, 4), chunk_sizes=(2, 4)) 1587*da0073e9SAndroid Build Coastguard Worker 1588*da0073e9SAndroid Build Coastguard Worker def test_scatter_gpu(self): 1589*da0073e9SAndroid Build Coastguard Worker self._test_scatter(torch.randn(4, 4).cuda(), dim=0) 1590*da0073e9SAndroid Build Coastguard Worker 1591*da0073e9SAndroid Build Coastguard Worker def test_scatter_gpu_dim(self): 1592*da0073e9SAndroid Build Coastguard Worker self._test_scatter(torch.randn(4, 4).cuda(), dim=1) 1593*da0073e9SAndroid Build Coastguard Worker 1594*da0073e9SAndroid Build Coastguard Worker def test_scatter_gpu_neg_dim(self): 1595*da0073e9SAndroid Build Coastguard Worker self._test_scatter(torch.randn(4, 4).cuda(), dim=-2) 1596*da0073e9SAndroid Build Coastguard Worker 1597*da0073e9SAndroid Build Coastguard Worker def test_scatter_gpu_sizes(self): 1598*da0073e9SAndroid Build Coastguard Worker self._test_scatter(torch.randn(6, 4).cuda(), chunk_sizes=(2, 4)) 1599*da0073e9SAndroid Build Coastguard Worker 1600*da0073e9SAndroid Build Coastguard Worker def _test_gather(self, dim): 1601*da0073e9SAndroid Build Coastguard Worker if not TEST_MULTIGPU: 1602*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("only one GPU detected") 1603*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 5, device=0) 1604*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 5, device=1) 1605*da0073e9SAndroid Build Coastguard Worker expected_size = list(x.size()) 1606*da0073e9SAndroid Build Coastguard Worker expected_size[dim] += y.size(dim) 1607*da0073e9SAndroid Build Coastguard Worker expected_size = torch.Size(expected_size) 1608*da0073e9SAndroid Build Coastguard Worker 1609*da0073e9SAndroid Build Coastguard Worker destinations = [None, torch.device("cuda:0"), torch.device("cpu")] 1610*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() > 2: 1611*da0073e9SAndroid Build Coastguard Worker destinations.append(torch.device("cuda:2")) 1612*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 1613*da0073e9SAndroid Build Coastguard Worker for destination in destinations: 1614*da0073e9SAndroid Build Coastguard Worker if destination is None: 1615*da0073e9SAndroid Build Coastguard Worker expected_device = torch.device("cuda", torch.cuda.current_device()) 1616*da0073e9SAndroid Build Coastguard Worker else: 1617*da0073e9SAndroid Build Coastguard Worker expected_device = destination 1618*da0073e9SAndroid Build Coastguard Worker for use_out in [True, False]: 1619*da0073e9SAndroid Build Coastguard Worker if use_out: 1620*da0073e9SAndroid Build Coastguard Worker out = torch.empty(expected_size, device=expected_device) 1621*da0073e9SAndroid Build Coastguard Worker result = comm.gather((x, y), dim, out=out) 1622*da0073e9SAndroid Build Coastguard Worker self.assertIs(out, result) 1623*da0073e9SAndroid Build Coastguard Worker else: 1624*da0073e9SAndroid Build Coastguard Worker result = comm.gather((x, y), dim, destination=destination) 1625*da0073e9SAndroid Build Coastguard Worker 1626*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.device, expected_device) 1627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.size(), expected_size) 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Worker index = [slice(None, None), slice(None, None)] 1630*da0073e9SAndroid Build Coastguard Worker index[dim] = slice(0, x.size(dim)) 1631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result[tuple(index)], x) 1632*da0073e9SAndroid Build Coastguard Worker index[dim] = slice(x.size(dim), x.size(dim) + y.size(dim)) 1633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result[tuple(index)], y) 1634*da0073e9SAndroid Build Coastguard Worker 1635*da0073e9SAndroid Build Coastguard Worker # test error msg 1636*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1637*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"'destination' must not be specified" 1638*da0073e9SAndroid Build Coastguard Worker ): 1639*da0073e9SAndroid Build Coastguard Worker comm.gather( 1640*da0073e9SAndroid Build Coastguard Worker (x, y), 1641*da0073e9SAndroid Build Coastguard Worker dim, 1642*da0073e9SAndroid Build Coastguard Worker destination="cpu", 1643*da0073e9SAndroid Build Coastguard Worker out=torch.empty(expected_size, device="cpu"), 1644*da0073e9SAndroid Build Coastguard Worker ) 1645*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1646*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Expected at least one tensor to gather from" 1647*da0073e9SAndroid Build Coastguard Worker ): 1648*da0073e9SAndroid Build Coastguard Worker comm.gather(()) 1649*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1650*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Expected all input tensors to be CUDA tensors, " 1651*da0073e9SAndroid Build Coastguard Worker ): 1652*da0073e9SAndroid Build Coastguard Worker comm.gather((x.cpu(), y)) 1653*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1654*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1655*da0073e9SAndroid Build Coastguard Worker r"Expected all input tensors to have the same number of dimensions", 1656*da0073e9SAndroid Build Coastguard Worker ): 1657*da0073e9SAndroid Build Coastguard Worker comm.gather((x, y.unsqueeze(0))) 1658*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1659*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Input tensor at index 1 has invalid shape" 1660*da0073e9SAndroid Build Coastguard Worker ): 1661*da0073e9SAndroid Build Coastguard Worker if dim in [0, -2]: 1662*da0073e9SAndroid Build Coastguard Worker comm.gather((x, y[:, 1:]), dim=dim) 1663*da0073e9SAndroid Build Coastguard Worker elif dim in [1, -1]: 1664*da0073e9SAndroid Build Coastguard Worker comm.gather((x, y[1:, :]), dim=dim) 1665*da0073e9SAndroid Build Coastguard Worker 1666*da0073e9SAndroid Build Coastguard Worker def test_gather(self): 1667*da0073e9SAndroid Build Coastguard Worker self._test_gather(0) 1668*da0073e9SAndroid Build Coastguard Worker 1669*da0073e9SAndroid Build Coastguard Worker def test_gather_dim(self): 1670*da0073e9SAndroid Build Coastguard Worker self._test_gather(1) 1671*da0073e9SAndroid Build Coastguard Worker 1672*da0073e9SAndroid Build Coastguard Worker def test_gather_neg_dim(self): 1673*da0073e9SAndroid Build Coastguard Worker self._test_gather(-1) 1674*da0073e9SAndroid Build Coastguard Worker 1675*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1676*da0073e9SAndroid Build Coastguard Worker def test_memory_format_scatter_gather(self): 1677*da0073e9SAndroid Build Coastguard Worker nhwc = torch.randn((10, 3, 32, 32), device="cpu").contiguous( 1678*da0073e9SAndroid Build Coastguard Worker memory_format=torch.channels_last 1679*da0073e9SAndroid Build Coastguard Worker ) 1680*da0073e9SAndroid Build Coastguard Worker results = torch.cuda.comm.scatter(nhwc, (0, 1), None, 0) 1681*da0073e9SAndroid Build Coastguard Worker for result in results: 1682*da0073e9SAndroid Build Coastguard Worker self.assertFalse(result.is_contiguous()) 1683*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.is_contiguous(memory_format=torch.channels_last)) 1684*da0073e9SAndroid Build Coastguard Worker 1685*da0073e9SAndroid Build Coastguard Worker gathered = torch.cuda.comm.gather(results) 1686*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gathered.is_contiguous(memory_format=torch.channels_last)) 1687*da0073e9SAndroid Build Coastguard Worker 1688*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs") 1689*da0073e9SAndroid Build Coastguard Worker def test_scatter_namedtuple(self): 1690*da0073e9SAndroid Build Coastguard Worker # tests ability to scatter namedtuples and retrieve a list where each 1691*da0073e9SAndroid Build Coastguard Worker # element is of the expected namedtuple type. 1692*da0073e9SAndroid Build Coastguard Worker fields = ("a", "b") 1693*da0073e9SAndroid Build Coastguard Worker TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", fields) 1694*da0073e9SAndroid Build Coastguard Worker num_gpus = torch.cuda.device_count() 1695*da0073e9SAndroid Build Coastguard Worker a = torch.rand(num_gpus * 2, device=0) 1696*da0073e9SAndroid Build Coastguard Worker b = torch.rand(num_gpus * 2, device=0) 1697*da0073e9SAndroid Build Coastguard Worker a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] 1698*da0073e9SAndroid Build Coastguard Worker b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] 1699*da0073e9SAndroid Build Coastguard Worker 1700*da0073e9SAndroid Build Coastguard Worker inp = TestNamedTupleInput_0(a, b) 1701*da0073e9SAndroid Build Coastguard Worker target_gpus = [torch.device(i) for i in range(num_gpus)] 1702*da0073e9SAndroid Build Coastguard Worker scatter_out = scatter_gather.scatter(inp, target_gpus) 1703*da0073e9SAndroid Build Coastguard Worker 1704*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(scatter_out): 1705*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(x, type(inp))) 1706*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x._fields, fields) 1707*da0073e9SAndroid Build Coastguard Worker expected_a = a_tensors_for_gpu[i] 1708*da0073e9SAndroid Build Coastguard Worker expected_b = b_tensors_for_gpu[i] 1709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_a, x.a) 1710*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_b, x.b) 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker class TestNamedTupleInput_1(NamedTuple): 1713*da0073e9SAndroid Build Coastguard Worker a: torch.tensor 1714*da0073e9SAndroid Build Coastguard Worker b: torch.tensor 1715*da0073e9SAndroid Build Coastguard Worker 1716*da0073e9SAndroid Build Coastguard Worker a = torch.rand(num_gpus * 2, device=0) 1717*da0073e9SAndroid Build Coastguard Worker b = torch.rand(num_gpus * 2, device=0) 1718*da0073e9SAndroid Build Coastguard Worker a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] 1719*da0073e9SAndroid Build Coastguard Worker b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] 1720*da0073e9SAndroid Build Coastguard Worker inp = TestNamedTupleInput_1(a, b) 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker scatter_out = scatter_gather.scatter(inp, target_gpus) 1723*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(scatter_out): 1724*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(x, type(inp))) 1725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x._fields, fields) 1726*da0073e9SAndroid Build Coastguard Worker expected_a = a_tensors_for_gpu[i] 1727*da0073e9SAndroid Build Coastguard Worker expected_b = b_tensors_for_gpu[i] 1728*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_a, x.a) 1729*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_b, x.b) 1730*da0073e9SAndroid Build Coastguard Worker 1731*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs") 1732*da0073e9SAndroid Build Coastguard Worker def test_gather_namedtuple(self): 1733*da0073e9SAndroid Build Coastguard Worker # tests ability to gather a list of namedtuples and return a namedtuple where each 1734*da0073e9SAndroid Build Coastguard Worker # element is of the expected tensor type. 1735*da0073e9SAndroid Build Coastguard Worker fields = ["a", "b"] 1736*da0073e9SAndroid Build Coastguard Worker TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", fields) 1737*da0073e9SAndroid Build Coastguard Worker 1738*da0073e9SAndroid Build Coastguard Worker num_gpus = torch.cuda.device_count() 1739*da0073e9SAndroid Build Coastguard Worker a = torch.rand(num_gpus * 2, device=0) 1740*da0073e9SAndroid Build Coastguard Worker b = torch.rand(num_gpus * 2, device=1) 1741*da0073e9SAndroid Build Coastguard Worker out1 = TestNamedTupleInput_0(a, b) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker a = torch.rand(num_gpus * 2, device=1) 1744*da0073e9SAndroid Build Coastguard Worker b = torch.rand(num_gpus * 2, device=0) 1745*da0073e9SAndroid Build Coastguard Worker out2 = TestNamedTupleInput_0(a, b) 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker outputs = [out1, out2] 1748*da0073e9SAndroid Build Coastguard Worker 1749*da0073e9SAndroid Build Coastguard Worker out = scatter_gather.gather(outputs, "cpu") # test on CPU 1750*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(out): 1751*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(x, type(out2[-1]))) # x must be a tensor 1752*da0073e9SAndroid Build Coastguard Worker cat = torch.cat((outputs[0][i].to("cpu"), outputs[1][i].to("cpu"))) 1753*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(x, cat)) 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker out = scatter_gather.gather(outputs, 0) # test on GPU 1756*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(out): 1757*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(x, type(out2[-1]))) 1758*da0073e9SAndroid Build Coastguard Worker cat = torch.cat((outputs[0][i].to(0), outputs[1][i].to(0))) 1759*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(x, cat)) 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker class TestNamedTupleInput_1(NamedTuple): 1762*da0073e9SAndroid Build Coastguard Worker a: torch.tensor 1763*da0073e9SAndroid Build Coastguard Worker b: torch.tensor 1764*da0073e9SAndroid Build Coastguard Worker 1765*da0073e9SAndroid Build Coastguard Worker a = torch.rand(num_gpus * 2, device=0) 1766*da0073e9SAndroid Build Coastguard Worker b = torch.rand(num_gpus * 2, device=1) 1767*da0073e9SAndroid Build Coastguard Worker out1 = TestNamedTupleInput_1(a, b) 1768*da0073e9SAndroid Build Coastguard Worker 1769*da0073e9SAndroid Build Coastguard Worker a = torch.rand(num_gpus * 2, device=1) 1770*da0073e9SAndroid Build Coastguard Worker b = torch.rand(num_gpus * 2, device=0) 1771*da0073e9SAndroid Build Coastguard Worker out2 = TestNamedTupleInput_1(a, b) 1772*da0073e9SAndroid Build Coastguard Worker 1773*da0073e9SAndroid Build Coastguard Worker outputs = [out1, out2] 1774*da0073e9SAndroid Build Coastguard Worker 1775*da0073e9SAndroid Build Coastguard Worker out = scatter_gather.gather(outputs, 0) # test on GPU 1776*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(out): 1777*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(x, type(out2[-1]))) 1778*da0073e9SAndroid Build Coastguard Worker cat = torch.cat((outputs[0][i].to(0), outputs[1][i].to(0))) 1779*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(x, cat)) 1780*da0073e9SAndroid Build Coastguard Worker 1781*da0073e9SAndroid Build Coastguard Worker out = scatter_gather.gather(outputs, "cpu") # test on CPU 1782*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(out): 1783*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(x, type(out2[-1]))) 1784*da0073e9SAndroid Build Coastguard Worker cat = torch.cat((outputs[0][i].to("cpu"), outputs[1][i].to("cpu"))) 1785*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(x, cat)) 1786*da0073e9SAndroid Build Coastguard Worker 1787*da0073e9SAndroid Build Coastguard Worker 1788*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCudaMultiGPU) 1789*da0073e9SAndroid Build Coastguard Worker 1790*da0073e9SAndroid Build Coastguard Worker 1791*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1792*da0073e9SAndroid Build Coastguard Worker run_tests() 1793