xref: /aosp_15_r20/external/pytorch/test/test_cuda_multigpu.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: cuda"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport contextlib
5*da0073e9SAndroid Build Coastguard Workerimport ctypes
6*da0073e9SAndroid Build Coastguard Workerimport gc
7*da0073e9SAndroid Build Coastguard Workerimport io
8*da0073e9SAndroid Build Coastguard Workerimport queue
9*da0073e9SAndroid Build Coastguard Workerimport sys
10*da0073e9SAndroid Build Coastguard Workerimport tempfile
11*da0073e9SAndroid Build Coastguard Workerimport threading
12*da0073e9SAndroid Build Coastguard Workerimport unittest
13*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain, repeat
14*da0073e9SAndroid Build Coastguard Workerfrom typing import NamedTuple, Union
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerimport torch
17*da0073e9SAndroid Build Coastguard Workerimport torch.cuda.comm as comm
18*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel import scatter_gather
19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import (
20*da0073e9SAndroid Build Coastguard Worker    _create_scaling_case,
21*da0073e9SAndroid Build Coastguard Worker    _create_scaling_models_optimizers,
22*da0073e9SAndroid Build Coastguard Worker    TEST_MULTIGPU,
23*da0073e9SAndroid Build Coastguard Worker)
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
25*da0073e9SAndroid Build Coastguard Worker    get_cycles_per_ms,
26*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
27*da0073e9SAndroid Build Coastguard Worker    IS_JETSON,
28*da0073e9SAndroid Build Coastguard Worker    IS_REMOTE_GPU,
29*da0073e9SAndroid Build Coastguard Worker    IS_SANDCASTLE,
30*da0073e9SAndroid Build Coastguard Worker    NoTest,
31*da0073e9SAndroid Build Coastguard Worker    run_tests,
32*da0073e9SAndroid Build Coastguard Worker    serialTest,
33*da0073e9SAndroid Build Coastguard Worker    skipCUDANonDefaultStreamIf,
34*da0073e9SAndroid Build Coastguard Worker    skipIfRocm,
35*da0073e9SAndroid Build Coastguard Worker    TEST_CUDA,
36*da0073e9SAndroid Build Coastguard Worker    TestCase,
37*da0073e9SAndroid Build Coastguard Worker)
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard WorkerTEST_CUDAMALLOCASYNC = TEST_CUDA and (
41*da0073e9SAndroid Build Coastguard Worker    torch.cuda.get_allocator_backend() == "cudaMallocAsync"
42*da0073e9SAndroid Build Coastguard Worker)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Workerif not TEST_CUDA:
45*da0073e9SAndroid Build Coastguard Worker    print("CUDA not available, skipping tests", file=sys.stderr)
46*da0073e9SAndroid Build Coastguard Worker    TestCase = NoTest  # noqa: F811
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Workerclass TestCudaMultiGPU(TestCase):
50*da0073e9SAndroid Build Coastguard Worker    FIFTY_MIL_CYCLES = 50000000
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    def _check_memory_stat_consistency(self):
53*da0073e9SAndroid Build Coastguard Worker        snapshot = torch.cuda.memory_snapshot()
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker        expected_each_device = collections.defaultdict(
56*da0073e9SAndroid Build Coastguard Worker            lambda: collections.defaultdict(int)
57*da0073e9SAndroid Build Coastguard Worker        )
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        for segment in snapshot:
60*da0073e9SAndroid Build Coastguard Worker            expandable = segment["is_expandable"]
61*da0073e9SAndroid Build Coastguard Worker            expected = expected_each_device[segment["device"]]
62*da0073e9SAndroid Build Coastguard Worker            pool_str = segment["segment_type"] + "_pool"
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker            if not expandable:
65*da0073e9SAndroid Build Coastguard Worker                expected["segment.all.current"] += 1
66*da0073e9SAndroid Build Coastguard Worker                expected["segment." + pool_str + ".current"] += 1
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker            expected["allocated_bytes.all.current"] += segment["allocated_size"]
69*da0073e9SAndroid Build Coastguard Worker            expected["allocated_bytes." + pool_str + ".current"] += segment[
70*da0073e9SAndroid Build Coastguard Worker                "allocated_size"
71*da0073e9SAndroid Build Coastguard Worker            ]
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker            expected["reserved_bytes.all.current"] += segment["total_size"]
74*da0073e9SAndroid Build Coastguard Worker            expected["reserved_bytes." + pool_str + ".current"] += segment["total_size"]
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker            expected["active_bytes.all.current"] += segment["active_size"]
77*da0073e9SAndroid Build Coastguard Worker            expected["active_bytes." + pool_str + ".current"] += segment["active_size"]
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker            expected["requested_bytes.all.current"] += segment["requested_size"]
80*da0073e9SAndroid Build Coastguard Worker            expected["requested_bytes." + pool_str + ".current"] += segment[
81*da0073e9SAndroid Build Coastguard Worker                "requested_size"
82*da0073e9SAndroid Build Coastguard Worker            ]
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker            sum_requested = 0
85*da0073e9SAndroid Build Coastguard Worker            is_split = len(segment["blocks"]) > 1
86*da0073e9SAndroid Build Coastguard Worker            for block in segment["blocks"]:
87*da0073e9SAndroid Build Coastguard Worker                if block["state"] == "active_allocated":
88*da0073e9SAndroid Build Coastguard Worker                    expected["allocation.all.current"] += 1
89*da0073e9SAndroid Build Coastguard Worker                    expected["allocation." + pool_str + ".current"] += 1
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker                if block["state"].startswith("active_"):
92*da0073e9SAndroid Build Coastguard Worker                    sum_requested += block["requested_size"]
93*da0073e9SAndroid Build Coastguard Worker                    expected["active.all.current"] += 1
94*da0073e9SAndroid Build Coastguard Worker                    expected["active." + pool_str + ".current"] += 1
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker                if block["state"] == "inactive" and is_split and not expandable:
97*da0073e9SAndroid Build Coastguard Worker                    expected["inactive_split.all.current"] += 1
98*da0073e9SAndroid Build Coastguard Worker                    expected["inactive_split." + pool_str + ".current"] += 1
99*da0073e9SAndroid Build Coastguard Worker                    expected["inactive_split_bytes.all.current"] += block["size"]
100*da0073e9SAndroid Build Coastguard Worker                    expected["inactive_split_bytes." + pool_str + ".current"] += block[
101*da0073e9SAndroid Build Coastguard Worker                        "size"
102*da0073e9SAndroid Build Coastguard Worker                    ]
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sum_requested, segment["requested_size"])
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker        for device, expected in expected_each_device.items():
107*da0073e9SAndroid Build Coastguard Worker            stats = torch.cuda.memory_stats(device)
108*da0073e9SAndroid Build Coastguard Worker            for k, v in expected.items():
109*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(v, stats[k])
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker    def test_cuda_synchronize(self):
112*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
113*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize("cuda")
114*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize("cuda:0")
115*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize(0)
116*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize(torch.device("cuda:0"))
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker        if TEST_MULTIGPU:
119*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize("cuda:1")
120*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize(1)
121*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize(torch.device("cuda:1"))
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"):
124*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize(torch.device("cpu"))
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"):
127*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize("cpu")
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    @staticmethod
130*da0073e9SAndroid Build Coastguard Worker    def _test_memory_stats_generator(self, device=None, N=35):
131*da0073e9SAndroid Build Coastguard Worker        if device is None:
132*da0073e9SAndroid Build Coastguard Worker            device = torch.cuda.current_device()
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker        m0 = torch.cuda.memory_allocated(device)
135*da0073e9SAndroid Build Coastguard Worker        last_m_arr = [torch.cuda.memory_allocated(device)]
136*da0073e9SAndroid Build Coastguard Worker        max_m_arr = [torch.cuda.max_memory_allocated(device)]
137*da0073e9SAndroid Build Coastguard Worker        last_r_arr = [torch.cuda.memory_reserved(device)]
138*da0073e9SAndroid Build Coastguard Worker        max_r_arr = [torch.cuda.max_memory_reserved(device)]
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        def alloc(*size):
141*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.device(device):
142*da0073e9SAndroid Build Coastguard Worker                # NOTE: do **not** use methods that can have additional
143*da0073e9SAndroid Build Coastguard Worker                #       memory overhead, e.g., inplace random sampling methods.
144*da0073e9SAndroid Build Coastguard Worker                #       they can leave some memory occupied even after being
145*da0073e9SAndroid Build Coastguard Worker                #       deallocated, e.g., initialized RNG state, causing some
146*da0073e9SAndroid Build Coastguard Worker                #       memory checks below to fail.
147*da0073e9SAndroid Build Coastguard Worker                return torch.cuda.FloatTensor(*size)
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        def assert_change(comp=1, empty_cache=False, reset_peak=False):
150*da0073e9SAndroid Build Coastguard Worker            # comp > 0: increased
151*da0073e9SAndroid Build Coastguard Worker            # comp = 0: equal
152*da0073e9SAndroid Build Coastguard Worker            # comp < 0: decreased
153*da0073e9SAndroid Build Coastguard Worker            new_m = torch.cuda.memory_allocated(device)
154*da0073e9SAndroid Build Coastguard Worker            new_max_m = torch.cuda.max_memory_allocated(device)
155*da0073e9SAndroid Build Coastguard Worker            if comp > 0:
156*da0073e9SAndroid Build Coastguard Worker                self.assertGreater(new_m, last_m_arr[0])
157*da0073e9SAndroid Build Coastguard Worker            elif comp < 0:
158*da0073e9SAndroid Build Coastguard Worker                self.assertLess(new_m, last_m_arr[0])
159*da0073e9SAndroid Build Coastguard Worker            else:
160*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(new_m, last_m_arr[0])
161*da0073e9SAndroid Build Coastguard Worker            self.assertLessEqual(new_m, new_max_m)
162*da0073e9SAndroid Build Coastguard Worker            self.assertGreaterEqual(new_max_m, max_m_arr[0])
163*da0073e9SAndroid Build Coastguard Worker            last_m_arr[0] = new_m
164*da0073e9SAndroid Build Coastguard Worker            max_m_arr[0] = new_max_m
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker            new_r = torch.cuda.memory_reserved(device)
167*da0073e9SAndroid Build Coastguard Worker            new_max_r = torch.cuda.max_memory_reserved(device)
168*da0073e9SAndroid Build Coastguard Worker            # emptying cache may happen (due to allocation or empty_cache), so
169*da0073e9SAndroid Build Coastguard Worker            # we can't assert new_c >= last_c
170*da0073e9SAndroid Build Coastguard Worker            self.assertLessEqual(new_r, new_max_r)
171*da0073e9SAndroid Build Coastguard Worker            self.assertGreaterEqual(new_max_r, max_r_arr[0])
172*da0073e9SAndroid Build Coastguard Worker            last_r_arr[0] = new_r
173*da0073e9SAndroid Build Coastguard Worker            max_r_arr[0] = new_max_r
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker            stat_key_n_sync = "num_sync_all_streams"
176*da0073e9SAndroid Build Coastguard Worker            stat_key_n_alloc = "num_device_alloc"
177*da0073e9SAndroid Build Coastguard Worker            stat_key_n_free = "num_device_free"
178*da0073e9SAndroid Build Coastguard Worker            if empty_cache:
179*da0073e9SAndroid Build Coastguard Worker                num_sync_1 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1)
180*da0073e9SAndroid Build Coastguard Worker                self.assertGreaterEqual(num_sync_1, 0)
181*da0073e9SAndroid Build Coastguard Worker                num_alloc_1 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1)
182*da0073e9SAndroid Build Coastguard Worker                # if current memory usage is greater than zero we must have
183*da0073e9SAndroid Build Coastguard Worker                # allocated something
184*da0073e9SAndroid Build Coastguard Worker                self.assertGreaterEqual(num_alloc_1, 0 if new_m == 0 else 1)
185*da0073e9SAndroid Build Coastguard Worker                num_free_1 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1)
186*da0073e9SAndroid Build Coastguard Worker                self.assertGreaterEqual(num_free_1, 0)
187*da0073e9SAndroid Build Coastguard Worker                # empty_cache will enforce the call of release_cached_blocks
188*da0073e9SAndroid Build Coastguard Worker                torch.cuda.empty_cache()
189*da0073e9SAndroid Build Coastguard Worker                num_sync_2 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1)
190*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(num_sync_1 + 1, num_sync_2)
191*da0073e9SAndroid Build Coastguard Worker                num_alloc_2 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1)
192*da0073e9SAndroid Build Coastguard Worker                self.assertGreaterEqual(num_alloc_2, num_alloc_1)
193*da0073e9SAndroid Build Coastguard Worker                num_free_2 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1)
194*da0073e9SAndroid Build Coastguard Worker                self.assertGreaterEqual(num_free_2, num_free_1)
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker                new_r = torch.cuda.memory_reserved(device)
197*da0073e9SAndroid Build Coastguard Worker                new_max_r = torch.cuda.max_memory_reserved(device)
198*da0073e9SAndroid Build Coastguard Worker                self.assertLessEqual(new_r, last_r_arr[0])
199*da0073e9SAndroid Build Coastguard Worker                self.assertLessEqual(new_r, new_max_r)
200*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(new_max_r, max_r_arr[0])
201*da0073e9SAndroid Build Coastguard Worker                last_r_arr[0] = new_r
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker            if reset_peak:
204*da0073e9SAndroid Build Coastguard Worker                torch.cuda.reset_peak_memory_stats(device)
205*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.cuda.memory_allocated(device), last_m_arr[0])
206*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.cuda.max_memory_allocated(device), last_m_arr[0])
207*da0073e9SAndroid Build Coastguard Worker                max_m_arr[0] = last_m_arr[0]
208*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.cuda.memory_reserved(device), last_r_arr[0])
209*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.cuda.max_memory_reserved(device), last_r_arr[0])
210*da0073e9SAndroid Build Coastguard Worker                max_r_arr[0] = last_r_arr[0]
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker        assert_change(0)
213*da0073e9SAndroid Build Coastguard Worker        assert_change(0, reset_peak=True)
214*da0073e9SAndroid Build Coastguard Worker        assert_change(0, empty_cache=True)
215*da0073e9SAndroid Build Coastguard Worker        assert_change(0, reset_peak=True)
216*da0073e9SAndroid Build Coastguard Worker        assert_change(0)
217*da0073e9SAndroid Build Coastguard Worker        yield
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        tensors1 = [alloc(1), alloc(10, 20), alloc(200, 300, 2000)]
220*da0073e9SAndroid Build Coastguard Worker        m1 = torch.cuda.memory_allocated(device)
221*da0073e9SAndroid Build Coastguard Worker        assert_change(1)
222*da0073e9SAndroid Build Coastguard Worker        yield
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker        tensors2 = []
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        for i in range(1, int(N / 2) + 1):
227*da0073e9SAndroid Build Coastguard Worker            # small ones
228*da0073e9SAndroid Build Coastguard Worker            tensors2.append(alloc(i, i * 4))
229*da0073e9SAndroid Build Coastguard Worker            assert_change(1)
230*da0073e9SAndroid Build Coastguard Worker            yield
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker        for i in range(5, int(N / 2) + 5):
233*da0073e9SAndroid Build Coastguard Worker            # large ones
234*da0073e9SAndroid Build Coastguard Worker            tensors2.append(alloc(i, i * 7, i * 9, i * 11))
235*da0073e9SAndroid Build Coastguard Worker            assert_change(1, reset_peak=(i % 2 == 0))
236*da0073e9SAndroid Build Coastguard Worker            yield
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker        tensors2.append(alloc(0, 0, 0))
239*da0073e9SAndroid Build Coastguard Worker        assert_change(0)
240*da0073e9SAndroid Build Coastguard Worker        yield
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker        permute = []
243*da0073e9SAndroid Build Coastguard Worker        for i in torch.randperm(len(tensors2)):
244*da0073e9SAndroid Build Coastguard Worker            permute.append(tensors2[i])
245*da0073e9SAndroid Build Coastguard Worker            assert_change(0)
246*da0073e9SAndroid Build Coastguard Worker            yield
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        del tensors2
249*da0073e9SAndroid Build Coastguard Worker        assert_change(0)
250*da0073e9SAndroid Build Coastguard Worker        yield
251*da0073e9SAndroid Build Coastguard Worker        tensors2 = permute
252*da0073e9SAndroid Build Coastguard Worker        assert_change(0)
253*da0073e9SAndroid Build Coastguard Worker        yield
254*da0073e9SAndroid Build Coastguard Worker        del permute
255*da0073e9SAndroid Build Coastguard Worker        assert_change(0, reset_peak=True)
256*da0073e9SAndroid Build Coastguard Worker        yield
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker        for i in range(int(N / 2)):
259*da0073e9SAndroid Build Coastguard Worker            x = tensors2[i].numel()
260*da0073e9SAndroid Build Coastguard Worker            del tensors2[i]
261*da0073e9SAndroid Build Coastguard Worker            assert_change(-x)  # in case that tensors2[i] is empty
262*da0073e9SAndroid Build Coastguard Worker            yield
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker        for i in range(2, int(2 * N / 3) + 2):
265*da0073e9SAndroid Build Coastguard Worker            tensors2.append(alloc(i, i * 3, i * 8))
266*da0073e9SAndroid Build Coastguard Worker            assert_change(1)
267*da0073e9SAndroid Build Coastguard Worker            yield
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker        del tensors2
270*da0073e9SAndroid Build Coastguard Worker        assert_change(-1, reset_peak=True)
271*da0073e9SAndroid Build Coastguard Worker        assert_change(0)
272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.cuda.memory_allocated(device), m1)
273*da0073e9SAndroid Build Coastguard Worker        yield True
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        del tensors1
276*da0073e9SAndroid Build Coastguard Worker        assert_change(-1, reset_peak=True)
277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.cuda.memory_allocated(device), m0)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker        # test empty_cache and reset_peak
280*da0073e9SAndroid Build Coastguard Worker        assert_change(0, empty_cache=True)
281*da0073e9SAndroid Build Coastguard Worker        assert_change(0, reset_peak=True)
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
284*da0073e9SAndroid Build Coastguard Worker    @serialTest()
285*da0073e9SAndroid Build Coastguard Worker    def test_memory_stats(self):
286*da0073e9SAndroid Build Coastguard Worker        gc.collect()
287*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
288*da0073e9SAndroid Build Coastguard Worker        for _ in self._test_memory_stats_generator(self):
289*da0073e9SAndroid Build Coastguard Worker            self._check_memory_stat_consistency()
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
292*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
293*da0073e9SAndroid Build Coastguard Worker    def test_memory_stats_multigpu(self):
294*da0073e9SAndroid Build Coastguard Worker        # advance a generator with a end flag
295*da0073e9SAndroid Build Coastguard Worker        def advance(gen, end):
296*da0073e9SAndroid Build Coastguard Worker            if not end:
297*da0073e9SAndroid Build Coastguard Worker                try:
298*da0073e9SAndroid Build Coastguard Worker                    next(gen)
299*da0073e9SAndroid Build Coastguard Worker                except StopIteration:
300*da0073e9SAndroid Build Coastguard Worker                    end = True
301*da0073e9SAndroid Build Coastguard Worker            return end
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker        # interlace
304*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
305*da0073e9SAndroid Build Coastguard Worker        gen0 = self._test_memory_stats_generator(self, device="cuda:0", N=35)
306*da0073e9SAndroid Build Coastguard Worker        gen1 = self._test_memory_stats_generator(
307*da0073e9SAndroid Build Coastguard Worker            self, device=torch.device("cuda:1"), N=35
308*da0073e9SAndroid Build Coastguard Worker        )
309*da0073e9SAndroid Build Coastguard Worker        end0 = end1 = False
310*da0073e9SAndroid Build Coastguard Worker        while not (end0 and end1):
311*da0073e9SAndroid Build Coastguard Worker            end0 = advance(gen0, end0)
312*da0073e9SAndroid Build Coastguard Worker            end1 = advance(gen1, end1)
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker        # semi-random order
315*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
316*da0073e9SAndroid Build Coastguard Worker        gen0 = self._test_memory_stats_generator(self, device=0, N=35)
317*da0073e9SAndroid Build Coastguard Worker        gen1 = self._test_memory_stats_generator(
318*da0073e9SAndroid Build Coastguard Worker            self, device=torch.device("cuda:1"), N=35
319*da0073e9SAndroid Build Coastguard Worker        )
320*da0073e9SAndroid Build Coastguard Worker        end0 = end1 = False
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        while not (end0 and end1):
323*da0073e9SAndroid Build Coastguard Worker            end0 = advance(gen0, end0)
324*da0073e9SAndroid Build Coastguard Worker            if not end0:
325*da0073e9SAndroid Build Coastguard Worker                gen1_max_times = torch.LongTensor(1).random_(0, 3)[0]
326*da0073e9SAndroid Build Coastguard Worker            else:
327*da0073e9SAndroid Build Coastguard Worker                gen1_max_times = torch.inf
328*da0073e9SAndroid Build Coastguard Worker            t = 0
329*da0073e9SAndroid Build Coastguard Worker            while t < gen1_max_times and not end1:
330*da0073e9SAndroid Build Coastguard Worker                end1 = advance(gen1, end1)
331*da0073e9SAndroid Build Coastguard Worker                t += 1
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
334*da0073e9SAndroid Build Coastguard Worker    def test_autogpu(self):
335*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5).cuda()
336*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5).cuda()
337*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.get_device(), 0)
338*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.get_device(), 0)
339*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(1):
340*da0073e9SAndroid Build Coastguard Worker            z = torch.randn(5, 5).cuda()
341*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z.get_device(), 1)
342*da0073e9SAndroid Build Coastguard Worker            q = x.add(y)
343*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(q.get_device(), 0)
344*da0073e9SAndroid Build Coastguard Worker            w = torch.randn(5, 5).cuda()
345*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(w.get_device(), 1)
346*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.cuda().get_device(), 1)
347*da0073e9SAndroid Build Coastguard Worker        z = z.cuda()
348*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.get_device(), 0)
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
351*da0073e9SAndroid Build Coastguard Worker    def test_new(self):
352*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3).cuda()
353*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new([0, 1, 2]).get_device(), 0)
354*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.new([0, 1, 2], device=1).get_device(), 1)
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(1):
357*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.new([0, 1, 2]).get_device(), 0)
358*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.new([0, 1, 2], device=1).get_device(), 1)
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
361*da0073e9SAndroid Build Coastguard Worker    def test_copy_device(self):
362*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5).cuda()
363*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(1):
364*da0073e9SAndroid Build Coastguard Worker            y = x.cuda()
365*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.get_device(), 1)
366*da0073e9SAndroid Build Coastguard Worker            self.assertIs(y.cuda(), y)
367*da0073e9SAndroid Build Coastguard Worker            z = y.cuda(0)
368*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z.get_device(), 0)
369*da0073e9SAndroid Build Coastguard Worker            self.assertIs(z.cuda(0), z)
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
372*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(1):
373*da0073e9SAndroid Build Coastguard Worker            y = x.cuda()
374*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.get_device(), 1)
375*da0073e9SAndroid Build Coastguard Worker            self.assertIs(y.cuda(), y)
376*da0073e9SAndroid Build Coastguard Worker            z = y.cuda(0)
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z.get_device(), 0)
379*da0073e9SAndroid Build Coastguard Worker            self.assertIs(z.cuda(0), z)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker    def _test_copy_sync_current_stream(self, x, y):
382*da0073e9SAndroid Build Coastguard Worker        x_plus_one = x + 1
383*da0073e9SAndroid Build Coastguard Worker        s0 = torch.cuda.Stream(device=x.device)
384*da0073e9SAndroid Build Coastguard Worker        s1 = torch.cuda.Stream(device=y.device)
385*da0073e9SAndroid Build Coastguard Worker        s2 = torch.cuda.Stream(device=x.device)
386*da0073e9SAndroid Build Coastguard Worker        s3 = torch.cuda.Stream(device=y.device)
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker        # same dst stream different src streams
389*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s0):
390*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
391*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s1):
392*da0073e9SAndroid Build Coastguard Worker                y.copy_(x_plus_one)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s2), torch.cuda.stream(s1):
395*da0073e9SAndroid Build Coastguard Worker            y.copy_(x)
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker        s1.synchronize()
398*da0073e9SAndroid Build Coastguard Worker        # The copy() is synchronized on the current streams of both src and dst.
399*da0073e9SAndroid Build Coastguard Worker        # In the above test, the _sleep() op on s0 will not block the copy() on
400*da0073e9SAndroid Build Coastguard Worker        # s2, but both copies are synchronized on s1 in the dst device. Hence,
401*da0073e9SAndroid Build Coastguard Worker        # x is copied to y after x_plus_one is copied to y. If x and y are on
402*da0073e9SAndroid Build Coastguard Worker        # the same device, both copy() ops are synchronized on s1.
403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x)
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker        # same src stream different dst streams
406*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s1):
407*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
408*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s0):
409*da0073e9SAndroid Build Coastguard Worker                y.copy_(x_plus_one)
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s3), torch.cuda.stream(s0):
412*da0073e9SAndroid Build Coastguard Worker            y.copy_(x)
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker        s0.synchronize()
415*da0073e9SAndroid Build Coastguard Worker        # Similarly, both copy() ops are synchronized on s0.
416*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x)
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
419*da0073e9SAndroid Build Coastguard Worker    def test_copy_streams(self):
420*da0073e9SAndroid Build Coastguard Worker        d0 = torch.device("cuda:0")
421*da0073e9SAndroid Build Coastguard Worker        x0 = torch.zeros(5, 5, device=d0)
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker        d1 = torch.device("cuda:1")
424*da0073e9SAndroid Build Coastguard Worker        x1 = torch.zeros(5, 5, device=d1)
425*da0073e9SAndroid Build Coastguard Worker        self._test_copy_sync_current_stream(x0, x1)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        x2 = torch.zeros(5, 5, device=d0)
428*da0073e9SAndroid Build Coastguard Worker        self._test_copy_sync_current_stream(x0, x2)
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
431*da0073e9SAndroid Build Coastguard Worker    def test_cat_autogpu(self):
432*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4).cuda(1)
433*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4).cuda(1)
434*da0073e9SAndroid Build Coastguard Worker        z = torch.cat([x, y], 0)
435*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.get_device(), x.get_device())
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(torch.cuda.device_count() >= 10, "Loading a cuda:9 tensor")
438*da0073e9SAndroid Build Coastguard Worker    def test_load_nonexistent_device(self):
439*da0073e9SAndroid Build Coastguard Worker        # Setup: create a serialized file object with a 'cuda:9' restore location
440*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, device="cuda")
441*da0073e9SAndroid Build Coastguard Worker        buf = io.BytesIO()
442*da0073e9SAndroid Build Coastguard Worker        torch.save(tensor, buf)
443*da0073e9SAndroid Build Coastguard Worker        # NB: this might not work in the future if serialization changes
444*da0073e9SAndroid Build Coastguard Worker        buf = io.BytesIO(buf.getvalue().replace(b"cuda:0", b"cuda:9"))
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker        msg = r"Attempting to deserialize object on CUDA device 9"
447*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
448*da0073e9SAndroid Build Coastguard Worker            _ = torch.load(buf)
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
451*da0073e9SAndroid Build Coastguard Worker    def test_multigpu_serialization_remap(self):
452*da0073e9SAndroid Build Coastguard Worker        x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker        def gpu_remap(storage, location):
455*da0073e9SAndroid Build Coastguard Worker            if location == "cuda:1":
456*da0073e9SAndroid Build Coastguard Worker                return storage.cuda(0)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
459*da0073e9SAndroid Build Coastguard Worker            torch.save(x, f)
460*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
461*da0073e9SAndroid Build Coastguard Worker            x_copy = torch.load(f, map_location=gpu_remap)
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker        for original, copy in zip(x, x_copy):
464*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(copy, original)
465*da0073e9SAndroid Build Coastguard Worker            self.assertIs(type(copy), type(original))
466*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(copy.get_device(), 0)
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
469*da0073e9SAndroid Build Coastguard Worker    def test_multigpu_serialization_remap_dict(self):
470*da0073e9SAndroid Build Coastguard Worker        x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
471*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
472*da0073e9SAndroid Build Coastguard Worker            torch.save(x, f)
473*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
474*da0073e9SAndroid Build Coastguard Worker            x_copy = torch.load(f, map_location={"cuda:1": "cuda:0"})
475*da0073e9SAndroid Build Coastguard Worker        for original, copy in zip(x, x_copy):
476*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(copy, original)
477*da0073e9SAndroid Build Coastguard Worker            self.assertIs(type(copy), type(original))
478*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(copy.get_device(), 0)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
481*da0073e9SAndroid Build Coastguard Worker    def test_multigpu_storage_clone(self):
482*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, device="cuda:1").storage()
483*da0073e9SAndroid Build Coastguard Worker        y = x.clone()
484*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.get_device(), y.get_device())
485*da0073e9SAndroid Build Coastguard Worker        for t in ["byte", "char", "short", "int", "long", "half", "double"]:
486*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(getattr(x, t)().get_device(), x.get_device())
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
489*da0073e9SAndroid Build Coastguard Worker    def test_cuda_set_device(self):
490*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
491*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(1):
492*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.cuda().get_device(), 1)
493*da0073e9SAndroid Build Coastguard Worker            torch.cuda.set_device(0)
494*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.cuda().get_device(), 0)
495*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.device(1):
496*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.cuda().get_device(), 1)
497*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.cuda().get_device(), 0)
498*da0073e9SAndroid Build Coastguard Worker            torch.cuda.set_device(1)
499*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.cuda().get_device(), 0)
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
502*da0073e9SAndroid Build Coastguard Worker    def test_current_stream(self):
503*da0073e9SAndroid Build Coastguard Worker        d0 = torch.device("cuda:0")
504*da0073e9SAndroid Build Coastguard Worker        d1 = torch.device("cuda:1")
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker        s0 = torch.cuda.current_stream()
507*da0073e9SAndroid Build Coastguard Worker        s1 = torch.cuda.current_stream(device=1)
508*da0073e9SAndroid Build Coastguard Worker        s2 = torch.cuda.current_stream(device=0)
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d0, s0.device)
511*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d1, s1.device)
512*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d0, s2.device)
513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s0, s2)
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
516*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.current_stream()
517*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.current_stream(1)
518*da0073e9SAndroid Build Coastguard Worker            s2 = torch.cuda.current_stream(d0)
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d1, s0.device)
521*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d1, s1.device)
522*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d0, s2.device)
523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s0, s1)
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a cuda device, but got: cpu"):
526*da0073e9SAndroid Build Coastguard Worker            torch.cuda.current_stream(torch.device("cpu"))
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
529*da0073e9SAndroid Build Coastguard Worker    @skipCUDANonDefaultStreamIf(True)
530*da0073e9SAndroid Build Coastguard Worker    def test_default_stream(self):
531*da0073e9SAndroid Build Coastguard Worker        d0 = torch.device("cuda:0")
532*da0073e9SAndroid Build Coastguard Worker        d1 = torch.device("cuda:1")
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
535*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.default_stream()
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
538*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.default_stream()
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker        s2 = torch.cuda.default_stream(device=0)
541*da0073e9SAndroid Build Coastguard Worker        s3 = torch.cuda.default_stream(d1)
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d0, s0.device)
544*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d1, s1.device)
545*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d0, s2.device)
546*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d1, s3.device)
547*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s0, s2)
548*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s1, s3)
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
551*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.current_stream(), s0)
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
554*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.current_stream(), s1)
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a cuda device, but got: cpu"):
557*da0073e9SAndroid Build Coastguard Worker            torch.cuda.default_stream(torch.device("cpu"))
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
560*da0073e9SAndroid Build Coastguard Worker    def test_stream_event_device(self):
561*da0073e9SAndroid Build Coastguard Worker        d0 = torch.device("cuda:0")
562*da0073e9SAndroid Build Coastguard Worker        d1 = torch.device("cuda:1")
563*da0073e9SAndroid Build Coastguard Worker        e0 = torch.cuda.Event()
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(None, e0.device)
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
568*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.current_stream()
569*da0073e9SAndroid Build Coastguard Worker            s0.record_event(e0)
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
572*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.Stream()
573*da0073e9SAndroid Build Coastguard Worker            e1 = s1.record_event()
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s0.device, torch.device("cuda:0"))
576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e0.device, torch.device("cuda:0"))
577*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s1.device, torch.device("cuda:1"))
578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e1.device, torch.device("cuda:1"))
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
581*da0073e9SAndroid Build Coastguard Worker    def test_stream_context(self):
582*da0073e9SAndroid Build Coastguard Worker        s0 = torch.cuda.current_stream()
583*da0073e9SAndroid Build Coastguard Worker        s1 = torch.cuda.Stream(device=1)
584*da0073e9SAndroid Build Coastguard Worker        s2 = torch.cuda.Stream(device=0)
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(s1.device):
587*da0073e9SAndroid Build Coastguard Worker            prev_stream_on_cuda1 = torch.cuda.current_stream()
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.cuda.current_stream(), s0)
590*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, torch.cuda.current_device())
591*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s1):
592*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.current_stream(), s1)
593*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(1, torch.cuda.current_device())
594*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s2):
595*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.cuda.current_stream(), s2)
596*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(0, torch.cuda.current_device())
597*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.stream(s0):
598*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(torch.cuda.current_stream(), s0)
599*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(0, torch.cuda.current_device())
600*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.cuda.current_stream(), s2)
601*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(0, torch.cuda.current_device())
602*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.current_stream(), s1)
603*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(1, torch.cuda.current_device())
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(s1.device):
606*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(prev_stream_on_cuda1, torch.cuda.current_stream())
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.cuda.current_stream(), s0)
609*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, torch.cuda.current_device())
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
612*da0073e9SAndroid Build Coastguard Worker    def test_streams_multi_gpu(self):
613*da0073e9SAndroid Build Coastguard Worker        default_stream = torch.cuda.current_stream()
614*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(default_stream.device, torch.device("cuda:0"))
615*da0073e9SAndroid Build Coastguard Worker        stream = torch.cuda.Stream(device=1)
616*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(stream.device, torch.device("cuda:1"))
617*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(1):
618*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.current_stream().device, torch.device("cuda:1"))
619*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(torch.cuda.current_stream(), default_stream)
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
622*da0073e9SAndroid Build Coastguard Worker    def test_streams_multi_gpu_query(self):
623*da0073e9SAndroid Build Coastguard Worker        d0 = torch.device("cuda:0")
624*da0073e9SAndroid Build Coastguard Worker        d1 = torch.device("cuda:1")
625*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize(d0)
626*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize(d1)
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
629*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.current_stream()
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
632*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.current_stream()
633*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s0.query())
636*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(s1.query())
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
639*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(s0.query())
640*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(s1.query())
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
643*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(s0.query())
644*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(s1.query())
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker        # deliberately using a different device
647*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
648*da0073e9SAndroid Build Coastguard Worker            s1.synchronize()
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s0.query())
651*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s1.query())
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
654*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(s0.query())
655*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(s1.query())
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
658*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(s0.query())
659*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(s1.query())
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
662*da0073e9SAndroid Build Coastguard Worker    def test_streams_multi_gpu_eq(self):
663*da0073e9SAndroid Build Coastguard Worker        d0 = torch.device("cuda:0")
664*da0073e9SAndroid Build Coastguard Worker        d1 = torch.device("cuda:1")
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
667*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.current_stream()
668*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.current_stream()
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
671*da0073e9SAndroid Build Coastguard Worker            s2 = torch.cuda.current_stream()
672*da0073e9SAndroid Build Coastguard Worker            s3 = torch.cuda.current_stream()
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s0 == s0)
675*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s0 == s1)
676*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s2 == s2)
677*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s2 == s3)
678*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(s0 == s2)
679*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(s1 == s3)
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s0.device, s1.device)
682*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s0.cuda_stream, s1.cuda_stream)
683*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s2.device, s3.device)
684*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s2.cuda_stream, s3.cuda_stream)
685*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(s0.device, s3.device)
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hash(s0), hash(s1))
688*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hash(s2), hash(s3))
689*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(hash(s0), hash(s3))
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
692*da0073e9SAndroid Build Coastguard Worker    def test_streams_priority(self):
693*da0073e9SAndroid Build Coastguard Worker        low, high = torch.cuda.Stream.priority_range()
694*da0073e9SAndroid Build Coastguard Worker        s0 = torch.cuda.Stream(device=0, priority=low)
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(low, s0.priority)
697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.device("cuda:0"), s0.device)
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker        s1 = torch.cuda.Stream(device=1, priority=high)
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(high, s1.priority)
702*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.device("cuda:1"), s1.device)
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
705*da0073e9SAndroid Build Coastguard Worker    def test_tensor_device(self):
706*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.cuda.FloatTensor(1).get_device(), 0)
707*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.cuda.FloatTensor(1, device=1).get_device(), 1)
708*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(1):
709*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.FloatTensor(1).get_device(), 1)
710*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.FloatTensor(1, device=0).get_device(), 0)
711*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.FloatTensor(1, device=None).get_device(), 1)
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker    @staticmethod
714*da0073e9SAndroid Build Coastguard Worker    def _stream_synchronize(self, spin_time_cycles):
715*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.current_stream()
716*da0073e9SAndroid Build Coastguard Worker        e_tik = torch.cuda.Event(enable_timing=True)
717*da0073e9SAndroid Build Coastguard Worker        e_tok = torch.cuda.Event(enable_timing=True)
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker        e_tik.record(s)
720*da0073e9SAndroid Build Coastguard Worker        torch.cuda._sleep(spin_time_cycles)
721*da0073e9SAndroid Build Coastguard Worker        e_tok.record(s)
722*da0073e9SAndroid Build Coastguard Worker        s.synchronize()
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s.query())
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker        # not necessary to check e_tik and e_tok, as elapsed_time would throw
727*da0073e9SAndroid Build Coastguard Worker        # exception if otherwise.
728*da0073e9SAndroid Build Coastguard Worker        return e_tik.elapsed_time(e_tok)
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker    @staticmethod
731*da0073e9SAndroid Build Coastguard Worker    def _event_synchronize(self, spin_time_cycles):
732*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.current_stream()
733*da0073e9SAndroid Build Coastguard Worker        e_tik = torch.cuda.Event(enable_timing=True)
734*da0073e9SAndroid Build Coastguard Worker        e_tok = torch.cuda.Event(enable_timing=True)
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker        e_tik.record(s)
737*da0073e9SAndroid Build Coastguard Worker        torch.cuda._sleep(spin_time_cycles)
738*da0073e9SAndroid Build Coastguard Worker        s.record_event(e_tok)
739*da0073e9SAndroid Build Coastguard Worker        e_tok.synchronize()
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s.query())
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker        # not necessary to check e_tik and e_tok, as elapsed_time would throw
744*da0073e9SAndroid Build Coastguard Worker        # exception if otherwise.
745*da0073e9SAndroid Build Coastguard Worker        return e_tik.elapsed_time(e_tok)
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker    @staticmethod
748*da0073e9SAndroid Build Coastguard Worker    def _event_wait(self, spin_time_cycles):
749*da0073e9SAndroid Build Coastguard Worker        s0 = torch.cuda.current_stream()
750*da0073e9SAndroid Build Coastguard Worker        s1 = torch.cuda.Stream()
751*da0073e9SAndroid Build Coastguard Worker        e_tik = torch.cuda.Event(blocking=True, enable_timing=True)
752*da0073e9SAndroid Build Coastguard Worker        e_tok = torch.cuda.Event(blocking=True, enable_timing=True)
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker        e_tik.record(s0)
755*da0073e9SAndroid Build Coastguard Worker        torch.cuda._sleep(spin_time_cycles - 10)
756*da0073e9SAndroid Build Coastguard Worker        e_sync = torch.cuda.Event(blocking=True)
757*da0073e9SAndroid Build Coastguard Worker        e_sync.record()
758*da0073e9SAndroid Build Coastguard Worker        e_sync.wait(s1)
759*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s1):
760*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(10)
761*da0073e9SAndroid Build Coastguard Worker        s1.synchronize()
762*da0073e9SAndroid Build Coastguard Worker        e_tok.record()
763*da0073e9SAndroid Build Coastguard Worker        e_tok.synchronize()
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s0.query())
766*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s1.query())
767*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(e_sync.query())
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker        # not necessary to check e_tik and e_tok, as elapsed_time would throw
770*da0073e9SAndroid Build Coastguard Worker        # exception if otherwise.
771*da0073e9SAndroid Build Coastguard Worker        return e_tik.elapsed_time(e_tok)
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker    @staticmethod
774*da0073e9SAndroid Build Coastguard Worker    def _test_stream_event_nogil(self, sync_func, p2c, c2p):
775*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device("cuda:1"):
776*da0073e9SAndroid Build Coastguard Worker            c2p.put(0)
777*da0073e9SAndroid Build Coastguard Worker            p2c.get()
778*da0073e9SAndroid Build Coastguard Worker            c2p.put(sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES))
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker    # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
781*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
782*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
783*da0073e9SAndroid Build Coastguard Worker    def test_stream_event_nogil(self):
784*da0073e9SAndroid Build Coastguard Worker        for sync_func in [
785*da0073e9SAndroid Build Coastguard Worker            TestCudaMultiGPU._stream_synchronize,
786*da0073e9SAndroid Build Coastguard Worker            TestCudaMultiGPU._event_synchronize,
787*da0073e9SAndroid Build Coastguard Worker            TestCudaMultiGPU._event_wait,
788*da0073e9SAndroid Build Coastguard Worker        ]:
789*da0073e9SAndroid Build Coastguard Worker            p2c = queue.Queue()
790*da0073e9SAndroid Build Coastguard Worker            c2p = queue.Queue()
791*da0073e9SAndroid Build Coastguard Worker            e_tik = torch.cuda.Event(enable_timing=True)
792*da0073e9SAndroid Build Coastguard Worker            e_tok = torch.cuda.Event(enable_timing=True)
793*da0073e9SAndroid Build Coastguard Worker
794*da0073e9SAndroid Build Coastguard Worker            t = threading.Thread(
795*da0073e9SAndroid Build Coastguard Worker                target=TestCudaMultiGPU._test_stream_event_nogil,
796*da0073e9SAndroid Build Coastguard Worker                args=(self, sync_func, p2c, c2p),
797*da0073e9SAndroid Build Coastguard Worker            )
798*da0073e9SAndroid Build Coastguard Worker            t.daemon = True
799*da0073e9SAndroid Build Coastguard Worker            t.start()
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker            c2p.get()
802*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.device("cuda:0"):
803*da0073e9SAndroid Build Coastguard Worker                e_tik.record()
804*da0073e9SAndroid Build Coastguard Worker                p2c.put(0)
805*da0073e9SAndroid Build Coastguard Worker                parent_time = sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES)
806*da0073e9SAndroid Build Coastguard Worker                child_time = c2p.get()
807*da0073e9SAndroid Build Coastguard Worker                e_tok.record()
808*da0073e9SAndroid Build Coastguard Worker                e_tok.synchronize()
809*da0073e9SAndroid Build Coastguard Worker                total_time = e_tik.elapsed_time(e_tok)
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker            # Without GIL, synchronizations in parent and child threads can
812*da0073e9SAndroid Build Coastguard Worker            # overlap. The total execution time should be a little bit longer
813*da0073e9SAndroid Build Coastguard Worker            # than spinning fifty million cycles and much shorter than twice of
814*da0073e9SAndroid Build Coastguard Worker            # that. However, testing absolute execution time is not reliable as
815*da0073e9SAndroid Build Coastguard Worker            # it may vary on different hardware in different environments.
816*da0073e9SAndroid Build Coastguard Worker            # Therefore, this test uses relative comparisons, checking if the
817*da0073e9SAndroid Build Coastguard Worker            # sum of parent and child threads execution time is greater than the
818*da0073e9SAndroid Build Coastguard Worker            # real execution time by least 40%.
819*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(parent_time + child_time, total_time * 1.4)
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker    # This test is flaky for ROCm, see issue #62602
822*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
823*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
824*da0073e9SAndroid Build Coastguard Worker    def test_events_wait(self):
825*da0073e9SAndroid Build Coastguard Worker        d0 = torch.device("cuda:0")
826*da0073e9SAndroid Build Coastguard Worker        d1 = torch.device("cuda:1")
827*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize(d0)
828*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize(d1)
829*da0073e9SAndroid Build Coastguard Worker
830*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
831*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.current_stream()
832*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
833*da0073e9SAndroid Build Coastguard Worker            e0 = torch.cuda.Event()
834*da0073e9SAndroid Build Coastguard Worker            s0.record_event(e0)
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
837*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.current_stream()
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(s0.query())
840*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s1.query())
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker        s1.wait_event(e0)
843*da0073e9SAndroid Build Coastguard Worker        s1.synchronize()
844*da0073e9SAndroid Build Coastguard Worker
845*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(e0.query())
846*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s0.query())
847*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s1.query())
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
850*da0073e9SAndroid Build Coastguard Worker    def test_events_multi_gpu_query(self):
851*da0073e9SAndroid Build Coastguard Worker        d0 = torch.device("cuda:0")
852*da0073e9SAndroid Build Coastguard Worker        d1 = torch.device("cuda:1")
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
855*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.current_stream()
856*da0073e9SAndroid Build Coastguard Worker            e0 = s0.record_event()
857*da0073e9SAndroid Build Coastguard Worker            s0.synchronize()
858*da0073e9SAndroid Build Coastguard Worker
859*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
860*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.current_stream()
861*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
862*da0073e9SAndroid Build Coastguard Worker            e1 = s1.record_event()
863*da0073e9SAndroid Build Coastguard Worker
864*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(e0.query())
865*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(e1.query())
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
868*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(e0.query())
869*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(e1.query())
870*da0073e9SAndroid Build Coastguard Worker
871*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
872*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(e0.query())
873*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(e1.query())
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker        # deliberately using a different device
876*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
877*da0073e9SAndroid Build Coastguard Worker            e1.synchronize()
878*da0073e9SAndroid Build Coastguard Worker
879*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(e0.query())
880*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(e1.query())
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
883*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(e0.query())
884*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(e1.query())
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
887*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(e0.query())
888*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(e1.query())
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
891*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
892*da0073e9SAndroid Build Coastguard Worker    def test_events_multi_gpu_elapsed_time(self):
893*da0073e9SAndroid Build Coastguard Worker        d0 = torch.device("cuda:0")
894*da0073e9SAndroid Build Coastguard Worker        d1 = torch.device("cuda:1")
895*da0073e9SAndroid Build Coastguard Worker
896*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
897*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.current_stream()
898*da0073e9SAndroid Build Coastguard Worker            e0 = torch.cuda.Event(enable_timing=True)
899*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(10)
900*da0073e9SAndroid Build Coastguard Worker            s0.record_event(e0)
901*da0073e9SAndroid Build Coastguard Worker
902*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
903*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.current_stream()
904*da0073e9SAndroid Build Coastguard Worker            e1 = torch.cuda.Event(enable_timing=True)
905*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
906*da0073e9SAndroid Build Coastguard Worker            s1.record_event(e1)
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Worker        e0.synchronize()
909*da0073e9SAndroid Build Coastguard Worker        e1.synchronize()
910*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
911*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
912*da0073e9SAndroid Build Coastguard Worker                self.assertGreater(e0.elapsed_time(e1), 0)
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
915*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
916*da0073e9SAndroid Build Coastguard Worker                self.assertGreater(e0.elapsed_time(e1), 0)
917*da0073e9SAndroid Build Coastguard Worker
918*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
919*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.current_stream()
920*da0073e9SAndroid Build Coastguard Worker            e2 = torch.cuda.Event(enable_timing=True)
921*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(TestCudaMultiGPU.FIFTY_MIL_CYCLES)
922*da0073e9SAndroid Build Coastguard Worker            s0.record_event(e2)
923*da0073e9SAndroid Build Coastguard Worker            s0.synchronize()
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(e0.elapsed_time(e2), 0)
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker        # deliberately calling from a different device
928*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
929*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(e0.elapsed_time(e2), 0)
930*da0073e9SAndroid Build Coastguard Worker
931*da0073e9SAndroid Build Coastguard Worker    @contextlib.contextmanager
932*da0073e9SAndroid Build Coastguard Worker    def _get_external_stream(self, device):
933*da0073e9SAndroid Build Coastguard Worker        cudart = torch.cuda.cudart()
934*da0073e9SAndroid Build Coastguard Worker        stream = ctypes.c_ulonglong(0)
935*da0073e9SAndroid Build Coastguard Worker        stream_p = ctypes.POINTER(ctypes.c_void_p)(stream)
936*da0073e9SAndroid Build Coastguard Worker        stream_p_int = ctypes.cast(stream_p, ctypes.c_void_p).value
937*da0073e9SAndroid Build Coastguard Worker        with device:
938*da0073e9SAndroid Build Coastguard Worker            try:
939*da0073e9SAndroid Build Coastguard Worker                out = cudart.cudaStreamCreate(stream_p_int)
940*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, 0)
941*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(stream.value, 0)
942*da0073e9SAndroid Build Coastguard Worker                yield stream.value
943*da0073e9SAndroid Build Coastguard Worker            finally:
944*da0073e9SAndroid Build Coastguard Worker                out = cudart.cudaStreamDestroy(stream.value)
945*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, 0)
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker    def test_external_streams(self):
948*da0073e9SAndroid Build Coastguard Worker        device = torch.cuda.device(0)
949*da0073e9SAndroid Build Coastguard Worker        with self._get_external_stream(device) as stream_v:
950*da0073e9SAndroid Build Coastguard Worker            ext_stream = torch.cuda.ExternalStream(stream_v)
951*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(stream_v, ext_stream.cuda_stream)
952*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ext_stream.device.index, device.idx)
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
955*da0073e9SAndroid Build Coastguard Worker    def test_external_streams_multi_device(self):
956*da0073e9SAndroid Build Coastguard Worker        device = torch.cuda.device(1)
957*da0073e9SAndroid Build Coastguard Worker        with self._get_external_stream(device) as stream_v:
958*da0073e9SAndroid Build Coastguard Worker            ext_stream = torch.cuda.ExternalStream(stream_v, device=device)
959*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(stream_v, ext_stream.cuda_stream)
960*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ext_stream.device.index, device.idx)
961*da0073e9SAndroid Build Coastguard Worker
962*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
963*da0073e9SAndroid Build Coastguard Worker    def test_caching_pinned_memory_multi_gpu(self):
964*da0073e9SAndroid Build Coastguard Worker        # checks that the events preventing pinned memory from being re-used
965*da0073e9SAndroid Build Coastguard Worker        # too early are recorded on the correct GPU
966*da0073e9SAndroid Build Coastguard Worker        cycles_per_ms = get_cycles_per_ms()
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker        t = torch.FloatTensor([1]).pin_memory()
969*da0073e9SAndroid Build Coastguard Worker        ptr = t.data_ptr()
970*da0073e9SAndroid Build Coastguard Worker        gpu_tensor0 = torch.cuda.FloatTensor([0], device=0)
971*da0073e9SAndroid Build Coastguard Worker        gpu_tensor1 = torch.cuda.FloatTensor([0], device=1)
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(1):
974*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(int(1000 * cycles_per_ms))  # delay the copy by 1s
975*da0073e9SAndroid Build Coastguard Worker            gpu_tensor1.copy_(t, non_blocking=True)
976*da0073e9SAndroid Build Coastguard Worker
977*da0073e9SAndroid Build Coastguard Worker        del t
978*da0073e9SAndroid Build Coastguard Worker        t = torch.FloatTensor([2]).pin_memory()
979*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon")
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(0):
982*da0073e9SAndroid Build Coastguard Worker            gpu_tensor0.copy_(t, non_blocking=True)
983*da0073e9SAndroid Build Coastguard Worker
984*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gpu_tensor1[0], 1)
985*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gpu_tensor0[0], 2)
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
988*da0073e9SAndroid Build Coastguard Worker    def test_get_set_rng_state_all(self):
989*da0073e9SAndroid Build Coastguard Worker        states = torch.cuda.get_rng_state_all()
990*da0073e9SAndroid Build Coastguard Worker        before0 = torch.cuda.FloatTensor(100, device=0).normal_()
991*da0073e9SAndroid Build Coastguard Worker        before1 = torch.cuda.FloatTensor(100, device=1).normal_()
992*da0073e9SAndroid Build Coastguard Worker        torch.cuda.set_rng_state_all(states)
993*da0073e9SAndroid Build Coastguard Worker        after0 = torch.cuda.FloatTensor(100, device=0).normal_()
994*da0073e9SAndroid Build Coastguard Worker        after1 = torch.cuda.FloatTensor(100, device=1).normal_()
995*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(before0, after0, atol=0, rtol=0)
996*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(before1, after1, atol=0, rtol=0)
997*da0073e9SAndroid Build Coastguard Worker
998*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
999*da0073e9SAndroid Build Coastguard Worker    def test_rng_state_offset(self):
1000*da0073e9SAndroid Build Coastguard Worker        before = torch.cuda.get_rng_state()
1001*da0073e9SAndroid Build Coastguard Worker        torch.cuda._set_rng_state_offset(100)
1002*da0073e9SAndroid Build Coastguard Worker        offset = torch.cuda._get_rng_state_offset()
1003*da0073e9SAndroid Build Coastguard Worker        torch.cuda.set_rng_state(before)
1004*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(offset, 100)
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker    # Verifies that mem_get_info works, including when called for a different device
1007*da0073e9SAndroid Build Coastguard Worker    def test_mem_get_info(self):
1008*da0073e9SAndroid Build Coastguard Worker        def _test(device: Union[str, int, torch.device]):
1009*da0073e9SAndroid Build Coastguard Worker            # Prevent PyTorch from reusing the allocated memory
1010*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
1011*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
1012*da0073e9SAndroid Build Coastguard Worker            before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(device)
1013*da0073e9SAndroid Build Coastguard Worker            # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
1014*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(1024 * 1024 * 8, device=device)
1015*da0073e9SAndroid Build Coastguard Worker            if IS_JETSON:
1016*da0073e9SAndroid Build Coastguard Worker                # w/o syncing, mem_get_info will run before memory allocated has actually increased.
1017*da0073e9SAndroid Build Coastguard Worker                # This race condition causes consistent failure
1018*da0073e9SAndroid Build Coastguard Worker                torch.cuda.synchronize()
1019*da0073e9SAndroid Build Coastguard Worker            after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(device)
1020*da0073e9SAndroid Build Coastguard Worker
1021*da0073e9SAndroid Build Coastguard Worker            self.assertLess(after_free_bytes, before_free_bytes)
1022*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(before_available_bytes, after_available_bytes)
1023*da0073e9SAndroid Build Coastguard Worker
1024*da0073e9SAndroid Build Coastguard Worker        # Test calls with different device representations
1025*da0073e9SAndroid Build Coastguard Worker        _test(0)
1026*da0073e9SAndroid Build Coastguard Worker        _test(torch.device("cuda"))
1027*da0073e9SAndroid Build Coastguard Worker        _test(torch.device("cuda:0"))
1028*da0073e9SAndroid Build Coastguard Worker        _test("cuda")
1029*da0073e9SAndroid Build Coastguard Worker        _test("cuda:0")
1030*da0073e9SAndroid Build Coastguard Worker        if TEST_MULTIGPU:
1031*da0073e9SAndroid Build Coastguard Worker            _test(1)
1032*da0073e9SAndroid Build Coastguard Worker            _test(torch.device("cuda:1"))
1033*da0073e9SAndroid Build Coastguard Worker            _test("cuda:1")
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard Worker    # Test that wrap_with_cuda_memory_check successfully detects leak
1036*da0073e9SAndroid Build Coastguard Worker    def test_cuda_memory_leak_detection(self):
1037*da0073e9SAndroid Build Coastguard Worker        l = []
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker        @self.wrap_with_cuda_memory_check
1040*da0073e9SAndroid Build Coastguard Worker        def no_leak():
1041*da0073e9SAndroid Build Coastguard Worker            pass
1042*da0073e9SAndroid Build Coastguard Worker
1043*da0073e9SAndroid Build Coastguard Worker        @self.wrap_with_cuda_memory_check
1044*da0073e9SAndroid Build Coastguard Worker        def leak_gpu0():
1045*da0073e9SAndroid Build Coastguard Worker            # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
1046*da0073e9SAndroid Build Coastguard Worker            l.append(torch.randn(1024 * 1024 * 8, device=torch.device("cuda:0")))
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker        no_leak()
1049*da0073e9SAndroid Build Coastguard Worker        regex = r"CUDA driver API confirmed .+ on device 0.+"
1050*da0073e9SAndroid Build Coastguard Worker        if IS_JETSON:
1051*da0073e9SAndroid Build Coastguard Worker            try:
1052*da0073e9SAndroid Build Coastguard Worker                leak_gpu0()
1053*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as e:
1054*da0073e9SAndroid Build Coastguard Worker                import re
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker                assert re.match(regex, str(e)), str(e) + "\n does not match: \n" + regex
1057*da0073e9SAndroid Build Coastguard Worker        else:
1058*da0073e9SAndroid Build Coastguard Worker            # assertRaisesRegex does not pass with Python for Jetson,
1059*da0073e9SAndroid Build Coastguard Worker            # even though the RuntimeError matches regex using re.match
1060*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, regex):
1061*da0073e9SAndroid Build Coastguard Worker                leak_gpu0()
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Worker        if TEST_MULTIGPU:
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker            @self.wrap_with_cuda_memory_check
1066*da0073e9SAndroid Build Coastguard Worker            def leak_gpu1():
1067*da0073e9SAndroid Build Coastguard Worker                # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
1068*da0073e9SAndroid Build Coastguard Worker                l.append(torch.randn(1024 * 1024 * 8, device=torch.device("cuda:1")))
1069*da0073e9SAndroid Build Coastguard Worker
1070*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1071*da0073e9SAndroid Build Coastguard Worker                RuntimeError, r"CUDA driver API confirmed .+ on device 1.+"
1072*da0073e9SAndroid Build Coastguard Worker            ):
1073*da0073e9SAndroid Build Coastguard Worker                leak_gpu1()
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1076*da0073e9SAndroid Build Coastguard Worker    def test_streaming_backwards_device_transfer(self):
1077*da0073e9SAndroid Build Coastguard Worker        # This function must run with non-default current streams on all devices, otherwise it's meaningless.
1078*da0073e9SAndroid Build Coastguard Worker        # The intention is to test that to()'s backward (CopyBackward) interacts properly with the
1079*da0073e9SAndroid Build Coastguard Worker        # synchronization logic in torch/csrc/autograd/input_buffer.cpp.
1080*da0073e9SAndroid Build Coastguard Worker        dev0 = torch.device("cuda:0")
1081*da0073e9SAndroid Build Coastguard Worker        dev1 = torch.device("cuda:1")
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker        # Unfortunately I need to make the tensors largeish.
1084*da0073e9SAndroid Build Coastguard Worker        # Bigger tensors = longer D2D transfers = more likely to expose races.
1085*da0073e9SAndroid Build Coastguard Worker        size = 2**26
1086*da0073e9SAndroid Build Coastguard Worker
1087*da0073e9SAndroid Build Coastguard Worker        a = torch.full((size,), 1, device=dev1, dtype=torch.float64, requires_grad=True)
1088*da0073e9SAndroid Build Coastguard Worker        b = torch.full((size,), 1, device=dev1, dtype=torch.float64, requires_grad=True)
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker        # Here to_backward_recipient = a*b is used only once, so MulBackward's InputBuffer slot only expects 1 input.
1091*da0073e9SAndroid Build Coastguard Worker        # This tests the situation where we don't call InputBuffer::accumulate for MulBackward's InputBuffer.
1092*da0073e9SAndroid Build Coastguard Worker        to_backward_recipient = a * b
1093*da0073e9SAndroid Build Coastguard Worker        s = to_backward_recipient.to(device="cuda:0").sum()
1094*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize(device=dev0)
1095*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize(device=dev1)
1096*da0073e9SAndroid Build Coastguard Worker        s.backward()
1097*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a.grad.sum().item() == size)
1098*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(b.grad.sum().item() == size)
1099*da0073e9SAndroid Build Coastguard Worker
1100*da0073e9SAndroid Build Coastguard Worker        # Here to_backward_recipient = a*b is used twice, so MulBackward's InputBuffer slot expects 2 inputs.
1101*da0073e9SAndroid Build Coastguard Worker        # This tests the situation where we do call InputBuffer::accumulate for MulBackward's InputBuffer.
1102*da0073e9SAndroid Build Coastguard Worker        a.grad = None
1103*da0073e9SAndroid Build Coastguard Worker        b.grad = None
1104*da0073e9SAndroid Build Coastguard Worker        to_backward_recipient = a * b
1105*da0073e9SAndroid Build Coastguard Worker        # Multiply by 2 here so to's backward creates gradient values that are different from the case above,
1106*da0073e9SAndroid Build Coastguard Worker        # to mitigate weirdness if the caching allocator happens to reuse memory regions that were populated
1107*da0073e9SAndroid Build Coastguard Worker        # with 1s by the case above
1108*da0073e9SAndroid Build Coastguard Worker        s0 = to_backward_recipient.to(device="cuda:0").sum() * 2.0
1109*da0073e9SAndroid Build Coastguard Worker        s1 = to_backward_recipient.to(device="cuda:0").sum() * 2.0
1110*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize(device=dev0)
1111*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize(device=dev1)
1112*da0073e9SAndroid Build Coastguard Worker        s0.backward(retain_graph=True)
1113*da0073e9SAndroid Build Coastguard Worker        s1.backward()
1114*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a.grad.sum().item() == 4 * size)
1115*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(b.grad.sum().item() == 4 * size)
1116*da0073e9SAndroid Build Coastguard Worker
1117*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1118*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE or IS_REMOTE_GPU, "Does not work on Sandcastle")
1119*da0073e9SAndroid Build Coastguard Worker    def test_cuda_init_race(self):
1120*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/16559
1121*da0073e9SAndroid Build Coastguard Worker        import subprocess
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker        subprocess.check_call(
1124*da0073e9SAndroid Build Coastguard Worker            [
1125*da0073e9SAndroid Build Coastguard Worker                sys.executable,
1126*da0073e9SAndroid Build Coastguard Worker                "-c",
1127*da0073e9SAndroid Build Coastguard Worker                """\
1128*da0073e9SAndroid Build Coastguard Workerimport torch
1129*da0073e9SAndroid Build Coastguard Workerimport threading
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Workerdef worker(rank):
1132*da0073e9SAndroid Build Coastguard Worker    torch.tensor([1.]).cuda(rank)
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Workert1 = threading.Thread(target=worker, args=(0,))
1135*da0073e9SAndroid Build Coastguard Workert2 = threading.Thread(target=worker, args=(1,))
1136*da0073e9SAndroid Build Coastguard Workert1.start()
1137*da0073e9SAndroid Build Coastguard Workert2.start()
1138*da0073e9SAndroid Build Coastguard Worker""",
1139*da0073e9SAndroid Build Coastguard Worker            ]
1140*da0073e9SAndroid Build Coastguard Worker        )
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1143*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_device_as_key(self):
1144*da0073e9SAndroid Build Coastguard Worker        # Ensure that different instances of "device" objects that point to the same device
1145*da0073e9SAndroid Build Coastguard Worker        # are treated as identical keys by dicts.  GradScaler relies on this behavior, and may
1146*da0073e9SAndroid Build Coastguard Worker        # error otherwise in a way that's difficult to detect (a silent performance hit).
1147*da0073e9SAndroid Build Coastguard Worker        d = {}
1148*da0073e9SAndroid Build Coastguard Worker        t = torch.empty((1,), device="cuda:0")
1149*da0073e9SAndroid Build Coastguard Worker        dev0a = torch.device("cuda:0")
1150*da0073e9SAndroid Build Coastguard Worker        dev0b = torch.device("cuda:0")
1151*da0073e9SAndroid Build Coastguard Worker        dev1a = torch.device("cuda:1")
1152*da0073e9SAndroid Build Coastguard Worker        dev1b = torch.device("cuda:1")
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hash(dev0a) == hash(dev0b))
1155*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hash(dev1a) == hash(dev1b))
1156*da0073e9SAndroid Build Coastguard Worker
1157*da0073e9SAndroid Build Coastguard Worker        d[dev0a] = "0a"
1158*da0073e9SAndroid Build Coastguard Worker        d[dev0b] = "0b"
1159*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(d) == 1)
1160*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(d[dev0a] == "0b")
1161*da0073e9SAndroid Build Coastguard Worker        d[t.device] = "t"
1162*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(d) == 1)
1163*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(d[dev0a] == "t")
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Worker        d[dev1a] = "1a"
1166*da0073e9SAndroid Build Coastguard Worker        d[dev1b] = "1b"
1167*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(d) == 2)
1168*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(d[dev1a] == "1b")
1169*da0073e9SAndroid Build Coastguard Worker
1170*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1171*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_scale(self):
1172*da0073e9SAndroid Build Coastguard Worker        scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
1173*da0073e9SAndroid Build Coastguard Worker        t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:0")
1174*da0073e9SAndroid Build Coastguard Worker        t1 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:1")
1175*da0073e9SAndroid Build Coastguard Worker        # Create some nested iterables of tensors on different devices.
1176*da0073e9SAndroid Build Coastguard Worker        outputs = (
1177*da0073e9SAndroid Build Coastguard Worker            t1.clone(),
1178*da0073e9SAndroid Build Coastguard Worker            (t0.clone(), t1.clone()),
1179*da0073e9SAndroid Build Coastguard Worker            [t0.clone(), (t1.clone(), t0.clone())],
1180*da0073e9SAndroid Build Coastguard Worker        )
1181*da0073e9SAndroid Build Coastguard Worker        outputs = scaler.scale(outputs)
1182*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
1183*da0073e9SAndroid Build Coastguard Worker            outputs[0] == 8.0
1184*da0073e9SAndroid Build Coastguard Worker            and outputs[1][0] == 8.0
1185*da0073e9SAndroid Build Coastguard Worker            and outputs[1][1] == 8.0
1186*da0073e9SAndroid Build Coastguard Worker            and outputs[2][0] == 8.0
1187*da0073e9SAndroid Build Coastguard Worker            and outputs[2][1][0] == 8.0
1188*da0073e9SAndroid Build Coastguard Worker            and outputs[2][1][1] == 8.0
1189*da0073e9SAndroid Build Coastguard Worker        )
1190*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(scaler._scale.device == t1.device)
1191*da0073e9SAndroid Build Coastguard Worker
1192*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1193*da0073e9SAndroid Build Coastguard Worker    def test_grad_scaling_multigpu(self):
1194*da0073e9SAndroid Build Coastguard Worker        # Same as above, but runs some of the models on device 1.
1195*da0073e9SAndroid Build Coastguard Worker        # GradScaler should transparently handle losses and gradients on multiple devices.
1196*da0073e9SAndroid Build Coastguard Worker        # This test could be combined with the test above, but I think it makes sense to treat
1197*da0073e9SAndroid Build Coastguard Worker        # multi-GPU operations separately.
1198*da0073e9SAndroid Build Coastguard Worker        dev0 = torch.device("cuda:0")
1199*da0073e9SAndroid Build Coastguard Worker        dev1 = torch.device("cuda:1")
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker        for enabled in True, False:
1202*da0073e9SAndroid Build Coastguard Worker            (
1203*da0073e9SAndroid Build Coastguard Worker                mod_control0,
1204*da0073e9SAndroid Build Coastguard Worker                mod_scaling0,
1205*da0073e9SAndroid Build Coastguard Worker                opt_control0,
1206*da0073e9SAndroid Build Coastguard Worker                opt_scaling0,
1207*da0073e9SAndroid Build Coastguard Worker                data,
1208*da0073e9SAndroid Build Coastguard Worker                loss_fn,
1209*da0073e9SAndroid Build Coastguard Worker                skip_iter,
1210*da0073e9SAndroid Build Coastguard Worker            ) = _create_scaling_case()
1211*da0073e9SAndroid Build Coastguard Worker            (
1212*da0073e9SAndroid Build Coastguard Worker                mod_control1,
1213*da0073e9SAndroid Build Coastguard Worker                mod_scaling1,
1214*da0073e9SAndroid Build Coastguard Worker                opt_control1,
1215*da0073e9SAndroid Build Coastguard Worker                opt_scaling1,
1216*da0073e9SAndroid Build Coastguard Worker            ) = _create_scaling_models_optimizers(device=dev1)
1217*da0073e9SAndroid Build Coastguard Worker
1218*da0073e9SAndroid Build Coastguard Worker            scaler = torch.amp.GradScaler(
1219*da0073e9SAndroid Build Coastguard Worker                device="cuda",
1220*da0073e9SAndroid Build Coastguard Worker                init_scale=128.0,
1221*da0073e9SAndroid Build Coastguard Worker                growth_factor=2.0,
1222*da0073e9SAndroid Build Coastguard Worker                enabled=enabled,
1223*da0073e9SAndroid Build Coastguard Worker                growth_interval=1,
1224*da0073e9SAndroid Build Coastguard Worker            )
1225*da0073e9SAndroid Build Coastguard Worker
1226*da0073e9SAndroid Build Coastguard Worker            def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
1227*da0073e9SAndroid Build Coastguard Worker                for i, (input, target) in enumerate(data):
1228*da0073e9SAndroid Build Coastguard Worker                    optimizer0.zero_grad()
1229*da0073e9SAndroid Build Coastguard Worker                    optimizer1.zero_grad()
1230*da0073e9SAndroid Build Coastguard Worker                    output0 = model0(input)
1231*da0073e9SAndroid Build Coastguard Worker                    output1 = model1(input.to(dev1))
1232*da0073e9SAndroid Build Coastguard Worker                    loss0 = loss_fn(0.3 * output0 + 0.7 * output1.to(dev0), target)
1233*da0073e9SAndroid Build Coastguard Worker                    loss1 = loss_fn(
1234*da0073e9SAndroid Build Coastguard Worker                        0.6 * output0.to(dev1) - 0.4 * output1, target.to(dev1)
1235*da0073e9SAndroid Build Coastguard Worker                    )
1236*da0073e9SAndroid Build Coastguard Worker
1237*da0073e9SAndroid Build Coastguard Worker                    if try_scaling_api:
1238*da0073e9SAndroid Build Coastguard Worker                        scaler.scale(loss0).backward(retain_graph=True)
1239*da0073e9SAndroid Build Coastguard Worker                        scaler.scale(loss1).backward()
1240*da0073e9SAndroid Build Coastguard Worker                        if i == skip_iter and scaler.is_enabled():
1241*da0073e9SAndroid Build Coastguard Worker                            model1[1].weight.grad.data.fill_(float("inf"))
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Worker                        # As an additional stress test, separately unscale for one of the optimizers.
1244*da0073e9SAndroid Build Coastguard Worker                        scaler.unscale_(optimizer0)
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker                        scaler.step(optimizer0)
1247*da0073e9SAndroid Build Coastguard Worker                        scaler.step(optimizer1)
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker                        # Make sure the found_infs were collected properly across optimizers and devices.
1250*da0073e9SAndroid Build Coastguard Worker                        if scaler.is_enabled():
1251*da0073e9SAndroid Build Coastguard Worker                            self.assertTrue(
1252*da0073e9SAndroid Build Coastguard Worker                                len(scaler._found_inf_per_device(optimizer0)) == 1
1253*da0073e9SAndroid Build Coastguard Worker                            )
1254*da0073e9SAndroid Build Coastguard Worker                            self.assertTrue(
1255*da0073e9SAndroid Build Coastguard Worker                                len(scaler._found_inf_per_device(optimizer1)) == 1
1256*da0073e9SAndroid Build Coastguard Worker                            )
1257*da0073e9SAndroid Build Coastguard Worker                            self.assertTrue(
1258*da0073e9SAndroid Build Coastguard Worker                                scaler._found_inf_per_device(optimizer0)[dev0].item()
1259*da0073e9SAndroid Build Coastguard Worker                                == 0.0
1260*da0073e9SAndroid Build Coastguard Worker                            )
1261*da0073e9SAndroid Build Coastguard Worker                            self.assertTrue(
1262*da0073e9SAndroid Build Coastguard Worker                                scaler._found_inf_per_device(optimizer1)[dev1].item()
1263*da0073e9SAndroid Build Coastguard Worker                                == float(i == skip_iter)
1264*da0073e9SAndroid Build Coastguard Worker                            )
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Worker                        scaler.update()
1267*da0073e9SAndroid Build Coastguard Worker                    else:
1268*da0073e9SAndroid Build Coastguard Worker                        loss0.backward(retain_graph=True)
1269*da0073e9SAndroid Build Coastguard Worker                        loss1.backward()
1270*da0073e9SAndroid Build Coastguard Worker                        optimizer0.step()
1271*da0073e9SAndroid Build Coastguard Worker                        if (not scaler.is_enabled()) or (i != skip_iter):
1272*da0073e9SAndroid Build Coastguard Worker                            optimizer1.step()
1273*da0073e9SAndroid Build Coastguard Worker
1274*da0073e9SAndroid Build Coastguard Worker            run(mod_control0, mod_control1, opt_control0, opt_control1, False)
1275*da0073e9SAndroid Build Coastguard Worker            run(mod_scaling0, mod_scaling1, opt_scaling0, opt_scaling1, True)
1276*da0073e9SAndroid Build Coastguard Worker
1277*da0073e9SAndroid Build Coastguard Worker            # The loss scale should have been multiplied by the growth factor 3 times and the backoff factor once.
1278*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1279*da0073e9SAndroid Build Coastguard Worker                scaler.get_scale()
1280*da0073e9SAndroid Build Coastguard Worker                == (
1281*da0073e9SAndroid Build Coastguard Worker                    128.0
1282*da0073e9SAndroid Build Coastguard Worker                    * scaler.get_growth_factor() ** 3
1283*da0073e9SAndroid Build Coastguard Worker                    * scaler.get_backoff_factor() ** 1
1284*da0073e9SAndroid Build Coastguard Worker                )
1285*da0073e9SAndroid Build Coastguard Worker                if enabled
1286*da0073e9SAndroid Build Coastguard Worker                else 1.0
1287*da0073e9SAndroid Build Coastguard Worker            )
1288*da0073e9SAndroid Build Coastguard Worker
1289*da0073e9SAndroid Build Coastguard Worker            # Copy mod_control1 and mod_scaling1 back the device 0 for comparison
1290*da0073e9SAndroid Build Coastguard Worker            mod_control1.to(dev0)
1291*da0073e9SAndroid Build Coastguard Worker            mod_scaling1.to(dev0)
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker            for c, s in zip(
1294*da0073e9SAndroid Build Coastguard Worker                chain(mod_control0.parameters(), mod_control1.parameters()),
1295*da0073e9SAndroid Build Coastguard Worker                chain(mod_scaling0.parameters(), mod_scaling1.parameters()),
1296*da0073e9SAndroid Build Coastguard Worker            ):
1297*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(c, s, rtol=1e-5, atol=1e-7)
1298*da0073e9SAndroid Build Coastguard Worker
1299*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs")
1300*da0073e9SAndroid Build Coastguard Worker    def test_cuda_device_memory_allocated(self):
1301*da0073e9SAndroid Build Coastguard Worker        from torch.cuda import memory_allocated
1302*da0073e9SAndroid Build Coastguard Worker
1303*da0073e9SAndroid Build Coastguard Worker        device_count = torch.cuda.device_count()
1304*da0073e9SAndroid Build Coastguard Worker        current_alloc = [memory_allocated(idx) for idx in range(device_count)]
1305*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(10, device="cuda:0")
1306*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(memory_allocated(0), current_alloc[0])
1307*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
1308*da0073e9SAndroid Build Coastguard Worker            all(
1309*da0073e9SAndroid Build Coastguard Worker                memory_allocated(torch.cuda.device(idx)) == current_alloc[idx]
1310*da0073e9SAndroid Build Coastguard Worker                for idx in range(1, device_count)
1311*da0073e9SAndroid Build Coastguard Worker            )
1312*da0073e9SAndroid Build Coastguard Worker        )
1313*da0073e9SAndroid Build Coastguard Worker
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Workerclass TestCudaComm(TestCase):
1316*da0073e9SAndroid Build Coastguard Worker    def _test_broadcast(self, input):
1317*da0073e9SAndroid Build Coastguard Worker        if not TEST_MULTIGPU:
1318*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("only one GPU detected")
1319*da0073e9SAndroid Build Coastguard Worker        # test regular
1320*da0073e9SAndroid Build Coastguard Worker        results = comm.broadcast(input, (0, 1))
1321*da0073e9SAndroid Build Coastguard Worker        for i, t in enumerate(results):
1322*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.get_device(), i)
1323*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t, input)
1324*da0073e9SAndroid Build Coastguard Worker            if (
1325*da0073e9SAndroid Build Coastguard Worker                input.is_cuda and input.get_device() == i
1326*da0073e9SAndroid Build Coastguard Worker            ):  # test not copying on same device
1327*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.data_ptr(), input.data_ptr())
1328*da0073e9SAndroid Build Coastguard Worker        # test out=
1329*da0073e9SAndroid Build Coastguard Worker        for inplace in [True, False]:
1330*da0073e9SAndroid Build Coastguard Worker            if inplace:
1331*da0073e9SAndroid Build Coastguard Worker                outputs = [
1332*da0073e9SAndroid Build Coastguard Worker                    torch.empty_like(input, device=0),
1333*da0073e9SAndroid Build Coastguard Worker                    torch.empty_like(input, device=1),
1334*da0073e9SAndroid Build Coastguard Worker                ]
1335*da0073e9SAndroid Build Coastguard Worker            else:
1336*da0073e9SAndroid Build Coastguard Worker                outputs = [input.cuda(0), torch.empty_like(input, device=1)]
1337*da0073e9SAndroid Build Coastguard Worker            results = comm.broadcast(input, out=outputs)
1338*da0073e9SAndroid Build Coastguard Worker            for r, o in zip(results, outputs):
1339*da0073e9SAndroid Build Coastguard Worker                self.assertIs(r, o)
1340*da0073e9SAndroid Build Coastguard Worker            for i, t in enumerate(results):
1341*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.get_device(), i)
1342*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t, input)
1343*da0073e9SAndroid Build Coastguard Worker        # test error msg
1344*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1345*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Exactly one of 'devices' and 'out'"
1346*da0073e9SAndroid Build Coastguard Worker        ):
1347*da0073e9SAndroid Build Coastguard Worker            comm.broadcast(input, (0, 1), out=outputs)
1348*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1349*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1350*da0073e9SAndroid Build Coastguard Worker            r"Expected all output tensors to be CUDA tensors, but output tensor at index 1",
1351*da0073e9SAndroid Build Coastguard Worker        ):
1352*da0073e9SAndroid Build Coastguard Worker            comm.broadcast(input, out=[input.cuda(0), input.cpu()])
1353*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1354*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1355*da0073e9SAndroid Build Coastguard Worker            r"Expected all output tensors to have same shape as the source .+ at index 1",
1356*da0073e9SAndroid Build Coastguard Worker        ):
1357*da0073e9SAndroid Build Coastguard Worker            comm.broadcast(input, out=[input.cuda(0), input.cuda(1).unsqueeze(0)])
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_cpu(self):
1360*da0073e9SAndroid Build Coastguard Worker        self._test_broadcast(torch.randn(5, 5))
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_gpu(self):
1363*da0073e9SAndroid Build Coastguard Worker        self._test_broadcast(torch.randn(5, 5).cuda())
1364*da0073e9SAndroid Build Coastguard Worker
1365*da0073e9SAndroid Build Coastguard Worker    def _test_broadcast_coalesced(self, tensors, buffer_size):
1366*da0073e9SAndroid Build Coastguard Worker        b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors]
1367*da0073e9SAndroid Build Coastguard Worker        for (_, bt), t in zip(b_tensors, tensors):
1368*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bt.get_device(), 1)
1369*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bt, t)
1370*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(bt, type(t))
1371*da0073e9SAndroid Build Coastguard Worker
1372*da0073e9SAndroid Build Coastguard Worker        bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=buffer_size)
1373*da0073e9SAndroid Build Coastguard Worker        bc_tensors_t = list(zip(*bc_tensors))
1374*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b_tensors, bc_tensors_t)
1375*da0073e9SAndroid Build Coastguard Worker        for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t):
1376*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bt.get_device(), bct.get_device())
1377*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(bct, type(bt))
1378*da0073e9SAndroid Build Coastguard Worker
1379*da0073e9SAndroid Build Coastguard Worker        # check that tensors on device[0] are returned as-is
1380*da0073e9SAndroid Build Coastguard Worker        for out_tensors in (b_tensors, bc_tensors_t):
1381*da0073e9SAndroid Build Coastguard Worker            for inp_t, (out_t, _) in zip(tensors, out_tensors):
1382*da0073e9SAndroid Build Coastguard Worker                self.assertIs(inp_t, out_t)
1383*da0073e9SAndroid Build Coastguard Worker
1384*da0073e9SAndroid Build Coastguard Worker        # check that the tensors not on device[0] have different version counters
1385*da0073e9SAndroid Build Coastguard Worker        # NOTE [ Version Counter in comm.*_coalesced ]
1386*da0073e9SAndroid Build Coastguard Worker        versions = [t._version for _, t in bc_tensors_t]
1387*da0073e9SAndroid Build Coastguard Worker        for old_version, (_, t) in zip(versions, bc_tensors_t):
1388*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t._version, old_version)
1389*da0073e9SAndroid Build Coastguard Worker            t.zero_()
1390*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t._version, old_version + 1)
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1393*da0073e9SAndroid Build Coastguard Worker    # Note: fails sometimes on the CI, passes on dual gfx906
1394*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_coalesced(self):
1395*da0073e9SAndroid Build Coastguard Worker        numel = 5
1396*da0073e9SAndroid Build Coastguard Worker        num_bytes = numel * 8
1397*da0073e9SAndroid Build Coastguard Worker        tensors = [
1398*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((2, 3), 2, 1, False, "cuda", torch.float64)[0],
1399*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1400*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).cuda(),
1401*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((2, 3), 2, 10, False, "cuda", torch.float64)[0],
1402*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((2, 3), 2, 5, False, "cuda", torch.float64)[0],
1403*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((3, 3), 2, 7, False, "cuda", torch.int64)[0],
1404*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((2, 3), 2, 2, False, "cuda", torch.float32)[0],
1405*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1406*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1407*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((2, 7), 2, 3, False, "cuda", torch.int64)[0],
1408*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel * 2).int().cuda(),  # int is 2x shorter
1409*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).cuda(),
1410*da0073e9SAndroid Build Coastguard Worker        ]
1411*da0073e9SAndroid Build Coastguard Worker        self._test_broadcast_coalesced(tensors, num_bytes * 5 // 2)
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1414*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_coalesced_dense_only(self):
1415*da0073e9SAndroid Build Coastguard Worker        numel = 5
1416*da0073e9SAndroid Build Coastguard Worker        num_bytes = numel * 8
1417*da0073e9SAndroid Build Coastguard Worker        tensors = [
1418*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1419*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).cuda(),
1420*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1421*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1422*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel * 2).int().cuda(),  # int is 2x shorter
1423*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).cuda(),
1424*da0073e9SAndroid Build Coastguard Worker        ]
1425*da0073e9SAndroid Build Coastguard Worker        self._test_broadcast_coalesced(tensors, num_bytes * 5 // 2)
1426*da0073e9SAndroid Build Coastguard Worker
1427*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1428*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_coalesced_empty_tensors(self):
1429*da0073e9SAndroid Build Coastguard Worker        tensors = [
1430*da0073e9SAndroid Build Coastguard Worker            torch.tensor([]).byte().cuda(),
1431*da0073e9SAndroid Build Coastguard Worker            torch.randn(5).cuda(),
1432*da0073e9SAndroid Build Coastguard Worker            torch.randn(5).double().cuda(),
1433*da0073e9SAndroid Build Coastguard Worker        ]
1434*da0073e9SAndroid Build Coastguard Worker        self._test_broadcast_coalesced(tensors, 256)
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1437*da0073e9SAndroid Build Coastguard Worker    def test_reduce_add(self):
1438*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
1439*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5)
1440*da0073e9SAndroid Build Coastguard Worker        x_cuda = x.cuda(0)
1441*da0073e9SAndroid Build Coastguard Worker        y_cuda = y.cuda(1)
1442*da0073e9SAndroid Build Coastguard Worker        result = comm.reduce_add((x_cuda, y_cuda))
1443*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.get_device(), 0)
1444*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.cpu(), x + y)
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker    def _test_reduce_add_coalesced(self, tensors, buffer_size):
1447*da0073e9SAndroid Build Coastguard Worker        dup_tensors = [tensors, [t.cuda(1) for t in tensors]]
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker        r_tensors = [comm.reduce_add(t) for t in zip(*dup_tensors)]
1450*da0073e9SAndroid Build Coastguard Worker        for r, t in zip(r_tensors, tensors):
1451*da0073e9SAndroid Build Coastguard Worker            self.assertEqualTypeString(r, t)
1452*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r.coalesce() if r.is_sparse else r, t * 2)
1453*da0073e9SAndroid Build Coastguard Worker
1454*da0073e9SAndroid Build Coastguard Worker        rc_tensors = comm.reduce_add_coalesced(dup_tensors, buffer_size=buffer_size)
1455*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r_tensors, rc_tensors)
1456*da0073e9SAndroid Build Coastguard Worker        for r, rc in zip(r_tensors, rc_tensors):
1457*da0073e9SAndroid Build Coastguard Worker            self.assertEqualTypeString(rc, r)
1458*da0073e9SAndroid Build Coastguard Worker
1459*da0073e9SAndroid Build Coastguard Worker        # Since we have both cuda:0 and cuda:1 inputs, the outputs must be new.
1460*da0073e9SAndroid Build Coastguard Worker        # We can check that they have different version counters.
1461*da0073e9SAndroid Build Coastguard Worker        # NOTE [ Version Counter in comm.*_coalesced ]
1462*da0073e9SAndroid Build Coastguard Worker        versions = [t._version for t in rc_tensors]
1463*da0073e9SAndroid Build Coastguard Worker        for old_version, t in zip(versions, rc_tensors):
1464*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t._version, old_version)
1465*da0073e9SAndroid Build Coastguard Worker            t.zero_()
1466*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t._version, old_version + 1)
1467*da0073e9SAndroid Build Coastguard Worker
1468*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1469*da0073e9SAndroid Build Coastguard Worker    def test_reduce_add_coalesced(self):
1470*da0073e9SAndroid Build Coastguard Worker        numel = 5
1471*da0073e9SAndroid Build Coastguard Worker        num_bytes = numel * 8
1472*da0073e9SAndroid Build Coastguard Worker        tensors = [
1473*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((2, 3), 2, 1, False, "cuda", torch.float64)[0],
1474*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1475*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).cuda(),
1476*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((2, 3), 2, 10, False, "cuda", torch.float64)[0],
1477*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((2, 3), 2, 5, False, "cuda", torch.float64)[0],
1478*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((3, 3), 2, 7, False, "cuda", torch.int64)[0],
1479*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((2, 3), 2, 2, False, "cuda", torch.float32)[0],
1480*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1481*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1482*da0073e9SAndroid Build Coastguard Worker            self.genSparseTensor((2, 7), 2, 3, False, "cuda", torch.int64)[0],
1483*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel * 2).int().cuda(),  # int is 2x shorter
1484*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).cuda(),
1485*da0073e9SAndroid Build Coastguard Worker        ]
1486*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_add_coalesced(tensors, num_bytes * 5 // 2)
1487*da0073e9SAndroid Build Coastguard Worker
1488*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1489*da0073e9SAndroid Build Coastguard Worker    def test_reduce_add_coalesced_dense_only(self):
1490*da0073e9SAndroid Build Coastguard Worker        numel = 5
1491*da0073e9SAndroid Build Coastguard Worker        num_bytes = numel * 8
1492*da0073e9SAndroid Build Coastguard Worker        tensors = [
1493*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1494*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).cuda(),
1495*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1496*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).long().cuda(),
1497*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel * 2).int().cuda(),  # int is 2x shorter
1498*da0073e9SAndroid Build Coastguard Worker            torch.randn(numel).cuda(),
1499*da0073e9SAndroid Build Coastguard Worker        ]
1500*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_add_coalesced(tensors, num_bytes * 5 // 2)
1501*da0073e9SAndroid Build Coastguard Worker
1502*da0073e9SAndroid Build Coastguard Worker    def _test_scatter(self, input, chunk_sizes=None, dim=0):
1503*da0073e9SAndroid Build Coastguard Worker        if not TEST_MULTIGPU:
1504*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("only one GPU detected")
1505*da0073e9SAndroid Build Coastguard Worker        if chunk_sizes is None:
1506*da0073e9SAndroid Build Coastguard Worker            ref_chunk_sizes = tuple(repeat(input.size(dim) // 2, 2))
1507*da0073e9SAndroid Build Coastguard Worker        else:
1508*da0073e9SAndroid Build Coastguard Worker            ref_chunk_sizes = chunk_sizes
1509*da0073e9SAndroid Build Coastguard Worker
1510*da0073e9SAndroid Build Coastguard Worker        # test regular
1511*da0073e9SAndroid Build Coastguard Worker        result = comm.scatter(input, (0, 1), chunk_sizes, dim)
1512*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(result), 2)
1513*da0073e9SAndroid Build Coastguard Worker        chunk_start = 0
1514*da0073e9SAndroid Build Coastguard Worker        for i, r in enumerate(result):
1515*da0073e9SAndroid Build Coastguard Worker            chunk_end = chunk_start + ref_chunk_sizes[i]
1516*da0073e9SAndroid Build Coastguard Worker            index = [slice(None, None) for _ in range(input.dim())]
1517*da0073e9SAndroid Build Coastguard Worker            index[dim] = slice(chunk_start, chunk_end)
1518*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r, input[tuple(index)], atol=0, rtol=0)
1519*da0073e9SAndroid Build Coastguard Worker            chunk_start = chunk_end
1520*da0073e9SAndroid Build Coastguard Worker            if r.device == input.device:
1521*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1522*da0073e9SAndroid Build Coastguard Worker                    r.data_ptr(), input.data_ptr()
1523*da0073e9SAndroid Build Coastguard Worker                )  # for target @ same device, a view should be returned
1524*da0073e9SAndroid Build Coastguard Worker
1525*da0073e9SAndroid Build Coastguard Worker        # test out
1526*da0073e9SAndroid Build Coastguard Worker        out = [torch.empty_like(t) for t in result]
1527*da0073e9SAndroid Build Coastguard Worker        result = comm.scatter(input, dim=dim, out=out)
1528*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(result), 2)
1529*da0073e9SAndroid Build Coastguard Worker        chunk_start = 0
1530*da0073e9SAndroid Build Coastguard Worker        for i, r in enumerate(result):
1531*da0073e9SAndroid Build Coastguard Worker            self.assertIs(r, out[i])
1532*da0073e9SAndroid Build Coastguard Worker            chunk_end = chunk_start + ref_chunk_sizes[i]
1533*da0073e9SAndroid Build Coastguard Worker            index = [slice(None, None) for _ in range(input.dim())]
1534*da0073e9SAndroid Build Coastguard Worker            index[dim] = slice(chunk_start, chunk_end)
1535*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r, input[tuple(index)], atol=0, rtol=0)
1536*da0073e9SAndroid Build Coastguard Worker            chunk_start = chunk_end
1537*da0073e9SAndroid Build Coastguard Worker
1538*da0073e9SAndroid Build Coastguard Worker        # test error msg
1539*da0073e9SAndroid Build Coastguard Worker        if chunk_sizes is not None:
1540*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1541*da0073e9SAndroid Build Coastguard Worker                RuntimeError, r"Expected devices and chunk_sizes to be of same length"
1542*da0073e9SAndroid Build Coastguard Worker            ):
1543*da0073e9SAndroid Build Coastguard Worker                comm.scatter(
1544*da0073e9SAndroid Build Coastguard Worker                    input,
1545*da0073e9SAndroid Build Coastguard Worker                    [0 for _ in range(len(chunk_sizes) + 1)],
1546*da0073e9SAndroid Build Coastguard Worker                    dim=dim,
1547*da0073e9SAndroid Build Coastguard Worker                    chunk_sizes=chunk_sizes,
1548*da0073e9SAndroid Build Coastguard Worker                )
1549*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"'devices' must not be specified"):
1550*da0073e9SAndroid Build Coastguard Worker            comm.scatter(input, (0, 1), dim=dim, out=out)
1551*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1552*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Expected at least one device to scatter to"
1553*da0073e9SAndroid Build Coastguard Worker        ):
1554*da0073e9SAndroid Build Coastguard Worker            comm.scatter(input, (), dim=dim)
1555*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1556*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Expected at least one output tensor to scatter to"
1557*da0073e9SAndroid Build Coastguard Worker        ):
1558*da0073e9SAndroid Build Coastguard Worker            comm.scatter(input, dim=dim, out=[])
1559*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1560*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1561*da0073e9SAndroid Build Coastguard Worker            r"Expected all output tensors to be CUDA tensors, but output tensor at index 0",
1562*da0073e9SAndroid Build Coastguard Worker        ):
1563*da0073e9SAndroid Build Coastguard Worker            comm.scatter(input, dim=dim, out=([out[0].cpu()] + out[1:]))
1564*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1565*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Output tensor at index 0 has incorrect shape"
1566*da0073e9SAndroid Build Coastguard Worker        ):
1567*da0073e9SAndroid Build Coastguard Worker            comm.scatter(input, dim=dim, out=([out[0].unsqueeze(0)] + out[1:]))
1568*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1569*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1570*da0073e9SAndroid Build Coastguard Worker            r"Total size for output tensors along scatter dim \d+ does not match",
1571*da0073e9SAndroid Build Coastguard Worker        ):
1572*da0073e9SAndroid Build Coastguard Worker            index = [slice(None, None) for _ in range(input.dim())]
1573*da0073e9SAndroid Build Coastguard Worker            index[dim] = slice(1, None)
1574*da0073e9SAndroid Build Coastguard Worker            comm.scatter(input, dim=dim, out=([out[0][tuple(index)]] + out[1:]))
1575*da0073e9SAndroid Build Coastguard Worker
1576*da0073e9SAndroid Build Coastguard Worker    def test_scatter_cpu(self):
1577*da0073e9SAndroid Build Coastguard Worker        self._test_scatter(torch.randn(4, 4), dim=0)
1578*da0073e9SAndroid Build Coastguard Worker
1579*da0073e9SAndroid Build Coastguard Worker    def test_scatter_cpu_dim(self):
1580*da0073e9SAndroid Build Coastguard Worker        self._test_scatter(torch.randn(4, 4), dim=1)
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker    def test_scatter_cpu_neg_dim(self):
1583*da0073e9SAndroid Build Coastguard Worker        self._test_scatter(torch.randn(4, 4), dim=-2)
1584*da0073e9SAndroid Build Coastguard Worker
1585*da0073e9SAndroid Build Coastguard Worker    def test_scatter_cpu_sizes(self):
1586*da0073e9SAndroid Build Coastguard Worker        self._test_scatter(torch.randn(6, 4), chunk_sizes=(2, 4))
1587*da0073e9SAndroid Build Coastguard Worker
1588*da0073e9SAndroid Build Coastguard Worker    def test_scatter_gpu(self):
1589*da0073e9SAndroid Build Coastguard Worker        self._test_scatter(torch.randn(4, 4).cuda(), dim=0)
1590*da0073e9SAndroid Build Coastguard Worker
1591*da0073e9SAndroid Build Coastguard Worker    def test_scatter_gpu_dim(self):
1592*da0073e9SAndroid Build Coastguard Worker        self._test_scatter(torch.randn(4, 4).cuda(), dim=1)
1593*da0073e9SAndroid Build Coastguard Worker
1594*da0073e9SAndroid Build Coastguard Worker    def test_scatter_gpu_neg_dim(self):
1595*da0073e9SAndroid Build Coastguard Worker        self._test_scatter(torch.randn(4, 4).cuda(), dim=-2)
1596*da0073e9SAndroid Build Coastguard Worker
1597*da0073e9SAndroid Build Coastguard Worker    def test_scatter_gpu_sizes(self):
1598*da0073e9SAndroid Build Coastguard Worker        self._test_scatter(torch.randn(6, 4).cuda(), chunk_sizes=(2, 4))
1599*da0073e9SAndroid Build Coastguard Worker
1600*da0073e9SAndroid Build Coastguard Worker    def _test_gather(self, dim):
1601*da0073e9SAndroid Build Coastguard Worker        if not TEST_MULTIGPU:
1602*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("only one GPU detected")
1603*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 5, device=0)
1604*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 5, device=1)
1605*da0073e9SAndroid Build Coastguard Worker        expected_size = list(x.size())
1606*da0073e9SAndroid Build Coastguard Worker        expected_size[dim] += y.size(dim)
1607*da0073e9SAndroid Build Coastguard Worker        expected_size = torch.Size(expected_size)
1608*da0073e9SAndroid Build Coastguard Worker
1609*da0073e9SAndroid Build Coastguard Worker        destinations = [None, torch.device("cuda:0"), torch.device("cpu")]
1610*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.device_count() > 2:
1611*da0073e9SAndroid Build Coastguard Worker            destinations.append(torch.device("cuda:2"))
1612*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(1):
1613*da0073e9SAndroid Build Coastguard Worker            for destination in destinations:
1614*da0073e9SAndroid Build Coastguard Worker                if destination is None:
1615*da0073e9SAndroid Build Coastguard Worker                    expected_device = torch.device("cuda", torch.cuda.current_device())
1616*da0073e9SAndroid Build Coastguard Worker                else:
1617*da0073e9SAndroid Build Coastguard Worker                    expected_device = destination
1618*da0073e9SAndroid Build Coastguard Worker                for use_out in [True, False]:
1619*da0073e9SAndroid Build Coastguard Worker                    if use_out:
1620*da0073e9SAndroid Build Coastguard Worker                        out = torch.empty(expected_size, device=expected_device)
1621*da0073e9SAndroid Build Coastguard Worker                        result = comm.gather((x, y), dim, out=out)
1622*da0073e9SAndroid Build Coastguard Worker                        self.assertIs(out, result)
1623*da0073e9SAndroid Build Coastguard Worker                    else:
1624*da0073e9SAndroid Build Coastguard Worker                        result = comm.gather((x, y), dim, destination=destination)
1625*da0073e9SAndroid Build Coastguard Worker
1626*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(result.device, expected_device)
1627*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(result.size(), expected_size)
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker                    index = [slice(None, None), slice(None, None)]
1630*da0073e9SAndroid Build Coastguard Worker                    index[dim] = slice(0, x.size(dim))
1631*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(result[tuple(index)], x)
1632*da0073e9SAndroid Build Coastguard Worker                    index[dim] = slice(x.size(dim), x.size(dim) + y.size(dim))
1633*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(result[tuple(index)], y)
1634*da0073e9SAndroid Build Coastguard Worker
1635*da0073e9SAndroid Build Coastguard Worker        # test error msg
1636*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1637*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"'destination' must not be specified"
1638*da0073e9SAndroid Build Coastguard Worker        ):
1639*da0073e9SAndroid Build Coastguard Worker            comm.gather(
1640*da0073e9SAndroid Build Coastguard Worker                (x, y),
1641*da0073e9SAndroid Build Coastguard Worker                dim,
1642*da0073e9SAndroid Build Coastguard Worker                destination="cpu",
1643*da0073e9SAndroid Build Coastguard Worker                out=torch.empty(expected_size, device="cpu"),
1644*da0073e9SAndroid Build Coastguard Worker            )
1645*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1646*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Expected at least one tensor to gather from"
1647*da0073e9SAndroid Build Coastguard Worker        ):
1648*da0073e9SAndroid Build Coastguard Worker            comm.gather(())
1649*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1650*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Expected all input tensors to be CUDA tensors, "
1651*da0073e9SAndroid Build Coastguard Worker        ):
1652*da0073e9SAndroid Build Coastguard Worker            comm.gather((x.cpu(), y))
1653*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1654*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1655*da0073e9SAndroid Build Coastguard Worker            r"Expected all input tensors to have the same number of dimensions",
1656*da0073e9SAndroid Build Coastguard Worker        ):
1657*da0073e9SAndroid Build Coastguard Worker            comm.gather((x, y.unsqueeze(0)))
1658*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1659*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Input tensor at index 1 has invalid shape"
1660*da0073e9SAndroid Build Coastguard Worker        ):
1661*da0073e9SAndroid Build Coastguard Worker            if dim in [0, -2]:
1662*da0073e9SAndroid Build Coastguard Worker                comm.gather((x, y[:, 1:]), dim=dim)
1663*da0073e9SAndroid Build Coastguard Worker            elif dim in [1, -1]:
1664*da0073e9SAndroid Build Coastguard Worker                comm.gather((x, y[1:, :]), dim=dim)
1665*da0073e9SAndroid Build Coastguard Worker
1666*da0073e9SAndroid Build Coastguard Worker    def test_gather(self):
1667*da0073e9SAndroid Build Coastguard Worker        self._test_gather(0)
1668*da0073e9SAndroid Build Coastguard Worker
1669*da0073e9SAndroid Build Coastguard Worker    def test_gather_dim(self):
1670*da0073e9SAndroid Build Coastguard Worker        self._test_gather(1)
1671*da0073e9SAndroid Build Coastguard Worker
1672*da0073e9SAndroid Build Coastguard Worker    def test_gather_neg_dim(self):
1673*da0073e9SAndroid Build Coastguard Worker        self._test_gather(-1)
1674*da0073e9SAndroid Build Coastguard Worker
1675*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
1676*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_scatter_gather(self):
1677*da0073e9SAndroid Build Coastguard Worker        nhwc = torch.randn((10, 3, 32, 32), device="cpu").contiguous(
1678*da0073e9SAndroid Build Coastguard Worker            memory_format=torch.channels_last
1679*da0073e9SAndroid Build Coastguard Worker        )
1680*da0073e9SAndroid Build Coastguard Worker        results = torch.cuda.comm.scatter(nhwc, (0, 1), None, 0)
1681*da0073e9SAndroid Build Coastguard Worker        for result in results:
1682*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(result.is_contiguous())
1683*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(result.is_contiguous(memory_format=torch.channels_last))
1684*da0073e9SAndroid Build Coastguard Worker
1685*da0073e9SAndroid Build Coastguard Worker        gathered = torch.cuda.comm.gather(results)
1686*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gathered.is_contiguous(memory_format=torch.channels_last))
1687*da0073e9SAndroid Build Coastguard Worker
1688*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs")
1689*da0073e9SAndroid Build Coastguard Worker    def test_scatter_namedtuple(self):
1690*da0073e9SAndroid Build Coastguard Worker        # tests ability to scatter namedtuples and retrieve a list where each
1691*da0073e9SAndroid Build Coastguard Worker        # element is of the expected namedtuple type.
1692*da0073e9SAndroid Build Coastguard Worker        fields = ("a", "b")
1693*da0073e9SAndroid Build Coastguard Worker        TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", fields)
1694*da0073e9SAndroid Build Coastguard Worker        num_gpus = torch.cuda.device_count()
1695*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(num_gpus * 2, device=0)
1696*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(num_gpus * 2, device=0)
1697*da0073e9SAndroid Build Coastguard Worker        a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)]
1698*da0073e9SAndroid Build Coastguard Worker        b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)]
1699*da0073e9SAndroid Build Coastguard Worker
1700*da0073e9SAndroid Build Coastguard Worker        inp = TestNamedTupleInput_0(a, b)
1701*da0073e9SAndroid Build Coastguard Worker        target_gpus = [torch.device(i) for i in range(num_gpus)]
1702*da0073e9SAndroid Build Coastguard Worker        scatter_out = scatter_gather.scatter(inp, target_gpus)
1703*da0073e9SAndroid Build Coastguard Worker
1704*da0073e9SAndroid Build Coastguard Worker        for i, x in enumerate(scatter_out):
1705*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(x, type(inp)))
1706*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x._fields, fields)
1707*da0073e9SAndroid Build Coastguard Worker            expected_a = a_tensors_for_gpu[i]
1708*da0073e9SAndroid Build Coastguard Worker            expected_b = b_tensors_for_gpu[i]
1709*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_a, x.a)
1710*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_b, x.b)
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker        class TestNamedTupleInput_1(NamedTuple):
1713*da0073e9SAndroid Build Coastguard Worker            a: torch.tensor
1714*da0073e9SAndroid Build Coastguard Worker            b: torch.tensor
1715*da0073e9SAndroid Build Coastguard Worker
1716*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(num_gpus * 2, device=0)
1717*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(num_gpus * 2, device=0)
1718*da0073e9SAndroid Build Coastguard Worker        a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)]
1719*da0073e9SAndroid Build Coastguard Worker        b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)]
1720*da0073e9SAndroid Build Coastguard Worker        inp = TestNamedTupleInput_1(a, b)
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker        scatter_out = scatter_gather.scatter(inp, target_gpus)
1723*da0073e9SAndroid Build Coastguard Worker        for i, x in enumerate(scatter_out):
1724*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(x, type(inp)))
1725*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x._fields, fields)
1726*da0073e9SAndroid Build Coastguard Worker            expected_a = a_tensors_for_gpu[i]
1727*da0073e9SAndroid Build Coastguard Worker            expected_b = b_tensors_for_gpu[i]
1728*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_a, x.a)
1729*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_b, x.b)
1730*da0073e9SAndroid Build Coastguard Worker
1731*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs")
1732*da0073e9SAndroid Build Coastguard Worker    def test_gather_namedtuple(self):
1733*da0073e9SAndroid Build Coastguard Worker        # tests ability to gather a list of namedtuples and return a namedtuple where each
1734*da0073e9SAndroid Build Coastguard Worker        # element is of the expected tensor type.
1735*da0073e9SAndroid Build Coastguard Worker        fields = ["a", "b"]
1736*da0073e9SAndroid Build Coastguard Worker        TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", fields)
1737*da0073e9SAndroid Build Coastguard Worker
1738*da0073e9SAndroid Build Coastguard Worker        num_gpus = torch.cuda.device_count()
1739*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(num_gpus * 2, device=0)
1740*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(num_gpus * 2, device=1)
1741*da0073e9SAndroid Build Coastguard Worker        out1 = TestNamedTupleInput_0(a, b)
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(num_gpus * 2, device=1)
1744*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(num_gpus * 2, device=0)
1745*da0073e9SAndroid Build Coastguard Worker        out2 = TestNamedTupleInput_0(a, b)
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker        outputs = [out1, out2]
1748*da0073e9SAndroid Build Coastguard Worker
1749*da0073e9SAndroid Build Coastguard Worker        out = scatter_gather.gather(outputs, "cpu")  # test on CPU
1750*da0073e9SAndroid Build Coastguard Worker        for i, x in enumerate(out):
1751*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(x, type(out2[-1])))  # x must be a tensor
1752*da0073e9SAndroid Build Coastguard Worker            cat = torch.cat((outputs[0][i].to("cpu"), outputs[1][i].to("cpu")))
1753*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.equal(x, cat))
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker        out = scatter_gather.gather(outputs, 0)  # test on GPU
1756*da0073e9SAndroid Build Coastguard Worker        for i, x in enumerate(out):
1757*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(x, type(out2[-1])))
1758*da0073e9SAndroid Build Coastguard Worker            cat = torch.cat((outputs[0][i].to(0), outputs[1][i].to(0)))
1759*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.equal(x, cat))
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker        class TestNamedTupleInput_1(NamedTuple):
1762*da0073e9SAndroid Build Coastguard Worker            a: torch.tensor
1763*da0073e9SAndroid Build Coastguard Worker            b: torch.tensor
1764*da0073e9SAndroid Build Coastguard Worker
1765*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(num_gpus * 2, device=0)
1766*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(num_gpus * 2, device=1)
1767*da0073e9SAndroid Build Coastguard Worker        out1 = TestNamedTupleInput_1(a, b)
1768*da0073e9SAndroid Build Coastguard Worker
1769*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(num_gpus * 2, device=1)
1770*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(num_gpus * 2, device=0)
1771*da0073e9SAndroid Build Coastguard Worker        out2 = TestNamedTupleInput_1(a, b)
1772*da0073e9SAndroid Build Coastguard Worker
1773*da0073e9SAndroid Build Coastguard Worker        outputs = [out1, out2]
1774*da0073e9SAndroid Build Coastguard Worker
1775*da0073e9SAndroid Build Coastguard Worker        out = scatter_gather.gather(outputs, 0)  # test on GPU
1776*da0073e9SAndroid Build Coastguard Worker        for i, x in enumerate(out):
1777*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(x, type(out2[-1])))
1778*da0073e9SAndroid Build Coastguard Worker            cat = torch.cat((outputs[0][i].to(0), outputs[1][i].to(0)))
1779*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.equal(x, cat))
1780*da0073e9SAndroid Build Coastguard Worker
1781*da0073e9SAndroid Build Coastguard Worker        out = scatter_gather.gather(outputs, "cpu")  # test on CPU
1782*da0073e9SAndroid Build Coastguard Worker        for i, x in enumerate(out):
1783*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(x, type(out2[-1])))
1784*da0073e9SAndroid Build Coastguard Worker            cat = torch.cat((outputs[0][i].to("cpu"), outputs[1][i].to("cpu")))
1785*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.equal(x, cat))
1786*da0073e9SAndroid Build Coastguard Worker
1787*da0073e9SAndroid Build Coastguard Worker
1788*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCudaMultiGPU)
1789*da0073e9SAndroid Build Coastguard Worker
1790*da0073e9SAndroid Build Coastguard Worker
1791*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1792*da0073e9SAndroid Build Coastguard Worker    run_tests()
1793