xref: /aosp_15_r20/external/pytorch/test/jit/test_cuda.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport gc
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Workerfrom typing import NamedTuple
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
13*da0073e9SAndroid Build Coastguard Worker    NoTest,
14*da0073e9SAndroid Build Coastguard Worker    skipCUDANonDefaultStreamIf,
15*da0073e9SAndroid Build Coastguard Worker    skipIfRocm,
16*da0073e9SAndroid Build Coastguard Worker    TEST_CUDA,
17*da0073e9SAndroid Build Coastguard Worker)
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
22*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
23*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker# If GPU is not available, then do not run the tests
26*da0073e9SAndroid Build Coastguard Workerif not TEST_CUDA:
27*da0073e9SAndroid Build Coastguard Worker    print("CUDA not available, skipping tests", file=sys.stderr)
28*da0073e9SAndroid Build Coastguard Worker    JitTestCase = NoTest  # noqa: F811
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard WorkerTEST_LARGE_TENSOR = TEST_CUDA
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker# If GPU is available, then initialize the cuda context and check
33*da0073e9SAndroid Build Coastguard Worker# if there is memory available to allocate for LARGE Tensors.
34*da0073e9SAndroid Build Coastguard Workerif TEST_CUDA:
35*da0073e9SAndroid Build Coastguard Worker    torch.ones(1).cuda()  # initialize cuda context
36*da0073e9SAndroid Build Coastguard Worker    TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
39*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
40*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
41*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
42*da0073e9SAndroid Build Coastguard Worker        "instead."
43*da0073e9SAndroid Build Coastguard Worker    )
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerclass TestCUDA(JitTestCase):
47*da0073e9SAndroid Build Coastguard Worker    """
48*da0073e9SAndroid Build Coastguard Worker    A suite of tests for the CUDA API in TorchScript.
49*da0073e9SAndroid Build Coastguard Worker    """
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
52*da0073e9SAndroid Build Coastguard Worker        gc.collect()
53*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
54*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
57*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
58*da0073e9SAndroid Build Coastguard Worker    def test_cuda_synchronize(self):
59*da0073e9SAndroid Build Coastguard Worker        # Test device synchronization.
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
62*da0073e9SAndroid Build Coastguard Worker        def test_device_synchronize():
63*da0073e9SAndroid Build Coastguard Worker            prev_current_device_index = torch.cuda.current_device()
64*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
65*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize("cuda")
66*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize("cuda:0")
67*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize(0)
68*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize(torch.device("cuda:1"))
69*da0073e9SAndroid Build Coastguard Worker            after_current_device_index = torch.cuda.current_device()
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker            # Check if the current device index is same as the device index before
72*da0073e9SAndroid Build Coastguard Worker            # synchronizing the device.
73*da0073e9SAndroid Build Coastguard Worker            return prev_current_device_index == after_current_device_index
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
76*da0073e9SAndroid Build Coastguard Worker        def test_multi_device_synchronize():
77*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize(torch.device("cuda:0"))
78*da0073e9SAndroid Build Coastguard Worker            prev_current_device_index = torch.cuda.current_device()
79*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize(1)
80*da0073e9SAndroid Build Coastguard Worker            after_current_device_index = torch.cuda.current_device()
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker            # Check if the current device index is same as the device index before
83*da0073e9SAndroid Build Coastguard Worker            # synchronizing the device.
84*da0073e9SAndroid Build Coastguard Worker            return prev_current_device_index == after_current_device_index
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_device_synchronize)
87*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("cuda::synchronize(").run(test_device_synchronize.graph)
88*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_multi_device_synchronize)
89*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("cuda::synchronize(").run(test_multi_device_synchronize.graph)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    def test_stream_args(self):
92*da0073e9SAndroid Build Coastguard Worker        # Test stream creation with default arguments
93*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
94*da0073e9SAndroid Build Coastguard Worker        def stream_default_args() -> bool:
95*da0073e9SAndroid Build Coastguard Worker            s = torch.cuda.Stream()
96*da0073e9SAndroid Build Coastguard Worker            return s.device_index() == torch.cuda.current_device()
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
99*da0073e9SAndroid Build Coastguard Worker        def stream_default_args_for_device() -> bool:
100*da0073e9SAndroid Build Coastguard Worker            s = torch.cuda.Stream(priority=0)
101*da0073e9SAndroid Build Coastguard Worker            return s.device_index() == torch.cuda.current_device()
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
104*da0073e9SAndroid Build Coastguard Worker        def stream_default_args_for_priority() -> bool:
105*da0073e9SAndroid Build Coastguard Worker            d = torch.device("cuda:1")
106*da0073e9SAndroid Build Coastguard Worker            s = torch.cuda.Stream(d)
107*da0073e9SAndroid Build Coastguard Worker            return s.device_index() == 1
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
110*da0073e9SAndroid Build Coastguard Worker        def stream_args_all() -> bool:
111*da0073e9SAndroid Build Coastguard Worker            d = torch.device("cuda:0")
112*da0073e9SAndroid Build Coastguard Worker            s = torch.cuda.Stream(d, 0)
113*da0073e9SAndroid Build Coastguard Worker            return s.device_index() == 0
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(stream_default_args)
116*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(stream_default_args_for_device)
117*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(stream_default_args_for_priority)
118*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(stream_args_all)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    def test_event_args(self):
121*da0073e9SAndroid Build Coastguard Worker        # Test Event creation with default arguments
122*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
123*da0073e9SAndroid Build Coastguard Worker        def event_default_args() -> bool:
124*da0073e9SAndroid Build Coastguard Worker            e = torch.cuda.Event()
125*da0073e9SAndroid Build Coastguard Worker            return e is not None
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event_default_args)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
130*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
131*da0073e9SAndroid Build Coastguard Worker    def test_current_stream(self):
132*da0073e9SAndroid Build Coastguard Worker        # Test current stream on the device and check if the stream device index
133*da0073e9SAndroid Build Coastguard Worker        # matches with the device ID
134*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
135*da0073e9SAndroid Build Coastguard Worker        def fn():
136*da0073e9SAndroid Build Coastguard Worker            device_index = torch.cuda.current_device()
137*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda:" + str(device_index))
138*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.current_stream(device)
139*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.current_stream(torch.device("cuda:1"))
140*da0073e9SAndroid Build Coastguard Worker            s2 = torch.cuda.current_stream(torch.device("cuda:0"))
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker            return s0.device_index(), s1.device_index(), s2.device_index()
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        d0, d1, d2 = fn()
145*da0073e9SAndroid Build Coastguard Worker        # By default, the current device ID is 0.
146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, d0)
147*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, d1)
148*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, d2)
149*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d0, d2)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker        # Test current_stream API by passing device ID as an argument and
152*da0073e9SAndroid Build Coastguard Worker        # and check if the stream device index matches with the device ID
153*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
154*da0073e9SAndroid Build Coastguard Worker        def fn_with_device_index_args():
155*da0073e9SAndroid Build Coastguard Worker            device_index = torch.cuda.current_device()
156*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.current_stream(device_index)
157*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.current_stream(1)
158*da0073e9SAndroid Build Coastguard Worker            s2 = torch.cuda.current_stream(0)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker            return s0.device_index(), s1.device_index(), s2.device_index()
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker        d0, d1, d2 = fn_with_device_index_args()
163*da0073e9SAndroid Build Coastguard Worker        # By default, the current device ID is 0.
164*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, d0)
165*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, d1)
166*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, d2)
167*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d0, d2)
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
170*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
171*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
172*da0073e9SAndroid Build Coastguard Worker    @skipCUDANonDefaultStreamIf(True)
173*da0073e9SAndroid Build Coastguard Worker    def test_streams_and_events(self):
174*da0073e9SAndroid Build Coastguard Worker        # Test default_stream API by passing device ID as an argument and
175*da0073e9SAndroid Build Coastguard Worker        # and check if the stream device index matches with the device ID
176*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
177*da0073e9SAndroid Build Coastguard Worker        def test_default_streams_with_device_index_args():
178*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.default_stream(0)
179*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.default_stream(1)
180*da0073e9SAndroid Build Coastguard Worker            return s0.device_index(), s1.device_index()
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        d0, d1 = test_default_streams_with_device_index_args()
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d0, 0)
185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d1, 1)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        # This test checks for the default stream ID is set to 0 on the device
188*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
189*da0073e9SAndroid Build Coastguard Worker        def test_default_streams():
190*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.default_stream(torch.device("cuda:0"))
191*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.default_stream(torch.device("cuda:1"))
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker            d = torch.device("cuda:1")
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker            # Check the current stream id and default id are same
196*da0073e9SAndroid Build Coastguard Worker            # on the current device. The current device id by default is 0
197*da0073e9SAndroid Build Coastguard Worker            s2 = torch.cuda.current_stream(torch.device("cuda:0"))
198*da0073e9SAndroid Build Coastguard Worker            check_s2 = s2.id() == s0.id()
199*da0073e9SAndroid Build Coastguard Worker            check_d0 = torch.cuda.current_device() == s2.device_index()
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker            # Set the current device to d1 and check if the stream
202*da0073e9SAndroid Build Coastguard Worker            # has been set to the default stream on d1
203*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.device(d):
204*da0073e9SAndroid Build Coastguard Worker                s3 = torch.cuda.current_stream(d)
205*da0073e9SAndroid Build Coastguard Worker                check_s3 = s3.id() == s1.id()
206*da0073e9SAndroid Build Coastguard Worker                check_d1 = torch.cuda.current_device() == s3.device_index()
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker            # Check if the current device was reset to 0
209*da0073e9SAndroid Build Coastguard Worker            is_device_d0 = torch.cuda.current_device() == s2.device_index()
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker            return (
212*da0073e9SAndroid Build Coastguard Worker                s0.device_index(),
213*da0073e9SAndroid Build Coastguard Worker                s1.device_index(),
214*da0073e9SAndroid Build Coastguard Worker                check_s2,
215*da0073e9SAndroid Build Coastguard Worker                check_s3,
216*da0073e9SAndroid Build Coastguard Worker                check_d0,
217*da0073e9SAndroid Build Coastguard Worker                check_d1,
218*da0073e9SAndroid Build Coastguard Worker                is_device_d0,
219*da0073e9SAndroid Build Coastguard Worker            )
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker        (
222*da0073e9SAndroid Build Coastguard Worker            d0,
223*da0073e9SAndroid Build Coastguard Worker            d1,
224*da0073e9SAndroid Build Coastguard Worker            check_s2,
225*da0073e9SAndroid Build Coastguard Worker            check_s3,
226*da0073e9SAndroid Build Coastguard Worker            check_d0,
227*da0073e9SAndroid Build Coastguard Worker            check_d1,
228*da0073e9SAndroid Build Coastguard Worker            is_device_d0,
229*da0073e9SAndroid Build Coastguard Worker        ) = test_default_streams()
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d0, 0)
232*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d1, 1)
233*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check_s2)
234*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check_s3)
235*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check_d0)
236*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check_d1)
237*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_device_d0)
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker        # This test checks if the Stream Context manager is a no op
240*da0073e9SAndroid Build Coastguard Worker        # when the stream is none for `with torch.cuda.stream`
241*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
242*da0073e9SAndroid Build Coastguard Worker        def test_set_none_stream():
243*da0073e9SAndroid Build Coastguard Worker            device_index = torch.cuda.current_device()
244*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda:" + str(device_index))
245*da0073e9SAndroid Build Coastguard Worker            current_stream = torch.cuda.current_stream(device)
246*da0073e9SAndroid Build Coastguard Worker            default_stream = torch.cuda.default_stream(device)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker            # When stream is none, check if this operation is a no-op
249*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(None):
250*da0073e9SAndroid Build Coastguard Worker                cur_device_index = torch.cuda.current_device()
251*da0073e9SAndroid Build Coastguard Worker                is_device_index_same = cur_device_index == device_index
252*da0073e9SAndroid Build Coastguard Worker                is_current_stream_same = (
253*da0073e9SAndroid Build Coastguard Worker                    torch.cuda.current_stream(device).id() == current_stream.id()
254*da0073e9SAndroid Build Coastguard Worker                )
255*da0073e9SAndroid Build Coastguard Worker                is_default_stream_same = (
256*da0073e9SAndroid Build Coastguard Worker                    torch.cuda.default_stream(device).id() == default_stream.id()
257*da0073e9SAndroid Build Coastguard Worker                )
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker            # Check if the device index, current stream and default streams have not changed
260*da0073e9SAndroid Build Coastguard Worker            are_streams_same = (
261*da0073e9SAndroid Build Coastguard Worker                is_device_index_same
262*da0073e9SAndroid Build Coastguard Worker                and is_current_stream_same
263*da0073e9SAndroid Build Coastguard Worker                and is_default_stream_same
264*da0073e9SAndroid Build Coastguard Worker            )
265*da0073e9SAndroid Build Coastguard Worker            return are_streams_same
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_set_none_stream())
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker        # This test checks if the Device Context manager is a no op
270*da0073e9SAndroid Build Coastguard Worker        # when the device is none for `with torch.cuda.device`
271*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
272*da0073e9SAndroid Build Coastguard Worker        def test_set_device_none():
273*da0073e9SAndroid Build Coastguard Worker            device_index = torch.cuda.current_device()
274*da0073e9SAndroid Build Coastguard Worker            # When device is none, check if this operation is a no-op
275*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.device(None):
276*da0073e9SAndroid Build Coastguard Worker                # Check if the current device is the same
277*da0073e9SAndroid Build Coastguard Worker                is_device_same = torch.cuda.current_device() == device_index
278*da0073e9SAndroid Build Coastguard Worker            return is_device_same
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_set_device_none())
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker        # Check if a CUDA JIT stream is created
283*da0073e9SAndroid Build Coastguard Worker        # on the current_device
284*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
285*da0073e9SAndroid Build Coastguard Worker        def test_simple_stream():
286*da0073e9SAndroid Build Coastguard Worker            device_index = torch.cuda.current_device()
287*da0073e9SAndroid Build Coastguard Worker            s = torch.cuda.Stream()
288*da0073e9SAndroid Build Coastguard Worker            return device_index == s.device_index()
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_simple_stream(), "Could not create Stream!")
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        # Class used to store results for the test: test_get_stream.
293*da0073e9SAndroid Build Coastguard Worker        class Result(NamedTuple):
294*da0073e9SAndroid Build Coastguard Worker            t1: torch.Tensor
295*da0073e9SAndroid Build Coastguard Worker            t2: torch.Tensor
296*da0073e9SAndroid Build Coastguard Worker            is_current_and_default_stream_same: bool
297*da0073e9SAndroid Build Coastguard Worker            is_default_and_user_stream_not_same: bool
298*da0073e9SAndroid Build Coastguard Worker            is_stream_set: bool
299*da0073e9SAndroid Build Coastguard Worker            is_stream_reset: bool
300*da0073e9SAndroid Build Coastguard Worker            default_stream_query: bool
301*da0073e9SAndroid Build Coastguard Worker            default_stream_id: int
302*da0073e9SAndroid Build Coastguard Worker            user_stream_id: int
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        # The test aims at checking different stream proporties.
305*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
306*da0073e9SAndroid Build Coastguard Worker        def test_get_stream():
307*da0073e9SAndroid Build Coastguard Worker            device_index = torch.cuda.current_device()
308*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda:" + str(device_index))
309*da0073e9SAndroid Build Coastguard Worker            current_stream = torch.cuda.current_stream(device)
310*da0073e9SAndroid Build Coastguard Worker            default_stream = torch.cuda.default_stream(device)
311*da0073e9SAndroid Build Coastguard Worker            user_stream = torch.cuda.Stream()
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker            # Check if the current and default streams are the same on the device
314*da0073e9SAndroid Build Coastguard Worker            is_current_and_default_stream_same = (
315*da0073e9SAndroid Build Coastguard Worker                current_stream.id() == default_stream.id()
316*da0073e9SAndroid Build Coastguard Worker            )
317*da0073e9SAndroid Build Coastguard Worker            # Check if user stream and default stream are not the same on the device
318*da0073e9SAndroid Build Coastguard Worker            is_default_and_user_stream_not_same = (
319*da0073e9SAndroid Build Coastguard Worker                default_stream.id() != user_stream.id()
320*da0073e9SAndroid Build Coastguard Worker            )
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(user_stream):
323*da0073e9SAndroid Build Coastguard Worker                is_stream_set = (
324*da0073e9SAndroid Build Coastguard Worker                    torch.cuda.current_stream(device).id() == user_stream.id()
325*da0073e9SAndroid Build Coastguard Worker                )
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker            # Check if the stream was reset to current_stream
328*da0073e9SAndroid Build Coastguard Worker            is_stream_reset = (
329*da0073e9SAndroid Build Coastguard Worker                torch.cuda.current_stream(device).id() == current_stream.id()
330*da0073e9SAndroid Build Coastguard Worker            )
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker            tensor1 = torch.rand(10000, 10000, device="cuda")
333*da0073e9SAndroid Build Coastguard Worker            tensor2 = torch.mm(tensor1, tensor1).to("cuda")
334*da0073e9SAndroid Build Coastguard Worker            default_stream.synchronize()
335*da0073e9SAndroid Build Coastguard Worker            default_stream_query = default_stream.query()
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker            # Capture all the results in the class Result
338*da0073e9SAndroid Build Coastguard Worker            res = Result(
339*da0073e9SAndroid Build Coastguard Worker                tensor1,
340*da0073e9SAndroid Build Coastguard Worker                tensor2,
341*da0073e9SAndroid Build Coastguard Worker                is_current_and_default_stream_same,
342*da0073e9SAndroid Build Coastguard Worker                is_default_and_user_stream_not_same,
343*da0073e9SAndroid Build Coastguard Worker                is_stream_set,
344*da0073e9SAndroid Build Coastguard Worker                is_stream_reset,
345*da0073e9SAndroid Build Coastguard Worker                default_stream_query,
346*da0073e9SAndroid Build Coastguard Worker                default_stream.id(),
347*da0073e9SAndroid Build Coastguard Worker                user_stream.id(),
348*da0073e9SAndroid Build Coastguard Worker            )
349*da0073e9SAndroid Build Coastguard Worker            return res
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker        result = test_get_stream()
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.matmul(result.t1, result.t1), result.t2)
354*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result.is_current_and_default_stream_same)
355*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result.is_default_and_user_stream_not_same)
356*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result.is_stream_set)
357*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result.is_stream_reset)
358*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result.default_stream_query)
359*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
360*da0073e9SAndroid Build Coastguard Worker            result.default_stream_id, 0
361*da0073e9SAndroid Build Coastguard Worker        )  # Check if the default stream ID is always 0
362*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(
363*da0073e9SAndroid Build Coastguard Worker            result.user_stream_id, 0
364*da0073e9SAndroid Build Coastguard Worker        )  # Check if the user stream is always non zero
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        # Test the stream context manager. This test checks if the stream is switched
367*da0073e9SAndroid Build Coastguard Worker        # to the user stream on using the stream context manager.
368*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
369*da0073e9SAndroid Build Coastguard Worker        def test_stream_context():
370*da0073e9SAndroid Build Coastguard Worker            device_index = torch.cuda.current_device()
371*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda:" + str(device_index))
372*da0073e9SAndroid Build Coastguard Worker            current_stream = torch.cuda.current_stream(device)
373*da0073e9SAndroid Build Coastguard Worker            user_stream = torch.cuda.Stream()
374*da0073e9SAndroid Build Coastguard Worker            A = torch.rand(1000, 1000, device="cuda")
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(user_stream):
377*da0073e9SAndroid Build Coastguard Worker                check = torch.cuda.current_stream(device).id() == user_stream.id()
378*da0073e9SAndroid Build Coastguard Worker                B = torch.mm(A, A).to("cuda")
379*da0073e9SAndroid Build Coastguard Worker            # Wait for B to be computed
380*da0073e9SAndroid Build Coastguard Worker            user_stream.synchronize()
381*da0073e9SAndroid Build Coastguard Worker            # Check if the stream has been reset on the current device
382*da0073e9SAndroid Build Coastguard Worker            is_stream_reset = (
383*da0073e9SAndroid Build Coastguard Worker                torch.cuda.current_stream(device).id() == current_stream.id()
384*da0073e9SAndroid Build Coastguard Worker            )
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker            return A, B, check, is_stream_reset
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker        A, B, is_stream_set, is_stream_reset = test_stream_context()
389*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.matmul(A, A), B)
390*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
391*da0073e9SAndroid Build Coastguard Worker            is_stream_set, "Error: Current stream was not set to user stream!"
392*da0073e9SAndroid Build Coastguard Worker        )
393*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
394*da0073e9SAndroid Build Coastguard Worker            is_stream_reset, "Error: The stream was not restored to previous stream!"
395*da0073e9SAndroid Build Coastguard Worker        )
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker        # Test multiple nested streams. Check if the operations are computed as expected on the streams
398*da0073e9SAndroid Build Coastguard Worker        # This test has been adapted from the eager mode tests available at test/test_cuda.py
399*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
400*da0073e9SAndroid Build Coastguard Worker        def test_multiple_stream():
401*da0073e9SAndroid Build Coastguard Worker            prev_device_index = torch.cuda.current_device()
402*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda:" + str(prev_device_index))
403*da0073e9SAndroid Build Coastguard Worker            prev_current_stream = torch.cuda.current_stream(device)
404*da0073e9SAndroid Build Coastguard Worker            d1 = torch.device("cuda:0")
405*da0073e9SAndroid Build Coastguard Worker            d2 = torch.device("cuda:1")
406*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.Stream(d1, 0)
407*da0073e9SAndroid Build Coastguard Worker            s2 = torch.cuda.Stream(d2, 0)
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker            A = torch.rand(1000, 1000, device="cuda")
410*da0073e9SAndroid Build Coastguard Worker            B = torch.rand(1000, 1000, device="cuda")
411*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s1):
412*da0073e9SAndroid Build Coastguard Worker                C = torch.mm(A, A).to("cuda")
413*da0073e9SAndroid Build Coastguard Worker                # Check if the stream and device have been set to s1
414*da0073e9SAndroid Build Coastguard Worker                is_stream_s1 = torch.cuda.current_stream(d1).id() == s1.id()
415*da0073e9SAndroid Build Coastguard Worker                is_device_s1 = torch.cuda.current_device() == s1.device_index()
416*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.stream(s2):
417*da0073e9SAndroid Build Coastguard Worker                    # Check if the stream and device have been set to s2
418*da0073e9SAndroid Build Coastguard Worker                    is_stream_s2 = torch.cuda.current_stream(d2).id() == s2.id()
419*da0073e9SAndroid Build Coastguard Worker                    is_device_s2 = torch.cuda.current_device() == s2.device_index()
420*da0073e9SAndroid Build Coastguard Worker                    D = torch.mm(B, B).to("cuda")
421*da0073e9SAndroid Build Coastguard Worker                # Check if the stream and device have been set to s1
422*da0073e9SAndroid Build Coastguard Worker                is_stream_s1_after = torch.cuda.current_stream(d1).id() == s1.id()
423*da0073e9SAndroid Build Coastguard Worker                is_device_s1_after = torch.cuda.current_device() == s1.device_index()
424*da0073e9SAndroid Build Coastguard Worker                # Wait for D to be computed
425*da0073e9SAndroid Build Coastguard Worker                s2.synchronize()
426*da0073e9SAndroid Build Coastguard Worker            # Wait for C to be computed on S1
427*da0073e9SAndroid Build Coastguard Worker            s1.synchronize()
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker            # Check if the stream and device has been restored to previous stream and device
430*da0073e9SAndroid Build Coastguard Worker            is_device_current = torch.cuda.current_device() == prev_device_index
431*da0073e9SAndroid Build Coastguard Worker            is_stream_current = (
432*da0073e9SAndroid Build Coastguard Worker                torch.cuda.current_stream(device).id() == prev_current_stream.id()
433*da0073e9SAndroid Build Coastguard Worker            )
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker            check_stream = (
436*da0073e9SAndroid Build Coastguard Worker                is_stream_s1
437*da0073e9SAndroid Build Coastguard Worker                and is_stream_s2
438*da0073e9SAndroid Build Coastguard Worker                and is_stream_s1_after
439*da0073e9SAndroid Build Coastguard Worker                and is_stream_current
440*da0073e9SAndroid Build Coastguard Worker            )
441*da0073e9SAndroid Build Coastguard Worker            check_device = (
442*da0073e9SAndroid Build Coastguard Worker                is_device_s1
443*da0073e9SAndroid Build Coastguard Worker                and is_device_s2
444*da0073e9SAndroid Build Coastguard Worker                and is_device_s1_after
445*da0073e9SAndroid Build Coastguard Worker                and is_device_current
446*da0073e9SAndroid Build Coastguard Worker            )
447*da0073e9SAndroid Build Coastguard Worker            return A, B, C, D, check_stream, check_device
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker        A, B, C, D, check_stream, check_device = test_multiple_stream()
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.matmul(A, A), C)
452*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.matmul(B, B), D)
453*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check_stream)
454*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check_device)
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        # Test multiple streams waiting on each other for the operations to be completed.
457*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
458*da0073e9SAndroid Build Coastguard Worker        def test_data_dependency_between_streams():
459*da0073e9SAndroid Build Coastguard Worker            device_index = torch.cuda.current_device()
460*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda:" + str(device_index))
461*da0073e9SAndroid Build Coastguard Worker            prev_current_stream = torch.cuda.current_stream(device)
462*da0073e9SAndroid Build Coastguard Worker            d = torch.device("cuda:0")
463*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.Stream(d, 0)
464*da0073e9SAndroid Build Coastguard Worker            s2 = torch.cuda.Stream(d, 0)
465*da0073e9SAndroid Build Coastguard Worker            event = torch.cuda.Event(False, False, False)
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker            A = torch.rand(1000, 1000, device="cuda")
468*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s1):
469*da0073e9SAndroid Build Coastguard Worker                is_stream_s1 = torch.cuda.current_stream(device).id() == s1.id()
470*da0073e9SAndroid Build Coastguard Worker                B = torch.mm(A, A).to("cuda")
471*da0073e9SAndroid Build Coastguard Worker            s1.record_event(event)
472*da0073e9SAndroid Build Coastguard Worker            # Check if the current_stream is reset
473*da0073e9SAndroid Build Coastguard Worker            is_current_stream_1 = (
474*da0073e9SAndroid Build Coastguard Worker                torch.cuda.current_stream(device).id() == prev_current_stream.id()
475*da0073e9SAndroid Build Coastguard Worker            )
476*da0073e9SAndroid Build Coastguard Worker            # Wait for ops on s1 to be computed
477*da0073e9SAndroid Build Coastguard Worker            s2.wait_event(event)
478*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s2):
479*da0073e9SAndroid Build Coastguard Worker                is_stream_s2 = torch.cuda.current_stream(device).id() == s2.id()
480*da0073e9SAndroid Build Coastguard Worker                C = torch.mm(B, B).to("cuda")
481*da0073e9SAndroid Build Coastguard Worker            # Wait for C to be computed
482*da0073e9SAndroid Build Coastguard Worker            s2.synchronize()
483*da0073e9SAndroid Build Coastguard Worker            # Check if the current_stream is reset
484*da0073e9SAndroid Build Coastguard Worker            is_current_stream_2 = (
485*da0073e9SAndroid Build Coastguard Worker                torch.cuda.current_stream(device).id() == prev_current_stream.id()
486*da0073e9SAndroid Build Coastguard Worker            )
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker            check_stream = (
489*da0073e9SAndroid Build Coastguard Worker                is_current_stream_1
490*da0073e9SAndroid Build Coastguard Worker                and is_current_stream_2
491*da0073e9SAndroid Build Coastguard Worker                and is_stream_s1
492*da0073e9SAndroid Build Coastguard Worker                and is_stream_s2
493*da0073e9SAndroid Build Coastguard Worker            )
494*da0073e9SAndroid Build Coastguard Worker            return A, B, C, check_stream
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker        A, B, C, check_stream = test_data_dependency_between_streams()
497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.matmul(A, A), B)
498*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.matmul(B, B), C)
499*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check_stream)
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker        # Test a simple CUDA event. Test if the CUDA event was created successfully
502*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
503*da0073e9SAndroid Build Coastguard Worker        def test_simple_event():
504*da0073e9SAndroid Build Coastguard Worker            e = torch.cuda.Event(True, False, False)
505*da0073e9SAndroid Build Coastguard Worker            return e is not None
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_simple_event(), "Could not create CUDA Event!")
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker        # Record the CUDA event for operation torch.mm on the current stream
510*da0073e9SAndroid Build Coastguard Worker        # and then test if the elapsed time is greater than 0. This test is also
511*da0073e9SAndroid Build Coastguard Worker        # an adaption from eager mdoe CUDA tests available at test/test_cuda.py
512*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
513*da0073e9SAndroid Build Coastguard Worker        def test_event():
514*da0073e9SAndroid Build Coastguard Worker            device_index = torch.cuda.current_device()
515*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda:" + str(device_index))
516*da0073e9SAndroid Build Coastguard Worker            stream = torch.cuda.current_stream(device)
517*da0073e9SAndroid Build Coastguard Worker            event = torch.cuda.Event(True, False, False)
518*da0073e9SAndroid Build Coastguard Worker            is_true_event_query = event.query()
519*da0073e9SAndroid Build Coastguard Worker            start_event = torch.cuda.Event(True, False, False)
520*da0073e9SAndroid Build Coastguard Worker            stream.record_event(start_event)
521*da0073e9SAndroid Build Coastguard Worker            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
522*da0073e9SAndroid Build Coastguard Worker            tensor2 = torch.mm(tensor1, tensor1).to("cuda")
523*da0073e9SAndroid Build Coastguard Worker            stream.record_event(event)
524*da0073e9SAndroid Build Coastguard Worker            event.synchronize()
525*da0073e9SAndroid Build Coastguard Worker            is_again_true_event_query = event.query()
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker            if not (is_true_event_query and is_again_true_event_query):
528*da0073e9SAndroid Build Coastguard Worker                return -1.0
529*da0073e9SAndroid Build Coastguard Worker            return start_event.elapsed_time(event)
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(test_event(), 0)
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker        # Check for stream synchronization , when a large tensor multiplication is
534*da0073e9SAndroid Build Coastguard Worker        # computed on the stream. The stream.query should be true once the synchroniztion is done
535*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
536*da0073e9SAndroid Build Coastguard Worker        def test_stream_synchronize() -> float:
537*da0073e9SAndroid Build Coastguard Worker            device_index = torch.cuda.current_device()
538*da0073e9SAndroid Build Coastguard Worker            s = torch.cuda.Stream()
539*da0073e9SAndroid Build Coastguard Worker            e_tik = torch.cuda.Event(True, False, False)
540*da0073e9SAndroid Build Coastguard Worker            e_tok = torch.cuda.Event(True, False, False)
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker            e_tik.record(s)
543*da0073e9SAndroid Build Coastguard Worker            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
544*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s):
545*da0073e9SAndroid Build Coastguard Worker                tensor2 = torch.mm(tensor1, tensor1).to("cuda")
546*da0073e9SAndroid Build Coastguard Worker            s.synchronize()
547*da0073e9SAndroid Build Coastguard Worker            e_tok.record(s)
548*da0073e9SAndroid Build Coastguard Worker            e_tok.synchronize()
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker            if not s.query():
551*da0073e9SAndroid Build Coastguard Worker                return -1.0
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker            # not necessary to check e_tik and e_tok, as elapsed_time would throw
554*da0073e9SAndroid Build Coastguard Worker            # exception if otherwise.
555*da0073e9SAndroid Build Coastguard Worker            return e_tik.elapsed_time(e_tok)
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(test_stream_synchronize(), 0)
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker        # Test event synchronization for the event that records a stream doing
560*da0073e9SAndroid Build Coastguard Worker        # a large tensor multiplication. Check if the elapsed time is greater than 0
561*da0073e9SAndroid Build Coastguard Worker        # and the stream.query evaluates to true.
562*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
563*da0073e9SAndroid Build Coastguard Worker        def test_event_synchronize() -> float:
564*da0073e9SAndroid Build Coastguard Worker            s = torch.cuda.Stream()
565*da0073e9SAndroid Build Coastguard Worker            e_tik = torch.cuda.Event(True, False, False)
566*da0073e9SAndroid Build Coastguard Worker            e_tok = torch.cuda.Event(True, False, False)
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker            e_tik.record(s)
569*da0073e9SAndroid Build Coastguard Worker            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
570*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s):
571*da0073e9SAndroid Build Coastguard Worker                tensor = torch.mm(tensor1, tensor1).to("cuda")
572*da0073e9SAndroid Build Coastguard Worker            s.record_event(e_tok)
573*da0073e9SAndroid Build Coastguard Worker            e_tok.synchronize()
574*da0073e9SAndroid Build Coastguard Worker            s.synchronize()
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker            if not s.query():
577*da0073e9SAndroid Build Coastguard Worker                return -1.0
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker            # not necessary to check e_tik and e_tok, as elapsed_time would throw
580*da0073e9SAndroid Build Coastguard Worker            # exception if otherwise.
581*da0073e9SAndroid Build Coastguard Worker            return e_tik.elapsed_time(e_tok)
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(test_event_synchronize(), 0)
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker        # Test for event wait. Check if event waits for the all the operations on
586*da0073e9SAndroid Build Coastguard Worker        # the stream to be done. Check for synchronizations and query on the streams
587*da0073e9SAndroid Build Coastguard Worker        # and events. This test is adapted from eager mode tests for CUDA. Please refer
588*da0073e9SAndroid Build Coastguard Worker        # test/test_cuda.py
589*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
590*da0073e9SAndroid Build Coastguard Worker        def test_event_wait() -> float:
591*da0073e9SAndroid Build Coastguard Worker            device_index = torch.cuda.current_device()
592*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda:" + str(device_index))
593*da0073e9SAndroid Build Coastguard Worker            s0 = torch.cuda.current_stream(device)
594*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.Stream()
595*da0073e9SAndroid Build Coastguard Worker            e_tik = torch.cuda.Event(True, True, False)
596*da0073e9SAndroid Build Coastguard Worker            e_tok = torch.cuda.Event(True, True, False)
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker            e_tik.record(s0)
599*da0073e9SAndroid Build Coastguard Worker            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
600*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s0):
601*da0073e9SAndroid Build Coastguard Worker                tensor2 = torch.mm(tensor1, tensor1).cuda()
602*da0073e9SAndroid Build Coastguard Worker            e_sync = torch.cuda.Event(True, False, False)
603*da0073e9SAndroid Build Coastguard Worker            e_sync.record(torch.cuda.current_stream(device))
604*da0073e9SAndroid Build Coastguard Worker            e_sync.wait(s1)
605*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(s1):
606*da0073e9SAndroid Build Coastguard Worker                tensor3 = torch.rand(1000000000, 1000000000, device="cuda")
607*da0073e9SAndroid Build Coastguard Worker                tensor4 = torch.mm(tensor3, tensor3).cuda()
608*da0073e9SAndroid Build Coastguard Worker            s1.synchronize()
609*da0073e9SAndroid Build Coastguard Worker            e_tok.record(torch.cuda.current_stream(device))
610*da0073e9SAndroid Build Coastguard Worker            e_tok.synchronize()
611*da0073e9SAndroid Build Coastguard Worker            s0.synchronize()
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker            if not s0.query() or not s1.query() or not e_sync.query():
614*da0073e9SAndroid Build Coastguard Worker                return -1.0
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker            # not necessary to check e_tik and e_tok, as elapsed_time would throw
617*da0073e9SAndroid Build Coastguard Worker            # exception if otherwise.
618*da0073e9SAndroid Build Coastguard Worker            return e_tik.elapsed_time(e_tok)
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(test_event_wait(), 0)
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker        # Test for stream wait_event. Checks if the stream waits on the event
623*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
624*da0073e9SAndroid Build Coastguard Worker        def test_wait_event():
625*da0073e9SAndroid Build Coastguard Worker            d1 = torch.device("cuda:1")
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.device(d1):
628*da0073e9SAndroid Build Coastguard Worker                s0 = torch.cuda.current_stream(d1)
629*da0073e9SAndroid Build Coastguard Worker                tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
630*da0073e9SAndroid Build Coastguard Worker                tensor2 = torch.mm(tensor1, tensor1).to("cuda")
631*da0073e9SAndroid Build Coastguard Worker                e0 = torch.cuda.Event(False, False, False)
632*da0073e9SAndroid Build Coastguard Worker                s0.record_event(e0)
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker            s1 = torch.cuda.current_stream(torch.device("cuda:0"))
635*da0073e9SAndroid Build Coastguard Worker            s1.wait_event(e0)
636*da0073e9SAndroid Build Coastguard Worker            s1.synchronize()
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker            return e0.query() and s0.query() and s1.query()
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_wait_event())
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker        # Test if a scripted module with cuda streams can be saved, loaded and executed
643*da0073e9SAndroid Build Coastguard Worker        def test_save_load(self):
644*da0073e9SAndroid Build Coastguard Worker            class Model(torch.nn.Module):
645*da0073e9SAndroid Build Coastguard Worker                def forward(self):
646*da0073e9SAndroid Build Coastguard Worker                    s = torch.cuda.Stream()
647*da0073e9SAndroid Build Coastguard Worker                    a = torch.rand(3, 4, device="cuda")
648*da0073e9SAndroid Build Coastguard Worker                    b = torch.rand(3, 4, device="cuda")
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker                    with torch.cuda.stream(s):
651*da0073e9SAndroid Build Coastguard Worker                        is_stream_s = torch.cuda.current_stream(s.device).id() == s.id()
652*da0073e9SAndroid Build Coastguard Worker                        c = torch.cat((a, b), 0).cuda()
653*da0073e9SAndroid Build Coastguard Worker                    s.synchronize()
654*da0073e9SAndroid Build Coastguard Worker                    return is_stream_s, a, b, c
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker            model = Model()
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker            # Script the model and save
659*da0073e9SAndroid Build Coastguard Worker            script_model = torch.jit.script(model)
660*da0073e9SAndroid Build Coastguard Worker            is_stream_s, a, b, c = script_model()
661*da0073e9SAndroid Build Coastguard Worker            # Verify if the output is correct
662*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(is_stream_s)
663*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cat((a, b), 0), c)
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker            # Save and load scripted model
666*da0073e9SAndroid Build Coastguard Worker            load_model = self.getExportImportCopy(script_model)
667*da0073e9SAndroid Build Coastguard Worker            is_stream_s, a_load, b_load, c_load = load_model()
668*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(is_stream_s)
669*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cat((a_load, b_load), 0), c_load)
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker    # Make sure that cuda._exchange_device doesn't get DCE'ed
672*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "Cuda not available")
673*da0073e9SAndroid Build Coastguard Worker    def test__exchange_device_op(self):
674*da0073e9SAndroid Build Coastguard Worker        def fn(device: int, tensor):
675*da0073e9SAndroid Build Coastguard Worker            torch.cuda._exchange_device(device)
676*da0073e9SAndroid Build Coastguard Worker            return tensor.cos().relu()
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
679*da0073e9SAndroid Build Coastguard Worker        # Just check the graph, don't run it. Otherwise, we'd  need to
680*da0073e9SAndroid Build Coastguard Worker        # run this test on a multi-gpu CI runner, which is overkill.
681*da0073e9SAndroid Build Coastguard Worker        g = fn_s.graph
682*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("cuda::_exchange_device(").run(g)
683*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_inline(g)
684*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("cuda::_exchange_device(").run(g)
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker    # Make sure that cuda._maybe_exchange_device doesn't get DCE'ed
687*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "Cuda not available")
688*da0073e9SAndroid Build Coastguard Worker    def test__maybe_exchange_device_op(self):
689*da0073e9SAndroid Build Coastguard Worker        def fn(device: int, tensor):
690*da0073e9SAndroid Build Coastguard Worker            torch.cuda._maybe_exchange_device(device)
691*da0073e9SAndroid Build Coastguard Worker            return tensor.cos().relu()
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
694*da0073e9SAndroid Build Coastguard Worker        # Just check the graph, don't run it. Otherwise, we'd  need to
695*da0073e9SAndroid Build Coastguard Worker        # run this test on a multi-gpu CI runner, which is overkill.
696*da0073e9SAndroid Build Coastguard Worker        g = fn_s.graph
697*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("cuda::_maybe_exchange_device(").run(g)
698*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_inline(g)
699*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("cuda::_maybe_exchange_device(").run(g)
700