1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport gc 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerimport unittest 7*da0073e9SAndroid Build Coastguard Workerfrom typing import NamedTuple 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 13*da0073e9SAndroid Build Coastguard Worker NoTest, 14*da0073e9SAndroid Build Coastguard Worker skipCUDANonDefaultStreamIf, 15*da0073e9SAndroid Build Coastguard Worker skipIfRocm, 16*da0073e9SAndroid Build Coastguard Worker TEST_CUDA, 17*da0073e9SAndroid Build Coastguard Worker) 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 22*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 23*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker# If GPU is not available, then do not run the tests 26*da0073e9SAndroid Build Coastguard Workerif not TEST_CUDA: 27*da0073e9SAndroid Build Coastguard Worker print("CUDA not available, skipping tests", file=sys.stderr) 28*da0073e9SAndroid Build Coastguard Worker JitTestCase = NoTest # noqa: F811 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard WorkerTEST_LARGE_TENSOR = TEST_CUDA 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker# If GPU is available, then initialize the cuda context and check 33*da0073e9SAndroid Build Coastguard Worker# if there is memory available to allocate for LARGE Tensors. 34*da0073e9SAndroid Build Coastguard Workerif TEST_CUDA: 35*da0073e9SAndroid Build Coastguard Worker torch.ones(1).cuda() # initialize cuda context 36*da0073e9SAndroid Build Coastguard Worker TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 39*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 40*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 41*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 42*da0073e9SAndroid Build Coastguard Worker "instead." 43*da0073e9SAndroid Build Coastguard Worker ) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerclass TestCUDA(JitTestCase): 47*da0073e9SAndroid Build Coastguard Worker """ 48*da0073e9SAndroid Build Coastguard Worker A suite of tests for the CUDA API in TorchScript. 49*da0073e9SAndroid Build Coastguard Worker """ 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 52*da0073e9SAndroid Build Coastguard Worker gc.collect() 53*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 54*da0073e9SAndroid Build Coastguard Worker super().tearDown() 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 57*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 58*da0073e9SAndroid Build Coastguard Worker def test_cuda_synchronize(self): 59*da0073e9SAndroid Build Coastguard Worker # Test device synchronization. 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 62*da0073e9SAndroid Build Coastguard Worker def test_device_synchronize(): 63*da0073e9SAndroid Build Coastguard Worker prev_current_device_index = torch.cuda.current_device() 64*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 65*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize("cuda") 66*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize("cuda:0") 67*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(0) 68*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(torch.device("cuda:1")) 69*da0073e9SAndroid Build Coastguard Worker after_current_device_index = torch.cuda.current_device() 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker # Check if the current device index is same as the device index before 72*da0073e9SAndroid Build Coastguard Worker # synchronizing the device. 73*da0073e9SAndroid Build Coastguard Worker return prev_current_device_index == after_current_device_index 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 76*da0073e9SAndroid Build Coastguard Worker def test_multi_device_synchronize(): 77*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(torch.device("cuda:0")) 78*da0073e9SAndroid Build Coastguard Worker prev_current_device_index = torch.cuda.current_device() 79*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize(1) 80*da0073e9SAndroid Build Coastguard Worker after_current_device_index = torch.cuda.current_device() 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker # Check if the current device index is same as the device index before 83*da0073e9SAndroid Build Coastguard Worker # synchronizing the device. 84*da0073e9SAndroid Build Coastguard Worker return prev_current_device_index == after_current_device_index 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_device_synchronize) 87*da0073e9SAndroid Build Coastguard Worker FileCheck().check("cuda::synchronize(").run(test_device_synchronize.graph) 88*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_multi_device_synchronize) 89*da0073e9SAndroid Build Coastguard Worker FileCheck().check("cuda::synchronize(").run(test_multi_device_synchronize.graph) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def test_stream_args(self): 92*da0073e9SAndroid Build Coastguard Worker # Test stream creation with default arguments 93*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 94*da0073e9SAndroid Build Coastguard Worker def stream_default_args() -> bool: 95*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 96*da0073e9SAndroid Build Coastguard Worker return s.device_index() == torch.cuda.current_device() 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 99*da0073e9SAndroid Build Coastguard Worker def stream_default_args_for_device() -> bool: 100*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream(priority=0) 101*da0073e9SAndroid Build Coastguard Worker return s.device_index() == torch.cuda.current_device() 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 104*da0073e9SAndroid Build Coastguard Worker def stream_default_args_for_priority() -> bool: 105*da0073e9SAndroid Build Coastguard Worker d = torch.device("cuda:1") 106*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream(d) 107*da0073e9SAndroid Build Coastguard Worker return s.device_index() == 1 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 110*da0073e9SAndroid Build Coastguard Worker def stream_args_all() -> bool: 111*da0073e9SAndroid Build Coastguard Worker d = torch.device("cuda:0") 112*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream(d, 0) 113*da0073e9SAndroid Build Coastguard Worker return s.device_index() == 0 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker self.assertTrue(stream_default_args) 116*da0073e9SAndroid Build Coastguard Worker self.assertTrue(stream_default_args_for_device) 117*da0073e9SAndroid Build Coastguard Worker self.assertTrue(stream_default_args_for_priority) 118*da0073e9SAndroid Build Coastguard Worker self.assertTrue(stream_args_all) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def test_event_args(self): 121*da0073e9SAndroid Build Coastguard Worker # Test Event creation with default arguments 122*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 123*da0073e9SAndroid Build Coastguard Worker def event_default_args() -> bool: 124*da0073e9SAndroid Build Coastguard Worker e = torch.cuda.Event() 125*da0073e9SAndroid Build Coastguard Worker return e is not None 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event_default_args) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 130*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 131*da0073e9SAndroid Build Coastguard Worker def test_current_stream(self): 132*da0073e9SAndroid Build Coastguard Worker # Test current stream on the device and check if the stream device index 133*da0073e9SAndroid Build Coastguard Worker # matches with the device ID 134*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 135*da0073e9SAndroid Build Coastguard Worker def fn(): 136*da0073e9SAndroid Build Coastguard Worker device_index = torch.cuda.current_device() 137*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:" + str(device_index)) 138*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream(device) 139*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.current_stream(torch.device("cuda:1")) 140*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.current_stream(torch.device("cuda:0")) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker return s0.device_index(), s1.device_index(), s2.device_index() 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker d0, d1, d2 = fn() 145*da0073e9SAndroid Build Coastguard Worker # By default, the current device ID is 0. 146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, d0) 147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, d1) 148*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, d2) 149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d0, d2) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker # Test current_stream API by passing device ID as an argument and 152*da0073e9SAndroid Build Coastguard Worker # and check if the stream device index matches with the device ID 153*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 154*da0073e9SAndroid Build Coastguard Worker def fn_with_device_index_args(): 155*da0073e9SAndroid Build Coastguard Worker device_index = torch.cuda.current_device() 156*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream(device_index) 157*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.current_stream(1) 158*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.current_stream(0) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker return s0.device_index(), s1.device_index(), s2.device_index() 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker d0, d1, d2 = fn_with_device_index_args() 163*da0073e9SAndroid Build Coastguard Worker # By default, the current device ID is 0. 164*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, d0) 165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, d1) 166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, d2) 167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d0, d2) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 170*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 171*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") 172*da0073e9SAndroid Build Coastguard Worker @skipCUDANonDefaultStreamIf(True) 173*da0073e9SAndroid Build Coastguard Worker def test_streams_and_events(self): 174*da0073e9SAndroid Build Coastguard Worker # Test default_stream API by passing device ID as an argument and 175*da0073e9SAndroid Build Coastguard Worker # and check if the stream device index matches with the device ID 176*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 177*da0073e9SAndroid Build Coastguard Worker def test_default_streams_with_device_index_args(): 178*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.default_stream(0) 179*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.default_stream(1) 180*da0073e9SAndroid Build Coastguard Worker return s0.device_index(), s1.device_index() 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker d0, d1 = test_default_streams_with_device_index_args() 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d0, 0) 185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d1, 1) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker # This test checks for the default stream ID is set to 0 on the device 188*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 189*da0073e9SAndroid Build Coastguard Worker def test_default_streams(): 190*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.default_stream(torch.device("cuda:0")) 191*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.default_stream(torch.device("cuda:1")) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker d = torch.device("cuda:1") 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker # Check the current stream id and default id are same 196*da0073e9SAndroid Build Coastguard Worker # on the current device. The current device id by default is 0 197*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.current_stream(torch.device("cuda:0")) 198*da0073e9SAndroid Build Coastguard Worker check_s2 = s2.id() == s0.id() 199*da0073e9SAndroid Build Coastguard Worker check_d0 = torch.cuda.current_device() == s2.device_index() 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker # Set the current device to d1 and check if the stream 202*da0073e9SAndroid Build Coastguard Worker # has been set to the default stream on d1 203*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d): 204*da0073e9SAndroid Build Coastguard Worker s3 = torch.cuda.current_stream(d) 205*da0073e9SAndroid Build Coastguard Worker check_s3 = s3.id() == s1.id() 206*da0073e9SAndroid Build Coastguard Worker check_d1 = torch.cuda.current_device() == s3.device_index() 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker # Check if the current device was reset to 0 209*da0073e9SAndroid Build Coastguard Worker is_device_d0 = torch.cuda.current_device() == s2.device_index() 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker return ( 212*da0073e9SAndroid Build Coastguard Worker s0.device_index(), 213*da0073e9SAndroid Build Coastguard Worker s1.device_index(), 214*da0073e9SAndroid Build Coastguard Worker check_s2, 215*da0073e9SAndroid Build Coastguard Worker check_s3, 216*da0073e9SAndroid Build Coastguard Worker check_d0, 217*da0073e9SAndroid Build Coastguard Worker check_d1, 218*da0073e9SAndroid Build Coastguard Worker is_device_d0, 219*da0073e9SAndroid Build Coastguard Worker ) 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker ( 222*da0073e9SAndroid Build Coastguard Worker d0, 223*da0073e9SAndroid Build Coastguard Worker d1, 224*da0073e9SAndroid Build Coastguard Worker check_s2, 225*da0073e9SAndroid Build Coastguard Worker check_s3, 226*da0073e9SAndroid Build Coastguard Worker check_d0, 227*da0073e9SAndroid Build Coastguard Worker check_d1, 228*da0073e9SAndroid Build Coastguard Worker is_device_d0, 229*da0073e9SAndroid Build Coastguard Worker ) = test_default_streams() 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d0, 0) 232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d1, 1) 233*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check_s2) 234*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check_s3) 235*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check_d0) 236*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check_d1) 237*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_device_d0) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker # This test checks if the Stream Context manager is a no op 240*da0073e9SAndroid Build Coastguard Worker # when the stream is none for `with torch.cuda.stream` 241*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 242*da0073e9SAndroid Build Coastguard Worker def test_set_none_stream(): 243*da0073e9SAndroid Build Coastguard Worker device_index = torch.cuda.current_device() 244*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:" + str(device_index)) 245*da0073e9SAndroid Build Coastguard Worker current_stream = torch.cuda.current_stream(device) 246*da0073e9SAndroid Build Coastguard Worker default_stream = torch.cuda.default_stream(device) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker # When stream is none, check if this operation is a no-op 249*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(None): 250*da0073e9SAndroid Build Coastguard Worker cur_device_index = torch.cuda.current_device() 251*da0073e9SAndroid Build Coastguard Worker is_device_index_same = cur_device_index == device_index 252*da0073e9SAndroid Build Coastguard Worker is_current_stream_same = ( 253*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream(device).id() == current_stream.id() 254*da0073e9SAndroid Build Coastguard Worker ) 255*da0073e9SAndroid Build Coastguard Worker is_default_stream_same = ( 256*da0073e9SAndroid Build Coastguard Worker torch.cuda.default_stream(device).id() == default_stream.id() 257*da0073e9SAndroid Build Coastguard Worker ) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker # Check if the device index, current stream and default streams have not changed 260*da0073e9SAndroid Build Coastguard Worker are_streams_same = ( 261*da0073e9SAndroid Build Coastguard Worker is_device_index_same 262*da0073e9SAndroid Build Coastguard Worker and is_current_stream_same 263*da0073e9SAndroid Build Coastguard Worker and is_default_stream_same 264*da0073e9SAndroid Build Coastguard Worker ) 265*da0073e9SAndroid Build Coastguard Worker return are_streams_same 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_set_none_stream()) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker # This test checks if the Device Context manager is a no op 270*da0073e9SAndroid Build Coastguard Worker # when the device is none for `with torch.cuda.device` 271*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 272*da0073e9SAndroid Build Coastguard Worker def test_set_device_none(): 273*da0073e9SAndroid Build Coastguard Worker device_index = torch.cuda.current_device() 274*da0073e9SAndroid Build Coastguard Worker # When device is none, check if this operation is a no-op 275*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(None): 276*da0073e9SAndroid Build Coastguard Worker # Check if the current device is the same 277*da0073e9SAndroid Build Coastguard Worker is_device_same = torch.cuda.current_device() == device_index 278*da0073e9SAndroid Build Coastguard Worker return is_device_same 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_set_device_none()) 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker # Check if a CUDA JIT stream is created 283*da0073e9SAndroid Build Coastguard Worker # on the current_device 284*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 285*da0073e9SAndroid Build Coastguard Worker def test_simple_stream(): 286*da0073e9SAndroid Build Coastguard Worker device_index = torch.cuda.current_device() 287*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 288*da0073e9SAndroid Build Coastguard Worker return device_index == s.device_index() 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_simple_stream(), "Could not create Stream!") 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker # Class used to store results for the test: test_get_stream. 293*da0073e9SAndroid Build Coastguard Worker class Result(NamedTuple): 294*da0073e9SAndroid Build Coastguard Worker t1: torch.Tensor 295*da0073e9SAndroid Build Coastguard Worker t2: torch.Tensor 296*da0073e9SAndroid Build Coastguard Worker is_current_and_default_stream_same: bool 297*da0073e9SAndroid Build Coastguard Worker is_default_and_user_stream_not_same: bool 298*da0073e9SAndroid Build Coastguard Worker is_stream_set: bool 299*da0073e9SAndroid Build Coastguard Worker is_stream_reset: bool 300*da0073e9SAndroid Build Coastguard Worker default_stream_query: bool 301*da0073e9SAndroid Build Coastguard Worker default_stream_id: int 302*da0073e9SAndroid Build Coastguard Worker user_stream_id: int 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker # The test aims at checking different stream proporties. 305*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 306*da0073e9SAndroid Build Coastguard Worker def test_get_stream(): 307*da0073e9SAndroid Build Coastguard Worker device_index = torch.cuda.current_device() 308*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:" + str(device_index)) 309*da0073e9SAndroid Build Coastguard Worker current_stream = torch.cuda.current_stream(device) 310*da0073e9SAndroid Build Coastguard Worker default_stream = torch.cuda.default_stream(device) 311*da0073e9SAndroid Build Coastguard Worker user_stream = torch.cuda.Stream() 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker # Check if the current and default streams are the same on the device 314*da0073e9SAndroid Build Coastguard Worker is_current_and_default_stream_same = ( 315*da0073e9SAndroid Build Coastguard Worker current_stream.id() == default_stream.id() 316*da0073e9SAndroid Build Coastguard Worker ) 317*da0073e9SAndroid Build Coastguard Worker # Check if user stream and default stream are not the same on the device 318*da0073e9SAndroid Build Coastguard Worker is_default_and_user_stream_not_same = ( 319*da0073e9SAndroid Build Coastguard Worker default_stream.id() != user_stream.id() 320*da0073e9SAndroid Build Coastguard Worker ) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(user_stream): 323*da0073e9SAndroid Build Coastguard Worker is_stream_set = ( 324*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream(device).id() == user_stream.id() 325*da0073e9SAndroid Build Coastguard Worker ) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker # Check if the stream was reset to current_stream 328*da0073e9SAndroid Build Coastguard Worker is_stream_reset = ( 329*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream(device).id() == current_stream.id() 330*da0073e9SAndroid Build Coastguard Worker ) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.rand(10000, 10000, device="cuda") 333*da0073e9SAndroid Build Coastguard Worker tensor2 = torch.mm(tensor1, tensor1).to("cuda") 334*da0073e9SAndroid Build Coastguard Worker default_stream.synchronize() 335*da0073e9SAndroid Build Coastguard Worker default_stream_query = default_stream.query() 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker # Capture all the results in the class Result 338*da0073e9SAndroid Build Coastguard Worker res = Result( 339*da0073e9SAndroid Build Coastguard Worker tensor1, 340*da0073e9SAndroid Build Coastguard Worker tensor2, 341*da0073e9SAndroid Build Coastguard Worker is_current_and_default_stream_same, 342*da0073e9SAndroid Build Coastguard Worker is_default_and_user_stream_not_same, 343*da0073e9SAndroid Build Coastguard Worker is_stream_set, 344*da0073e9SAndroid Build Coastguard Worker is_stream_reset, 345*da0073e9SAndroid Build Coastguard Worker default_stream_query, 346*da0073e9SAndroid Build Coastguard Worker default_stream.id(), 347*da0073e9SAndroid Build Coastguard Worker user_stream.id(), 348*da0073e9SAndroid Build Coastguard Worker ) 349*da0073e9SAndroid Build Coastguard Worker return res 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker result = test_get_stream() 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.matmul(result.t1, result.t1), result.t2) 354*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.is_current_and_default_stream_same) 355*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.is_default_and_user_stream_not_same) 356*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.is_stream_set) 357*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.is_stream_reset) 358*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.default_stream_query) 359*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 360*da0073e9SAndroid Build Coastguard Worker result.default_stream_id, 0 361*da0073e9SAndroid Build Coastguard Worker ) # Check if the default stream ID is always 0 362*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual( 363*da0073e9SAndroid Build Coastguard Worker result.user_stream_id, 0 364*da0073e9SAndroid Build Coastguard Worker ) # Check if the user stream is always non zero 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker # Test the stream context manager. This test checks if the stream is switched 367*da0073e9SAndroid Build Coastguard Worker # to the user stream on using the stream context manager. 368*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 369*da0073e9SAndroid Build Coastguard Worker def test_stream_context(): 370*da0073e9SAndroid Build Coastguard Worker device_index = torch.cuda.current_device() 371*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:" + str(device_index)) 372*da0073e9SAndroid Build Coastguard Worker current_stream = torch.cuda.current_stream(device) 373*da0073e9SAndroid Build Coastguard Worker user_stream = torch.cuda.Stream() 374*da0073e9SAndroid Build Coastguard Worker A = torch.rand(1000, 1000, device="cuda") 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(user_stream): 377*da0073e9SAndroid Build Coastguard Worker check = torch.cuda.current_stream(device).id() == user_stream.id() 378*da0073e9SAndroid Build Coastguard Worker B = torch.mm(A, A).to("cuda") 379*da0073e9SAndroid Build Coastguard Worker # Wait for B to be computed 380*da0073e9SAndroid Build Coastguard Worker user_stream.synchronize() 381*da0073e9SAndroid Build Coastguard Worker # Check if the stream has been reset on the current device 382*da0073e9SAndroid Build Coastguard Worker is_stream_reset = ( 383*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream(device).id() == current_stream.id() 384*da0073e9SAndroid Build Coastguard Worker ) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker return A, B, check, is_stream_reset 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker A, B, is_stream_set, is_stream_reset = test_stream_context() 389*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.matmul(A, A), B) 390*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 391*da0073e9SAndroid Build Coastguard Worker is_stream_set, "Error: Current stream was not set to user stream!" 392*da0073e9SAndroid Build Coastguard Worker ) 393*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 394*da0073e9SAndroid Build Coastguard Worker is_stream_reset, "Error: The stream was not restored to previous stream!" 395*da0073e9SAndroid Build Coastguard Worker ) 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker # Test multiple nested streams. Check if the operations are computed as expected on the streams 398*da0073e9SAndroid Build Coastguard Worker # This test has been adapted from the eager mode tests available at test/test_cuda.py 399*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 400*da0073e9SAndroid Build Coastguard Worker def test_multiple_stream(): 401*da0073e9SAndroid Build Coastguard Worker prev_device_index = torch.cuda.current_device() 402*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:" + str(prev_device_index)) 403*da0073e9SAndroid Build Coastguard Worker prev_current_stream = torch.cuda.current_stream(device) 404*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:0") 405*da0073e9SAndroid Build Coastguard Worker d2 = torch.device("cuda:1") 406*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream(d1, 0) 407*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.Stream(d2, 0) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker A = torch.rand(1000, 1000, device="cuda") 410*da0073e9SAndroid Build Coastguard Worker B = torch.rand(1000, 1000, device="cuda") 411*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 412*da0073e9SAndroid Build Coastguard Worker C = torch.mm(A, A).to("cuda") 413*da0073e9SAndroid Build Coastguard Worker # Check if the stream and device have been set to s1 414*da0073e9SAndroid Build Coastguard Worker is_stream_s1 = torch.cuda.current_stream(d1).id() == s1.id() 415*da0073e9SAndroid Build Coastguard Worker is_device_s1 = torch.cuda.current_device() == s1.device_index() 416*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s2): 417*da0073e9SAndroid Build Coastguard Worker # Check if the stream and device have been set to s2 418*da0073e9SAndroid Build Coastguard Worker is_stream_s2 = torch.cuda.current_stream(d2).id() == s2.id() 419*da0073e9SAndroid Build Coastguard Worker is_device_s2 = torch.cuda.current_device() == s2.device_index() 420*da0073e9SAndroid Build Coastguard Worker D = torch.mm(B, B).to("cuda") 421*da0073e9SAndroid Build Coastguard Worker # Check if the stream and device have been set to s1 422*da0073e9SAndroid Build Coastguard Worker is_stream_s1_after = torch.cuda.current_stream(d1).id() == s1.id() 423*da0073e9SAndroid Build Coastguard Worker is_device_s1_after = torch.cuda.current_device() == s1.device_index() 424*da0073e9SAndroid Build Coastguard Worker # Wait for D to be computed 425*da0073e9SAndroid Build Coastguard Worker s2.synchronize() 426*da0073e9SAndroid Build Coastguard Worker # Wait for C to be computed on S1 427*da0073e9SAndroid Build Coastguard Worker s1.synchronize() 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker # Check if the stream and device has been restored to previous stream and device 430*da0073e9SAndroid Build Coastguard Worker is_device_current = torch.cuda.current_device() == prev_device_index 431*da0073e9SAndroid Build Coastguard Worker is_stream_current = ( 432*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream(device).id() == prev_current_stream.id() 433*da0073e9SAndroid Build Coastguard Worker ) 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker check_stream = ( 436*da0073e9SAndroid Build Coastguard Worker is_stream_s1 437*da0073e9SAndroid Build Coastguard Worker and is_stream_s2 438*da0073e9SAndroid Build Coastguard Worker and is_stream_s1_after 439*da0073e9SAndroid Build Coastguard Worker and is_stream_current 440*da0073e9SAndroid Build Coastguard Worker ) 441*da0073e9SAndroid Build Coastguard Worker check_device = ( 442*da0073e9SAndroid Build Coastguard Worker is_device_s1 443*da0073e9SAndroid Build Coastguard Worker and is_device_s2 444*da0073e9SAndroid Build Coastguard Worker and is_device_s1_after 445*da0073e9SAndroid Build Coastguard Worker and is_device_current 446*da0073e9SAndroid Build Coastguard Worker ) 447*da0073e9SAndroid Build Coastguard Worker return A, B, C, D, check_stream, check_device 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker A, B, C, D, check_stream, check_device = test_multiple_stream() 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.matmul(A, A), C) 452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.matmul(B, B), D) 453*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check_stream) 454*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check_device) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker # Test multiple streams waiting on each other for the operations to be completed. 457*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 458*da0073e9SAndroid Build Coastguard Worker def test_data_dependency_between_streams(): 459*da0073e9SAndroid Build Coastguard Worker device_index = torch.cuda.current_device() 460*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:" + str(device_index)) 461*da0073e9SAndroid Build Coastguard Worker prev_current_stream = torch.cuda.current_stream(device) 462*da0073e9SAndroid Build Coastguard Worker d = torch.device("cuda:0") 463*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream(d, 0) 464*da0073e9SAndroid Build Coastguard Worker s2 = torch.cuda.Stream(d, 0) 465*da0073e9SAndroid Build Coastguard Worker event = torch.cuda.Event(False, False, False) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker A = torch.rand(1000, 1000, device="cuda") 468*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 469*da0073e9SAndroid Build Coastguard Worker is_stream_s1 = torch.cuda.current_stream(device).id() == s1.id() 470*da0073e9SAndroid Build Coastguard Worker B = torch.mm(A, A).to("cuda") 471*da0073e9SAndroid Build Coastguard Worker s1.record_event(event) 472*da0073e9SAndroid Build Coastguard Worker # Check if the current_stream is reset 473*da0073e9SAndroid Build Coastguard Worker is_current_stream_1 = ( 474*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream(device).id() == prev_current_stream.id() 475*da0073e9SAndroid Build Coastguard Worker ) 476*da0073e9SAndroid Build Coastguard Worker # Wait for ops on s1 to be computed 477*da0073e9SAndroid Build Coastguard Worker s2.wait_event(event) 478*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s2): 479*da0073e9SAndroid Build Coastguard Worker is_stream_s2 = torch.cuda.current_stream(device).id() == s2.id() 480*da0073e9SAndroid Build Coastguard Worker C = torch.mm(B, B).to("cuda") 481*da0073e9SAndroid Build Coastguard Worker # Wait for C to be computed 482*da0073e9SAndroid Build Coastguard Worker s2.synchronize() 483*da0073e9SAndroid Build Coastguard Worker # Check if the current_stream is reset 484*da0073e9SAndroid Build Coastguard Worker is_current_stream_2 = ( 485*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream(device).id() == prev_current_stream.id() 486*da0073e9SAndroid Build Coastguard Worker ) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker check_stream = ( 489*da0073e9SAndroid Build Coastguard Worker is_current_stream_1 490*da0073e9SAndroid Build Coastguard Worker and is_current_stream_2 491*da0073e9SAndroid Build Coastguard Worker and is_stream_s1 492*da0073e9SAndroid Build Coastguard Worker and is_stream_s2 493*da0073e9SAndroid Build Coastguard Worker ) 494*da0073e9SAndroid Build Coastguard Worker return A, B, C, check_stream 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker A, B, C, check_stream = test_data_dependency_between_streams() 497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.matmul(A, A), B) 498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.matmul(B, B), C) 499*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check_stream) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker # Test a simple CUDA event. Test if the CUDA event was created successfully 502*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 503*da0073e9SAndroid Build Coastguard Worker def test_simple_event(): 504*da0073e9SAndroid Build Coastguard Worker e = torch.cuda.Event(True, False, False) 505*da0073e9SAndroid Build Coastguard Worker return e is not None 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_simple_event(), "Could not create CUDA Event!") 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker # Record the CUDA event for operation torch.mm on the current stream 510*da0073e9SAndroid Build Coastguard Worker # and then test if the elapsed time is greater than 0. This test is also 511*da0073e9SAndroid Build Coastguard Worker # an adaption from eager mdoe CUDA tests available at test/test_cuda.py 512*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 513*da0073e9SAndroid Build Coastguard Worker def test_event(): 514*da0073e9SAndroid Build Coastguard Worker device_index = torch.cuda.current_device() 515*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:" + str(device_index)) 516*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.current_stream(device) 517*da0073e9SAndroid Build Coastguard Worker event = torch.cuda.Event(True, False, False) 518*da0073e9SAndroid Build Coastguard Worker is_true_event_query = event.query() 519*da0073e9SAndroid Build Coastguard Worker start_event = torch.cuda.Event(True, False, False) 520*da0073e9SAndroid Build Coastguard Worker stream.record_event(start_event) 521*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.rand(1000000000, 1000000000, device="cuda") 522*da0073e9SAndroid Build Coastguard Worker tensor2 = torch.mm(tensor1, tensor1).to("cuda") 523*da0073e9SAndroid Build Coastguard Worker stream.record_event(event) 524*da0073e9SAndroid Build Coastguard Worker event.synchronize() 525*da0073e9SAndroid Build Coastguard Worker is_again_true_event_query = event.query() 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker if not (is_true_event_query and is_again_true_event_query): 528*da0073e9SAndroid Build Coastguard Worker return -1.0 529*da0073e9SAndroid Build Coastguard Worker return start_event.elapsed_time(event) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker self.assertGreater(test_event(), 0) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker # Check for stream synchronization , when a large tensor multiplication is 534*da0073e9SAndroid Build Coastguard Worker # computed on the stream. The stream.query should be true once the synchroniztion is done 535*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 536*da0073e9SAndroid Build Coastguard Worker def test_stream_synchronize() -> float: 537*da0073e9SAndroid Build Coastguard Worker device_index = torch.cuda.current_device() 538*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 539*da0073e9SAndroid Build Coastguard Worker e_tik = torch.cuda.Event(True, False, False) 540*da0073e9SAndroid Build Coastguard Worker e_tok = torch.cuda.Event(True, False, False) 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker e_tik.record(s) 543*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.rand(1000000000, 1000000000, device="cuda") 544*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 545*da0073e9SAndroid Build Coastguard Worker tensor2 = torch.mm(tensor1, tensor1).to("cuda") 546*da0073e9SAndroid Build Coastguard Worker s.synchronize() 547*da0073e9SAndroid Build Coastguard Worker e_tok.record(s) 548*da0073e9SAndroid Build Coastguard Worker e_tok.synchronize() 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker if not s.query(): 551*da0073e9SAndroid Build Coastguard Worker return -1.0 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker # not necessary to check e_tik and e_tok, as elapsed_time would throw 554*da0073e9SAndroid Build Coastguard Worker # exception if otherwise. 555*da0073e9SAndroid Build Coastguard Worker return e_tik.elapsed_time(e_tok) 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker self.assertGreater(test_stream_synchronize(), 0) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker # Test event synchronization for the event that records a stream doing 560*da0073e9SAndroid Build Coastguard Worker # a large tensor multiplication. Check if the elapsed time is greater than 0 561*da0073e9SAndroid Build Coastguard Worker # and the stream.query evaluates to true. 562*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 563*da0073e9SAndroid Build Coastguard Worker def test_event_synchronize() -> float: 564*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 565*da0073e9SAndroid Build Coastguard Worker e_tik = torch.cuda.Event(True, False, False) 566*da0073e9SAndroid Build Coastguard Worker e_tok = torch.cuda.Event(True, False, False) 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker e_tik.record(s) 569*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.rand(1000000000, 1000000000, device="cuda") 570*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 571*da0073e9SAndroid Build Coastguard Worker tensor = torch.mm(tensor1, tensor1).to("cuda") 572*da0073e9SAndroid Build Coastguard Worker s.record_event(e_tok) 573*da0073e9SAndroid Build Coastguard Worker e_tok.synchronize() 574*da0073e9SAndroid Build Coastguard Worker s.synchronize() 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker if not s.query(): 577*da0073e9SAndroid Build Coastguard Worker return -1.0 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker # not necessary to check e_tik and e_tok, as elapsed_time would throw 580*da0073e9SAndroid Build Coastguard Worker # exception if otherwise. 581*da0073e9SAndroid Build Coastguard Worker return e_tik.elapsed_time(e_tok) 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker self.assertGreater(test_event_synchronize(), 0) 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker # Test for event wait. Check if event waits for the all the operations on 586*da0073e9SAndroid Build Coastguard Worker # the stream to be done. Check for synchronizations and query on the streams 587*da0073e9SAndroid Build Coastguard Worker # and events. This test is adapted from eager mode tests for CUDA. Please refer 588*da0073e9SAndroid Build Coastguard Worker # test/test_cuda.py 589*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 590*da0073e9SAndroid Build Coastguard Worker def test_event_wait() -> float: 591*da0073e9SAndroid Build Coastguard Worker device_index = torch.cuda.current_device() 592*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:" + str(device_index)) 593*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream(device) 594*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.Stream() 595*da0073e9SAndroid Build Coastguard Worker e_tik = torch.cuda.Event(True, True, False) 596*da0073e9SAndroid Build Coastguard Worker e_tok = torch.cuda.Event(True, True, False) 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker e_tik.record(s0) 599*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.rand(1000000000, 1000000000, device="cuda") 600*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s0): 601*da0073e9SAndroid Build Coastguard Worker tensor2 = torch.mm(tensor1, tensor1).cuda() 602*da0073e9SAndroid Build Coastguard Worker e_sync = torch.cuda.Event(True, False, False) 603*da0073e9SAndroid Build Coastguard Worker e_sync.record(torch.cuda.current_stream(device)) 604*da0073e9SAndroid Build Coastguard Worker e_sync.wait(s1) 605*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s1): 606*da0073e9SAndroid Build Coastguard Worker tensor3 = torch.rand(1000000000, 1000000000, device="cuda") 607*da0073e9SAndroid Build Coastguard Worker tensor4 = torch.mm(tensor3, tensor3).cuda() 608*da0073e9SAndroid Build Coastguard Worker s1.synchronize() 609*da0073e9SAndroid Build Coastguard Worker e_tok.record(torch.cuda.current_stream(device)) 610*da0073e9SAndroid Build Coastguard Worker e_tok.synchronize() 611*da0073e9SAndroid Build Coastguard Worker s0.synchronize() 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker if not s0.query() or not s1.query() or not e_sync.query(): 614*da0073e9SAndroid Build Coastguard Worker return -1.0 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker # not necessary to check e_tik and e_tok, as elapsed_time would throw 617*da0073e9SAndroid Build Coastguard Worker # exception if otherwise. 618*da0073e9SAndroid Build Coastguard Worker return e_tik.elapsed_time(e_tok) 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker self.assertGreater(test_event_wait(), 0) 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker # Test for stream wait_event. Checks if the stream waits on the event 623*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 624*da0073e9SAndroid Build Coastguard Worker def test_wait_event(): 625*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:1") 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 628*da0073e9SAndroid Build Coastguard Worker s0 = torch.cuda.current_stream(d1) 629*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.rand(1000000000, 1000000000, device="cuda") 630*da0073e9SAndroid Build Coastguard Worker tensor2 = torch.mm(tensor1, tensor1).to("cuda") 631*da0073e9SAndroid Build Coastguard Worker e0 = torch.cuda.Event(False, False, False) 632*da0073e9SAndroid Build Coastguard Worker s0.record_event(e0) 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.current_stream(torch.device("cuda:0")) 635*da0073e9SAndroid Build Coastguard Worker s1.wait_event(e0) 636*da0073e9SAndroid Build Coastguard Worker s1.synchronize() 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker return e0.query() and s0.query() and s1.query() 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_wait_event()) 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker # Test if a scripted module with cuda streams can be saved, loaded and executed 643*da0073e9SAndroid Build Coastguard Worker def test_save_load(self): 644*da0073e9SAndroid Build Coastguard Worker class Model(torch.nn.Module): 645*da0073e9SAndroid Build Coastguard Worker def forward(self): 646*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 647*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3, 4, device="cuda") 648*da0073e9SAndroid Build Coastguard Worker b = torch.rand(3, 4, device="cuda") 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 651*da0073e9SAndroid Build Coastguard Worker is_stream_s = torch.cuda.current_stream(s.device).id() == s.id() 652*da0073e9SAndroid Build Coastguard Worker c = torch.cat((a, b), 0).cuda() 653*da0073e9SAndroid Build Coastguard Worker s.synchronize() 654*da0073e9SAndroid Build Coastguard Worker return is_stream_s, a, b, c 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker model = Model() 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker # Script the model and save 659*da0073e9SAndroid Build Coastguard Worker script_model = torch.jit.script(model) 660*da0073e9SAndroid Build Coastguard Worker is_stream_s, a, b, c = script_model() 661*da0073e9SAndroid Build Coastguard Worker # Verify if the output is correct 662*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_stream_s) 663*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cat((a, b), 0), c) 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker # Save and load scripted model 666*da0073e9SAndroid Build Coastguard Worker load_model = self.getExportImportCopy(script_model) 667*da0073e9SAndroid Build Coastguard Worker is_stream_s, a_load, b_load, c_load = load_model() 668*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_stream_s) 669*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cat((a_load, b_load), 0), c_load) 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker # Make sure that cuda._exchange_device doesn't get DCE'ed 672*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "Cuda not available") 673*da0073e9SAndroid Build Coastguard Worker def test__exchange_device_op(self): 674*da0073e9SAndroid Build Coastguard Worker def fn(device: int, tensor): 675*da0073e9SAndroid Build Coastguard Worker torch.cuda._exchange_device(device) 676*da0073e9SAndroid Build Coastguard Worker return tensor.cos().relu() 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 679*da0073e9SAndroid Build Coastguard Worker # Just check the graph, don't run it. Otherwise, we'd need to 680*da0073e9SAndroid Build Coastguard Worker # run this test on a multi-gpu CI runner, which is overkill. 681*da0073e9SAndroid Build Coastguard Worker g = fn_s.graph 682*da0073e9SAndroid Build Coastguard Worker FileCheck().check("cuda::_exchange_device(").run(g) 683*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_inline(g) 684*da0073e9SAndroid Build Coastguard Worker FileCheck().check("cuda::_exchange_device(").run(g) 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker # Make sure that cuda._maybe_exchange_device doesn't get DCE'ed 687*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "Cuda not available") 688*da0073e9SAndroid Build Coastguard Worker def test__maybe_exchange_device_op(self): 689*da0073e9SAndroid Build Coastguard Worker def fn(device: int, tensor): 690*da0073e9SAndroid Build Coastguard Worker torch.cuda._maybe_exchange_device(device) 691*da0073e9SAndroid Build Coastguard Worker return tensor.cos().relu() 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 694*da0073e9SAndroid Build Coastguard Worker # Just check the graph, don't run it. Otherwise, we'd need to 695*da0073e9SAndroid Build Coastguard Worker # run this test on a multi-gpu CI runner, which is overkill. 696*da0073e9SAndroid Build Coastguard Worker g = fn_s.graph 697*da0073e9SAndroid Build Coastguard Worker FileCheck().check("cuda::_maybe_exchange_device(").run(g) 698*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_inline(g) 699*da0073e9SAndroid Build Coastguard Worker FileCheck().check("cuda::_maybe_exchange_device(").run(g) 700