1# Owner(s): ["module: cuda"] 2 3import collections 4import contextlib 5import ctypes 6import gc 7import io 8import queue 9import sys 10import tempfile 11import threading 12import unittest 13from itertools import chain, repeat 14from typing import NamedTuple, Union 15 16import torch 17import torch.cuda.comm as comm 18from torch.nn.parallel import scatter_gather 19from torch.testing._internal.common_cuda import ( 20 _create_scaling_case, 21 _create_scaling_models_optimizers, 22 TEST_MULTIGPU, 23) 24from torch.testing._internal.common_utils import ( 25 get_cycles_per_ms, 26 instantiate_parametrized_tests, 27 IS_JETSON, 28 IS_REMOTE_GPU, 29 IS_SANDCASTLE, 30 NoTest, 31 run_tests, 32 serialTest, 33 skipCUDANonDefaultStreamIf, 34 skipIfRocm, 35 TEST_CUDA, 36 TestCase, 37) 38 39 40TEST_CUDAMALLOCASYNC = TEST_CUDA and ( 41 torch.cuda.get_allocator_backend() == "cudaMallocAsync" 42) 43 44if not TEST_CUDA: 45 print("CUDA not available, skipping tests", file=sys.stderr) 46 TestCase = NoTest # noqa: F811 47 48 49class TestCudaMultiGPU(TestCase): 50 FIFTY_MIL_CYCLES = 50000000 51 52 def _check_memory_stat_consistency(self): 53 snapshot = torch.cuda.memory_snapshot() 54 55 expected_each_device = collections.defaultdict( 56 lambda: collections.defaultdict(int) 57 ) 58 59 for segment in snapshot: 60 expandable = segment["is_expandable"] 61 expected = expected_each_device[segment["device"]] 62 pool_str = segment["segment_type"] + "_pool" 63 64 if not expandable: 65 expected["segment.all.current"] += 1 66 expected["segment." + pool_str + ".current"] += 1 67 68 expected["allocated_bytes.all.current"] += segment["allocated_size"] 69 expected["allocated_bytes." + pool_str + ".current"] += segment[ 70 "allocated_size" 71 ] 72 73 expected["reserved_bytes.all.current"] += segment["total_size"] 74 expected["reserved_bytes." + pool_str + ".current"] += segment["total_size"] 75 76 expected["active_bytes.all.current"] += segment["active_size"] 77 expected["active_bytes." + pool_str + ".current"] += segment["active_size"] 78 79 expected["requested_bytes.all.current"] += segment["requested_size"] 80 expected["requested_bytes." + pool_str + ".current"] += segment[ 81 "requested_size" 82 ] 83 84 sum_requested = 0 85 is_split = len(segment["blocks"]) > 1 86 for block in segment["blocks"]: 87 if block["state"] == "active_allocated": 88 expected["allocation.all.current"] += 1 89 expected["allocation." + pool_str + ".current"] += 1 90 91 if block["state"].startswith("active_"): 92 sum_requested += block["requested_size"] 93 expected["active.all.current"] += 1 94 expected["active." + pool_str + ".current"] += 1 95 96 if block["state"] == "inactive" and is_split and not expandable: 97 expected["inactive_split.all.current"] += 1 98 expected["inactive_split." + pool_str + ".current"] += 1 99 expected["inactive_split_bytes.all.current"] += block["size"] 100 expected["inactive_split_bytes." + pool_str + ".current"] += block[ 101 "size" 102 ] 103 104 self.assertEqual(sum_requested, segment["requested_size"]) 105 106 for device, expected in expected_each_device.items(): 107 stats = torch.cuda.memory_stats(device) 108 for k, v in expected.items(): 109 self.assertEqual(v, stats[k]) 110 111 def test_cuda_synchronize(self): 112 torch.cuda.synchronize() 113 torch.cuda.synchronize("cuda") 114 torch.cuda.synchronize("cuda:0") 115 torch.cuda.synchronize(0) 116 torch.cuda.synchronize(torch.device("cuda:0")) 117 118 if TEST_MULTIGPU: 119 torch.cuda.synchronize("cuda:1") 120 torch.cuda.synchronize(1) 121 torch.cuda.synchronize(torch.device("cuda:1")) 122 123 with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"): 124 torch.cuda.synchronize(torch.device("cpu")) 125 126 with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"): 127 torch.cuda.synchronize("cpu") 128 129 @staticmethod 130 def _test_memory_stats_generator(self, device=None, N=35): 131 if device is None: 132 device = torch.cuda.current_device() 133 134 m0 = torch.cuda.memory_allocated(device) 135 last_m_arr = [torch.cuda.memory_allocated(device)] 136 max_m_arr = [torch.cuda.max_memory_allocated(device)] 137 last_r_arr = [torch.cuda.memory_reserved(device)] 138 max_r_arr = [torch.cuda.max_memory_reserved(device)] 139 140 def alloc(*size): 141 with torch.cuda.device(device): 142 # NOTE: do **not** use methods that can have additional 143 # memory overhead, e.g., inplace random sampling methods. 144 # they can leave some memory occupied even after being 145 # deallocated, e.g., initialized RNG state, causing some 146 # memory checks below to fail. 147 return torch.cuda.FloatTensor(*size) 148 149 def assert_change(comp=1, empty_cache=False, reset_peak=False): 150 # comp > 0: increased 151 # comp = 0: equal 152 # comp < 0: decreased 153 new_m = torch.cuda.memory_allocated(device) 154 new_max_m = torch.cuda.max_memory_allocated(device) 155 if comp > 0: 156 self.assertGreater(new_m, last_m_arr[0]) 157 elif comp < 0: 158 self.assertLess(new_m, last_m_arr[0]) 159 else: 160 self.assertEqual(new_m, last_m_arr[0]) 161 self.assertLessEqual(new_m, new_max_m) 162 self.assertGreaterEqual(new_max_m, max_m_arr[0]) 163 last_m_arr[0] = new_m 164 max_m_arr[0] = new_max_m 165 166 new_r = torch.cuda.memory_reserved(device) 167 new_max_r = torch.cuda.max_memory_reserved(device) 168 # emptying cache may happen (due to allocation or empty_cache), so 169 # we can't assert new_c >= last_c 170 self.assertLessEqual(new_r, new_max_r) 171 self.assertGreaterEqual(new_max_r, max_r_arr[0]) 172 last_r_arr[0] = new_r 173 max_r_arr[0] = new_max_r 174 175 stat_key_n_sync = "num_sync_all_streams" 176 stat_key_n_alloc = "num_device_alloc" 177 stat_key_n_free = "num_device_free" 178 if empty_cache: 179 num_sync_1 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1) 180 self.assertGreaterEqual(num_sync_1, 0) 181 num_alloc_1 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1) 182 # if current memory usage is greater than zero we must have 183 # allocated something 184 self.assertGreaterEqual(num_alloc_1, 0 if new_m == 0 else 1) 185 num_free_1 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1) 186 self.assertGreaterEqual(num_free_1, 0) 187 # empty_cache will enforce the call of release_cached_blocks 188 torch.cuda.empty_cache() 189 num_sync_2 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1) 190 self.assertEqual(num_sync_1 + 1, num_sync_2) 191 num_alloc_2 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1) 192 self.assertGreaterEqual(num_alloc_2, num_alloc_1) 193 num_free_2 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1) 194 self.assertGreaterEqual(num_free_2, num_free_1) 195 196 new_r = torch.cuda.memory_reserved(device) 197 new_max_r = torch.cuda.max_memory_reserved(device) 198 self.assertLessEqual(new_r, last_r_arr[0]) 199 self.assertLessEqual(new_r, new_max_r) 200 self.assertEqual(new_max_r, max_r_arr[0]) 201 last_r_arr[0] = new_r 202 203 if reset_peak: 204 torch.cuda.reset_peak_memory_stats(device) 205 self.assertEqual(torch.cuda.memory_allocated(device), last_m_arr[0]) 206 self.assertEqual(torch.cuda.max_memory_allocated(device), last_m_arr[0]) 207 max_m_arr[0] = last_m_arr[0] 208 self.assertEqual(torch.cuda.memory_reserved(device), last_r_arr[0]) 209 self.assertEqual(torch.cuda.max_memory_reserved(device), last_r_arr[0]) 210 max_r_arr[0] = last_r_arr[0] 211 212 assert_change(0) 213 assert_change(0, reset_peak=True) 214 assert_change(0, empty_cache=True) 215 assert_change(0, reset_peak=True) 216 assert_change(0) 217 yield 218 219 tensors1 = [alloc(1), alloc(10, 20), alloc(200, 300, 2000)] 220 m1 = torch.cuda.memory_allocated(device) 221 assert_change(1) 222 yield 223 224 tensors2 = [] 225 226 for i in range(1, int(N / 2) + 1): 227 # small ones 228 tensors2.append(alloc(i, i * 4)) 229 assert_change(1) 230 yield 231 232 for i in range(5, int(N / 2) + 5): 233 # large ones 234 tensors2.append(alloc(i, i * 7, i * 9, i * 11)) 235 assert_change(1, reset_peak=(i % 2 == 0)) 236 yield 237 238 tensors2.append(alloc(0, 0, 0)) 239 assert_change(0) 240 yield 241 242 permute = [] 243 for i in torch.randperm(len(tensors2)): 244 permute.append(tensors2[i]) 245 assert_change(0) 246 yield 247 248 del tensors2 249 assert_change(0) 250 yield 251 tensors2 = permute 252 assert_change(0) 253 yield 254 del permute 255 assert_change(0, reset_peak=True) 256 yield 257 258 for i in range(int(N / 2)): 259 x = tensors2[i].numel() 260 del tensors2[i] 261 assert_change(-x) # in case that tensors2[i] is empty 262 yield 263 264 for i in range(2, int(2 * N / 3) + 2): 265 tensors2.append(alloc(i, i * 3, i * 8)) 266 assert_change(1) 267 yield 268 269 del tensors2 270 assert_change(-1, reset_peak=True) 271 assert_change(0) 272 self.assertEqual(torch.cuda.memory_allocated(device), m1) 273 yield True 274 275 del tensors1 276 assert_change(-1, reset_peak=True) 277 self.assertEqual(torch.cuda.memory_allocated(device), m0) 278 279 # test empty_cache and reset_peak 280 assert_change(0, empty_cache=True) 281 assert_change(0, reset_peak=True) 282 283 @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled") 284 @serialTest() 285 def test_memory_stats(self): 286 gc.collect() 287 torch.cuda.empty_cache() 288 for _ in self._test_memory_stats_generator(self): 289 self._check_memory_stat_consistency() 290 291 @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled") 292 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 293 def test_memory_stats_multigpu(self): 294 # advance a generator with a end flag 295 def advance(gen, end): 296 if not end: 297 try: 298 next(gen) 299 except StopIteration: 300 end = True 301 return end 302 303 # interlace 304 torch.cuda.empty_cache() 305 gen0 = self._test_memory_stats_generator(self, device="cuda:0", N=35) 306 gen1 = self._test_memory_stats_generator( 307 self, device=torch.device("cuda:1"), N=35 308 ) 309 end0 = end1 = False 310 while not (end0 and end1): 311 end0 = advance(gen0, end0) 312 end1 = advance(gen1, end1) 313 314 # semi-random order 315 torch.cuda.empty_cache() 316 gen0 = self._test_memory_stats_generator(self, device=0, N=35) 317 gen1 = self._test_memory_stats_generator( 318 self, device=torch.device("cuda:1"), N=35 319 ) 320 end0 = end1 = False 321 322 while not (end0 and end1): 323 end0 = advance(gen0, end0) 324 if not end0: 325 gen1_max_times = torch.LongTensor(1).random_(0, 3)[0] 326 else: 327 gen1_max_times = torch.inf 328 t = 0 329 while t < gen1_max_times and not end1: 330 end1 = advance(gen1, end1) 331 t += 1 332 333 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 334 def test_autogpu(self): 335 x = torch.randn(5, 5).cuda() 336 y = torch.randn(5, 5).cuda() 337 self.assertEqual(x.get_device(), 0) 338 self.assertEqual(x.get_device(), 0) 339 with torch.cuda.device(1): 340 z = torch.randn(5, 5).cuda() 341 self.assertEqual(z.get_device(), 1) 342 q = x.add(y) 343 self.assertEqual(q.get_device(), 0) 344 w = torch.randn(5, 5).cuda() 345 self.assertEqual(w.get_device(), 1) 346 self.assertEqual(y.cuda().get_device(), 1) 347 z = z.cuda() 348 self.assertEqual(z.get_device(), 0) 349 350 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 351 def test_new(self): 352 x = torch.randn(3, 3).cuda() 353 self.assertEqual(x.new([0, 1, 2]).get_device(), 0) 354 self.assertEqual(x.new([0, 1, 2], device=1).get_device(), 1) 355 356 with torch.cuda.device(1): 357 self.assertEqual(x.new([0, 1, 2]).get_device(), 0) 358 self.assertEqual(x.new([0, 1, 2], device=1).get_device(), 1) 359 360 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 361 def test_copy_device(self): 362 x = torch.randn(5, 5).cuda() 363 with torch.cuda.device(1): 364 y = x.cuda() 365 self.assertEqual(y.get_device(), 1) 366 self.assertIs(y.cuda(), y) 367 z = y.cuda(0) 368 self.assertEqual(z.get_device(), 0) 369 self.assertIs(z.cuda(0), z) 370 371 x = torch.randn(5, 5) 372 with torch.cuda.device(1): 373 y = x.cuda() 374 self.assertEqual(y.get_device(), 1) 375 self.assertIs(y.cuda(), y) 376 z = y.cuda(0) 377 378 self.assertEqual(z.get_device(), 0) 379 self.assertIs(z.cuda(0), z) 380 381 def _test_copy_sync_current_stream(self, x, y): 382 x_plus_one = x + 1 383 s0 = torch.cuda.Stream(device=x.device) 384 s1 = torch.cuda.Stream(device=y.device) 385 s2 = torch.cuda.Stream(device=x.device) 386 s3 = torch.cuda.Stream(device=y.device) 387 388 # same dst stream different src streams 389 with torch.cuda.stream(s0): 390 torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 391 with torch.cuda.stream(s1): 392 y.copy_(x_plus_one) 393 394 with torch.cuda.stream(s2), torch.cuda.stream(s1): 395 y.copy_(x) 396 397 s1.synchronize() 398 # The copy() is synchronized on the current streams of both src and dst. 399 # In the above test, the _sleep() op on s0 will not block the copy() on 400 # s2, but both copies are synchronized on s1 in the dst device. Hence, 401 # x is copied to y after x_plus_one is copied to y. If x and y are on 402 # the same device, both copy() ops are synchronized on s1. 403 self.assertEqual(y, x) 404 405 # same src stream different dst streams 406 with torch.cuda.stream(s1): 407 torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 408 with torch.cuda.stream(s0): 409 y.copy_(x_plus_one) 410 411 with torch.cuda.stream(s3), torch.cuda.stream(s0): 412 y.copy_(x) 413 414 s0.synchronize() 415 # Similarly, both copy() ops are synchronized on s0. 416 self.assertEqual(y, x) 417 418 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 419 def test_copy_streams(self): 420 d0 = torch.device("cuda:0") 421 x0 = torch.zeros(5, 5, device=d0) 422 423 d1 = torch.device("cuda:1") 424 x1 = torch.zeros(5, 5, device=d1) 425 self._test_copy_sync_current_stream(x0, x1) 426 427 x2 = torch.zeros(5, 5, device=d0) 428 self._test_copy_sync_current_stream(x0, x2) 429 430 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 431 def test_cat_autogpu(self): 432 x = torch.randn(4, 4).cuda(1) 433 y = torch.randn(4, 4).cuda(1) 434 z = torch.cat([x, y], 0) 435 self.assertEqual(z.get_device(), x.get_device()) 436 437 @unittest.skipIf(torch.cuda.device_count() >= 10, "Loading a cuda:9 tensor") 438 def test_load_nonexistent_device(self): 439 # Setup: create a serialized file object with a 'cuda:9' restore location 440 tensor = torch.randn(2, device="cuda") 441 buf = io.BytesIO() 442 torch.save(tensor, buf) 443 # NB: this might not work in the future if serialization changes 444 buf = io.BytesIO(buf.getvalue().replace(b"cuda:0", b"cuda:9")) 445 446 msg = r"Attempting to deserialize object on CUDA device 9" 447 with self.assertRaisesRegex(RuntimeError, msg): 448 _ = torch.load(buf) 449 450 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 451 def test_multigpu_serialization_remap(self): 452 x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)] 453 454 def gpu_remap(storage, location): 455 if location == "cuda:1": 456 return storage.cuda(0) 457 458 with tempfile.NamedTemporaryFile() as f: 459 torch.save(x, f) 460 f.seek(0) 461 x_copy = torch.load(f, map_location=gpu_remap) 462 463 for original, copy in zip(x, x_copy): 464 self.assertEqual(copy, original) 465 self.assertIs(type(copy), type(original)) 466 self.assertEqual(copy.get_device(), 0) 467 468 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 469 def test_multigpu_serialization_remap_dict(self): 470 x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)] 471 with tempfile.NamedTemporaryFile() as f: 472 torch.save(x, f) 473 f.seek(0) 474 x_copy = torch.load(f, map_location={"cuda:1": "cuda:0"}) 475 for original, copy in zip(x, x_copy): 476 self.assertEqual(copy, original) 477 self.assertIs(type(copy), type(original)) 478 self.assertEqual(copy.get_device(), 0) 479 480 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 481 def test_multigpu_storage_clone(self): 482 x = torch.randn(4, 4, device="cuda:1").storage() 483 y = x.clone() 484 self.assertEqual(x.get_device(), y.get_device()) 485 for t in ["byte", "char", "short", "int", "long", "half", "double"]: 486 self.assertEqual(getattr(x, t)().get_device(), x.get_device()) 487 488 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 489 def test_cuda_set_device(self): 490 x = torch.randn(5, 5) 491 with torch.cuda.device(1): 492 self.assertEqual(x.cuda().get_device(), 1) 493 torch.cuda.set_device(0) 494 self.assertEqual(x.cuda().get_device(), 0) 495 with torch.cuda.device(1): 496 self.assertEqual(x.cuda().get_device(), 1) 497 self.assertEqual(x.cuda().get_device(), 0) 498 torch.cuda.set_device(1) 499 self.assertEqual(x.cuda().get_device(), 0) 500 501 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 502 def test_current_stream(self): 503 d0 = torch.device("cuda:0") 504 d1 = torch.device("cuda:1") 505 506 s0 = torch.cuda.current_stream() 507 s1 = torch.cuda.current_stream(device=1) 508 s2 = torch.cuda.current_stream(device=0) 509 510 self.assertEqual(d0, s0.device) 511 self.assertEqual(d1, s1.device) 512 self.assertEqual(d0, s2.device) 513 self.assertEqual(s0, s2) 514 515 with torch.cuda.device(d1): 516 s0 = torch.cuda.current_stream() 517 s1 = torch.cuda.current_stream(1) 518 s2 = torch.cuda.current_stream(d0) 519 520 self.assertEqual(d1, s0.device) 521 self.assertEqual(d1, s1.device) 522 self.assertEqual(d0, s2.device) 523 self.assertEqual(s0, s1) 524 525 with self.assertRaisesRegex(ValueError, "Expected a cuda device, but got: cpu"): 526 torch.cuda.current_stream(torch.device("cpu")) 527 528 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 529 @skipCUDANonDefaultStreamIf(True) 530 def test_default_stream(self): 531 d0 = torch.device("cuda:0") 532 d1 = torch.device("cuda:1") 533 534 with torch.cuda.device(d0): 535 s0 = torch.cuda.default_stream() 536 537 with torch.cuda.device(d1): 538 s1 = torch.cuda.default_stream() 539 540 s2 = torch.cuda.default_stream(device=0) 541 s3 = torch.cuda.default_stream(d1) 542 543 self.assertEqual(d0, s0.device) 544 self.assertEqual(d1, s1.device) 545 self.assertEqual(d0, s2.device) 546 self.assertEqual(d1, s3.device) 547 self.assertEqual(s0, s2) 548 self.assertEqual(s1, s3) 549 550 with torch.cuda.device(d0): 551 self.assertEqual(torch.cuda.current_stream(), s0) 552 553 with torch.cuda.device(d1): 554 self.assertEqual(torch.cuda.current_stream(), s1) 555 556 with self.assertRaisesRegex(ValueError, "Expected a cuda device, but got: cpu"): 557 torch.cuda.default_stream(torch.device("cpu")) 558 559 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 560 def test_stream_event_device(self): 561 d0 = torch.device("cuda:0") 562 d1 = torch.device("cuda:1") 563 e0 = torch.cuda.Event() 564 565 self.assertEqual(None, e0.device) 566 567 with torch.cuda.device(d0): 568 s0 = torch.cuda.current_stream() 569 s0.record_event(e0) 570 571 with torch.cuda.device(d1): 572 s1 = torch.cuda.Stream() 573 e1 = s1.record_event() 574 575 self.assertEqual(s0.device, torch.device("cuda:0")) 576 self.assertEqual(e0.device, torch.device("cuda:0")) 577 self.assertEqual(s1.device, torch.device("cuda:1")) 578 self.assertEqual(e1.device, torch.device("cuda:1")) 579 580 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 581 def test_stream_context(self): 582 s0 = torch.cuda.current_stream() 583 s1 = torch.cuda.Stream(device=1) 584 s2 = torch.cuda.Stream(device=0) 585 586 with torch.cuda.device(s1.device): 587 prev_stream_on_cuda1 = torch.cuda.current_stream() 588 589 self.assertEqual(torch.cuda.current_stream(), s0) 590 self.assertEqual(0, torch.cuda.current_device()) 591 with torch.cuda.stream(s1): 592 self.assertEqual(torch.cuda.current_stream(), s1) 593 self.assertEqual(1, torch.cuda.current_device()) 594 with torch.cuda.stream(s2): 595 self.assertEqual(torch.cuda.current_stream(), s2) 596 self.assertEqual(0, torch.cuda.current_device()) 597 with torch.cuda.stream(s0): 598 self.assertEqual(torch.cuda.current_stream(), s0) 599 self.assertEqual(0, torch.cuda.current_device()) 600 self.assertEqual(torch.cuda.current_stream(), s2) 601 self.assertEqual(0, torch.cuda.current_device()) 602 self.assertEqual(torch.cuda.current_stream(), s1) 603 self.assertEqual(1, torch.cuda.current_device()) 604 605 with torch.cuda.device(s1.device): 606 self.assertEqual(prev_stream_on_cuda1, torch.cuda.current_stream()) 607 608 self.assertEqual(torch.cuda.current_stream(), s0) 609 self.assertEqual(0, torch.cuda.current_device()) 610 611 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 612 def test_streams_multi_gpu(self): 613 default_stream = torch.cuda.current_stream() 614 self.assertEqual(default_stream.device, torch.device("cuda:0")) 615 stream = torch.cuda.Stream(device=1) 616 self.assertEqual(stream.device, torch.device("cuda:1")) 617 with torch.cuda.device(1): 618 self.assertEqual(torch.cuda.current_stream().device, torch.device("cuda:1")) 619 self.assertNotEqual(torch.cuda.current_stream(), default_stream) 620 621 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 622 def test_streams_multi_gpu_query(self): 623 d0 = torch.device("cuda:0") 624 d1 = torch.device("cuda:1") 625 torch.cuda.synchronize(d0) 626 torch.cuda.synchronize(d1) 627 628 with torch.cuda.device(d0): 629 s0 = torch.cuda.current_stream() 630 631 with torch.cuda.device(d1): 632 s1 = torch.cuda.current_stream() 633 torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 634 635 self.assertTrue(s0.query()) 636 self.assertFalse(s1.query()) 637 638 with torch.cuda.device(d0): 639 self.assertTrue(s0.query()) 640 self.assertFalse(s1.query()) 641 642 with torch.cuda.device(d1): 643 self.assertTrue(s0.query()) 644 self.assertFalse(s1.query()) 645 646 # deliberately using a different device 647 with torch.cuda.device(d0): 648 s1.synchronize() 649 650 self.assertTrue(s0.query()) 651 self.assertTrue(s1.query()) 652 653 with torch.cuda.device(d0): 654 self.assertTrue(s0.query()) 655 self.assertTrue(s1.query()) 656 657 with torch.cuda.device(d1): 658 self.assertTrue(s0.query()) 659 self.assertTrue(s1.query()) 660 661 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 662 def test_streams_multi_gpu_eq(self): 663 d0 = torch.device("cuda:0") 664 d1 = torch.device("cuda:1") 665 666 with torch.cuda.device(d0): 667 s0 = torch.cuda.current_stream() 668 s1 = torch.cuda.current_stream() 669 670 with torch.cuda.device(d1): 671 s2 = torch.cuda.current_stream() 672 s3 = torch.cuda.current_stream() 673 674 self.assertTrue(s0 == s0) 675 self.assertTrue(s0 == s1) 676 self.assertTrue(s2 == s2) 677 self.assertTrue(s2 == s3) 678 self.assertFalse(s0 == s2) 679 self.assertFalse(s1 == s3) 680 681 self.assertEqual(s0.device, s1.device) 682 self.assertEqual(s0.cuda_stream, s1.cuda_stream) 683 self.assertEqual(s2.device, s3.device) 684 self.assertEqual(s2.cuda_stream, s3.cuda_stream) 685 self.assertNotEqual(s0.device, s3.device) 686 687 self.assertEqual(hash(s0), hash(s1)) 688 self.assertEqual(hash(s2), hash(s3)) 689 self.assertNotEqual(hash(s0), hash(s3)) 690 691 @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 692 def test_streams_priority(self): 693 low, high = torch.cuda.Stream.priority_range() 694 s0 = torch.cuda.Stream(device=0, priority=low) 695 696 self.assertEqual(low, s0.priority) 697 self.assertEqual(torch.device("cuda:0"), s0.device) 698 699 s1 = torch.cuda.Stream(device=1, priority=high) 700 701 self.assertEqual(high, s1.priority) 702 self.assertEqual(torch.device("cuda:1"), s1.device) 703 704 @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 705 def test_tensor_device(self): 706 self.assertEqual(torch.cuda.FloatTensor(1).get_device(), 0) 707 self.assertEqual(torch.cuda.FloatTensor(1, device=1).get_device(), 1) 708 with torch.cuda.device(1): 709 self.assertEqual(torch.cuda.FloatTensor(1).get_device(), 1) 710 self.assertEqual(torch.cuda.FloatTensor(1, device=0).get_device(), 0) 711 self.assertEqual(torch.cuda.FloatTensor(1, device=None).get_device(), 1) 712 713 @staticmethod 714 def _stream_synchronize(self, spin_time_cycles): 715 s = torch.cuda.current_stream() 716 e_tik = torch.cuda.Event(enable_timing=True) 717 e_tok = torch.cuda.Event(enable_timing=True) 718 719 e_tik.record(s) 720 torch.cuda._sleep(spin_time_cycles) 721 e_tok.record(s) 722 s.synchronize() 723 724 self.assertTrue(s.query()) 725 726 # not necessary to check e_tik and e_tok, as elapsed_time would throw 727 # exception if otherwise. 728 return e_tik.elapsed_time(e_tok) 729 730 @staticmethod 731 def _event_synchronize(self, spin_time_cycles): 732 s = torch.cuda.current_stream() 733 e_tik = torch.cuda.Event(enable_timing=True) 734 e_tok = torch.cuda.Event(enable_timing=True) 735 736 e_tik.record(s) 737 torch.cuda._sleep(spin_time_cycles) 738 s.record_event(e_tok) 739 e_tok.synchronize() 740 741 self.assertTrue(s.query()) 742 743 # not necessary to check e_tik and e_tok, as elapsed_time would throw 744 # exception if otherwise. 745 return e_tik.elapsed_time(e_tok) 746 747 @staticmethod 748 def _event_wait(self, spin_time_cycles): 749 s0 = torch.cuda.current_stream() 750 s1 = torch.cuda.Stream() 751 e_tik = torch.cuda.Event(blocking=True, enable_timing=True) 752 e_tok = torch.cuda.Event(blocking=True, enable_timing=True) 753 754 e_tik.record(s0) 755 torch.cuda._sleep(spin_time_cycles - 10) 756 e_sync = torch.cuda.Event(blocking=True) 757 e_sync.record() 758 e_sync.wait(s1) 759 with torch.cuda.stream(s1): 760 torch.cuda._sleep(10) 761 s1.synchronize() 762 e_tok.record() 763 e_tok.synchronize() 764 765 self.assertTrue(s0.query()) 766 self.assertTrue(s1.query()) 767 self.assertTrue(e_sync.query()) 768 769 # not necessary to check e_tik and e_tok, as elapsed_time would throw 770 # exception if otherwise. 771 return e_tik.elapsed_time(e_tok) 772 773 @staticmethod 774 def _test_stream_event_nogil(self, sync_func, p2c, c2p): 775 with torch.cuda.device("cuda:1"): 776 c2p.put(0) 777 p2c.get() 778 c2p.put(sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES)) 779 780 # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190 781 @skipIfRocm 782 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 783 def test_stream_event_nogil(self): 784 for sync_func in [ 785 TestCudaMultiGPU._stream_synchronize, 786 TestCudaMultiGPU._event_synchronize, 787 TestCudaMultiGPU._event_wait, 788 ]: 789 p2c = queue.Queue() 790 c2p = queue.Queue() 791 e_tik = torch.cuda.Event(enable_timing=True) 792 e_tok = torch.cuda.Event(enable_timing=True) 793 794 t = threading.Thread( 795 target=TestCudaMultiGPU._test_stream_event_nogil, 796 args=(self, sync_func, p2c, c2p), 797 ) 798 t.daemon = True 799 t.start() 800 801 c2p.get() 802 with torch.cuda.device("cuda:0"): 803 e_tik.record() 804 p2c.put(0) 805 parent_time = sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES) 806 child_time = c2p.get() 807 e_tok.record() 808 e_tok.synchronize() 809 total_time = e_tik.elapsed_time(e_tok) 810 811 # Without GIL, synchronizations in parent and child threads can 812 # overlap. The total execution time should be a little bit longer 813 # than spinning fifty million cycles and much shorter than twice of 814 # that. However, testing absolute execution time is not reliable as 815 # it may vary on different hardware in different environments. 816 # Therefore, this test uses relative comparisons, checking if the 817 # sum of parent and child threads execution time is greater than the 818 # real execution time by least 40%. 819 self.assertGreater(parent_time + child_time, total_time * 1.4) 820 821 # This test is flaky for ROCm, see issue #62602 822 @skipIfRocm 823 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 824 def test_events_wait(self): 825 d0 = torch.device("cuda:0") 826 d1 = torch.device("cuda:1") 827 torch.cuda.synchronize(d0) 828 torch.cuda.synchronize(d1) 829 830 with torch.cuda.device(d0): 831 s0 = torch.cuda.current_stream() 832 torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 833 e0 = torch.cuda.Event() 834 s0.record_event(e0) 835 836 with torch.cuda.device(d1): 837 s1 = torch.cuda.current_stream() 838 839 self.assertFalse(s0.query()) 840 self.assertTrue(s1.query()) 841 842 s1.wait_event(e0) 843 s1.synchronize() 844 845 self.assertTrue(e0.query()) 846 self.assertTrue(s0.query()) 847 self.assertTrue(s1.query()) 848 849 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 850 def test_events_multi_gpu_query(self): 851 d0 = torch.device("cuda:0") 852 d1 = torch.device("cuda:1") 853 854 with torch.cuda.device(d0): 855 s0 = torch.cuda.current_stream() 856 e0 = s0.record_event() 857 s0.synchronize() 858 859 with torch.cuda.device(d1): 860 s1 = torch.cuda.current_stream() 861 torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 862 e1 = s1.record_event() 863 864 self.assertTrue(e0.query()) 865 self.assertFalse(e1.query()) 866 867 with torch.cuda.device(d0): 868 self.assertTrue(e0.query()) 869 self.assertFalse(e1.query()) 870 871 with torch.cuda.device(d1): 872 self.assertTrue(e0.query()) 873 self.assertFalse(e1.query()) 874 875 # deliberately using a different device 876 with torch.cuda.device(d0): 877 e1.synchronize() 878 879 self.assertTrue(e0.query()) 880 self.assertTrue(e1.query()) 881 882 with torch.cuda.device(d0): 883 self.assertTrue(e0.query()) 884 self.assertTrue(e1.query()) 885 886 with torch.cuda.device(d1): 887 self.assertTrue(e0.query()) 888 self.assertTrue(e1.query()) 889 890 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 891 @skipIfRocm 892 def test_events_multi_gpu_elapsed_time(self): 893 d0 = torch.device("cuda:0") 894 d1 = torch.device("cuda:1") 895 896 with torch.cuda.device(d0): 897 s0 = torch.cuda.current_stream() 898 e0 = torch.cuda.Event(enable_timing=True) 899 torch.cuda._sleep(10) 900 s0.record_event(e0) 901 902 with torch.cuda.device(d1): 903 s1 = torch.cuda.current_stream() 904 e1 = torch.cuda.Event(enable_timing=True) 905 torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 906 s1.record_event(e1) 907 908 e0.synchronize() 909 e1.synchronize() 910 with torch.cuda.device(d0): 911 with self.assertRaises(RuntimeError): 912 self.assertGreater(e0.elapsed_time(e1), 0) 913 914 with torch.cuda.device(d1): 915 with self.assertRaises(RuntimeError): 916 self.assertGreater(e0.elapsed_time(e1), 0) 917 918 with torch.cuda.device(d0): 919 s0 = torch.cuda.current_stream() 920 e2 = torch.cuda.Event(enable_timing=True) 921 torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES) 922 s0.record_event(e2) 923 s0.synchronize() 924 925 self.assertGreater(e0.elapsed_time(e2), 0) 926 927 # deliberately calling from a different device 928 with torch.cuda.device(d1): 929 self.assertGreater(e0.elapsed_time(e2), 0) 930 931 @contextlib.contextmanager 932 def _get_external_stream(self, device): 933 cudart = torch.cuda.cudart() 934 stream = ctypes.c_ulonglong(0) 935 stream_p = ctypes.POINTER(ctypes.c_void_p)(stream) 936 stream_p_int = ctypes.cast(stream_p, ctypes.c_void_p).value 937 with device: 938 try: 939 out = cudart.cudaStreamCreate(stream_p_int) 940 self.assertEqual(out, 0) 941 self.assertNotEqual(stream.value, 0) 942 yield stream.value 943 finally: 944 out = cudart.cudaStreamDestroy(stream.value) 945 self.assertEqual(out, 0) 946 947 def test_external_streams(self): 948 device = torch.cuda.device(0) 949 with self._get_external_stream(device) as stream_v: 950 ext_stream = torch.cuda.ExternalStream(stream_v) 951 self.assertEqual(stream_v, ext_stream.cuda_stream) 952 self.assertEqual(ext_stream.device.index, device.idx) 953 954 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 955 def test_external_streams_multi_device(self): 956 device = torch.cuda.device(1) 957 with self._get_external_stream(device) as stream_v: 958 ext_stream = torch.cuda.ExternalStream(stream_v, device=device) 959 self.assertEqual(stream_v, ext_stream.cuda_stream) 960 self.assertEqual(ext_stream.device.index, device.idx) 961 962 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 963 def test_caching_pinned_memory_multi_gpu(self): 964 # checks that the events preventing pinned memory from being re-used 965 # too early are recorded on the correct GPU 966 cycles_per_ms = get_cycles_per_ms() 967 968 t = torch.FloatTensor([1]).pin_memory() 969 ptr = t.data_ptr() 970 gpu_tensor0 = torch.cuda.FloatTensor([0], device=0) 971 gpu_tensor1 = torch.cuda.FloatTensor([0], device=1) 972 973 with torch.cuda.device(1): 974 torch.cuda._sleep(int(1000 * cycles_per_ms)) # delay the copy by 1s 975 gpu_tensor1.copy_(t, non_blocking=True) 976 977 del t 978 t = torch.FloatTensor([2]).pin_memory() 979 self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon") 980 981 with torch.cuda.device(0): 982 gpu_tensor0.copy_(t, non_blocking=True) 983 984 self.assertEqual(gpu_tensor1[0], 1) 985 self.assertEqual(gpu_tensor0[0], 2) 986 987 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 988 def test_get_set_rng_state_all(self): 989 states = torch.cuda.get_rng_state_all() 990 before0 = torch.cuda.FloatTensor(100, device=0).normal_() 991 before1 = torch.cuda.FloatTensor(100, device=1).normal_() 992 torch.cuda.set_rng_state_all(states) 993 after0 = torch.cuda.FloatTensor(100, device=0).normal_() 994 after1 = torch.cuda.FloatTensor(100, device=1).normal_() 995 self.assertEqual(before0, after0, atol=0, rtol=0) 996 self.assertEqual(before1, after1, atol=0, rtol=0) 997 998 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 999 def test_rng_state_offset(self): 1000 before = torch.cuda.get_rng_state() 1001 torch.cuda._set_rng_state_offset(100) 1002 offset = torch.cuda._get_rng_state_offset() 1003 torch.cuda.set_rng_state(before) 1004 self.assertEqual(offset, 100) 1005 1006 # Verifies that mem_get_info works, including when called for a different device 1007 def test_mem_get_info(self): 1008 def _test(device: Union[str, int, torch.device]): 1009 # Prevent PyTorch from reusing the allocated memory 1010 torch.cuda.empty_cache() 1011 torch.cuda.synchronize() 1012 before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(device) 1013 # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms 1014 t = torch.randn(1024 * 1024 * 8, device=device) 1015 if IS_JETSON: 1016 # w/o syncing, mem_get_info will run before memory allocated has actually increased. 1017 # This race condition causes consistent failure 1018 torch.cuda.synchronize() 1019 after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(device) 1020 1021 self.assertLess(after_free_bytes, before_free_bytes) 1022 self.assertEqual(before_available_bytes, after_available_bytes) 1023 1024 # Test calls with different device representations 1025 _test(0) 1026 _test(torch.device("cuda")) 1027 _test(torch.device("cuda:0")) 1028 _test("cuda") 1029 _test("cuda:0") 1030 if TEST_MULTIGPU: 1031 _test(1) 1032 _test(torch.device("cuda:1")) 1033 _test("cuda:1") 1034 1035 # Test that wrap_with_cuda_memory_check successfully detects leak 1036 def test_cuda_memory_leak_detection(self): 1037 l = [] 1038 1039 @self.wrap_with_cuda_memory_check 1040 def no_leak(): 1041 pass 1042 1043 @self.wrap_with_cuda_memory_check 1044 def leak_gpu0(): 1045 # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms 1046 l.append(torch.randn(1024 * 1024 * 8, device=torch.device("cuda:0"))) 1047 1048 no_leak() 1049 regex = r"CUDA driver API confirmed .+ on device 0.+" 1050 if IS_JETSON: 1051 try: 1052 leak_gpu0() 1053 except RuntimeError as e: 1054 import re 1055 1056 assert re.match(regex, str(e)), str(e) + "\n does not match: \n" + regex 1057 else: 1058 # assertRaisesRegex does not pass with Python for Jetson, 1059 # even though the RuntimeError matches regex using re.match 1060 with self.assertRaisesRegex(RuntimeError, regex): 1061 leak_gpu0() 1062 1063 if TEST_MULTIGPU: 1064 1065 @self.wrap_with_cuda_memory_check 1066 def leak_gpu1(): 1067 # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms 1068 l.append(torch.randn(1024 * 1024 * 8, device=torch.device("cuda:1"))) 1069 1070 with self.assertRaisesRegex( 1071 RuntimeError, r"CUDA driver API confirmed .+ on device 1.+" 1072 ): 1073 leak_gpu1() 1074 1075 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1076 def test_streaming_backwards_device_transfer(self): 1077 # This function must run with non-default current streams on all devices, otherwise it's meaningless. 1078 # The intention is to test that to()'s backward (CopyBackward) interacts properly with the 1079 # synchronization logic in torch/csrc/autograd/input_buffer.cpp. 1080 dev0 = torch.device("cuda:0") 1081 dev1 = torch.device("cuda:1") 1082 1083 # Unfortunately I need to make the tensors largeish. 1084 # Bigger tensors = longer D2D transfers = more likely to expose races. 1085 size = 2**26 1086 1087 a = torch.full((size,), 1, device=dev1, dtype=torch.float64, requires_grad=True) 1088 b = torch.full((size,), 1, device=dev1, dtype=torch.float64, requires_grad=True) 1089 1090 # Here to_backward_recipient = a*b is used only once, so MulBackward's InputBuffer slot only expects 1 input. 1091 # This tests the situation where we don't call InputBuffer::accumulate for MulBackward's InputBuffer. 1092 to_backward_recipient = a * b 1093 s = to_backward_recipient.to(device="cuda:0").sum() 1094 torch.cuda.synchronize(device=dev0) 1095 torch.cuda.synchronize(device=dev1) 1096 s.backward() 1097 self.assertTrue(a.grad.sum().item() == size) 1098 self.assertTrue(b.grad.sum().item() == size) 1099 1100 # Here to_backward_recipient = a*b is used twice, so MulBackward's InputBuffer slot expects 2 inputs. 1101 # This tests the situation where we do call InputBuffer::accumulate for MulBackward's InputBuffer. 1102 a.grad = None 1103 b.grad = None 1104 to_backward_recipient = a * b 1105 # Multiply by 2 here so to's backward creates gradient values that are different from the case above, 1106 # to mitigate weirdness if the caching allocator happens to reuse memory regions that were populated 1107 # with 1s by the case above 1108 s0 = to_backward_recipient.to(device="cuda:0").sum() * 2.0 1109 s1 = to_backward_recipient.to(device="cuda:0").sum() * 2.0 1110 torch.cuda.synchronize(device=dev0) 1111 torch.cuda.synchronize(device=dev1) 1112 s0.backward(retain_graph=True) 1113 s1.backward() 1114 self.assertTrue(a.grad.sum().item() == 4 * size) 1115 self.assertTrue(b.grad.sum().item() == 4 * size) 1116 1117 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1118 @unittest.skipIf(IS_SANDCASTLE or IS_REMOTE_GPU, "Does not work on Sandcastle") 1119 def test_cuda_init_race(self): 1120 # See https://github.com/pytorch/pytorch/issues/16559 1121 import subprocess 1122 1123 subprocess.check_call( 1124 [ 1125 sys.executable, 1126 "-c", 1127 """\ 1128import torch 1129import threading 1130 1131def worker(rank): 1132 torch.tensor([1.]).cuda(rank) 1133 1134t1 = threading.Thread(target=worker, args=(0,)) 1135t2 = threading.Thread(target=worker, args=(1,)) 1136t1.start() 1137t2.start() 1138""", 1139 ] 1140 ) 1141 1142 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1143 def test_grad_scaling_device_as_key(self): 1144 # Ensure that different instances of "device" objects that point to the same device 1145 # are treated as identical keys by dicts. GradScaler relies on this behavior, and may 1146 # error otherwise in a way that's difficult to detect (a silent performance hit). 1147 d = {} 1148 t = torch.empty((1,), device="cuda:0") 1149 dev0a = torch.device("cuda:0") 1150 dev0b = torch.device("cuda:0") 1151 dev1a = torch.device("cuda:1") 1152 dev1b = torch.device("cuda:1") 1153 1154 self.assertTrue(hash(dev0a) == hash(dev0b)) 1155 self.assertTrue(hash(dev1a) == hash(dev1b)) 1156 1157 d[dev0a] = "0a" 1158 d[dev0b] = "0b" 1159 self.assertTrue(len(d) == 1) 1160 self.assertTrue(d[dev0a] == "0b") 1161 d[t.device] = "t" 1162 self.assertTrue(len(d) == 1) 1163 self.assertTrue(d[dev0a] == "t") 1164 1165 d[dev1a] = "1a" 1166 d[dev1b] = "1b" 1167 self.assertTrue(len(d) == 2) 1168 self.assertTrue(d[dev1a] == "1b") 1169 1170 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1171 def test_grad_scaling_scale(self): 1172 scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0) 1173 t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:0") 1174 t1 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:1") 1175 # Create some nested iterables of tensors on different devices. 1176 outputs = ( 1177 t1.clone(), 1178 (t0.clone(), t1.clone()), 1179 [t0.clone(), (t1.clone(), t0.clone())], 1180 ) 1181 outputs = scaler.scale(outputs) 1182 self.assertTrue( 1183 outputs[0] == 8.0 1184 and outputs[1][0] == 8.0 1185 and outputs[1][1] == 8.0 1186 and outputs[2][0] == 8.0 1187 and outputs[2][1][0] == 8.0 1188 and outputs[2][1][1] == 8.0 1189 ) 1190 self.assertTrue(scaler._scale.device == t1.device) 1191 1192 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1193 def test_grad_scaling_multigpu(self): 1194 # Same as above, but runs some of the models on device 1. 1195 # GradScaler should transparently handle losses and gradients on multiple devices. 1196 # This test could be combined with the test above, but I think it makes sense to treat 1197 # multi-GPU operations separately. 1198 dev0 = torch.device("cuda:0") 1199 dev1 = torch.device("cuda:1") 1200 1201 for enabled in True, False: 1202 ( 1203 mod_control0, 1204 mod_scaling0, 1205 opt_control0, 1206 opt_scaling0, 1207 data, 1208 loss_fn, 1209 skip_iter, 1210 ) = _create_scaling_case() 1211 ( 1212 mod_control1, 1213 mod_scaling1, 1214 opt_control1, 1215 opt_scaling1, 1216 ) = _create_scaling_models_optimizers(device=dev1) 1217 1218 scaler = torch.amp.GradScaler( 1219 device="cuda", 1220 init_scale=128.0, 1221 growth_factor=2.0, 1222 enabled=enabled, 1223 growth_interval=1, 1224 ) 1225 1226 def run(model0, model1, optimizer0, optimizer1, try_scaling_api): 1227 for i, (input, target) in enumerate(data): 1228 optimizer0.zero_grad() 1229 optimizer1.zero_grad() 1230 output0 = model0(input) 1231 output1 = model1(input.to(dev1)) 1232 loss0 = loss_fn(0.3 * output0 + 0.7 * output1.to(dev0), target) 1233 loss1 = loss_fn( 1234 0.6 * output0.to(dev1) - 0.4 * output1, target.to(dev1) 1235 ) 1236 1237 if try_scaling_api: 1238 scaler.scale(loss0).backward(retain_graph=True) 1239 scaler.scale(loss1).backward() 1240 if i == skip_iter and scaler.is_enabled(): 1241 model1[1].weight.grad.data.fill_(float("inf")) 1242 1243 # As an additional stress test, separately unscale for one of the optimizers. 1244 scaler.unscale_(optimizer0) 1245 1246 scaler.step(optimizer0) 1247 scaler.step(optimizer1) 1248 1249 # Make sure the found_infs were collected properly across optimizers and devices. 1250 if scaler.is_enabled(): 1251 self.assertTrue( 1252 len(scaler._found_inf_per_device(optimizer0)) == 1 1253 ) 1254 self.assertTrue( 1255 len(scaler._found_inf_per_device(optimizer1)) == 1 1256 ) 1257 self.assertTrue( 1258 scaler._found_inf_per_device(optimizer0)[dev0].item() 1259 == 0.0 1260 ) 1261 self.assertTrue( 1262 scaler._found_inf_per_device(optimizer1)[dev1].item() 1263 == float(i == skip_iter) 1264 ) 1265 1266 scaler.update() 1267 else: 1268 loss0.backward(retain_graph=True) 1269 loss1.backward() 1270 optimizer0.step() 1271 if (not scaler.is_enabled()) or (i != skip_iter): 1272 optimizer1.step() 1273 1274 run(mod_control0, mod_control1, opt_control0, opt_control1, False) 1275 run(mod_scaling0, mod_scaling1, opt_scaling0, opt_scaling1, True) 1276 1277 # The loss scale should have been multiplied by the growth factor 3 times and the backoff factor once. 1278 self.assertTrue( 1279 scaler.get_scale() 1280 == ( 1281 128.0 1282 * scaler.get_growth_factor() ** 3 1283 * scaler.get_backoff_factor() ** 1 1284 ) 1285 if enabled 1286 else 1.0 1287 ) 1288 1289 # Copy mod_control1 and mod_scaling1 back the device 0 for comparison 1290 mod_control1.to(dev0) 1291 mod_scaling1.to(dev0) 1292 1293 for c, s in zip( 1294 chain(mod_control0.parameters(), mod_control1.parameters()), 1295 chain(mod_scaling0.parameters(), mod_scaling1.parameters()), 1296 ): 1297 self.assertEqual(c, s, rtol=1e-5, atol=1e-7) 1298 1299 @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs") 1300 def test_cuda_device_memory_allocated(self): 1301 from torch.cuda import memory_allocated 1302 1303 device_count = torch.cuda.device_count() 1304 current_alloc = [memory_allocated(idx) for idx in range(device_count)] 1305 x = torch.ones(10, device="cuda:0") 1306 self.assertGreater(memory_allocated(0), current_alloc[0]) 1307 self.assertTrue( 1308 all( 1309 memory_allocated(torch.cuda.device(idx)) == current_alloc[idx] 1310 for idx in range(1, device_count) 1311 ) 1312 ) 1313 1314 1315class TestCudaComm(TestCase): 1316 def _test_broadcast(self, input): 1317 if not TEST_MULTIGPU: 1318 raise unittest.SkipTest("only one GPU detected") 1319 # test regular 1320 results = comm.broadcast(input, (0, 1)) 1321 for i, t in enumerate(results): 1322 self.assertEqual(t.get_device(), i) 1323 self.assertEqual(t, input) 1324 if ( 1325 input.is_cuda and input.get_device() == i 1326 ): # test not copying on same device 1327 self.assertEqual(t.data_ptr(), input.data_ptr()) 1328 # test out= 1329 for inplace in [True, False]: 1330 if inplace: 1331 outputs = [ 1332 torch.empty_like(input, device=0), 1333 torch.empty_like(input, device=1), 1334 ] 1335 else: 1336 outputs = [input.cuda(0), torch.empty_like(input, device=1)] 1337 results = comm.broadcast(input, out=outputs) 1338 for r, o in zip(results, outputs): 1339 self.assertIs(r, o) 1340 for i, t in enumerate(results): 1341 self.assertEqual(t.get_device(), i) 1342 self.assertEqual(t, input) 1343 # test error msg 1344 with self.assertRaisesRegex( 1345 RuntimeError, r"Exactly one of 'devices' and 'out'" 1346 ): 1347 comm.broadcast(input, (0, 1), out=outputs) 1348 with self.assertRaisesRegex( 1349 RuntimeError, 1350 r"Expected all output tensors to be CUDA tensors, but output tensor at index 1", 1351 ): 1352 comm.broadcast(input, out=[input.cuda(0), input.cpu()]) 1353 with self.assertRaisesRegex( 1354 RuntimeError, 1355 r"Expected all output tensors to have same shape as the source .+ at index 1", 1356 ): 1357 comm.broadcast(input, out=[input.cuda(0), input.cuda(1).unsqueeze(0)]) 1358 1359 def test_broadcast_cpu(self): 1360 self._test_broadcast(torch.randn(5, 5)) 1361 1362 def test_broadcast_gpu(self): 1363 self._test_broadcast(torch.randn(5, 5).cuda()) 1364 1365 def _test_broadcast_coalesced(self, tensors, buffer_size): 1366 b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors] 1367 for (_, bt), t in zip(b_tensors, tensors): 1368 self.assertEqual(bt.get_device(), 1) 1369 self.assertEqual(bt, t) 1370 self.assertIsInstance(bt, type(t)) 1371 1372 bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=buffer_size) 1373 bc_tensors_t = list(zip(*bc_tensors)) 1374 self.assertEqual(b_tensors, bc_tensors_t) 1375 for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t): 1376 self.assertEqual(bt.get_device(), bct.get_device()) 1377 self.assertIsInstance(bct, type(bt)) 1378 1379 # check that tensors on device[0] are returned as-is 1380 for out_tensors in (b_tensors, bc_tensors_t): 1381 for inp_t, (out_t, _) in zip(tensors, out_tensors): 1382 self.assertIs(inp_t, out_t) 1383 1384 # check that the tensors not on device[0] have different version counters 1385 # NOTE [ Version Counter in comm.*_coalesced ] 1386 versions = [t._version for _, t in bc_tensors_t] 1387 for old_version, (_, t) in zip(versions, bc_tensors_t): 1388 self.assertEqual(t._version, old_version) 1389 t.zero_() 1390 self.assertEqual(t._version, old_version + 1) 1391 1392 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1393 # Note: fails sometimes on the CI, passes on dual gfx906 1394 def test_broadcast_coalesced(self): 1395 numel = 5 1396 num_bytes = numel * 8 1397 tensors = [ 1398 self.genSparseTensor((2, 3), 2, 1, False, "cuda", torch.float64)[0], 1399 torch.randn(numel).long().cuda(), 1400 torch.randn(numel).cuda(), 1401 self.genSparseTensor((2, 3), 2, 10, False, "cuda", torch.float64)[0], 1402 self.genSparseTensor((2, 3), 2, 5, False, "cuda", torch.float64)[0], 1403 self.genSparseTensor((3, 3), 2, 7, False, "cuda", torch.int64)[0], 1404 self.genSparseTensor((2, 3), 2, 2, False, "cuda", torch.float32)[0], 1405 torch.randn(numel).long().cuda(), 1406 torch.randn(numel).long().cuda(), 1407 self.genSparseTensor((2, 7), 2, 3, False, "cuda", torch.int64)[0], 1408 torch.randn(numel * 2).int().cuda(), # int is 2x shorter 1409 torch.randn(numel).cuda(), 1410 ] 1411 self._test_broadcast_coalesced(tensors, num_bytes * 5 // 2) 1412 1413 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1414 def test_broadcast_coalesced_dense_only(self): 1415 numel = 5 1416 num_bytes = numel * 8 1417 tensors = [ 1418 torch.randn(numel).long().cuda(), 1419 torch.randn(numel).cuda(), 1420 torch.randn(numel).long().cuda(), 1421 torch.randn(numel).long().cuda(), 1422 torch.randn(numel * 2).int().cuda(), # int is 2x shorter 1423 torch.randn(numel).cuda(), 1424 ] 1425 self._test_broadcast_coalesced(tensors, num_bytes * 5 // 2) 1426 1427 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1428 def test_broadcast_coalesced_empty_tensors(self): 1429 tensors = [ 1430 torch.tensor([]).byte().cuda(), 1431 torch.randn(5).cuda(), 1432 torch.randn(5).double().cuda(), 1433 ] 1434 self._test_broadcast_coalesced(tensors, 256) 1435 1436 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1437 def test_reduce_add(self): 1438 x = torch.randn(5, 5) 1439 y = torch.randn(5, 5) 1440 x_cuda = x.cuda(0) 1441 y_cuda = y.cuda(1) 1442 result = comm.reduce_add((x_cuda, y_cuda)) 1443 self.assertEqual(result.get_device(), 0) 1444 self.assertEqual(result.cpu(), x + y) 1445 1446 def _test_reduce_add_coalesced(self, tensors, buffer_size): 1447 dup_tensors = [tensors, [t.cuda(1) for t in tensors]] 1448 1449 r_tensors = [comm.reduce_add(t) for t in zip(*dup_tensors)] 1450 for r, t in zip(r_tensors, tensors): 1451 self.assertEqualTypeString(r, t) 1452 self.assertEqual(r.coalesce() if r.is_sparse else r, t * 2) 1453 1454 rc_tensors = comm.reduce_add_coalesced(dup_tensors, buffer_size=buffer_size) 1455 self.assertEqual(r_tensors, rc_tensors) 1456 for r, rc in zip(r_tensors, rc_tensors): 1457 self.assertEqualTypeString(rc, r) 1458 1459 # Since we have both cuda:0 and cuda:1 inputs, the outputs must be new. 1460 # We can check that they have different version counters. 1461 # NOTE [ Version Counter in comm.*_coalesced ] 1462 versions = [t._version for t in rc_tensors] 1463 for old_version, t in zip(versions, rc_tensors): 1464 self.assertEqual(t._version, old_version) 1465 t.zero_() 1466 self.assertEqual(t._version, old_version + 1) 1467 1468 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1469 def test_reduce_add_coalesced(self): 1470 numel = 5 1471 num_bytes = numel * 8 1472 tensors = [ 1473 self.genSparseTensor((2, 3), 2, 1, False, "cuda", torch.float64)[0], 1474 torch.randn(numel).long().cuda(), 1475 torch.randn(numel).cuda(), 1476 self.genSparseTensor((2, 3), 2, 10, False, "cuda", torch.float64)[0], 1477 self.genSparseTensor((2, 3), 2, 5, False, "cuda", torch.float64)[0], 1478 self.genSparseTensor((3, 3), 2, 7, False, "cuda", torch.int64)[0], 1479 self.genSparseTensor((2, 3), 2, 2, False, "cuda", torch.float32)[0], 1480 torch.randn(numel).long().cuda(), 1481 torch.randn(numel).long().cuda(), 1482 self.genSparseTensor((2, 7), 2, 3, False, "cuda", torch.int64)[0], 1483 torch.randn(numel * 2).int().cuda(), # int is 2x shorter 1484 torch.randn(numel).cuda(), 1485 ] 1486 self._test_reduce_add_coalesced(tensors, num_bytes * 5 // 2) 1487 1488 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1489 def test_reduce_add_coalesced_dense_only(self): 1490 numel = 5 1491 num_bytes = numel * 8 1492 tensors = [ 1493 torch.randn(numel).long().cuda(), 1494 torch.randn(numel).cuda(), 1495 torch.randn(numel).long().cuda(), 1496 torch.randn(numel).long().cuda(), 1497 torch.randn(numel * 2).int().cuda(), # int is 2x shorter 1498 torch.randn(numel).cuda(), 1499 ] 1500 self._test_reduce_add_coalesced(tensors, num_bytes * 5 // 2) 1501 1502 def _test_scatter(self, input, chunk_sizes=None, dim=0): 1503 if not TEST_MULTIGPU: 1504 raise unittest.SkipTest("only one GPU detected") 1505 if chunk_sizes is None: 1506 ref_chunk_sizes = tuple(repeat(input.size(dim) // 2, 2)) 1507 else: 1508 ref_chunk_sizes = chunk_sizes 1509 1510 # test regular 1511 result = comm.scatter(input, (0, 1), chunk_sizes, dim) 1512 self.assertEqual(len(result), 2) 1513 chunk_start = 0 1514 for i, r in enumerate(result): 1515 chunk_end = chunk_start + ref_chunk_sizes[i] 1516 index = [slice(None, None) for _ in range(input.dim())] 1517 index[dim] = slice(chunk_start, chunk_end) 1518 self.assertEqual(r, input[tuple(index)], atol=0, rtol=0) 1519 chunk_start = chunk_end 1520 if r.device == input.device: 1521 self.assertEqual( 1522 r.data_ptr(), input.data_ptr() 1523 ) # for target @ same device, a view should be returned 1524 1525 # test out 1526 out = [torch.empty_like(t) for t in result] 1527 result = comm.scatter(input, dim=dim, out=out) 1528 self.assertEqual(len(result), 2) 1529 chunk_start = 0 1530 for i, r in enumerate(result): 1531 self.assertIs(r, out[i]) 1532 chunk_end = chunk_start + ref_chunk_sizes[i] 1533 index = [slice(None, None) for _ in range(input.dim())] 1534 index[dim] = slice(chunk_start, chunk_end) 1535 self.assertEqual(r, input[tuple(index)], atol=0, rtol=0) 1536 chunk_start = chunk_end 1537 1538 # test error msg 1539 if chunk_sizes is not None: 1540 with self.assertRaisesRegex( 1541 RuntimeError, r"Expected devices and chunk_sizes to be of same length" 1542 ): 1543 comm.scatter( 1544 input, 1545 [0 for _ in range(len(chunk_sizes) + 1)], 1546 dim=dim, 1547 chunk_sizes=chunk_sizes, 1548 ) 1549 with self.assertRaisesRegex(RuntimeError, r"'devices' must not be specified"): 1550 comm.scatter(input, (0, 1), dim=dim, out=out) 1551 with self.assertRaisesRegex( 1552 RuntimeError, r"Expected at least one device to scatter to" 1553 ): 1554 comm.scatter(input, (), dim=dim) 1555 with self.assertRaisesRegex( 1556 RuntimeError, r"Expected at least one output tensor to scatter to" 1557 ): 1558 comm.scatter(input, dim=dim, out=[]) 1559 with self.assertRaisesRegex( 1560 RuntimeError, 1561 r"Expected all output tensors to be CUDA tensors, but output tensor at index 0", 1562 ): 1563 comm.scatter(input, dim=dim, out=([out[0].cpu()] + out[1:])) 1564 with self.assertRaisesRegex( 1565 RuntimeError, r"Output tensor at index 0 has incorrect shape" 1566 ): 1567 comm.scatter(input, dim=dim, out=([out[0].unsqueeze(0)] + out[1:])) 1568 with self.assertRaisesRegex( 1569 RuntimeError, 1570 r"Total size for output tensors along scatter dim \d+ does not match", 1571 ): 1572 index = [slice(None, None) for _ in range(input.dim())] 1573 index[dim] = slice(1, None) 1574 comm.scatter(input, dim=dim, out=([out[0][tuple(index)]] + out[1:])) 1575 1576 def test_scatter_cpu(self): 1577 self._test_scatter(torch.randn(4, 4), dim=0) 1578 1579 def test_scatter_cpu_dim(self): 1580 self._test_scatter(torch.randn(4, 4), dim=1) 1581 1582 def test_scatter_cpu_neg_dim(self): 1583 self._test_scatter(torch.randn(4, 4), dim=-2) 1584 1585 def test_scatter_cpu_sizes(self): 1586 self._test_scatter(torch.randn(6, 4), chunk_sizes=(2, 4)) 1587 1588 def test_scatter_gpu(self): 1589 self._test_scatter(torch.randn(4, 4).cuda(), dim=0) 1590 1591 def test_scatter_gpu_dim(self): 1592 self._test_scatter(torch.randn(4, 4).cuda(), dim=1) 1593 1594 def test_scatter_gpu_neg_dim(self): 1595 self._test_scatter(torch.randn(4, 4).cuda(), dim=-2) 1596 1597 def test_scatter_gpu_sizes(self): 1598 self._test_scatter(torch.randn(6, 4).cuda(), chunk_sizes=(2, 4)) 1599 1600 def _test_gather(self, dim): 1601 if not TEST_MULTIGPU: 1602 raise unittest.SkipTest("only one GPU detected") 1603 x = torch.randn(2, 5, device=0) 1604 y = torch.randn(2, 5, device=1) 1605 expected_size = list(x.size()) 1606 expected_size[dim] += y.size(dim) 1607 expected_size = torch.Size(expected_size) 1608 1609 destinations = [None, torch.device("cuda:0"), torch.device("cpu")] 1610 if torch.cuda.device_count() > 2: 1611 destinations.append(torch.device("cuda:2")) 1612 with torch.cuda.device(1): 1613 for destination in destinations: 1614 if destination is None: 1615 expected_device = torch.device("cuda", torch.cuda.current_device()) 1616 else: 1617 expected_device = destination 1618 for use_out in [True, False]: 1619 if use_out: 1620 out = torch.empty(expected_size, device=expected_device) 1621 result = comm.gather((x, y), dim, out=out) 1622 self.assertIs(out, result) 1623 else: 1624 result = comm.gather((x, y), dim, destination=destination) 1625 1626 self.assertEqual(result.device, expected_device) 1627 self.assertEqual(result.size(), expected_size) 1628 1629 index = [slice(None, None), slice(None, None)] 1630 index[dim] = slice(0, x.size(dim)) 1631 self.assertEqual(result[tuple(index)], x) 1632 index[dim] = slice(x.size(dim), x.size(dim) + y.size(dim)) 1633 self.assertEqual(result[tuple(index)], y) 1634 1635 # test error msg 1636 with self.assertRaisesRegex( 1637 RuntimeError, r"'destination' must not be specified" 1638 ): 1639 comm.gather( 1640 (x, y), 1641 dim, 1642 destination="cpu", 1643 out=torch.empty(expected_size, device="cpu"), 1644 ) 1645 with self.assertRaisesRegex( 1646 RuntimeError, r"Expected at least one tensor to gather from" 1647 ): 1648 comm.gather(()) 1649 with self.assertRaisesRegex( 1650 RuntimeError, r"Expected all input tensors to be CUDA tensors, " 1651 ): 1652 comm.gather((x.cpu(), y)) 1653 with self.assertRaisesRegex( 1654 RuntimeError, 1655 r"Expected all input tensors to have the same number of dimensions", 1656 ): 1657 comm.gather((x, y.unsqueeze(0))) 1658 with self.assertRaisesRegex( 1659 RuntimeError, r"Input tensor at index 1 has invalid shape" 1660 ): 1661 if dim in [0, -2]: 1662 comm.gather((x, y[:, 1:]), dim=dim) 1663 elif dim in [1, -1]: 1664 comm.gather((x, y[1:, :]), dim=dim) 1665 1666 def test_gather(self): 1667 self._test_gather(0) 1668 1669 def test_gather_dim(self): 1670 self._test_gather(1) 1671 1672 def test_gather_neg_dim(self): 1673 self._test_gather(-1) 1674 1675 @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 1676 def test_memory_format_scatter_gather(self): 1677 nhwc = torch.randn((10, 3, 32, 32), device="cpu").contiguous( 1678 memory_format=torch.channels_last 1679 ) 1680 results = torch.cuda.comm.scatter(nhwc, (0, 1), None, 0) 1681 for result in results: 1682 self.assertFalse(result.is_contiguous()) 1683 self.assertTrue(result.is_contiguous(memory_format=torch.channels_last)) 1684 1685 gathered = torch.cuda.comm.gather(results) 1686 self.assertTrue(gathered.is_contiguous(memory_format=torch.channels_last)) 1687 1688 @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs") 1689 def test_scatter_namedtuple(self): 1690 # tests ability to scatter namedtuples and retrieve a list where each 1691 # element is of the expected namedtuple type. 1692 fields = ("a", "b") 1693 TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", fields) 1694 num_gpus = torch.cuda.device_count() 1695 a = torch.rand(num_gpus * 2, device=0) 1696 b = torch.rand(num_gpus * 2, device=0) 1697 a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] 1698 b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] 1699 1700 inp = TestNamedTupleInput_0(a, b) 1701 target_gpus = [torch.device(i) for i in range(num_gpus)] 1702 scatter_out = scatter_gather.scatter(inp, target_gpus) 1703 1704 for i, x in enumerate(scatter_out): 1705 self.assertTrue(isinstance(x, type(inp))) 1706 self.assertEqual(x._fields, fields) 1707 expected_a = a_tensors_for_gpu[i] 1708 expected_b = b_tensors_for_gpu[i] 1709 self.assertEqual(expected_a, x.a) 1710 self.assertEqual(expected_b, x.b) 1711 1712 class TestNamedTupleInput_1(NamedTuple): 1713 a: torch.tensor 1714 b: torch.tensor 1715 1716 a = torch.rand(num_gpus * 2, device=0) 1717 b = torch.rand(num_gpus * 2, device=0) 1718 a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] 1719 b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] 1720 inp = TestNamedTupleInput_1(a, b) 1721 1722 scatter_out = scatter_gather.scatter(inp, target_gpus) 1723 for i, x in enumerate(scatter_out): 1724 self.assertTrue(isinstance(x, type(inp))) 1725 self.assertEqual(x._fields, fields) 1726 expected_a = a_tensors_for_gpu[i] 1727 expected_b = b_tensors_for_gpu[i] 1728 self.assertEqual(expected_a, x.a) 1729 self.assertEqual(expected_b, x.b) 1730 1731 @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs") 1732 def test_gather_namedtuple(self): 1733 # tests ability to gather a list of namedtuples and return a namedtuple where each 1734 # element is of the expected tensor type. 1735 fields = ["a", "b"] 1736 TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", fields) 1737 1738 num_gpus = torch.cuda.device_count() 1739 a = torch.rand(num_gpus * 2, device=0) 1740 b = torch.rand(num_gpus * 2, device=1) 1741 out1 = TestNamedTupleInput_0(a, b) 1742 1743 a = torch.rand(num_gpus * 2, device=1) 1744 b = torch.rand(num_gpus * 2, device=0) 1745 out2 = TestNamedTupleInput_0(a, b) 1746 1747 outputs = [out1, out2] 1748 1749 out = scatter_gather.gather(outputs, "cpu") # test on CPU 1750 for i, x in enumerate(out): 1751 self.assertTrue(isinstance(x, type(out2[-1]))) # x must be a tensor 1752 cat = torch.cat((outputs[0][i].to("cpu"), outputs[1][i].to("cpu"))) 1753 self.assertTrue(torch.equal(x, cat)) 1754 1755 out = scatter_gather.gather(outputs, 0) # test on GPU 1756 for i, x in enumerate(out): 1757 self.assertTrue(isinstance(x, type(out2[-1]))) 1758 cat = torch.cat((outputs[0][i].to(0), outputs[1][i].to(0))) 1759 self.assertTrue(torch.equal(x, cat)) 1760 1761 class TestNamedTupleInput_1(NamedTuple): 1762 a: torch.tensor 1763 b: torch.tensor 1764 1765 a = torch.rand(num_gpus * 2, device=0) 1766 b = torch.rand(num_gpus * 2, device=1) 1767 out1 = TestNamedTupleInput_1(a, b) 1768 1769 a = torch.rand(num_gpus * 2, device=1) 1770 b = torch.rand(num_gpus * 2, device=0) 1771 out2 = TestNamedTupleInput_1(a, b) 1772 1773 outputs = [out1, out2] 1774 1775 out = scatter_gather.gather(outputs, 0) # test on GPU 1776 for i, x in enumerate(out): 1777 self.assertTrue(isinstance(x, type(out2[-1]))) 1778 cat = torch.cat((outputs[0][i].to(0), outputs[1][i].to(0))) 1779 self.assertTrue(torch.equal(x, cat)) 1780 1781 out = scatter_gather.gather(outputs, "cpu") # test on CPU 1782 for i, x in enumerate(out): 1783 self.assertTrue(isinstance(x, type(out2[-1]))) 1784 cat = torch.cat((outputs[0][i].to("cpu"), outputs[1][i].to("cpu"))) 1785 self.assertTrue(torch.equal(x, cat)) 1786 1787 1788instantiate_parametrized_tests(TestCudaMultiGPU) 1789 1790 1791if __name__ == "__main__": 1792 run_tests() 1793