xref: /aosp_15_r20/external/pytorch/test/jit/test_cuda.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import gc
4import os
5import sys
6import unittest
7from typing import NamedTuple
8
9import torch
10from torch.testing import FileCheck
11from torch.testing._internal.common_cuda import TEST_MULTIGPU
12from torch.testing._internal.common_utils import (
13    NoTest,
14    skipCUDANonDefaultStreamIf,
15    skipIfRocm,
16    TEST_CUDA,
17)
18from torch.testing._internal.jit_utils import JitTestCase
19
20
21# Make the helper files in test/ importable
22pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
23sys.path.append(pytorch_test_dir)
24
25# If GPU is not available, then do not run the tests
26if not TEST_CUDA:
27    print("CUDA not available, skipping tests", file=sys.stderr)
28    JitTestCase = NoTest  # noqa: F811
29
30TEST_LARGE_TENSOR = TEST_CUDA
31
32# If GPU is available, then initialize the cuda context and check
33# if there is memory available to allocate for LARGE Tensors.
34if TEST_CUDA:
35    torch.ones(1).cuda()  # initialize cuda context
36    TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9
37
38if __name__ == "__main__":
39    raise RuntimeError(
40        "This test file is not meant to be run directly, use:\n\n"
41        "\tpython test/test_jit.py TESTNAME\n\n"
42        "instead."
43    )
44
45
46class TestCUDA(JitTestCase):
47    """
48    A suite of tests for the CUDA API in TorchScript.
49    """
50
51    def tearDown(self):
52        gc.collect()
53        torch.cuda.empty_cache()
54        super().tearDown()
55
56    @skipIfRocm
57    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
58    def test_cuda_synchronize(self):
59        # Test device synchronization.
60
61        @torch.jit.script
62        def test_device_synchronize():
63            prev_current_device_index = torch.cuda.current_device()
64            torch.cuda.synchronize()
65            torch.cuda.synchronize("cuda")
66            torch.cuda.synchronize("cuda:0")
67            torch.cuda.synchronize(0)
68            torch.cuda.synchronize(torch.device("cuda:1"))
69            after_current_device_index = torch.cuda.current_device()
70
71            # Check if the current device index is same as the device index before
72            # synchronizing the device.
73            return prev_current_device_index == after_current_device_index
74
75        @torch.jit.script
76        def test_multi_device_synchronize():
77            torch.cuda.synchronize(torch.device("cuda:0"))
78            prev_current_device_index = torch.cuda.current_device()
79            torch.cuda.synchronize(1)
80            after_current_device_index = torch.cuda.current_device()
81
82            # Check if the current device index is same as the device index before
83            # synchronizing the device.
84            return prev_current_device_index == after_current_device_index
85
86        self.assertTrue(test_device_synchronize)
87        FileCheck().check("cuda::synchronize(").run(test_device_synchronize.graph)
88        self.assertTrue(test_multi_device_synchronize)
89        FileCheck().check("cuda::synchronize(").run(test_multi_device_synchronize.graph)
90
91    def test_stream_args(self):
92        # Test stream creation with default arguments
93        @torch.jit.script
94        def stream_default_args() -> bool:
95            s = torch.cuda.Stream()
96            return s.device_index() == torch.cuda.current_device()
97
98        @torch.jit.script
99        def stream_default_args_for_device() -> bool:
100            s = torch.cuda.Stream(priority=0)
101            return s.device_index() == torch.cuda.current_device()
102
103        @torch.jit.script
104        def stream_default_args_for_priority() -> bool:
105            d = torch.device("cuda:1")
106            s = torch.cuda.Stream(d)
107            return s.device_index() == 1
108
109        @torch.jit.script
110        def stream_args_all() -> bool:
111            d = torch.device("cuda:0")
112            s = torch.cuda.Stream(d, 0)
113            return s.device_index() == 0
114
115        self.assertTrue(stream_default_args)
116        self.assertTrue(stream_default_args_for_device)
117        self.assertTrue(stream_default_args_for_priority)
118        self.assertTrue(stream_args_all)
119
120    def test_event_args(self):
121        # Test Event creation with default arguments
122        @torch.jit.script
123        def event_default_args() -> bool:
124            e = torch.cuda.Event()
125            return e is not None
126
127        self.assertTrue(event_default_args)
128
129    @skipIfRocm
130    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
131    def test_current_stream(self):
132        # Test current stream on the device and check if the stream device index
133        # matches with the device ID
134        @torch.jit.script
135        def fn():
136            device_index = torch.cuda.current_device()
137            device = torch.device("cuda:" + str(device_index))
138            s0 = torch.cuda.current_stream(device)
139            s1 = torch.cuda.current_stream(torch.device("cuda:1"))
140            s2 = torch.cuda.current_stream(torch.device("cuda:0"))
141
142            return s0.device_index(), s1.device_index(), s2.device_index()
143
144        d0, d1, d2 = fn()
145        # By default, the current device ID is 0.
146        self.assertEqual(0, d0)
147        self.assertEqual(1, d1)
148        self.assertEqual(0, d2)
149        self.assertEqual(d0, d2)
150
151        # Test current_stream API by passing device ID as an argument and
152        # and check if the stream device index matches with the device ID
153        @torch.jit.script
154        def fn_with_device_index_args():
155            device_index = torch.cuda.current_device()
156            s0 = torch.cuda.current_stream(device_index)
157            s1 = torch.cuda.current_stream(1)
158            s2 = torch.cuda.current_stream(0)
159
160            return s0.device_index(), s1.device_index(), s2.device_index()
161
162        d0, d1, d2 = fn_with_device_index_args()
163        # By default, the current device ID is 0.
164        self.assertEqual(0, d0)
165        self.assertEqual(1, d1)
166        self.assertEqual(0, d2)
167        self.assertEqual(d0, d2)
168
169    @skipIfRocm
170    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
171    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
172    @skipCUDANonDefaultStreamIf(True)
173    def test_streams_and_events(self):
174        # Test default_stream API by passing device ID as an argument and
175        # and check if the stream device index matches with the device ID
176        @torch.jit.script
177        def test_default_streams_with_device_index_args():
178            s0 = torch.cuda.default_stream(0)
179            s1 = torch.cuda.default_stream(1)
180            return s0.device_index(), s1.device_index()
181
182        d0, d1 = test_default_streams_with_device_index_args()
183
184        self.assertEqual(d0, 0)
185        self.assertEqual(d1, 1)
186
187        # This test checks for the default stream ID is set to 0 on the device
188        @torch.jit.script
189        def test_default_streams():
190            s0 = torch.cuda.default_stream(torch.device("cuda:0"))
191            s1 = torch.cuda.default_stream(torch.device("cuda:1"))
192
193            d = torch.device("cuda:1")
194
195            # Check the current stream id and default id are same
196            # on the current device. The current device id by default is 0
197            s2 = torch.cuda.current_stream(torch.device("cuda:0"))
198            check_s2 = s2.id() == s0.id()
199            check_d0 = torch.cuda.current_device() == s2.device_index()
200
201            # Set the current device to d1 and check if the stream
202            # has been set to the default stream on d1
203            with torch.cuda.device(d):
204                s3 = torch.cuda.current_stream(d)
205                check_s3 = s3.id() == s1.id()
206                check_d1 = torch.cuda.current_device() == s3.device_index()
207
208            # Check if the current device was reset to 0
209            is_device_d0 = torch.cuda.current_device() == s2.device_index()
210
211            return (
212                s0.device_index(),
213                s1.device_index(),
214                check_s2,
215                check_s3,
216                check_d0,
217                check_d1,
218                is_device_d0,
219            )
220
221        (
222            d0,
223            d1,
224            check_s2,
225            check_s3,
226            check_d0,
227            check_d1,
228            is_device_d0,
229        ) = test_default_streams()
230
231        self.assertEqual(d0, 0)
232        self.assertEqual(d1, 1)
233        self.assertTrue(check_s2)
234        self.assertTrue(check_s3)
235        self.assertTrue(check_d0)
236        self.assertTrue(check_d1)
237        self.assertTrue(is_device_d0)
238
239        # This test checks if the Stream Context manager is a no op
240        # when the stream is none for `with torch.cuda.stream`
241        @torch.jit.script
242        def test_set_none_stream():
243            device_index = torch.cuda.current_device()
244            device = torch.device("cuda:" + str(device_index))
245            current_stream = torch.cuda.current_stream(device)
246            default_stream = torch.cuda.default_stream(device)
247
248            # When stream is none, check if this operation is a no-op
249            with torch.cuda.stream(None):
250                cur_device_index = torch.cuda.current_device()
251                is_device_index_same = cur_device_index == device_index
252                is_current_stream_same = (
253                    torch.cuda.current_stream(device).id() == current_stream.id()
254                )
255                is_default_stream_same = (
256                    torch.cuda.default_stream(device).id() == default_stream.id()
257                )
258
259            # Check if the device index, current stream and default streams have not changed
260            are_streams_same = (
261                is_device_index_same
262                and is_current_stream_same
263                and is_default_stream_same
264            )
265            return are_streams_same
266
267        self.assertTrue(test_set_none_stream())
268
269        # This test checks if the Device Context manager is a no op
270        # when the device is none for `with torch.cuda.device`
271        @torch.jit.script
272        def test_set_device_none():
273            device_index = torch.cuda.current_device()
274            # When device is none, check if this operation is a no-op
275            with torch.cuda.device(None):
276                # Check if the current device is the same
277                is_device_same = torch.cuda.current_device() == device_index
278            return is_device_same
279
280        self.assertTrue(test_set_device_none())
281
282        # Check if a CUDA JIT stream is created
283        # on the current_device
284        @torch.jit.script
285        def test_simple_stream():
286            device_index = torch.cuda.current_device()
287            s = torch.cuda.Stream()
288            return device_index == s.device_index()
289
290        self.assertTrue(test_simple_stream(), "Could not create Stream!")
291
292        # Class used to store results for the test: test_get_stream.
293        class Result(NamedTuple):
294            t1: torch.Tensor
295            t2: torch.Tensor
296            is_current_and_default_stream_same: bool
297            is_default_and_user_stream_not_same: bool
298            is_stream_set: bool
299            is_stream_reset: bool
300            default_stream_query: bool
301            default_stream_id: int
302            user_stream_id: int
303
304        # The test aims at checking different stream proporties.
305        @torch.jit.script
306        def test_get_stream():
307            device_index = torch.cuda.current_device()
308            device = torch.device("cuda:" + str(device_index))
309            current_stream = torch.cuda.current_stream(device)
310            default_stream = torch.cuda.default_stream(device)
311            user_stream = torch.cuda.Stream()
312
313            # Check if the current and default streams are the same on the device
314            is_current_and_default_stream_same = (
315                current_stream.id() == default_stream.id()
316            )
317            # Check if user stream and default stream are not the same on the device
318            is_default_and_user_stream_not_same = (
319                default_stream.id() != user_stream.id()
320            )
321
322            with torch.cuda.stream(user_stream):
323                is_stream_set = (
324                    torch.cuda.current_stream(device).id() == user_stream.id()
325                )
326
327            # Check if the stream was reset to current_stream
328            is_stream_reset = (
329                torch.cuda.current_stream(device).id() == current_stream.id()
330            )
331
332            tensor1 = torch.rand(10000, 10000, device="cuda")
333            tensor2 = torch.mm(tensor1, tensor1).to("cuda")
334            default_stream.synchronize()
335            default_stream_query = default_stream.query()
336
337            # Capture all the results in the class Result
338            res = Result(
339                tensor1,
340                tensor2,
341                is_current_and_default_stream_same,
342                is_default_and_user_stream_not_same,
343                is_stream_set,
344                is_stream_reset,
345                default_stream_query,
346                default_stream.id(),
347                user_stream.id(),
348            )
349            return res
350
351        result = test_get_stream()
352
353        self.assertEqual(torch.matmul(result.t1, result.t1), result.t2)
354        self.assertTrue(result.is_current_and_default_stream_same)
355        self.assertTrue(result.is_default_and_user_stream_not_same)
356        self.assertTrue(result.is_stream_set)
357        self.assertTrue(result.is_stream_reset)
358        self.assertTrue(result.default_stream_query)
359        self.assertEqual(
360            result.default_stream_id, 0
361        )  # Check if the default stream ID is always 0
362        self.assertNotEqual(
363            result.user_stream_id, 0
364        )  # Check if the user stream is always non zero
365
366        # Test the stream context manager. This test checks if the stream is switched
367        # to the user stream on using the stream context manager.
368        @torch.jit.script
369        def test_stream_context():
370            device_index = torch.cuda.current_device()
371            device = torch.device("cuda:" + str(device_index))
372            current_stream = torch.cuda.current_stream(device)
373            user_stream = torch.cuda.Stream()
374            A = torch.rand(1000, 1000, device="cuda")
375
376            with torch.cuda.stream(user_stream):
377                check = torch.cuda.current_stream(device).id() == user_stream.id()
378                B = torch.mm(A, A).to("cuda")
379            # Wait for B to be computed
380            user_stream.synchronize()
381            # Check if the stream has been reset on the current device
382            is_stream_reset = (
383                torch.cuda.current_stream(device).id() == current_stream.id()
384            )
385
386            return A, B, check, is_stream_reset
387
388        A, B, is_stream_set, is_stream_reset = test_stream_context()
389        self.assertEqual(torch.matmul(A, A), B)
390        self.assertTrue(
391            is_stream_set, "Error: Current stream was not set to user stream!"
392        )
393        self.assertTrue(
394            is_stream_reset, "Error: The stream was not restored to previous stream!"
395        )
396
397        # Test multiple nested streams. Check if the operations are computed as expected on the streams
398        # This test has been adapted from the eager mode tests available at test/test_cuda.py
399        @torch.jit.script
400        def test_multiple_stream():
401            prev_device_index = torch.cuda.current_device()
402            device = torch.device("cuda:" + str(prev_device_index))
403            prev_current_stream = torch.cuda.current_stream(device)
404            d1 = torch.device("cuda:0")
405            d2 = torch.device("cuda:1")
406            s1 = torch.cuda.Stream(d1, 0)
407            s2 = torch.cuda.Stream(d2, 0)
408
409            A = torch.rand(1000, 1000, device="cuda")
410            B = torch.rand(1000, 1000, device="cuda")
411            with torch.cuda.stream(s1):
412                C = torch.mm(A, A).to("cuda")
413                # Check if the stream and device have been set to s1
414                is_stream_s1 = torch.cuda.current_stream(d1).id() == s1.id()
415                is_device_s1 = torch.cuda.current_device() == s1.device_index()
416                with torch.cuda.stream(s2):
417                    # Check if the stream and device have been set to s2
418                    is_stream_s2 = torch.cuda.current_stream(d2).id() == s2.id()
419                    is_device_s2 = torch.cuda.current_device() == s2.device_index()
420                    D = torch.mm(B, B).to("cuda")
421                # Check if the stream and device have been set to s1
422                is_stream_s1_after = torch.cuda.current_stream(d1).id() == s1.id()
423                is_device_s1_after = torch.cuda.current_device() == s1.device_index()
424                # Wait for D to be computed
425                s2.synchronize()
426            # Wait for C to be computed on S1
427            s1.synchronize()
428
429            # Check if the stream and device has been restored to previous stream and device
430            is_device_current = torch.cuda.current_device() == prev_device_index
431            is_stream_current = (
432                torch.cuda.current_stream(device).id() == prev_current_stream.id()
433            )
434
435            check_stream = (
436                is_stream_s1
437                and is_stream_s2
438                and is_stream_s1_after
439                and is_stream_current
440            )
441            check_device = (
442                is_device_s1
443                and is_device_s2
444                and is_device_s1_after
445                and is_device_current
446            )
447            return A, B, C, D, check_stream, check_device
448
449        A, B, C, D, check_stream, check_device = test_multiple_stream()
450
451        self.assertEqual(torch.matmul(A, A), C)
452        self.assertEqual(torch.matmul(B, B), D)
453        self.assertTrue(check_stream)
454        self.assertTrue(check_device)
455
456        # Test multiple streams waiting on each other for the operations to be completed.
457        @torch.jit.script
458        def test_data_dependency_between_streams():
459            device_index = torch.cuda.current_device()
460            device = torch.device("cuda:" + str(device_index))
461            prev_current_stream = torch.cuda.current_stream(device)
462            d = torch.device("cuda:0")
463            s1 = torch.cuda.Stream(d, 0)
464            s2 = torch.cuda.Stream(d, 0)
465            event = torch.cuda.Event(False, False, False)
466
467            A = torch.rand(1000, 1000, device="cuda")
468            with torch.cuda.stream(s1):
469                is_stream_s1 = torch.cuda.current_stream(device).id() == s1.id()
470                B = torch.mm(A, A).to("cuda")
471            s1.record_event(event)
472            # Check if the current_stream is reset
473            is_current_stream_1 = (
474                torch.cuda.current_stream(device).id() == prev_current_stream.id()
475            )
476            # Wait for ops on s1 to be computed
477            s2.wait_event(event)
478            with torch.cuda.stream(s2):
479                is_stream_s2 = torch.cuda.current_stream(device).id() == s2.id()
480                C = torch.mm(B, B).to("cuda")
481            # Wait for C to be computed
482            s2.synchronize()
483            # Check if the current_stream is reset
484            is_current_stream_2 = (
485                torch.cuda.current_stream(device).id() == prev_current_stream.id()
486            )
487
488            check_stream = (
489                is_current_stream_1
490                and is_current_stream_2
491                and is_stream_s1
492                and is_stream_s2
493            )
494            return A, B, C, check_stream
495
496        A, B, C, check_stream = test_data_dependency_between_streams()
497        self.assertEqual(torch.matmul(A, A), B)
498        self.assertEqual(torch.matmul(B, B), C)
499        self.assertTrue(check_stream)
500
501        # Test a simple CUDA event. Test if the CUDA event was created successfully
502        @torch.jit.script
503        def test_simple_event():
504            e = torch.cuda.Event(True, False, False)
505            return e is not None
506
507        self.assertTrue(test_simple_event(), "Could not create CUDA Event!")
508
509        # Record the CUDA event for operation torch.mm on the current stream
510        # and then test if the elapsed time is greater than 0. This test is also
511        # an adaption from eager mdoe CUDA tests available at test/test_cuda.py
512        @torch.jit.script
513        def test_event():
514            device_index = torch.cuda.current_device()
515            device = torch.device("cuda:" + str(device_index))
516            stream = torch.cuda.current_stream(device)
517            event = torch.cuda.Event(True, False, False)
518            is_true_event_query = event.query()
519            start_event = torch.cuda.Event(True, False, False)
520            stream.record_event(start_event)
521            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
522            tensor2 = torch.mm(tensor1, tensor1).to("cuda")
523            stream.record_event(event)
524            event.synchronize()
525            is_again_true_event_query = event.query()
526
527            if not (is_true_event_query and is_again_true_event_query):
528                return -1.0
529            return start_event.elapsed_time(event)
530
531        self.assertGreater(test_event(), 0)
532
533        # Check for stream synchronization , when a large tensor multiplication is
534        # computed on the stream. The stream.query should be true once the synchroniztion is done
535        @torch.jit.script
536        def test_stream_synchronize() -> float:
537            device_index = torch.cuda.current_device()
538            s = torch.cuda.Stream()
539            e_tik = torch.cuda.Event(True, False, False)
540            e_tok = torch.cuda.Event(True, False, False)
541
542            e_tik.record(s)
543            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
544            with torch.cuda.stream(s):
545                tensor2 = torch.mm(tensor1, tensor1).to("cuda")
546            s.synchronize()
547            e_tok.record(s)
548            e_tok.synchronize()
549
550            if not s.query():
551                return -1.0
552
553            # not necessary to check e_tik and e_tok, as elapsed_time would throw
554            # exception if otherwise.
555            return e_tik.elapsed_time(e_tok)
556
557        self.assertGreater(test_stream_synchronize(), 0)
558
559        # Test event synchronization for the event that records a stream doing
560        # a large tensor multiplication. Check if the elapsed time is greater than 0
561        # and the stream.query evaluates to true.
562        @torch.jit.script
563        def test_event_synchronize() -> float:
564            s = torch.cuda.Stream()
565            e_tik = torch.cuda.Event(True, False, False)
566            e_tok = torch.cuda.Event(True, False, False)
567
568            e_tik.record(s)
569            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
570            with torch.cuda.stream(s):
571                tensor = torch.mm(tensor1, tensor1).to("cuda")
572            s.record_event(e_tok)
573            e_tok.synchronize()
574            s.synchronize()
575
576            if not s.query():
577                return -1.0
578
579            # not necessary to check e_tik and e_tok, as elapsed_time would throw
580            # exception if otherwise.
581            return e_tik.elapsed_time(e_tok)
582
583        self.assertGreater(test_event_synchronize(), 0)
584
585        # Test for event wait. Check if event waits for the all the operations on
586        # the stream to be done. Check for synchronizations and query on the streams
587        # and events. This test is adapted from eager mode tests for CUDA. Please refer
588        # test/test_cuda.py
589        @torch.jit.script
590        def test_event_wait() -> float:
591            device_index = torch.cuda.current_device()
592            device = torch.device("cuda:" + str(device_index))
593            s0 = torch.cuda.current_stream(device)
594            s1 = torch.cuda.Stream()
595            e_tik = torch.cuda.Event(True, True, False)
596            e_tok = torch.cuda.Event(True, True, False)
597
598            e_tik.record(s0)
599            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
600            with torch.cuda.stream(s0):
601                tensor2 = torch.mm(tensor1, tensor1).cuda()
602            e_sync = torch.cuda.Event(True, False, False)
603            e_sync.record(torch.cuda.current_stream(device))
604            e_sync.wait(s1)
605            with torch.cuda.stream(s1):
606                tensor3 = torch.rand(1000000000, 1000000000, device="cuda")
607                tensor4 = torch.mm(tensor3, tensor3).cuda()
608            s1.synchronize()
609            e_tok.record(torch.cuda.current_stream(device))
610            e_tok.synchronize()
611            s0.synchronize()
612
613            if not s0.query() or not s1.query() or not e_sync.query():
614                return -1.0
615
616            # not necessary to check e_tik and e_tok, as elapsed_time would throw
617            # exception if otherwise.
618            return e_tik.elapsed_time(e_tok)
619
620        self.assertGreater(test_event_wait(), 0)
621
622        # Test for stream wait_event. Checks if the stream waits on the event
623        @torch.jit.script
624        def test_wait_event():
625            d1 = torch.device("cuda:1")
626
627            with torch.cuda.device(d1):
628                s0 = torch.cuda.current_stream(d1)
629                tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
630                tensor2 = torch.mm(tensor1, tensor1).to("cuda")
631                e0 = torch.cuda.Event(False, False, False)
632                s0.record_event(e0)
633
634            s1 = torch.cuda.current_stream(torch.device("cuda:0"))
635            s1.wait_event(e0)
636            s1.synchronize()
637
638            return e0.query() and s0.query() and s1.query()
639
640        self.assertTrue(test_wait_event())
641
642        # Test if a scripted module with cuda streams can be saved, loaded and executed
643        def test_save_load(self):
644            class Model(torch.nn.Module):
645                def forward(self):
646                    s = torch.cuda.Stream()
647                    a = torch.rand(3, 4, device="cuda")
648                    b = torch.rand(3, 4, device="cuda")
649
650                    with torch.cuda.stream(s):
651                        is_stream_s = torch.cuda.current_stream(s.device).id() == s.id()
652                        c = torch.cat((a, b), 0).cuda()
653                    s.synchronize()
654                    return is_stream_s, a, b, c
655
656            model = Model()
657
658            # Script the model and save
659            script_model = torch.jit.script(model)
660            is_stream_s, a, b, c = script_model()
661            # Verify if the output is correct
662            self.assertTrue(is_stream_s)
663            self.assertEqual(torch.cat((a, b), 0), c)
664
665            # Save and load scripted model
666            load_model = self.getExportImportCopy(script_model)
667            is_stream_s, a_load, b_load, c_load = load_model()
668            self.assertTrue(is_stream_s)
669            self.assertEqual(torch.cat((a_load, b_load), 0), c_load)
670
671    # Make sure that cuda._exchange_device doesn't get DCE'ed
672    @unittest.skipIf(not TEST_CUDA, "Cuda not available")
673    def test__exchange_device_op(self):
674        def fn(device: int, tensor):
675            torch.cuda._exchange_device(device)
676            return tensor.cos().relu()
677
678        fn_s = torch.jit.script(fn)
679        # Just check the graph, don't run it. Otherwise, we'd  need to
680        # run this test on a multi-gpu CI runner, which is overkill.
681        g = fn_s.graph
682        FileCheck().check("cuda::_exchange_device(").run(g)
683        torch._C._jit_pass_inline(g)
684        FileCheck().check("cuda::_exchange_device(").run(g)
685
686    # Make sure that cuda._maybe_exchange_device doesn't get DCE'ed
687    @unittest.skipIf(not TEST_CUDA, "Cuda not available")
688    def test__maybe_exchange_device_op(self):
689        def fn(device: int, tensor):
690            torch.cuda._maybe_exchange_device(device)
691            return tensor.cos().relu()
692
693        fn_s = torch.jit.script(fn)
694        # Just check the graph, don't run it. Otherwise, we'd  need to
695        # run this test on a multi-gpu CI runner, which is overkill.
696        g = fn_s.graph
697        FileCheck().check("cuda::_maybe_exchange_device(").run(g)
698        torch._C._jit_pass_inline(g)
699        FileCheck().check("cuda::_maybe_exchange_device(").run(g)
700