xref: /aosp_15_r20/external/pytorch/test/test_cuda_multigpu.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cuda"]
2
3import collections
4import contextlib
5import ctypes
6import gc
7import io
8import queue
9import sys
10import tempfile
11import threading
12import unittest
13from itertools import chain, repeat
14from typing import NamedTuple, Union
15
16import torch
17import torch.cuda.comm as comm
18from torch.nn.parallel import scatter_gather
19from torch.testing._internal.common_cuda import (
20    _create_scaling_case,
21    _create_scaling_models_optimizers,
22    TEST_MULTIGPU,
23)
24from torch.testing._internal.common_utils import (
25    get_cycles_per_ms,
26    instantiate_parametrized_tests,
27    IS_JETSON,
28    IS_REMOTE_GPU,
29    IS_SANDCASTLE,
30    NoTest,
31    run_tests,
32    serialTest,
33    skipCUDANonDefaultStreamIf,
34    skipIfRocm,
35    TEST_CUDA,
36    TestCase,
37)
38
39
40TEST_CUDAMALLOCASYNC = TEST_CUDA and (
41    torch.cuda.get_allocator_backend() == "cudaMallocAsync"
42)
43
44if not TEST_CUDA:
45    print("CUDA not available, skipping tests", file=sys.stderr)
46    TestCase = NoTest  # noqa: F811
47
48
49class TestCudaMultiGPU(TestCase):
50    FIFTY_MIL_CYCLES = 50000000
51
52    def _check_memory_stat_consistency(self):
53        snapshot = torch.cuda.memory_snapshot()
54
55        expected_each_device = collections.defaultdict(
56            lambda: collections.defaultdict(int)
57        )
58
59        for segment in snapshot:
60            expandable = segment["is_expandable"]
61            expected = expected_each_device[segment["device"]]
62            pool_str = segment["segment_type"] + "_pool"
63
64            if not expandable:
65                expected["segment.all.current"] += 1
66                expected["segment." + pool_str + ".current"] += 1
67
68            expected["allocated_bytes.all.current"] += segment["allocated_size"]
69            expected["allocated_bytes." + pool_str + ".current"] += segment[
70                "allocated_size"
71            ]
72
73            expected["reserved_bytes.all.current"] += segment["total_size"]
74            expected["reserved_bytes." + pool_str + ".current"] += segment["total_size"]
75
76            expected["active_bytes.all.current"] += segment["active_size"]
77            expected["active_bytes." + pool_str + ".current"] += segment["active_size"]
78
79            expected["requested_bytes.all.current"] += segment["requested_size"]
80            expected["requested_bytes." + pool_str + ".current"] += segment[
81                "requested_size"
82            ]
83
84            sum_requested = 0
85            is_split = len(segment["blocks"]) > 1
86            for block in segment["blocks"]:
87                if block["state"] == "active_allocated":
88                    expected["allocation.all.current"] += 1
89                    expected["allocation." + pool_str + ".current"] += 1
90
91                if block["state"].startswith("active_"):
92                    sum_requested += block["requested_size"]
93                    expected["active.all.current"] += 1
94                    expected["active." + pool_str + ".current"] += 1
95
96                if block["state"] == "inactive" and is_split and not expandable:
97                    expected["inactive_split.all.current"] += 1
98                    expected["inactive_split." + pool_str + ".current"] += 1
99                    expected["inactive_split_bytes.all.current"] += block["size"]
100                    expected["inactive_split_bytes." + pool_str + ".current"] += block[
101                        "size"
102                    ]
103
104            self.assertEqual(sum_requested, segment["requested_size"])
105
106        for device, expected in expected_each_device.items():
107            stats = torch.cuda.memory_stats(device)
108            for k, v in expected.items():
109                self.assertEqual(v, stats[k])
110
111    def test_cuda_synchronize(self):
112        torch.cuda.synchronize()
113        torch.cuda.synchronize("cuda")
114        torch.cuda.synchronize("cuda:0")
115        torch.cuda.synchronize(0)
116        torch.cuda.synchronize(torch.device("cuda:0"))
117
118        if TEST_MULTIGPU:
119            torch.cuda.synchronize("cuda:1")
120            torch.cuda.synchronize(1)
121            torch.cuda.synchronize(torch.device("cuda:1"))
122
123        with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"):
124            torch.cuda.synchronize(torch.device("cpu"))
125
126        with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"):
127            torch.cuda.synchronize("cpu")
128
129    @staticmethod
130    def _test_memory_stats_generator(self, device=None, N=35):
131        if device is None:
132            device = torch.cuda.current_device()
133
134        m0 = torch.cuda.memory_allocated(device)
135        last_m_arr = [torch.cuda.memory_allocated(device)]
136        max_m_arr = [torch.cuda.max_memory_allocated(device)]
137        last_r_arr = [torch.cuda.memory_reserved(device)]
138        max_r_arr = [torch.cuda.max_memory_reserved(device)]
139
140        def alloc(*size):
141            with torch.cuda.device(device):
142                # NOTE: do **not** use methods that can have additional
143                #       memory overhead, e.g., inplace random sampling methods.
144                #       they can leave some memory occupied even after being
145                #       deallocated, e.g., initialized RNG state, causing some
146                #       memory checks below to fail.
147                return torch.cuda.FloatTensor(*size)
148
149        def assert_change(comp=1, empty_cache=False, reset_peak=False):
150            # comp > 0: increased
151            # comp = 0: equal
152            # comp < 0: decreased
153            new_m = torch.cuda.memory_allocated(device)
154            new_max_m = torch.cuda.max_memory_allocated(device)
155            if comp > 0:
156                self.assertGreater(new_m, last_m_arr[0])
157            elif comp < 0:
158                self.assertLess(new_m, last_m_arr[0])
159            else:
160                self.assertEqual(new_m, last_m_arr[0])
161            self.assertLessEqual(new_m, new_max_m)
162            self.assertGreaterEqual(new_max_m, max_m_arr[0])
163            last_m_arr[0] = new_m
164            max_m_arr[0] = new_max_m
165
166            new_r = torch.cuda.memory_reserved(device)
167            new_max_r = torch.cuda.max_memory_reserved(device)
168            # emptying cache may happen (due to allocation or empty_cache), so
169            # we can't assert new_c >= last_c
170            self.assertLessEqual(new_r, new_max_r)
171            self.assertGreaterEqual(new_max_r, max_r_arr[0])
172            last_r_arr[0] = new_r
173            max_r_arr[0] = new_max_r
174
175            stat_key_n_sync = "num_sync_all_streams"
176            stat_key_n_alloc = "num_device_alloc"
177            stat_key_n_free = "num_device_free"
178            if empty_cache:
179                num_sync_1 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1)
180                self.assertGreaterEqual(num_sync_1, 0)
181                num_alloc_1 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1)
182                # if current memory usage is greater than zero we must have
183                # allocated something
184                self.assertGreaterEqual(num_alloc_1, 0 if new_m == 0 else 1)
185                num_free_1 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1)
186                self.assertGreaterEqual(num_free_1, 0)
187                # empty_cache will enforce the call of release_cached_blocks
188                torch.cuda.empty_cache()
189                num_sync_2 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1)
190                self.assertEqual(num_sync_1 + 1, num_sync_2)
191                num_alloc_2 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1)
192                self.assertGreaterEqual(num_alloc_2, num_alloc_1)
193                num_free_2 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1)
194                self.assertGreaterEqual(num_free_2, num_free_1)
195
196                new_r = torch.cuda.memory_reserved(device)
197                new_max_r = torch.cuda.max_memory_reserved(device)
198                self.assertLessEqual(new_r, last_r_arr[0])
199                self.assertLessEqual(new_r, new_max_r)
200                self.assertEqual(new_max_r, max_r_arr[0])
201                last_r_arr[0] = new_r
202
203            if reset_peak:
204                torch.cuda.reset_peak_memory_stats(device)
205                self.assertEqual(torch.cuda.memory_allocated(device), last_m_arr[0])
206                self.assertEqual(torch.cuda.max_memory_allocated(device), last_m_arr[0])
207                max_m_arr[0] = last_m_arr[0]
208                self.assertEqual(torch.cuda.memory_reserved(device), last_r_arr[0])
209                self.assertEqual(torch.cuda.max_memory_reserved(device), last_r_arr[0])
210                max_r_arr[0] = last_r_arr[0]
211
212        assert_change(0)
213        assert_change(0, reset_peak=True)
214        assert_change(0, empty_cache=True)
215        assert_change(0, reset_peak=True)
216        assert_change(0)
217        yield
218
219        tensors1 = [alloc(1), alloc(10, 20), alloc(200, 300, 2000)]
220        m1 = torch.cuda.memory_allocated(device)
221        assert_change(1)
222        yield
223
224        tensors2 = []
225
226        for i in range(1, int(N / 2) + 1):
227            # small ones
228            tensors2.append(alloc(i, i * 4))
229            assert_change(1)
230            yield
231
232        for i in range(5, int(N / 2) + 5):
233            # large ones
234            tensors2.append(alloc(i, i * 7, i * 9, i * 11))
235            assert_change(1, reset_peak=(i % 2 == 0))
236            yield
237
238        tensors2.append(alloc(0, 0, 0))
239        assert_change(0)
240        yield
241
242        permute = []
243        for i in torch.randperm(len(tensors2)):
244            permute.append(tensors2[i])
245            assert_change(0)
246            yield
247
248        del tensors2
249        assert_change(0)
250        yield
251        tensors2 = permute
252        assert_change(0)
253        yield
254        del permute
255        assert_change(0, reset_peak=True)
256        yield
257
258        for i in range(int(N / 2)):
259            x = tensors2[i].numel()
260            del tensors2[i]
261            assert_change(-x)  # in case that tensors2[i] is empty
262            yield
263
264        for i in range(2, int(2 * N / 3) + 2):
265            tensors2.append(alloc(i, i * 3, i * 8))
266            assert_change(1)
267            yield
268
269        del tensors2
270        assert_change(-1, reset_peak=True)
271        assert_change(0)
272        self.assertEqual(torch.cuda.memory_allocated(device), m1)
273        yield True
274
275        del tensors1
276        assert_change(-1, reset_peak=True)
277        self.assertEqual(torch.cuda.memory_allocated(device), m0)
278
279        # test empty_cache and reset_peak
280        assert_change(0, empty_cache=True)
281        assert_change(0, reset_peak=True)
282
283    @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
284    @serialTest()
285    def test_memory_stats(self):
286        gc.collect()
287        torch.cuda.empty_cache()
288        for _ in self._test_memory_stats_generator(self):
289            self._check_memory_stat_consistency()
290
291    @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
292    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
293    def test_memory_stats_multigpu(self):
294        # advance a generator with a end flag
295        def advance(gen, end):
296            if not end:
297                try:
298                    next(gen)
299                except StopIteration:
300                    end = True
301            return end
302
303        # interlace
304        torch.cuda.empty_cache()
305        gen0 = self._test_memory_stats_generator(self, device="cuda:0", N=35)
306        gen1 = self._test_memory_stats_generator(
307            self, device=torch.device("cuda:1"), N=35
308        )
309        end0 = end1 = False
310        while not (end0 and end1):
311            end0 = advance(gen0, end0)
312            end1 = advance(gen1, end1)
313
314        # semi-random order
315        torch.cuda.empty_cache()
316        gen0 = self._test_memory_stats_generator(self, device=0, N=35)
317        gen1 = self._test_memory_stats_generator(
318            self, device=torch.device("cuda:1"), N=35
319        )
320        end0 = end1 = False
321
322        while not (end0 and end1):
323            end0 = advance(gen0, end0)
324            if not end0:
325                gen1_max_times = torch.LongTensor(1).random_(0, 3)[0]
326            else:
327                gen1_max_times = torch.inf
328            t = 0
329            while t < gen1_max_times and not end1:
330                end1 = advance(gen1, end1)
331                t += 1
332
333    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
334    def test_autogpu(self):
335        x = torch.randn(5, 5).cuda()
336        y = torch.randn(5, 5).cuda()
337        self.assertEqual(x.get_device(), 0)
338        self.assertEqual(x.get_device(), 0)
339        with torch.cuda.device(1):
340            z = torch.randn(5, 5).cuda()
341            self.assertEqual(z.get_device(), 1)
342            q = x.add(y)
343            self.assertEqual(q.get_device(), 0)
344            w = torch.randn(5, 5).cuda()
345            self.assertEqual(w.get_device(), 1)
346            self.assertEqual(y.cuda().get_device(), 1)
347        z = z.cuda()
348        self.assertEqual(z.get_device(), 0)
349
350    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
351    def test_new(self):
352        x = torch.randn(3, 3).cuda()
353        self.assertEqual(x.new([0, 1, 2]).get_device(), 0)
354        self.assertEqual(x.new([0, 1, 2], device=1).get_device(), 1)
355
356        with torch.cuda.device(1):
357            self.assertEqual(x.new([0, 1, 2]).get_device(), 0)
358            self.assertEqual(x.new([0, 1, 2], device=1).get_device(), 1)
359
360    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
361    def test_copy_device(self):
362        x = torch.randn(5, 5).cuda()
363        with torch.cuda.device(1):
364            y = x.cuda()
365            self.assertEqual(y.get_device(), 1)
366            self.assertIs(y.cuda(), y)
367            z = y.cuda(0)
368            self.assertEqual(z.get_device(), 0)
369            self.assertIs(z.cuda(0), z)
370
371        x = torch.randn(5, 5)
372        with torch.cuda.device(1):
373            y = x.cuda()
374            self.assertEqual(y.get_device(), 1)
375            self.assertIs(y.cuda(), y)
376            z = y.cuda(0)
377
378            self.assertEqual(z.get_device(), 0)
379            self.assertIs(z.cuda(0), z)
380
381    def _test_copy_sync_current_stream(self, x, y):
382        x_plus_one = x + 1
383        s0 = torch.cuda.Stream(device=x.device)
384        s1 = torch.cuda.Stream(device=y.device)
385        s2 = torch.cuda.Stream(device=x.device)
386        s3 = torch.cuda.Stream(device=y.device)
387
388        # same dst stream different src streams
389        with torch.cuda.stream(s0):
390            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
391            with torch.cuda.stream(s1):
392                y.copy_(x_plus_one)
393
394        with torch.cuda.stream(s2), torch.cuda.stream(s1):
395            y.copy_(x)
396
397        s1.synchronize()
398        # The copy() is synchronized on the current streams of both src and dst.
399        # In the above test, the _sleep() op on s0 will not block the copy() on
400        # s2, but both copies are synchronized on s1 in the dst device. Hence,
401        # x is copied to y after x_plus_one is copied to y. If x and y are on
402        # the same device, both copy() ops are synchronized on s1.
403        self.assertEqual(y, x)
404
405        # same src stream different dst streams
406        with torch.cuda.stream(s1):
407            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
408            with torch.cuda.stream(s0):
409                y.copy_(x_plus_one)
410
411        with torch.cuda.stream(s3), torch.cuda.stream(s0):
412            y.copy_(x)
413
414        s0.synchronize()
415        # Similarly, both copy() ops are synchronized on s0.
416        self.assertEqual(y, x)
417
418    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
419    def test_copy_streams(self):
420        d0 = torch.device("cuda:0")
421        x0 = torch.zeros(5, 5, device=d0)
422
423        d1 = torch.device("cuda:1")
424        x1 = torch.zeros(5, 5, device=d1)
425        self._test_copy_sync_current_stream(x0, x1)
426
427        x2 = torch.zeros(5, 5, device=d0)
428        self._test_copy_sync_current_stream(x0, x2)
429
430    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
431    def test_cat_autogpu(self):
432        x = torch.randn(4, 4).cuda(1)
433        y = torch.randn(4, 4).cuda(1)
434        z = torch.cat([x, y], 0)
435        self.assertEqual(z.get_device(), x.get_device())
436
437    @unittest.skipIf(torch.cuda.device_count() >= 10, "Loading a cuda:9 tensor")
438    def test_load_nonexistent_device(self):
439        # Setup: create a serialized file object with a 'cuda:9' restore location
440        tensor = torch.randn(2, device="cuda")
441        buf = io.BytesIO()
442        torch.save(tensor, buf)
443        # NB: this might not work in the future if serialization changes
444        buf = io.BytesIO(buf.getvalue().replace(b"cuda:0", b"cuda:9"))
445
446        msg = r"Attempting to deserialize object on CUDA device 9"
447        with self.assertRaisesRegex(RuntimeError, msg):
448            _ = torch.load(buf)
449
450    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
451    def test_multigpu_serialization_remap(self):
452        x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
453
454        def gpu_remap(storage, location):
455            if location == "cuda:1":
456                return storage.cuda(0)
457
458        with tempfile.NamedTemporaryFile() as f:
459            torch.save(x, f)
460            f.seek(0)
461            x_copy = torch.load(f, map_location=gpu_remap)
462
463        for original, copy in zip(x, x_copy):
464            self.assertEqual(copy, original)
465            self.assertIs(type(copy), type(original))
466            self.assertEqual(copy.get_device(), 0)
467
468    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
469    def test_multigpu_serialization_remap_dict(self):
470        x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
471        with tempfile.NamedTemporaryFile() as f:
472            torch.save(x, f)
473            f.seek(0)
474            x_copy = torch.load(f, map_location={"cuda:1": "cuda:0"})
475        for original, copy in zip(x, x_copy):
476            self.assertEqual(copy, original)
477            self.assertIs(type(copy), type(original))
478            self.assertEqual(copy.get_device(), 0)
479
480    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
481    def test_multigpu_storage_clone(self):
482        x = torch.randn(4, 4, device="cuda:1").storage()
483        y = x.clone()
484        self.assertEqual(x.get_device(), y.get_device())
485        for t in ["byte", "char", "short", "int", "long", "half", "double"]:
486            self.assertEqual(getattr(x, t)().get_device(), x.get_device())
487
488    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
489    def test_cuda_set_device(self):
490        x = torch.randn(5, 5)
491        with torch.cuda.device(1):
492            self.assertEqual(x.cuda().get_device(), 1)
493            torch.cuda.set_device(0)
494            self.assertEqual(x.cuda().get_device(), 0)
495            with torch.cuda.device(1):
496                self.assertEqual(x.cuda().get_device(), 1)
497            self.assertEqual(x.cuda().get_device(), 0)
498            torch.cuda.set_device(1)
499        self.assertEqual(x.cuda().get_device(), 0)
500
501    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
502    def test_current_stream(self):
503        d0 = torch.device("cuda:0")
504        d1 = torch.device("cuda:1")
505
506        s0 = torch.cuda.current_stream()
507        s1 = torch.cuda.current_stream(device=1)
508        s2 = torch.cuda.current_stream(device=0)
509
510        self.assertEqual(d0, s0.device)
511        self.assertEqual(d1, s1.device)
512        self.assertEqual(d0, s2.device)
513        self.assertEqual(s0, s2)
514
515        with torch.cuda.device(d1):
516            s0 = torch.cuda.current_stream()
517            s1 = torch.cuda.current_stream(1)
518            s2 = torch.cuda.current_stream(d0)
519
520        self.assertEqual(d1, s0.device)
521        self.assertEqual(d1, s1.device)
522        self.assertEqual(d0, s2.device)
523        self.assertEqual(s0, s1)
524
525        with self.assertRaisesRegex(ValueError, "Expected a cuda device, but got: cpu"):
526            torch.cuda.current_stream(torch.device("cpu"))
527
528    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
529    @skipCUDANonDefaultStreamIf(True)
530    def test_default_stream(self):
531        d0 = torch.device("cuda:0")
532        d1 = torch.device("cuda:1")
533
534        with torch.cuda.device(d0):
535            s0 = torch.cuda.default_stream()
536
537        with torch.cuda.device(d1):
538            s1 = torch.cuda.default_stream()
539
540        s2 = torch.cuda.default_stream(device=0)
541        s3 = torch.cuda.default_stream(d1)
542
543        self.assertEqual(d0, s0.device)
544        self.assertEqual(d1, s1.device)
545        self.assertEqual(d0, s2.device)
546        self.assertEqual(d1, s3.device)
547        self.assertEqual(s0, s2)
548        self.assertEqual(s1, s3)
549
550        with torch.cuda.device(d0):
551            self.assertEqual(torch.cuda.current_stream(), s0)
552
553        with torch.cuda.device(d1):
554            self.assertEqual(torch.cuda.current_stream(), s1)
555
556        with self.assertRaisesRegex(ValueError, "Expected a cuda device, but got: cpu"):
557            torch.cuda.default_stream(torch.device("cpu"))
558
559    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
560    def test_stream_event_device(self):
561        d0 = torch.device("cuda:0")
562        d1 = torch.device("cuda:1")
563        e0 = torch.cuda.Event()
564
565        self.assertEqual(None, e0.device)
566
567        with torch.cuda.device(d0):
568            s0 = torch.cuda.current_stream()
569            s0.record_event(e0)
570
571        with torch.cuda.device(d1):
572            s1 = torch.cuda.Stream()
573            e1 = s1.record_event()
574
575        self.assertEqual(s0.device, torch.device("cuda:0"))
576        self.assertEqual(e0.device, torch.device("cuda:0"))
577        self.assertEqual(s1.device, torch.device("cuda:1"))
578        self.assertEqual(e1.device, torch.device("cuda:1"))
579
580    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
581    def test_stream_context(self):
582        s0 = torch.cuda.current_stream()
583        s1 = torch.cuda.Stream(device=1)
584        s2 = torch.cuda.Stream(device=0)
585
586        with torch.cuda.device(s1.device):
587            prev_stream_on_cuda1 = torch.cuda.current_stream()
588
589        self.assertEqual(torch.cuda.current_stream(), s0)
590        self.assertEqual(0, torch.cuda.current_device())
591        with torch.cuda.stream(s1):
592            self.assertEqual(torch.cuda.current_stream(), s1)
593            self.assertEqual(1, torch.cuda.current_device())
594            with torch.cuda.stream(s2):
595                self.assertEqual(torch.cuda.current_stream(), s2)
596                self.assertEqual(0, torch.cuda.current_device())
597                with torch.cuda.stream(s0):
598                    self.assertEqual(torch.cuda.current_stream(), s0)
599                    self.assertEqual(0, torch.cuda.current_device())
600                self.assertEqual(torch.cuda.current_stream(), s2)
601                self.assertEqual(0, torch.cuda.current_device())
602            self.assertEqual(torch.cuda.current_stream(), s1)
603            self.assertEqual(1, torch.cuda.current_device())
604
605        with torch.cuda.device(s1.device):
606            self.assertEqual(prev_stream_on_cuda1, torch.cuda.current_stream())
607
608        self.assertEqual(torch.cuda.current_stream(), s0)
609        self.assertEqual(0, torch.cuda.current_device())
610
611    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
612    def test_streams_multi_gpu(self):
613        default_stream = torch.cuda.current_stream()
614        self.assertEqual(default_stream.device, torch.device("cuda:0"))
615        stream = torch.cuda.Stream(device=1)
616        self.assertEqual(stream.device, torch.device("cuda:1"))
617        with torch.cuda.device(1):
618            self.assertEqual(torch.cuda.current_stream().device, torch.device("cuda:1"))
619            self.assertNotEqual(torch.cuda.current_stream(), default_stream)
620
621    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
622    def test_streams_multi_gpu_query(self):
623        d0 = torch.device("cuda:0")
624        d1 = torch.device("cuda:1")
625        torch.cuda.synchronize(d0)
626        torch.cuda.synchronize(d1)
627
628        with torch.cuda.device(d0):
629            s0 = torch.cuda.current_stream()
630
631        with torch.cuda.device(d1):
632            s1 = torch.cuda.current_stream()
633            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
634
635        self.assertTrue(s0.query())
636        self.assertFalse(s1.query())
637
638        with torch.cuda.device(d0):
639            self.assertTrue(s0.query())
640            self.assertFalse(s1.query())
641
642        with torch.cuda.device(d1):
643            self.assertTrue(s0.query())
644            self.assertFalse(s1.query())
645
646        # deliberately using a different device
647        with torch.cuda.device(d0):
648            s1.synchronize()
649
650        self.assertTrue(s0.query())
651        self.assertTrue(s1.query())
652
653        with torch.cuda.device(d0):
654            self.assertTrue(s0.query())
655            self.assertTrue(s1.query())
656
657        with torch.cuda.device(d1):
658            self.assertTrue(s0.query())
659            self.assertTrue(s1.query())
660
661    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
662    def test_streams_multi_gpu_eq(self):
663        d0 = torch.device("cuda:0")
664        d1 = torch.device("cuda:1")
665
666        with torch.cuda.device(d0):
667            s0 = torch.cuda.current_stream()
668            s1 = torch.cuda.current_stream()
669
670        with torch.cuda.device(d1):
671            s2 = torch.cuda.current_stream()
672            s3 = torch.cuda.current_stream()
673
674        self.assertTrue(s0 == s0)
675        self.assertTrue(s0 == s1)
676        self.assertTrue(s2 == s2)
677        self.assertTrue(s2 == s3)
678        self.assertFalse(s0 == s2)
679        self.assertFalse(s1 == s3)
680
681        self.assertEqual(s0.device, s1.device)
682        self.assertEqual(s0.cuda_stream, s1.cuda_stream)
683        self.assertEqual(s2.device, s3.device)
684        self.assertEqual(s2.cuda_stream, s3.cuda_stream)
685        self.assertNotEqual(s0.device, s3.device)
686
687        self.assertEqual(hash(s0), hash(s1))
688        self.assertEqual(hash(s2), hash(s3))
689        self.assertNotEqual(hash(s0), hash(s3))
690
691    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
692    def test_streams_priority(self):
693        low, high = torch.cuda.Stream.priority_range()
694        s0 = torch.cuda.Stream(device=0, priority=low)
695
696        self.assertEqual(low, s0.priority)
697        self.assertEqual(torch.device("cuda:0"), s0.device)
698
699        s1 = torch.cuda.Stream(device=1, priority=high)
700
701        self.assertEqual(high, s1.priority)
702        self.assertEqual(torch.device("cuda:1"), s1.device)
703
704    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
705    def test_tensor_device(self):
706        self.assertEqual(torch.cuda.FloatTensor(1).get_device(), 0)
707        self.assertEqual(torch.cuda.FloatTensor(1, device=1).get_device(), 1)
708        with torch.cuda.device(1):
709            self.assertEqual(torch.cuda.FloatTensor(1).get_device(), 1)
710            self.assertEqual(torch.cuda.FloatTensor(1, device=0).get_device(), 0)
711            self.assertEqual(torch.cuda.FloatTensor(1, device=None).get_device(), 1)
712
713    @staticmethod
714    def _stream_synchronize(self, spin_time_cycles):
715        s = torch.cuda.current_stream()
716        e_tik = torch.cuda.Event(enable_timing=True)
717        e_tok = torch.cuda.Event(enable_timing=True)
718
719        e_tik.record(s)
720        torch.cuda._sleep(spin_time_cycles)
721        e_tok.record(s)
722        s.synchronize()
723
724        self.assertTrue(s.query())
725
726        # not necessary to check e_tik and e_tok, as elapsed_time would throw
727        # exception if otherwise.
728        return e_tik.elapsed_time(e_tok)
729
730    @staticmethod
731    def _event_synchronize(self, spin_time_cycles):
732        s = torch.cuda.current_stream()
733        e_tik = torch.cuda.Event(enable_timing=True)
734        e_tok = torch.cuda.Event(enable_timing=True)
735
736        e_tik.record(s)
737        torch.cuda._sleep(spin_time_cycles)
738        s.record_event(e_tok)
739        e_tok.synchronize()
740
741        self.assertTrue(s.query())
742
743        # not necessary to check e_tik and e_tok, as elapsed_time would throw
744        # exception if otherwise.
745        return e_tik.elapsed_time(e_tok)
746
747    @staticmethod
748    def _event_wait(self, spin_time_cycles):
749        s0 = torch.cuda.current_stream()
750        s1 = torch.cuda.Stream()
751        e_tik = torch.cuda.Event(blocking=True, enable_timing=True)
752        e_tok = torch.cuda.Event(blocking=True, enable_timing=True)
753
754        e_tik.record(s0)
755        torch.cuda._sleep(spin_time_cycles - 10)
756        e_sync = torch.cuda.Event(blocking=True)
757        e_sync.record()
758        e_sync.wait(s1)
759        with torch.cuda.stream(s1):
760            torch.cuda._sleep(10)
761        s1.synchronize()
762        e_tok.record()
763        e_tok.synchronize()
764
765        self.assertTrue(s0.query())
766        self.assertTrue(s1.query())
767        self.assertTrue(e_sync.query())
768
769        # not necessary to check e_tik and e_tok, as elapsed_time would throw
770        # exception if otherwise.
771        return e_tik.elapsed_time(e_tok)
772
773    @staticmethod
774    def _test_stream_event_nogil(self, sync_func, p2c, c2p):
775        with torch.cuda.device("cuda:1"):
776            c2p.put(0)
777            p2c.get()
778            c2p.put(sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES))
779
780    # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
781    @skipIfRocm
782    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
783    def test_stream_event_nogil(self):
784        for sync_func in [
785            TestCudaMultiGPU._stream_synchronize,
786            TestCudaMultiGPU._event_synchronize,
787            TestCudaMultiGPU._event_wait,
788        ]:
789            p2c = queue.Queue()
790            c2p = queue.Queue()
791            e_tik = torch.cuda.Event(enable_timing=True)
792            e_tok = torch.cuda.Event(enable_timing=True)
793
794            t = threading.Thread(
795                target=TestCudaMultiGPU._test_stream_event_nogil,
796                args=(self, sync_func, p2c, c2p),
797            )
798            t.daemon = True
799            t.start()
800
801            c2p.get()
802            with torch.cuda.device("cuda:0"):
803                e_tik.record()
804                p2c.put(0)
805                parent_time = sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES)
806                child_time = c2p.get()
807                e_tok.record()
808                e_tok.synchronize()
809                total_time = e_tik.elapsed_time(e_tok)
810
811            # Without GIL, synchronizations in parent and child threads can
812            # overlap. The total execution time should be a little bit longer
813            # than spinning fifty million cycles and much shorter than twice of
814            # that. However, testing absolute execution time is not reliable as
815            # it may vary on different hardware in different environments.
816            # Therefore, this test uses relative comparisons, checking if the
817            # sum of parent and child threads execution time is greater than the
818            # real execution time by least 40%.
819            self.assertGreater(parent_time + child_time, total_time * 1.4)
820
821    # This test is flaky for ROCm, see issue #62602
822    @skipIfRocm
823    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
824    def test_events_wait(self):
825        d0 = torch.device("cuda:0")
826        d1 = torch.device("cuda:1")
827        torch.cuda.synchronize(d0)
828        torch.cuda.synchronize(d1)
829
830        with torch.cuda.device(d0):
831            s0 = torch.cuda.current_stream()
832            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
833            e0 = torch.cuda.Event()
834            s0.record_event(e0)
835
836        with torch.cuda.device(d1):
837            s1 = torch.cuda.current_stream()
838
839        self.assertFalse(s0.query())
840        self.assertTrue(s1.query())
841
842        s1.wait_event(e0)
843        s1.synchronize()
844
845        self.assertTrue(e0.query())
846        self.assertTrue(s0.query())
847        self.assertTrue(s1.query())
848
849    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
850    def test_events_multi_gpu_query(self):
851        d0 = torch.device("cuda:0")
852        d1 = torch.device("cuda:1")
853
854        with torch.cuda.device(d0):
855            s0 = torch.cuda.current_stream()
856            e0 = s0.record_event()
857            s0.synchronize()
858
859        with torch.cuda.device(d1):
860            s1 = torch.cuda.current_stream()
861            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
862            e1 = s1.record_event()
863
864        self.assertTrue(e0.query())
865        self.assertFalse(e1.query())
866
867        with torch.cuda.device(d0):
868            self.assertTrue(e0.query())
869            self.assertFalse(e1.query())
870
871        with torch.cuda.device(d1):
872            self.assertTrue(e0.query())
873            self.assertFalse(e1.query())
874
875        # deliberately using a different device
876        with torch.cuda.device(d0):
877            e1.synchronize()
878
879        self.assertTrue(e0.query())
880        self.assertTrue(e1.query())
881
882        with torch.cuda.device(d0):
883            self.assertTrue(e0.query())
884            self.assertTrue(e1.query())
885
886        with torch.cuda.device(d1):
887            self.assertTrue(e0.query())
888            self.assertTrue(e1.query())
889
890    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
891    @skipIfRocm
892    def test_events_multi_gpu_elapsed_time(self):
893        d0 = torch.device("cuda:0")
894        d1 = torch.device("cuda:1")
895
896        with torch.cuda.device(d0):
897            s0 = torch.cuda.current_stream()
898            e0 = torch.cuda.Event(enable_timing=True)
899            torch.cuda._sleep(10)
900            s0.record_event(e0)
901
902        with torch.cuda.device(d1):
903            s1 = torch.cuda.current_stream()
904            e1 = torch.cuda.Event(enable_timing=True)
905            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
906            s1.record_event(e1)
907
908        e0.synchronize()
909        e1.synchronize()
910        with torch.cuda.device(d0):
911            with self.assertRaises(RuntimeError):
912                self.assertGreater(e0.elapsed_time(e1), 0)
913
914        with torch.cuda.device(d1):
915            with self.assertRaises(RuntimeError):
916                self.assertGreater(e0.elapsed_time(e1), 0)
917
918        with torch.cuda.device(d0):
919            s0 = torch.cuda.current_stream()
920            e2 = torch.cuda.Event(enable_timing=True)
921            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
922            s0.record_event(e2)
923            s0.synchronize()
924
925        self.assertGreater(e0.elapsed_time(e2), 0)
926
927        # deliberately calling from a different device
928        with torch.cuda.device(d1):
929            self.assertGreater(e0.elapsed_time(e2), 0)
930
931    @contextlib.contextmanager
932    def _get_external_stream(self, device):
933        cudart = torch.cuda.cudart()
934        stream = ctypes.c_ulonglong(0)
935        stream_p = ctypes.POINTER(ctypes.c_void_p)(stream)
936        stream_p_int = ctypes.cast(stream_p, ctypes.c_void_p).value
937        with device:
938            try:
939                out = cudart.cudaStreamCreate(stream_p_int)
940                self.assertEqual(out, 0)
941                self.assertNotEqual(stream.value, 0)
942                yield stream.value
943            finally:
944                out = cudart.cudaStreamDestroy(stream.value)
945                self.assertEqual(out, 0)
946
947    def test_external_streams(self):
948        device = torch.cuda.device(0)
949        with self._get_external_stream(device) as stream_v:
950            ext_stream = torch.cuda.ExternalStream(stream_v)
951            self.assertEqual(stream_v, ext_stream.cuda_stream)
952            self.assertEqual(ext_stream.device.index, device.idx)
953
954    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
955    def test_external_streams_multi_device(self):
956        device = torch.cuda.device(1)
957        with self._get_external_stream(device) as stream_v:
958            ext_stream = torch.cuda.ExternalStream(stream_v, device=device)
959            self.assertEqual(stream_v, ext_stream.cuda_stream)
960            self.assertEqual(ext_stream.device.index, device.idx)
961
962    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
963    def test_caching_pinned_memory_multi_gpu(self):
964        # checks that the events preventing pinned memory from being re-used
965        # too early are recorded on the correct GPU
966        cycles_per_ms = get_cycles_per_ms()
967
968        t = torch.FloatTensor([1]).pin_memory()
969        ptr = t.data_ptr()
970        gpu_tensor0 = torch.cuda.FloatTensor([0], device=0)
971        gpu_tensor1 = torch.cuda.FloatTensor([0], device=1)
972
973        with torch.cuda.device(1):
974            torch.cuda._sleep(int(1000 * cycles_per_ms))  # delay the copy by 1s
975            gpu_tensor1.copy_(t, non_blocking=True)
976
977        del t
978        t = torch.FloatTensor([2]).pin_memory()
979        self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon")
980
981        with torch.cuda.device(0):
982            gpu_tensor0.copy_(t, non_blocking=True)
983
984        self.assertEqual(gpu_tensor1[0], 1)
985        self.assertEqual(gpu_tensor0[0], 2)
986
987    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
988    def test_get_set_rng_state_all(self):
989        states = torch.cuda.get_rng_state_all()
990        before0 = torch.cuda.FloatTensor(100, device=0).normal_()
991        before1 = torch.cuda.FloatTensor(100, device=1).normal_()
992        torch.cuda.set_rng_state_all(states)
993        after0 = torch.cuda.FloatTensor(100, device=0).normal_()
994        after1 = torch.cuda.FloatTensor(100, device=1).normal_()
995        self.assertEqual(before0, after0, atol=0, rtol=0)
996        self.assertEqual(before1, after1, atol=0, rtol=0)
997
998    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
999    def test_rng_state_offset(self):
1000        before = torch.cuda.get_rng_state()
1001        torch.cuda._set_rng_state_offset(100)
1002        offset = torch.cuda._get_rng_state_offset()
1003        torch.cuda.set_rng_state(before)
1004        self.assertEqual(offset, 100)
1005
1006    # Verifies that mem_get_info works, including when called for a different device
1007    def test_mem_get_info(self):
1008        def _test(device: Union[str, int, torch.device]):
1009            # Prevent PyTorch from reusing the allocated memory
1010            torch.cuda.empty_cache()
1011            torch.cuda.synchronize()
1012            before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(device)
1013            # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
1014            t = torch.randn(1024 * 1024 * 8, device=device)
1015            if IS_JETSON:
1016                # w/o syncing, mem_get_info will run before memory allocated has actually increased.
1017                # This race condition causes consistent failure
1018                torch.cuda.synchronize()
1019            after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(device)
1020
1021            self.assertLess(after_free_bytes, before_free_bytes)
1022            self.assertEqual(before_available_bytes, after_available_bytes)
1023
1024        # Test calls with different device representations
1025        _test(0)
1026        _test(torch.device("cuda"))
1027        _test(torch.device("cuda:0"))
1028        _test("cuda")
1029        _test("cuda:0")
1030        if TEST_MULTIGPU:
1031            _test(1)
1032            _test(torch.device("cuda:1"))
1033            _test("cuda:1")
1034
1035    # Test that wrap_with_cuda_memory_check successfully detects leak
1036    def test_cuda_memory_leak_detection(self):
1037        l = []
1038
1039        @self.wrap_with_cuda_memory_check
1040        def no_leak():
1041            pass
1042
1043        @self.wrap_with_cuda_memory_check
1044        def leak_gpu0():
1045            # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
1046            l.append(torch.randn(1024 * 1024 * 8, device=torch.device("cuda:0")))
1047
1048        no_leak()
1049        regex = r"CUDA driver API confirmed .+ on device 0.+"
1050        if IS_JETSON:
1051            try:
1052                leak_gpu0()
1053            except RuntimeError as e:
1054                import re
1055
1056                assert re.match(regex, str(e)), str(e) + "\n does not match: \n" + regex
1057        else:
1058            # assertRaisesRegex does not pass with Python for Jetson,
1059            # even though the RuntimeError matches regex using re.match
1060            with self.assertRaisesRegex(RuntimeError, regex):
1061                leak_gpu0()
1062
1063        if TEST_MULTIGPU:
1064
1065            @self.wrap_with_cuda_memory_check
1066            def leak_gpu1():
1067                # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
1068                l.append(torch.randn(1024 * 1024 * 8, device=torch.device("cuda:1")))
1069
1070            with self.assertRaisesRegex(
1071                RuntimeError, r"CUDA driver API confirmed .+ on device 1.+"
1072            ):
1073                leak_gpu1()
1074
1075    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1076    def test_streaming_backwards_device_transfer(self):
1077        # This function must run with non-default current streams on all devices, otherwise it's meaningless.
1078        # The intention is to test that to()'s backward (CopyBackward) interacts properly with the
1079        # synchronization logic in torch/csrc/autograd/input_buffer.cpp.
1080        dev0 = torch.device("cuda:0")
1081        dev1 = torch.device("cuda:1")
1082
1083        # Unfortunately I need to make the tensors largeish.
1084        # Bigger tensors = longer D2D transfers = more likely to expose races.
1085        size = 2**26
1086
1087        a = torch.full((size,), 1, device=dev1, dtype=torch.float64, requires_grad=True)
1088        b = torch.full((size,), 1, device=dev1, dtype=torch.float64, requires_grad=True)
1089
1090        # Here to_backward_recipient = a*b is used only once, so MulBackward's InputBuffer slot only expects 1 input.
1091        # This tests the situation where we don't call InputBuffer::accumulate for MulBackward's InputBuffer.
1092        to_backward_recipient = a * b
1093        s = to_backward_recipient.to(device="cuda:0").sum()
1094        torch.cuda.synchronize(device=dev0)
1095        torch.cuda.synchronize(device=dev1)
1096        s.backward()
1097        self.assertTrue(a.grad.sum().item() == size)
1098        self.assertTrue(b.grad.sum().item() == size)
1099
1100        # Here to_backward_recipient = a*b is used twice, so MulBackward's InputBuffer slot expects 2 inputs.
1101        # This tests the situation where we do call InputBuffer::accumulate for MulBackward's InputBuffer.
1102        a.grad = None
1103        b.grad = None
1104        to_backward_recipient = a * b
1105        # Multiply by 2 here so to's backward creates gradient values that are different from the case above,
1106        # to mitigate weirdness if the caching allocator happens to reuse memory regions that were populated
1107        # with 1s by the case above
1108        s0 = to_backward_recipient.to(device="cuda:0").sum() * 2.0
1109        s1 = to_backward_recipient.to(device="cuda:0").sum() * 2.0
1110        torch.cuda.synchronize(device=dev0)
1111        torch.cuda.synchronize(device=dev1)
1112        s0.backward(retain_graph=True)
1113        s1.backward()
1114        self.assertTrue(a.grad.sum().item() == 4 * size)
1115        self.assertTrue(b.grad.sum().item() == 4 * size)
1116
1117    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1118    @unittest.skipIf(IS_SANDCASTLE or IS_REMOTE_GPU, "Does not work on Sandcastle")
1119    def test_cuda_init_race(self):
1120        # See https://github.com/pytorch/pytorch/issues/16559
1121        import subprocess
1122
1123        subprocess.check_call(
1124            [
1125                sys.executable,
1126                "-c",
1127                """\
1128import torch
1129import threading
1130
1131def worker(rank):
1132    torch.tensor([1.]).cuda(rank)
1133
1134t1 = threading.Thread(target=worker, args=(0,))
1135t2 = threading.Thread(target=worker, args=(1,))
1136t1.start()
1137t2.start()
1138""",
1139            ]
1140        )
1141
1142    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1143    def test_grad_scaling_device_as_key(self):
1144        # Ensure that different instances of "device" objects that point to the same device
1145        # are treated as identical keys by dicts.  GradScaler relies on this behavior, and may
1146        # error otherwise in a way that's difficult to detect (a silent performance hit).
1147        d = {}
1148        t = torch.empty((1,), device="cuda:0")
1149        dev0a = torch.device("cuda:0")
1150        dev0b = torch.device("cuda:0")
1151        dev1a = torch.device("cuda:1")
1152        dev1b = torch.device("cuda:1")
1153
1154        self.assertTrue(hash(dev0a) == hash(dev0b))
1155        self.assertTrue(hash(dev1a) == hash(dev1b))
1156
1157        d[dev0a] = "0a"
1158        d[dev0b] = "0b"
1159        self.assertTrue(len(d) == 1)
1160        self.assertTrue(d[dev0a] == "0b")
1161        d[t.device] = "t"
1162        self.assertTrue(len(d) == 1)
1163        self.assertTrue(d[dev0a] == "t")
1164
1165        d[dev1a] = "1a"
1166        d[dev1b] = "1b"
1167        self.assertTrue(len(d) == 2)
1168        self.assertTrue(d[dev1a] == "1b")
1169
1170    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1171    def test_grad_scaling_scale(self):
1172        scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
1173        t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:0")
1174        t1 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:1")
1175        # Create some nested iterables of tensors on different devices.
1176        outputs = (
1177            t1.clone(),
1178            (t0.clone(), t1.clone()),
1179            [t0.clone(), (t1.clone(), t0.clone())],
1180        )
1181        outputs = scaler.scale(outputs)
1182        self.assertTrue(
1183            outputs[0] == 8.0
1184            and outputs[1][0] == 8.0
1185            and outputs[1][1] == 8.0
1186            and outputs[2][0] == 8.0
1187            and outputs[2][1][0] == 8.0
1188            and outputs[2][1][1] == 8.0
1189        )
1190        self.assertTrue(scaler._scale.device == t1.device)
1191
1192    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1193    def test_grad_scaling_multigpu(self):
1194        # Same as above, but runs some of the models on device 1.
1195        # GradScaler should transparently handle losses and gradients on multiple devices.
1196        # This test could be combined with the test above, but I think it makes sense to treat
1197        # multi-GPU operations separately.
1198        dev0 = torch.device("cuda:0")
1199        dev1 = torch.device("cuda:1")
1200
1201        for enabled in True, False:
1202            (
1203                mod_control0,
1204                mod_scaling0,
1205                opt_control0,
1206                opt_scaling0,
1207                data,
1208                loss_fn,
1209                skip_iter,
1210            ) = _create_scaling_case()
1211            (
1212                mod_control1,
1213                mod_scaling1,
1214                opt_control1,
1215                opt_scaling1,
1216            ) = _create_scaling_models_optimizers(device=dev1)
1217
1218            scaler = torch.amp.GradScaler(
1219                device="cuda",
1220                init_scale=128.0,
1221                growth_factor=2.0,
1222                enabled=enabled,
1223                growth_interval=1,
1224            )
1225
1226            def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
1227                for i, (input, target) in enumerate(data):
1228                    optimizer0.zero_grad()
1229                    optimizer1.zero_grad()
1230                    output0 = model0(input)
1231                    output1 = model1(input.to(dev1))
1232                    loss0 = loss_fn(0.3 * output0 + 0.7 * output1.to(dev0), target)
1233                    loss1 = loss_fn(
1234                        0.6 * output0.to(dev1) - 0.4 * output1, target.to(dev1)
1235                    )
1236
1237                    if try_scaling_api:
1238                        scaler.scale(loss0).backward(retain_graph=True)
1239                        scaler.scale(loss1).backward()
1240                        if i == skip_iter and scaler.is_enabled():
1241                            model1[1].weight.grad.data.fill_(float("inf"))
1242
1243                        # As an additional stress test, separately unscale for one of the optimizers.
1244                        scaler.unscale_(optimizer0)
1245
1246                        scaler.step(optimizer0)
1247                        scaler.step(optimizer1)
1248
1249                        # Make sure the found_infs were collected properly across optimizers and devices.
1250                        if scaler.is_enabled():
1251                            self.assertTrue(
1252                                len(scaler._found_inf_per_device(optimizer0)) == 1
1253                            )
1254                            self.assertTrue(
1255                                len(scaler._found_inf_per_device(optimizer1)) == 1
1256                            )
1257                            self.assertTrue(
1258                                scaler._found_inf_per_device(optimizer0)[dev0].item()
1259                                == 0.0
1260                            )
1261                            self.assertTrue(
1262                                scaler._found_inf_per_device(optimizer1)[dev1].item()
1263                                == float(i == skip_iter)
1264                            )
1265
1266                        scaler.update()
1267                    else:
1268                        loss0.backward(retain_graph=True)
1269                        loss1.backward()
1270                        optimizer0.step()
1271                        if (not scaler.is_enabled()) or (i != skip_iter):
1272                            optimizer1.step()
1273
1274            run(mod_control0, mod_control1, opt_control0, opt_control1, False)
1275            run(mod_scaling0, mod_scaling1, opt_scaling0, opt_scaling1, True)
1276
1277            # The loss scale should have been multiplied by the growth factor 3 times and the backoff factor once.
1278            self.assertTrue(
1279                scaler.get_scale()
1280                == (
1281                    128.0
1282                    * scaler.get_growth_factor() ** 3
1283                    * scaler.get_backoff_factor() ** 1
1284                )
1285                if enabled
1286                else 1.0
1287            )
1288
1289            # Copy mod_control1 and mod_scaling1 back the device 0 for comparison
1290            mod_control1.to(dev0)
1291            mod_scaling1.to(dev0)
1292
1293            for c, s in zip(
1294                chain(mod_control0.parameters(), mod_control1.parameters()),
1295                chain(mod_scaling0.parameters(), mod_scaling1.parameters()),
1296            ):
1297                self.assertEqual(c, s, rtol=1e-5, atol=1e-7)
1298
1299    @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs")
1300    def test_cuda_device_memory_allocated(self):
1301        from torch.cuda import memory_allocated
1302
1303        device_count = torch.cuda.device_count()
1304        current_alloc = [memory_allocated(idx) for idx in range(device_count)]
1305        x = torch.ones(10, device="cuda:0")
1306        self.assertGreater(memory_allocated(0), current_alloc[0])
1307        self.assertTrue(
1308            all(
1309                memory_allocated(torch.cuda.device(idx)) == current_alloc[idx]
1310                for idx in range(1, device_count)
1311            )
1312        )
1313
1314
1315class TestCudaComm(TestCase):
1316    def _test_broadcast(self, input):
1317        if not TEST_MULTIGPU:
1318            raise unittest.SkipTest("only one GPU detected")
1319        # test regular
1320        results = comm.broadcast(input, (0, 1))
1321        for i, t in enumerate(results):
1322            self.assertEqual(t.get_device(), i)
1323            self.assertEqual(t, input)
1324            if (
1325                input.is_cuda and input.get_device() == i
1326            ):  # test not copying on same device
1327                self.assertEqual(t.data_ptr(), input.data_ptr())
1328        # test out=
1329        for inplace in [True, False]:
1330            if inplace:
1331                outputs = [
1332                    torch.empty_like(input, device=0),
1333                    torch.empty_like(input, device=1),
1334                ]
1335            else:
1336                outputs = [input.cuda(0), torch.empty_like(input, device=1)]
1337            results = comm.broadcast(input, out=outputs)
1338            for r, o in zip(results, outputs):
1339                self.assertIs(r, o)
1340            for i, t in enumerate(results):
1341                self.assertEqual(t.get_device(), i)
1342                self.assertEqual(t, input)
1343        # test error msg
1344        with self.assertRaisesRegex(
1345            RuntimeError, r"Exactly one of 'devices' and 'out'"
1346        ):
1347            comm.broadcast(input, (0, 1), out=outputs)
1348        with self.assertRaisesRegex(
1349            RuntimeError,
1350            r"Expected all output tensors to be CUDA tensors, but output tensor at index 1",
1351        ):
1352            comm.broadcast(input, out=[input.cuda(0), input.cpu()])
1353        with self.assertRaisesRegex(
1354            RuntimeError,
1355            r"Expected all output tensors to have same shape as the source .+ at index 1",
1356        ):
1357            comm.broadcast(input, out=[input.cuda(0), input.cuda(1).unsqueeze(0)])
1358
1359    def test_broadcast_cpu(self):
1360        self._test_broadcast(torch.randn(5, 5))
1361
1362    def test_broadcast_gpu(self):
1363        self._test_broadcast(torch.randn(5, 5).cuda())
1364
1365    def _test_broadcast_coalesced(self, tensors, buffer_size):
1366        b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors]
1367        for (_, bt), t in zip(b_tensors, tensors):
1368            self.assertEqual(bt.get_device(), 1)
1369            self.assertEqual(bt, t)
1370            self.assertIsInstance(bt, type(t))
1371
1372        bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=buffer_size)
1373        bc_tensors_t = list(zip(*bc_tensors))
1374        self.assertEqual(b_tensors, bc_tensors_t)
1375        for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t):
1376            self.assertEqual(bt.get_device(), bct.get_device())
1377            self.assertIsInstance(bct, type(bt))
1378
1379        # check that tensors on device[0] are returned as-is
1380        for out_tensors in (b_tensors, bc_tensors_t):
1381            for inp_t, (out_t, _) in zip(tensors, out_tensors):
1382                self.assertIs(inp_t, out_t)
1383
1384        # check that the tensors not on device[0] have different version counters
1385        # NOTE [ Version Counter in comm.*_coalesced ]
1386        versions = [t._version for _, t in bc_tensors_t]
1387        for old_version, (_, t) in zip(versions, bc_tensors_t):
1388            self.assertEqual(t._version, old_version)
1389            t.zero_()
1390            self.assertEqual(t._version, old_version + 1)
1391
1392    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1393    # Note: fails sometimes on the CI, passes on dual gfx906
1394    def test_broadcast_coalesced(self):
1395        numel = 5
1396        num_bytes = numel * 8
1397        tensors = [
1398            self.genSparseTensor((2, 3), 2, 1, False, "cuda", torch.float64)[0],
1399            torch.randn(numel).long().cuda(),
1400            torch.randn(numel).cuda(),
1401            self.genSparseTensor((2, 3), 2, 10, False, "cuda", torch.float64)[0],
1402            self.genSparseTensor((2, 3), 2, 5, False, "cuda", torch.float64)[0],
1403            self.genSparseTensor((3, 3), 2, 7, False, "cuda", torch.int64)[0],
1404            self.genSparseTensor((2, 3), 2, 2, False, "cuda", torch.float32)[0],
1405            torch.randn(numel).long().cuda(),
1406            torch.randn(numel).long().cuda(),
1407            self.genSparseTensor((2, 7), 2, 3, False, "cuda", torch.int64)[0],
1408            torch.randn(numel * 2).int().cuda(),  # int is 2x shorter
1409            torch.randn(numel).cuda(),
1410        ]
1411        self._test_broadcast_coalesced(tensors, num_bytes * 5 // 2)
1412
1413    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1414    def test_broadcast_coalesced_dense_only(self):
1415        numel = 5
1416        num_bytes = numel * 8
1417        tensors = [
1418            torch.randn(numel).long().cuda(),
1419            torch.randn(numel).cuda(),
1420            torch.randn(numel).long().cuda(),
1421            torch.randn(numel).long().cuda(),
1422            torch.randn(numel * 2).int().cuda(),  # int is 2x shorter
1423            torch.randn(numel).cuda(),
1424        ]
1425        self._test_broadcast_coalesced(tensors, num_bytes * 5 // 2)
1426
1427    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1428    def test_broadcast_coalesced_empty_tensors(self):
1429        tensors = [
1430            torch.tensor([]).byte().cuda(),
1431            torch.randn(5).cuda(),
1432            torch.randn(5).double().cuda(),
1433        ]
1434        self._test_broadcast_coalesced(tensors, 256)
1435
1436    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1437    def test_reduce_add(self):
1438        x = torch.randn(5, 5)
1439        y = torch.randn(5, 5)
1440        x_cuda = x.cuda(0)
1441        y_cuda = y.cuda(1)
1442        result = comm.reduce_add((x_cuda, y_cuda))
1443        self.assertEqual(result.get_device(), 0)
1444        self.assertEqual(result.cpu(), x + y)
1445
1446    def _test_reduce_add_coalesced(self, tensors, buffer_size):
1447        dup_tensors = [tensors, [t.cuda(1) for t in tensors]]
1448
1449        r_tensors = [comm.reduce_add(t) for t in zip(*dup_tensors)]
1450        for r, t in zip(r_tensors, tensors):
1451            self.assertEqualTypeString(r, t)
1452            self.assertEqual(r.coalesce() if r.is_sparse else r, t * 2)
1453
1454        rc_tensors = comm.reduce_add_coalesced(dup_tensors, buffer_size=buffer_size)
1455        self.assertEqual(r_tensors, rc_tensors)
1456        for r, rc in zip(r_tensors, rc_tensors):
1457            self.assertEqualTypeString(rc, r)
1458
1459        # Since we have both cuda:0 and cuda:1 inputs, the outputs must be new.
1460        # We can check that they have different version counters.
1461        # NOTE [ Version Counter in comm.*_coalesced ]
1462        versions = [t._version for t in rc_tensors]
1463        for old_version, t in zip(versions, rc_tensors):
1464            self.assertEqual(t._version, old_version)
1465            t.zero_()
1466            self.assertEqual(t._version, old_version + 1)
1467
1468    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1469    def test_reduce_add_coalesced(self):
1470        numel = 5
1471        num_bytes = numel * 8
1472        tensors = [
1473            self.genSparseTensor((2, 3), 2, 1, False, "cuda", torch.float64)[0],
1474            torch.randn(numel).long().cuda(),
1475            torch.randn(numel).cuda(),
1476            self.genSparseTensor((2, 3), 2, 10, False, "cuda", torch.float64)[0],
1477            self.genSparseTensor((2, 3), 2, 5, False, "cuda", torch.float64)[0],
1478            self.genSparseTensor((3, 3), 2, 7, False, "cuda", torch.int64)[0],
1479            self.genSparseTensor((2, 3), 2, 2, False, "cuda", torch.float32)[0],
1480            torch.randn(numel).long().cuda(),
1481            torch.randn(numel).long().cuda(),
1482            self.genSparseTensor((2, 7), 2, 3, False, "cuda", torch.int64)[0],
1483            torch.randn(numel * 2).int().cuda(),  # int is 2x shorter
1484            torch.randn(numel).cuda(),
1485        ]
1486        self._test_reduce_add_coalesced(tensors, num_bytes * 5 // 2)
1487
1488    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1489    def test_reduce_add_coalesced_dense_only(self):
1490        numel = 5
1491        num_bytes = numel * 8
1492        tensors = [
1493            torch.randn(numel).long().cuda(),
1494            torch.randn(numel).cuda(),
1495            torch.randn(numel).long().cuda(),
1496            torch.randn(numel).long().cuda(),
1497            torch.randn(numel * 2).int().cuda(),  # int is 2x shorter
1498            torch.randn(numel).cuda(),
1499        ]
1500        self._test_reduce_add_coalesced(tensors, num_bytes * 5 // 2)
1501
1502    def _test_scatter(self, input, chunk_sizes=None, dim=0):
1503        if not TEST_MULTIGPU:
1504            raise unittest.SkipTest("only one GPU detected")
1505        if chunk_sizes is None:
1506            ref_chunk_sizes = tuple(repeat(input.size(dim) // 2, 2))
1507        else:
1508            ref_chunk_sizes = chunk_sizes
1509
1510        # test regular
1511        result = comm.scatter(input, (0, 1), chunk_sizes, dim)
1512        self.assertEqual(len(result), 2)
1513        chunk_start = 0
1514        for i, r in enumerate(result):
1515            chunk_end = chunk_start + ref_chunk_sizes[i]
1516            index = [slice(None, None) for _ in range(input.dim())]
1517            index[dim] = slice(chunk_start, chunk_end)
1518            self.assertEqual(r, input[tuple(index)], atol=0, rtol=0)
1519            chunk_start = chunk_end
1520            if r.device == input.device:
1521                self.assertEqual(
1522                    r.data_ptr(), input.data_ptr()
1523                )  # for target @ same device, a view should be returned
1524
1525        # test out
1526        out = [torch.empty_like(t) for t in result]
1527        result = comm.scatter(input, dim=dim, out=out)
1528        self.assertEqual(len(result), 2)
1529        chunk_start = 0
1530        for i, r in enumerate(result):
1531            self.assertIs(r, out[i])
1532            chunk_end = chunk_start + ref_chunk_sizes[i]
1533            index = [slice(None, None) for _ in range(input.dim())]
1534            index[dim] = slice(chunk_start, chunk_end)
1535            self.assertEqual(r, input[tuple(index)], atol=0, rtol=0)
1536            chunk_start = chunk_end
1537
1538        # test error msg
1539        if chunk_sizes is not None:
1540            with self.assertRaisesRegex(
1541                RuntimeError, r"Expected devices and chunk_sizes to be of same length"
1542            ):
1543                comm.scatter(
1544                    input,
1545                    [0 for _ in range(len(chunk_sizes) + 1)],
1546                    dim=dim,
1547                    chunk_sizes=chunk_sizes,
1548                )
1549        with self.assertRaisesRegex(RuntimeError, r"'devices' must not be specified"):
1550            comm.scatter(input, (0, 1), dim=dim, out=out)
1551        with self.assertRaisesRegex(
1552            RuntimeError, r"Expected at least one device to scatter to"
1553        ):
1554            comm.scatter(input, (), dim=dim)
1555        with self.assertRaisesRegex(
1556            RuntimeError, r"Expected at least one output tensor to scatter to"
1557        ):
1558            comm.scatter(input, dim=dim, out=[])
1559        with self.assertRaisesRegex(
1560            RuntimeError,
1561            r"Expected all output tensors to be CUDA tensors, but output tensor at index 0",
1562        ):
1563            comm.scatter(input, dim=dim, out=([out[0].cpu()] + out[1:]))
1564        with self.assertRaisesRegex(
1565            RuntimeError, r"Output tensor at index 0 has incorrect shape"
1566        ):
1567            comm.scatter(input, dim=dim, out=([out[0].unsqueeze(0)] + out[1:]))
1568        with self.assertRaisesRegex(
1569            RuntimeError,
1570            r"Total size for output tensors along scatter dim \d+ does not match",
1571        ):
1572            index = [slice(None, None) for _ in range(input.dim())]
1573            index[dim] = slice(1, None)
1574            comm.scatter(input, dim=dim, out=([out[0][tuple(index)]] + out[1:]))
1575
1576    def test_scatter_cpu(self):
1577        self._test_scatter(torch.randn(4, 4), dim=0)
1578
1579    def test_scatter_cpu_dim(self):
1580        self._test_scatter(torch.randn(4, 4), dim=1)
1581
1582    def test_scatter_cpu_neg_dim(self):
1583        self._test_scatter(torch.randn(4, 4), dim=-2)
1584
1585    def test_scatter_cpu_sizes(self):
1586        self._test_scatter(torch.randn(6, 4), chunk_sizes=(2, 4))
1587
1588    def test_scatter_gpu(self):
1589        self._test_scatter(torch.randn(4, 4).cuda(), dim=0)
1590
1591    def test_scatter_gpu_dim(self):
1592        self._test_scatter(torch.randn(4, 4).cuda(), dim=1)
1593
1594    def test_scatter_gpu_neg_dim(self):
1595        self._test_scatter(torch.randn(4, 4).cuda(), dim=-2)
1596
1597    def test_scatter_gpu_sizes(self):
1598        self._test_scatter(torch.randn(6, 4).cuda(), chunk_sizes=(2, 4))
1599
1600    def _test_gather(self, dim):
1601        if not TEST_MULTIGPU:
1602            raise unittest.SkipTest("only one GPU detected")
1603        x = torch.randn(2, 5, device=0)
1604        y = torch.randn(2, 5, device=1)
1605        expected_size = list(x.size())
1606        expected_size[dim] += y.size(dim)
1607        expected_size = torch.Size(expected_size)
1608
1609        destinations = [None, torch.device("cuda:0"), torch.device("cpu")]
1610        if torch.cuda.device_count() > 2:
1611            destinations.append(torch.device("cuda:2"))
1612        with torch.cuda.device(1):
1613            for destination in destinations:
1614                if destination is None:
1615                    expected_device = torch.device("cuda", torch.cuda.current_device())
1616                else:
1617                    expected_device = destination
1618                for use_out in [True, False]:
1619                    if use_out:
1620                        out = torch.empty(expected_size, device=expected_device)
1621                        result = comm.gather((x, y), dim, out=out)
1622                        self.assertIs(out, result)
1623                    else:
1624                        result = comm.gather((x, y), dim, destination=destination)
1625
1626                    self.assertEqual(result.device, expected_device)
1627                    self.assertEqual(result.size(), expected_size)
1628
1629                    index = [slice(None, None), slice(None, None)]
1630                    index[dim] = slice(0, x.size(dim))
1631                    self.assertEqual(result[tuple(index)], x)
1632                    index[dim] = slice(x.size(dim), x.size(dim) + y.size(dim))
1633                    self.assertEqual(result[tuple(index)], y)
1634
1635        # test error msg
1636        with self.assertRaisesRegex(
1637            RuntimeError, r"'destination' must not be specified"
1638        ):
1639            comm.gather(
1640                (x, y),
1641                dim,
1642                destination="cpu",
1643                out=torch.empty(expected_size, device="cpu"),
1644            )
1645        with self.assertRaisesRegex(
1646            RuntimeError, r"Expected at least one tensor to gather from"
1647        ):
1648            comm.gather(())
1649        with self.assertRaisesRegex(
1650            RuntimeError, r"Expected all input tensors to be CUDA tensors, "
1651        ):
1652            comm.gather((x.cpu(), y))
1653        with self.assertRaisesRegex(
1654            RuntimeError,
1655            r"Expected all input tensors to have the same number of dimensions",
1656        ):
1657            comm.gather((x, y.unsqueeze(0)))
1658        with self.assertRaisesRegex(
1659            RuntimeError, r"Input tensor at index 1 has invalid shape"
1660        ):
1661            if dim in [0, -2]:
1662                comm.gather((x, y[:, 1:]), dim=dim)
1663            elif dim in [1, -1]:
1664                comm.gather((x, y[1:, :]), dim=dim)
1665
1666    def test_gather(self):
1667        self._test_gather(0)
1668
1669    def test_gather_dim(self):
1670        self._test_gather(1)
1671
1672    def test_gather_neg_dim(self):
1673        self._test_gather(-1)
1674
1675    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1676    def test_memory_format_scatter_gather(self):
1677        nhwc = torch.randn((10, 3, 32, 32), device="cpu").contiguous(
1678            memory_format=torch.channels_last
1679        )
1680        results = torch.cuda.comm.scatter(nhwc, (0, 1), None, 0)
1681        for result in results:
1682            self.assertFalse(result.is_contiguous())
1683            self.assertTrue(result.is_contiguous(memory_format=torch.channels_last))
1684
1685        gathered = torch.cuda.comm.gather(results)
1686        self.assertTrue(gathered.is_contiguous(memory_format=torch.channels_last))
1687
1688    @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs")
1689    def test_scatter_namedtuple(self):
1690        # tests ability to scatter namedtuples and retrieve a list where each
1691        # element is of the expected namedtuple type.
1692        fields = ("a", "b")
1693        TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", fields)
1694        num_gpus = torch.cuda.device_count()
1695        a = torch.rand(num_gpus * 2, device=0)
1696        b = torch.rand(num_gpus * 2, device=0)
1697        a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)]
1698        b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)]
1699
1700        inp = TestNamedTupleInput_0(a, b)
1701        target_gpus = [torch.device(i) for i in range(num_gpus)]
1702        scatter_out = scatter_gather.scatter(inp, target_gpus)
1703
1704        for i, x in enumerate(scatter_out):
1705            self.assertTrue(isinstance(x, type(inp)))
1706            self.assertEqual(x._fields, fields)
1707            expected_a = a_tensors_for_gpu[i]
1708            expected_b = b_tensors_for_gpu[i]
1709            self.assertEqual(expected_a, x.a)
1710            self.assertEqual(expected_b, x.b)
1711
1712        class TestNamedTupleInput_1(NamedTuple):
1713            a: torch.tensor
1714            b: torch.tensor
1715
1716        a = torch.rand(num_gpus * 2, device=0)
1717        b = torch.rand(num_gpus * 2, device=0)
1718        a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)]
1719        b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)]
1720        inp = TestNamedTupleInput_1(a, b)
1721
1722        scatter_out = scatter_gather.scatter(inp, target_gpus)
1723        for i, x in enumerate(scatter_out):
1724            self.assertTrue(isinstance(x, type(inp)))
1725            self.assertEqual(x._fields, fields)
1726            expected_a = a_tensors_for_gpu[i]
1727            expected_b = b_tensors_for_gpu[i]
1728            self.assertEqual(expected_a, x.a)
1729            self.assertEqual(expected_b, x.b)
1730
1731    @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs")
1732    def test_gather_namedtuple(self):
1733        # tests ability to gather a list of namedtuples and return a namedtuple where each
1734        # element is of the expected tensor type.
1735        fields = ["a", "b"]
1736        TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", fields)
1737
1738        num_gpus = torch.cuda.device_count()
1739        a = torch.rand(num_gpus * 2, device=0)
1740        b = torch.rand(num_gpus * 2, device=1)
1741        out1 = TestNamedTupleInput_0(a, b)
1742
1743        a = torch.rand(num_gpus * 2, device=1)
1744        b = torch.rand(num_gpus * 2, device=0)
1745        out2 = TestNamedTupleInput_0(a, b)
1746
1747        outputs = [out1, out2]
1748
1749        out = scatter_gather.gather(outputs, "cpu")  # test on CPU
1750        for i, x in enumerate(out):
1751            self.assertTrue(isinstance(x, type(out2[-1])))  # x must be a tensor
1752            cat = torch.cat((outputs[0][i].to("cpu"), outputs[1][i].to("cpu")))
1753            self.assertTrue(torch.equal(x, cat))
1754
1755        out = scatter_gather.gather(outputs, 0)  # test on GPU
1756        for i, x in enumerate(out):
1757            self.assertTrue(isinstance(x, type(out2[-1])))
1758            cat = torch.cat((outputs[0][i].to(0), outputs[1][i].to(0)))
1759            self.assertTrue(torch.equal(x, cat))
1760
1761        class TestNamedTupleInput_1(NamedTuple):
1762            a: torch.tensor
1763            b: torch.tensor
1764
1765        a = torch.rand(num_gpus * 2, device=0)
1766        b = torch.rand(num_gpus * 2, device=1)
1767        out1 = TestNamedTupleInput_1(a, b)
1768
1769        a = torch.rand(num_gpus * 2, device=1)
1770        b = torch.rand(num_gpus * 2, device=0)
1771        out2 = TestNamedTupleInput_1(a, b)
1772
1773        outputs = [out1, out2]
1774
1775        out = scatter_gather.gather(outputs, 0)  # test on GPU
1776        for i, x in enumerate(out):
1777            self.assertTrue(isinstance(x, type(out2[-1])))
1778            cat = torch.cat((outputs[0][i].to(0), outputs[1][i].to(0)))
1779            self.assertTrue(torch.equal(x, cat))
1780
1781        out = scatter_gather.gather(outputs, "cpu")  # test on CPU
1782        for i, x in enumerate(out):
1783            self.assertTrue(isinstance(x, type(out2[-1])))
1784            cat = torch.cat((outputs[0][i].to("cpu"), outputs[1][i].to("cpu")))
1785            self.assertTrue(torch.equal(x, cat))
1786
1787
1788instantiate_parametrized_tests(TestCudaMultiGPU)
1789
1790
1791if __name__ == "__main__":
1792    run_tests()
1793