1# Owner(s): ["oncall: jit"] 2 3import gc 4import os 5import sys 6import unittest 7from typing import NamedTuple 8 9import torch 10from torch.testing import FileCheck 11from torch.testing._internal.common_cuda import TEST_MULTIGPU 12from torch.testing._internal.common_utils import ( 13 NoTest, 14 skipCUDANonDefaultStreamIf, 15 skipIfRocm, 16 TEST_CUDA, 17) 18from torch.testing._internal.jit_utils import JitTestCase 19 20 21# Make the helper files in test/ importable 22pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 23sys.path.append(pytorch_test_dir) 24 25# If GPU is not available, then do not run the tests 26if not TEST_CUDA: 27 print("CUDA not available, skipping tests", file=sys.stderr) 28 JitTestCase = NoTest # noqa: F811 29 30TEST_LARGE_TENSOR = TEST_CUDA 31 32# If GPU is available, then initialize the cuda context and check 33# if there is memory available to allocate for LARGE Tensors. 34if TEST_CUDA: 35 torch.ones(1).cuda() # initialize cuda context 36 TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9 37 38if __name__ == "__main__": 39 raise RuntimeError( 40 "This test file is not meant to be run directly, use:\n\n" 41 "\tpython test/test_jit.py TESTNAME\n\n" 42 "instead." 43 ) 44 45 46class TestCUDA(JitTestCase): 47 """ 48 A suite of tests for the CUDA API in TorchScript. 49 """ 50 51 def tearDown(self): 52 gc.collect() 53 torch.cuda.empty_cache() 54 super().tearDown() 55 56 @skipIfRocm 57 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 58 def test_cuda_synchronize(self): 59 # Test device synchronization. 60 61 @torch.jit.script 62 def test_device_synchronize(): 63 prev_current_device_index = torch.cuda.current_device() 64 torch.cuda.synchronize() 65 torch.cuda.synchronize("cuda") 66 torch.cuda.synchronize("cuda:0") 67 torch.cuda.synchronize(0) 68 torch.cuda.synchronize(torch.device("cuda:1")) 69 after_current_device_index = torch.cuda.current_device() 70 71 # Check if the current device index is same as the device index before 72 # synchronizing the device. 73 return prev_current_device_index == after_current_device_index 74 75 @torch.jit.script 76 def test_multi_device_synchronize(): 77 torch.cuda.synchronize(torch.device("cuda:0")) 78 prev_current_device_index = torch.cuda.current_device() 79 torch.cuda.synchronize(1) 80 after_current_device_index = torch.cuda.current_device() 81 82 # Check if the current device index is same as the device index before 83 # synchronizing the device. 84 return prev_current_device_index == after_current_device_index 85 86 self.assertTrue(test_device_synchronize) 87 FileCheck().check("cuda::synchronize(").run(test_device_synchronize.graph) 88 self.assertTrue(test_multi_device_synchronize) 89 FileCheck().check("cuda::synchronize(").run(test_multi_device_synchronize.graph) 90 91 def test_stream_args(self): 92 # Test stream creation with default arguments 93 @torch.jit.script 94 def stream_default_args() -> bool: 95 s = torch.cuda.Stream() 96 return s.device_index() == torch.cuda.current_device() 97 98 @torch.jit.script 99 def stream_default_args_for_device() -> bool: 100 s = torch.cuda.Stream(priority=0) 101 return s.device_index() == torch.cuda.current_device() 102 103 @torch.jit.script 104 def stream_default_args_for_priority() -> bool: 105 d = torch.device("cuda:1") 106 s = torch.cuda.Stream(d) 107 return s.device_index() == 1 108 109 @torch.jit.script 110 def stream_args_all() -> bool: 111 d = torch.device("cuda:0") 112 s = torch.cuda.Stream(d, 0) 113 return s.device_index() == 0 114 115 self.assertTrue(stream_default_args) 116 self.assertTrue(stream_default_args_for_device) 117 self.assertTrue(stream_default_args_for_priority) 118 self.assertTrue(stream_args_all) 119 120 def test_event_args(self): 121 # Test Event creation with default arguments 122 @torch.jit.script 123 def event_default_args() -> bool: 124 e = torch.cuda.Event() 125 return e is not None 126 127 self.assertTrue(event_default_args) 128 129 @skipIfRocm 130 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 131 def test_current_stream(self): 132 # Test current stream on the device and check if the stream device index 133 # matches with the device ID 134 @torch.jit.script 135 def fn(): 136 device_index = torch.cuda.current_device() 137 device = torch.device("cuda:" + str(device_index)) 138 s0 = torch.cuda.current_stream(device) 139 s1 = torch.cuda.current_stream(torch.device("cuda:1")) 140 s2 = torch.cuda.current_stream(torch.device("cuda:0")) 141 142 return s0.device_index(), s1.device_index(), s2.device_index() 143 144 d0, d1, d2 = fn() 145 # By default, the current device ID is 0. 146 self.assertEqual(0, d0) 147 self.assertEqual(1, d1) 148 self.assertEqual(0, d2) 149 self.assertEqual(d0, d2) 150 151 # Test current_stream API by passing device ID as an argument and 152 # and check if the stream device index matches with the device ID 153 @torch.jit.script 154 def fn_with_device_index_args(): 155 device_index = torch.cuda.current_device() 156 s0 = torch.cuda.current_stream(device_index) 157 s1 = torch.cuda.current_stream(1) 158 s2 = torch.cuda.current_stream(0) 159 160 return s0.device_index(), s1.device_index(), s2.device_index() 161 162 d0, d1, d2 = fn_with_device_index_args() 163 # By default, the current device ID is 0. 164 self.assertEqual(0, d0) 165 self.assertEqual(1, d1) 166 self.assertEqual(0, d2) 167 self.assertEqual(d0, d2) 168 169 @skipIfRocm 170 @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") 171 @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") 172 @skipCUDANonDefaultStreamIf(True) 173 def test_streams_and_events(self): 174 # Test default_stream API by passing device ID as an argument and 175 # and check if the stream device index matches with the device ID 176 @torch.jit.script 177 def test_default_streams_with_device_index_args(): 178 s0 = torch.cuda.default_stream(0) 179 s1 = torch.cuda.default_stream(1) 180 return s0.device_index(), s1.device_index() 181 182 d0, d1 = test_default_streams_with_device_index_args() 183 184 self.assertEqual(d0, 0) 185 self.assertEqual(d1, 1) 186 187 # This test checks for the default stream ID is set to 0 on the device 188 @torch.jit.script 189 def test_default_streams(): 190 s0 = torch.cuda.default_stream(torch.device("cuda:0")) 191 s1 = torch.cuda.default_stream(torch.device("cuda:1")) 192 193 d = torch.device("cuda:1") 194 195 # Check the current stream id and default id are same 196 # on the current device. The current device id by default is 0 197 s2 = torch.cuda.current_stream(torch.device("cuda:0")) 198 check_s2 = s2.id() == s0.id() 199 check_d0 = torch.cuda.current_device() == s2.device_index() 200 201 # Set the current device to d1 and check if the stream 202 # has been set to the default stream on d1 203 with torch.cuda.device(d): 204 s3 = torch.cuda.current_stream(d) 205 check_s3 = s3.id() == s1.id() 206 check_d1 = torch.cuda.current_device() == s3.device_index() 207 208 # Check if the current device was reset to 0 209 is_device_d0 = torch.cuda.current_device() == s2.device_index() 210 211 return ( 212 s0.device_index(), 213 s1.device_index(), 214 check_s2, 215 check_s3, 216 check_d0, 217 check_d1, 218 is_device_d0, 219 ) 220 221 ( 222 d0, 223 d1, 224 check_s2, 225 check_s3, 226 check_d0, 227 check_d1, 228 is_device_d0, 229 ) = test_default_streams() 230 231 self.assertEqual(d0, 0) 232 self.assertEqual(d1, 1) 233 self.assertTrue(check_s2) 234 self.assertTrue(check_s3) 235 self.assertTrue(check_d0) 236 self.assertTrue(check_d1) 237 self.assertTrue(is_device_d0) 238 239 # This test checks if the Stream Context manager is a no op 240 # when the stream is none for `with torch.cuda.stream` 241 @torch.jit.script 242 def test_set_none_stream(): 243 device_index = torch.cuda.current_device() 244 device = torch.device("cuda:" + str(device_index)) 245 current_stream = torch.cuda.current_stream(device) 246 default_stream = torch.cuda.default_stream(device) 247 248 # When stream is none, check if this operation is a no-op 249 with torch.cuda.stream(None): 250 cur_device_index = torch.cuda.current_device() 251 is_device_index_same = cur_device_index == device_index 252 is_current_stream_same = ( 253 torch.cuda.current_stream(device).id() == current_stream.id() 254 ) 255 is_default_stream_same = ( 256 torch.cuda.default_stream(device).id() == default_stream.id() 257 ) 258 259 # Check if the device index, current stream and default streams have not changed 260 are_streams_same = ( 261 is_device_index_same 262 and is_current_stream_same 263 and is_default_stream_same 264 ) 265 return are_streams_same 266 267 self.assertTrue(test_set_none_stream()) 268 269 # This test checks if the Device Context manager is a no op 270 # when the device is none for `with torch.cuda.device` 271 @torch.jit.script 272 def test_set_device_none(): 273 device_index = torch.cuda.current_device() 274 # When device is none, check if this operation is a no-op 275 with torch.cuda.device(None): 276 # Check if the current device is the same 277 is_device_same = torch.cuda.current_device() == device_index 278 return is_device_same 279 280 self.assertTrue(test_set_device_none()) 281 282 # Check if a CUDA JIT stream is created 283 # on the current_device 284 @torch.jit.script 285 def test_simple_stream(): 286 device_index = torch.cuda.current_device() 287 s = torch.cuda.Stream() 288 return device_index == s.device_index() 289 290 self.assertTrue(test_simple_stream(), "Could not create Stream!") 291 292 # Class used to store results for the test: test_get_stream. 293 class Result(NamedTuple): 294 t1: torch.Tensor 295 t2: torch.Tensor 296 is_current_and_default_stream_same: bool 297 is_default_and_user_stream_not_same: bool 298 is_stream_set: bool 299 is_stream_reset: bool 300 default_stream_query: bool 301 default_stream_id: int 302 user_stream_id: int 303 304 # The test aims at checking different stream proporties. 305 @torch.jit.script 306 def test_get_stream(): 307 device_index = torch.cuda.current_device() 308 device = torch.device("cuda:" + str(device_index)) 309 current_stream = torch.cuda.current_stream(device) 310 default_stream = torch.cuda.default_stream(device) 311 user_stream = torch.cuda.Stream() 312 313 # Check if the current and default streams are the same on the device 314 is_current_and_default_stream_same = ( 315 current_stream.id() == default_stream.id() 316 ) 317 # Check if user stream and default stream are not the same on the device 318 is_default_and_user_stream_not_same = ( 319 default_stream.id() != user_stream.id() 320 ) 321 322 with torch.cuda.stream(user_stream): 323 is_stream_set = ( 324 torch.cuda.current_stream(device).id() == user_stream.id() 325 ) 326 327 # Check if the stream was reset to current_stream 328 is_stream_reset = ( 329 torch.cuda.current_stream(device).id() == current_stream.id() 330 ) 331 332 tensor1 = torch.rand(10000, 10000, device="cuda") 333 tensor2 = torch.mm(tensor1, tensor1).to("cuda") 334 default_stream.synchronize() 335 default_stream_query = default_stream.query() 336 337 # Capture all the results in the class Result 338 res = Result( 339 tensor1, 340 tensor2, 341 is_current_and_default_stream_same, 342 is_default_and_user_stream_not_same, 343 is_stream_set, 344 is_stream_reset, 345 default_stream_query, 346 default_stream.id(), 347 user_stream.id(), 348 ) 349 return res 350 351 result = test_get_stream() 352 353 self.assertEqual(torch.matmul(result.t1, result.t1), result.t2) 354 self.assertTrue(result.is_current_and_default_stream_same) 355 self.assertTrue(result.is_default_and_user_stream_not_same) 356 self.assertTrue(result.is_stream_set) 357 self.assertTrue(result.is_stream_reset) 358 self.assertTrue(result.default_stream_query) 359 self.assertEqual( 360 result.default_stream_id, 0 361 ) # Check if the default stream ID is always 0 362 self.assertNotEqual( 363 result.user_stream_id, 0 364 ) # Check if the user stream is always non zero 365 366 # Test the stream context manager. This test checks if the stream is switched 367 # to the user stream on using the stream context manager. 368 @torch.jit.script 369 def test_stream_context(): 370 device_index = torch.cuda.current_device() 371 device = torch.device("cuda:" + str(device_index)) 372 current_stream = torch.cuda.current_stream(device) 373 user_stream = torch.cuda.Stream() 374 A = torch.rand(1000, 1000, device="cuda") 375 376 with torch.cuda.stream(user_stream): 377 check = torch.cuda.current_stream(device).id() == user_stream.id() 378 B = torch.mm(A, A).to("cuda") 379 # Wait for B to be computed 380 user_stream.synchronize() 381 # Check if the stream has been reset on the current device 382 is_stream_reset = ( 383 torch.cuda.current_stream(device).id() == current_stream.id() 384 ) 385 386 return A, B, check, is_stream_reset 387 388 A, B, is_stream_set, is_stream_reset = test_stream_context() 389 self.assertEqual(torch.matmul(A, A), B) 390 self.assertTrue( 391 is_stream_set, "Error: Current stream was not set to user stream!" 392 ) 393 self.assertTrue( 394 is_stream_reset, "Error: The stream was not restored to previous stream!" 395 ) 396 397 # Test multiple nested streams. Check if the operations are computed as expected on the streams 398 # This test has been adapted from the eager mode tests available at test/test_cuda.py 399 @torch.jit.script 400 def test_multiple_stream(): 401 prev_device_index = torch.cuda.current_device() 402 device = torch.device("cuda:" + str(prev_device_index)) 403 prev_current_stream = torch.cuda.current_stream(device) 404 d1 = torch.device("cuda:0") 405 d2 = torch.device("cuda:1") 406 s1 = torch.cuda.Stream(d1, 0) 407 s2 = torch.cuda.Stream(d2, 0) 408 409 A = torch.rand(1000, 1000, device="cuda") 410 B = torch.rand(1000, 1000, device="cuda") 411 with torch.cuda.stream(s1): 412 C = torch.mm(A, A).to("cuda") 413 # Check if the stream and device have been set to s1 414 is_stream_s1 = torch.cuda.current_stream(d1).id() == s1.id() 415 is_device_s1 = torch.cuda.current_device() == s1.device_index() 416 with torch.cuda.stream(s2): 417 # Check if the stream and device have been set to s2 418 is_stream_s2 = torch.cuda.current_stream(d2).id() == s2.id() 419 is_device_s2 = torch.cuda.current_device() == s2.device_index() 420 D = torch.mm(B, B).to("cuda") 421 # Check if the stream and device have been set to s1 422 is_stream_s1_after = torch.cuda.current_stream(d1).id() == s1.id() 423 is_device_s1_after = torch.cuda.current_device() == s1.device_index() 424 # Wait for D to be computed 425 s2.synchronize() 426 # Wait for C to be computed on S1 427 s1.synchronize() 428 429 # Check if the stream and device has been restored to previous stream and device 430 is_device_current = torch.cuda.current_device() == prev_device_index 431 is_stream_current = ( 432 torch.cuda.current_stream(device).id() == prev_current_stream.id() 433 ) 434 435 check_stream = ( 436 is_stream_s1 437 and is_stream_s2 438 and is_stream_s1_after 439 and is_stream_current 440 ) 441 check_device = ( 442 is_device_s1 443 and is_device_s2 444 and is_device_s1_after 445 and is_device_current 446 ) 447 return A, B, C, D, check_stream, check_device 448 449 A, B, C, D, check_stream, check_device = test_multiple_stream() 450 451 self.assertEqual(torch.matmul(A, A), C) 452 self.assertEqual(torch.matmul(B, B), D) 453 self.assertTrue(check_stream) 454 self.assertTrue(check_device) 455 456 # Test multiple streams waiting on each other for the operations to be completed. 457 @torch.jit.script 458 def test_data_dependency_between_streams(): 459 device_index = torch.cuda.current_device() 460 device = torch.device("cuda:" + str(device_index)) 461 prev_current_stream = torch.cuda.current_stream(device) 462 d = torch.device("cuda:0") 463 s1 = torch.cuda.Stream(d, 0) 464 s2 = torch.cuda.Stream(d, 0) 465 event = torch.cuda.Event(False, False, False) 466 467 A = torch.rand(1000, 1000, device="cuda") 468 with torch.cuda.stream(s1): 469 is_stream_s1 = torch.cuda.current_stream(device).id() == s1.id() 470 B = torch.mm(A, A).to("cuda") 471 s1.record_event(event) 472 # Check if the current_stream is reset 473 is_current_stream_1 = ( 474 torch.cuda.current_stream(device).id() == prev_current_stream.id() 475 ) 476 # Wait for ops on s1 to be computed 477 s2.wait_event(event) 478 with torch.cuda.stream(s2): 479 is_stream_s2 = torch.cuda.current_stream(device).id() == s2.id() 480 C = torch.mm(B, B).to("cuda") 481 # Wait for C to be computed 482 s2.synchronize() 483 # Check if the current_stream is reset 484 is_current_stream_2 = ( 485 torch.cuda.current_stream(device).id() == prev_current_stream.id() 486 ) 487 488 check_stream = ( 489 is_current_stream_1 490 and is_current_stream_2 491 and is_stream_s1 492 and is_stream_s2 493 ) 494 return A, B, C, check_stream 495 496 A, B, C, check_stream = test_data_dependency_between_streams() 497 self.assertEqual(torch.matmul(A, A), B) 498 self.assertEqual(torch.matmul(B, B), C) 499 self.assertTrue(check_stream) 500 501 # Test a simple CUDA event. Test if the CUDA event was created successfully 502 @torch.jit.script 503 def test_simple_event(): 504 e = torch.cuda.Event(True, False, False) 505 return e is not None 506 507 self.assertTrue(test_simple_event(), "Could not create CUDA Event!") 508 509 # Record the CUDA event for operation torch.mm on the current stream 510 # and then test if the elapsed time is greater than 0. This test is also 511 # an adaption from eager mdoe CUDA tests available at test/test_cuda.py 512 @torch.jit.script 513 def test_event(): 514 device_index = torch.cuda.current_device() 515 device = torch.device("cuda:" + str(device_index)) 516 stream = torch.cuda.current_stream(device) 517 event = torch.cuda.Event(True, False, False) 518 is_true_event_query = event.query() 519 start_event = torch.cuda.Event(True, False, False) 520 stream.record_event(start_event) 521 tensor1 = torch.rand(1000000000, 1000000000, device="cuda") 522 tensor2 = torch.mm(tensor1, tensor1).to("cuda") 523 stream.record_event(event) 524 event.synchronize() 525 is_again_true_event_query = event.query() 526 527 if not (is_true_event_query and is_again_true_event_query): 528 return -1.0 529 return start_event.elapsed_time(event) 530 531 self.assertGreater(test_event(), 0) 532 533 # Check for stream synchronization , when a large tensor multiplication is 534 # computed on the stream. The stream.query should be true once the synchroniztion is done 535 @torch.jit.script 536 def test_stream_synchronize() -> float: 537 device_index = torch.cuda.current_device() 538 s = torch.cuda.Stream() 539 e_tik = torch.cuda.Event(True, False, False) 540 e_tok = torch.cuda.Event(True, False, False) 541 542 e_tik.record(s) 543 tensor1 = torch.rand(1000000000, 1000000000, device="cuda") 544 with torch.cuda.stream(s): 545 tensor2 = torch.mm(tensor1, tensor1).to("cuda") 546 s.synchronize() 547 e_tok.record(s) 548 e_tok.synchronize() 549 550 if not s.query(): 551 return -1.0 552 553 # not necessary to check e_tik and e_tok, as elapsed_time would throw 554 # exception if otherwise. 555 return e_tik.elapsed_time(e_tok) 556 557 self.assertGreater(test_stream_synchronize(), 0) 558 559 # Test event synchronization for the event that records a stream doing 560 # a large tensor multiplication. Check if the elapsed time is greater than 0 561 # and the stream.query evaluates to true. 562 @torch.jit.script 563 def test_event_synchronize() -> float: 564 s = torch.cuda.Stream() 565 e_tik = torch.cuda.Event(True, False, False) 566 e_tok = torch.cuda.Event(True, False, False) 567 568 e_tik.record(s) 569 tensor1 = torch.rand(1000000000, 1000000000, device="cuda") 570 with torch.cuda.stream(s): 571 tensor = torch.mm(tensor1, tensor1).to("cuda") 572 s.record_event(e_tok) 573 e_tok.synchronize() 574 s.synchronize() 575 576 if not s.query(): 577 return -1.0 578 579 # not necessary to check e_tik and e_tok, as elapsed_time would throw 580 # exception if otherwise. 581 return e_tik.elapsed_time(e_tok) 582 583 self.assertGreater(test_event_synchronize(), 0) 584 585 # Test for event wait. Check if event waits for the all the operations on 586 # the stream to be done. Check for synchronizations and query on the streams 587 # and events. This test is adapted from eager mode tests for CUDA. Please refer 588 # test/test_cuda.py 589 @torch.jit.script 590 def test_event_wait() -> float: 591 device_index = torch.cuda.current_device() 592 device = torch.device("cuda:" + str(device_index)) 593 s0 = torch.cuda.current_stream(device) 594 s1 = torch.cuda.Stream() 595 e_tik = torch.cuda.Event(True, True, False) 596 e_tok = torch.cuda.Event(True, True, False) 597 598 e_tik.record(s0) 599 tensor1 = torch.rand(1000000000, 1000000000, device="cuda") 600 with torch.cuda.stream(s0): 601 tensor2 = torch.mm(tensor1, tensor1).cuda() 602 e_sync = torch.cuda.Event(True, False, False) 603 e_sync.record(torch.cuda.current_stream(device)) 604 e_sync.wait(s1) 605 with torch.cuda.stream(s1): 606 tensor3 = torch.rand(1000000000, 1000000000, device="cuda") 607 tensor4 = torch.mm(tensor3, tensor3).cuda() 608 s1.synchronize() 609 e_tok.record(torch.cuda.current_stream(device)) 610 e_tok.synchronize() 611 s0.synchronize() 612 613 if not s0.query() or not s1.query() or not e_sync.query(): 614 return -1.0 615 616 # not necessary to check e_tik and e_tok, as elapsed_time would throw 617 # exception if otherwise. 618 return e_tik.elapsed_time(e_tok) 619 620 self.assertGreater(test_event_wait(), 0) 621 622 # Test for stream wait_event. Checks if the stream waits on the event 623 @torch.jit.script 624 def test_wait_event(): 625 d1 = torch.device("cuda:1") 626 627 with torch.cuda.device(d1): 628 s0 = torch.cuda.current_stream(d1) 629 tensor1 = torch.rand(1000000000, 1000000000, device="cuda") 630 tensor2 = torch.mm(tensor1, tensor1).to("cuda") 631 e0 = torch.cuda.Event(False, False, False) 632 s0.record_event(e0) 633 634 s1 = torch.cuda.current_stream(torch.device("cuda:0")) 635 s1.wait_event(e0) 636 s1.synchronize() 637 638 return e0.query() and s0.query() and s1.query() 639 640 self.assertTrue(test_wait_event()) 641 642 # Test if a scripted module with cuda streams can be saved, loaded and executed 643 def test_save_load(self): 644 class Model(torch.nn.Module): 645 def forward(self): 646 s = torch.cuda.Stream() 647 a = torch.rand(3, 4, device="cuda") 648 b = torch.rand(3, 4, device="cuda") 649 650 with torch.cuda.stream(s): 651 is_stream_s = torch.cuda.current_stream(s.device).id() == s.id() 652 c = torch.cat((a, b), 0).cuda() 653 s.synchronize() 654 return is_stream_s, a, b, c 655 656 model = Model() 657 658 # Script the model and save 659 script_model = torch.jit.script(model) 660 is_stream_s, a, b, c = script_model() 661 # Verify if the output is correct 662 self.assertTrue(is_stream_s) 663 self.assertEqual(torch.cat((a, b), 0), c) 664 665 # Save and load scripted model 666 load_model = self.getExportImportCopy(script_model) 667 is_stream_s, a_load, b_load, c_load = load_model() 668 self.assertTrue(is_stream_s) 669 self.assertEqual(torch.cat((a_load, b_load), 0), c_load) 670 671 # Make sure that cuda._exchange_device doesn't get DCE'ed 672 @unittest.skipIf(not TEST_CUDA, "Cuda not available") 673 def test__exchange_device_op(self): 674 def fn(device: int, tensor): 675 torch.cuda._exchange_device(device) 676 return tensor.cos().relu() 677 678 fn_s = torch.jit.script(fn) 679 # Just check the graph, don't run it. Otherwise, we'd need to 680 # run this test on a multi-gpu CI runner, which is overkill. 681 g = fn_s.graph 682 FileCheck().check("cuda::_exchange_device(").run(g) 683 torch._C._jit_pass_inline(g) 684 FileCheck().check("cuda::_exchange_device(").run(g) 685 686 # Make sure that cuda._maybe_exchange_device doesn't get DCE'ed 687 @unittest.skipIf(not TEST_CUDA, "Cuda not available") 688 def test__maybe_exchange_device_op(self): 689 def fn(device: int, tensor): 690 torch.cuda._maybe_exchange_device(device) 691 return tensor.cos().relu() 692 693 fn_s = torch.jit.script(fn) 694 # Just check the graph, don't run it. Otherwise, we'd need to 695 # run this test on a multi-gpu CI runner, which is overkill. 696 g = fn_s.graph 697 FileCheck().check("cuda::_maybe_exchange_device(").run(g) 698 torch._C._jit_pass_inline(g) 699 FileCheck().check("cuda::_maybe_exchange_device(").run(g) 700