1# Owner(s): ["module: cuda"] 2 3import sys 4import textwrap 5import traceback 6from typing import List 7 8import torch 9import torch.cuda._sanitizer as csan 10from torch.cuda._sanitizer import DataPtr, EventId, StreamId 11from torch.testing._internal.common_utils import NoTest, run_tests, TEST_CUDA, TestCase 12 13 14if not TEST_CUDA: 15 print("CUDA not available, skipping tests", file=sys.stderr) 16 TestCase = NoTest # noqa: F811 17 18 19class TestArgumentHandler(TestCase): 20 def test_add(self): 21 add_func = torch.ops.aten.add.Tensor 22 a = torch.ones(5, 3, device="cuda") 23 b = torch.randn(5, 3, device="cuda") 24 25 argument_handler = csan.ArgumentHandler() 26 argument_handler.parse_inputs(add_func._schema, (a, b), {}) 27 c = torch.add(a, b) 28 argument_handler.parse_outputs(c) 29 30 self.assertEqual({a.data_ptr(), b.data_ptr()}, argument_handler.dataptrs_read) 31 self.assertEqual({c.data_ptr()}, argument_handler.dataptrs_written) 32 33 def test_cat(self): 34 cat_func = torch.ops.aten.cat.default 35 a = torch.ones(2, 4, 5, device="cuda") 36 b = torch.zeros(2, 1, 5, device="cuda") 37 c = torch.rand(2, 7, 5, device="cuda") 38 39 argument_handler = csan.ArgumentHandler() 40 argument_handler.parse_inputs(cat_func._schema, ([a, b, c], 1), {}) 41 d = torch.cat((a, b, c), dim=1) 42 argument_handler.parse_outputs(d) 43 44 self.assertEqual( 45 {a.data_ptr(), b.data_ptr(), c.data_ptr()}, argument_handler.dataptrs_read 46 ) 47 self.assertEqual({d.data_ptr()}, argument_handler.dataptrs_written) 48 49 def test_split(self): 50 split_func = torch.ops.aten.split.Tensor 51 a = torch.arange(10, device="cuda").reshape(5, 2) 52 53 argument_handler = csan.ArgumentHandler() 54 argument_handler.parse_inputs(split_func._schema, (a, 2), {}) 55 out = torch.split(a, 2) 56 argument_handler.parse_outputs(out) 57 58 outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()} 59 self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) 60 self.assertEqual(outputs, argument_handler.dataptrs_written) 61 62 def test_inplace(self): 63 add_inplace_func = torch.ops.aten.add_.Tensor 64 a = torch.rand(4, 2, device="cuda") 65 66 argument_handler = csan.ArgumentHandler() 67 argument_handler.parse_inputs(add_inplace_func._schema, (a, 5), {}) 68 a.add_(5) 69 argument_handler.parse_outputs(a) 70 71 self.assertEqual(set(), argument_handler.dataptrs_read) 72 self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_written) 73 74 def test_out(self): 75 mul_out_func = torch.ops.aten.mul.out 76 a = torch.arange(8, device="cuda") 77 b = torch.empty(8, device="cuda") 78 79 argument_handler = csan.ArgumentHandler() 80 argument_handler.parse_inputs(mul_out_func._schema, (a, 3), {"out": b}) 81 torch.mul(a, 3, out=b) 82 argument_handler.parse_outputs(b) 83 84 self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) 85 self.assertEqual({b.data_ptr()}, argument_handler.dataptrs_written) 86 87 def test_nonzero(self): 88 nonzero_func = torch.ops.aten.nonzero.default 89 a = torch.ones(5, 3, 2, device="cuda") 90 91 argument_handler = csan.ArgumentHandler() 92 argument_handler.parse_inputs(nonzero_func._schema, (a,), {"as_tuple": True}) 93 out = torch.nonzero(a, as_tuple=True) 94 argument_handler.parse_outputs(out) 95 96 outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()} 97 self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) 98 self.assertEqual(outputs, argument_handler.dataptrs_written) 99 100 def test_tensor_names(self): 101 addr_func = torch.ops.aten.addr.default 102 vec = torch.arange(1, 4, device="cuda") 103 M = torch.zeros(3, 3, device="cuda") 104 105 argument_handler = csan.ArgumentHandler() 106 argument_handler.parse_inputs(addr_func._schema, (M, vec, vec), {}) 107 out = torch.addr(M, vec, vec) 108 argument_handler.parse_outputs(out) 109 110 self.assertEqual( 111 argument_handler.tensor_aliases, 112 { 113 M.data_ptr(): ["self"], 114 vec.data_ptr(): ["vec1", "vec2"], 115 out.data_ptr(): [], 116 }, 117 ) 118 self.assertEqual({out.data_ptr()}, argument_handler.outputs) 119 120 121def tensor_id(i: int) -> DataPtr: 122 return i 123 124 125def stream_id(i: int) -> StreamId: 126 return 1000 + i 127 128 129def event_id(i: int) -> EventId: 130 return 2000 + i 131 132 133class TestEventHandler(TestCase): 134 def setUp(self): 135 self.handler = csan.EventHandler() 136 137 def kernel_launch( 138 self, 139 stream: StreamId, 140 read_only: List[DataPtr] = None, 141 read_write: List[DataPtr] = None, 142 ) -> List[csan.SynchronizationError]: 143 if read_only is None: 144 read_only = [] 145 if read_write is None: 146 read_write = [] 147 return self.handler._handle_kernel_launch( 148 stream, 149 read_only, 150 read_write, 151 {}, 152 "", 153 {k: [""] for k in read_only + read_write}, 154 ) 155 156 def assert_good_kernel_launch( 157 self, 158 stream: StreamId, 159 read_only: List[DataPtr] = None, 160 read_write: List[DataPtr] = None, 161 ) -> None: 162 self.assertEqual(self.kernel_launch(stream, read_only, read_write), []) 163 164 def assert_bad_kernel_launch( 165 self, 166 number_of_errors: int, 167 stream: StreamId, 168 read_only: List[DataPtr] = None, 169 read_write: List[DataPtr] = None, 170 ) -> None: 171 errors = self.kernel_launch(stream, read_only, read_write) 172 self.assertEqual(len(errors), number_of_errors) 173 174 def test_empty_kernel_launch(self): 175 self.assert_good_kernel_launch(stream_id(0)) 176 177 def test_simple_passing(self): 178 self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 179 self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) 180 181 def test_simple_error(self): 182 self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 183 self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) 184 185 def test_simple_sync(self): 186 self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 187 self.handler._handle_event_record(event_id(0), stream_id(1)) 188 self.handler._handle_event_wait(event_id(0), stream_id(2)) 189 self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)]) 190 191 def test_reads_check_last_write(self): 192 # Tests that not only the first read operation checks if it is in conflict 193 # with the last write operation, but all read operations do. 194 195 self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 196 self.handler._handle_event_record(event_id(0), stream_id(1)) 197 self.handler._handle_event_wait(event_id(0), stream_id(2)) 198 self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) 199 200 self.assert_bad_kernel_launch(1, stream_id(3), read_only=[tensor_id(1)]) 201 202 def test_branch_sync(self): 203 # Tests that two streams can read after both waiting for a third, but they 204 # cannot write without further synchronization. 205 206 self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 207 self.handler._handle_event_record(event_id(0), stream_id(1)) 208 self.handler._handle_event_wait(event_id(0), stream_id(2)) 209 self.handler._handle_event_wait(event_id(0), stream_id(3)) 210 self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) 211 self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)]) 212 213 self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) 214 215 def test_chain_sync(self): 216 iterations = 10 217 218 self.assert_good_kernel_launch(stream_id(0), read_only=[tensor_id(1)]) 219 for i in range(iterations): 220 self.handler._handle_event_record(event_id(i), stream_id(i)) 221 self.handler._handle_event_wait(event_id(i), stream_id(i + 1)) 222 self.assert_good_kernel_launch(stream_id(iterations), read_write=[tensor_id(1)]) 223 224 def test_expired_record(self): 225 self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 226 self.handler._handle_event_record(event_id(0), stream_id(1)) 227 self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 228 self.handler._handle_event_wait(event_id(0), stream_id(2)) 229 230 self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) 231 232 def test_deleted_record(self): 233 for should_delete, should_create in [ 234 (True, True), 235 (True, False), 236 (False, True), 237 ]: 238 self.setUp() 239 with self.subTest(should_delete=should_delete, should_create=should_create): 240 self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 241 self.handler._handle_event_record(event_id(0), stream_id(1)) 242 243 if should_delete: 244 self.handler._handle_event_deletion(event_id(0)) 245 if should_create: 246 self.handler._handle_event_creation(event_id(0)) 247 248 self.handler._handle_event_wait(event_id(0), stream_id(2)) 249 self.assert_bad_kernel_launch( 250 1, stream_id(2), read_write=[tensor_id(1)] 251 ) 252 253 def test_all_reads_checked_failing(self): 254 iterations = 10 255 for i in range(1, iterations): 256 self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)]) 257 self.handler._handle_event_record(event_id(i), stream_id(i)) 258 259 for i in range(1, iterations): 260 self.handler._handle_event_wait(event_id(i), stream_id(0)) 261 262 self.assert_good_kernel_launch(stream_id(iterations), read_only=[tensor_id(1)]) 263 self.handler._handle_event_record(event_id(iterations), stream_id(i)) 264 265 # Does not synchronize with the last read. 266 self.assert_bad_kernel_launch(1, stream_id(0), read_write=[tensor_id(1)]) 267 268 def test_all_reads_checked_passing(self): 269 iterations = 10 270 for i in range(1, iterations): 271 self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)]) 272 self.handler._handle_event_record(event_id(i), stream_id(i)) 273 274 for i in range(1, iterations): 275 self.handler._handle_event_wait(event_id(i), stream_id(0)) 276 277 self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)]) 278 279 def test_multiple_errors(self): 280 iterations = 10 281 self.assert_good_kernel_launch( 282 stream_id(0), read_write=[tensor_id(i) for i in range(iterations)] 283 ) 284 self.assert_bad_kernel_launch( 285 iterations, 286 stream_id(1), 287 read_write=[tensor_id(i) for i in range(iterations)], 288 ) 289 290 def test_correct_state_merging(self): 291 # Tests that after waiting for an event, a stream's state is indeed set 292 # to the pointwise maximum of its old state and the recorded state. 293 294 self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 295 self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)]) 296 self.handler._handle_event_record(event_id(1), stream_id(1)) 297 self.handler._handle_event_record(event_id(2), stream_id(2)) 298 299 self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 300 self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)]) 301 self.handler._handle_event_wait(event_id(1), stream_id(2)) 302 self.handler._handle_event_wait(event_id(2), stream_id(1)) 303 304 self.handler._handle_event_record(event_id(3), stream_id(2)) 305 self.handler._handle_event_wait(event_id(3), stream_id(1)) 306 self.assert_good_kernel_launch( 307 stream_id(1), read_write=[tensor_id(1), tensor_id(2)] 308 ) 309 310 def test_record_override(self): 311 self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 312 self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(2)]) 313 self.handler._handle_event_record(event_id(1), stream_id(1)) 314 self.handler._handle_event_record(event_id(1), stream_id(2)) 315 316 self.handler._handle_event_wait(event_id(1), stream_id(3)) 317 self.assert_bad_kernel_launch(1, stream_id(3), read_write=[tensor_id(1)]) 318 319 def test_multiple_wait(self): 320 # Tests that a wait operation can be performed multiple times on the same event 321 # by different streams. 322 323 self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 324 self.handler._handle_event_record(event_id(1), stream_id(1)) 325 self.handler._handle_event_wait(event_id(1), stream_id(2)) 326 self.handler._handle_event_wait(event_id(1), stream_id(3)) 327 328 self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) 329 self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)]) 330 331 def test_device_synchronize(self): 332 # Tests that a device synchronization does correctly cause all streams 333 # to synchronize with each other. 334 335 iterations = 10 336 for i in range(1, iterations): 337 self.assert_good_kernel_launch(stream_id(i), read_write=[tensor_id(i)]) 338 339 self.handler._handle_device_synchronization() 340 self.assert_good_kernel_launch( 341 stream_id(0), read_write=[tensor_id(i) for i in range(1, iterations)] 342 ) 343 344 def test_device_synchronization_expired(self): 345 # Tests that a device synchronization is a one-time synchronization. 346 self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 347 self.handler._handle_device_synchronization() 348 self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 349 350 self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) 351 352 def test_new_stream_is_synchronized(self): 353 # Tests that after synchronizing operations with the host, any newly created 354 # stream is guaranteed to be synchronized with them as well. 355 356 self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 357 self.handler._handle_device_synchronization() 358 self.handler._handle_stream_creation(stream_id(2)) 359 self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)]) 360 361 def test_stream_synchronize(self): 362 # Tests that a stream synchronization does correctly cause all streams to wait 363 # for one specific stream, but does not synchronize all streams with each other. 364 365 self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)]) 366 self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)]) 367 self.handler._handle_stream_synchronization(stream_id(0)) 368 369 self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) 370 self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)]) 371 self.assert_bad_kernel_launch(1, stream_id(4), read_only=[tensor_id(2)]) 372 373 def test_event_synchronize(self): 374 # Tests that an event synchronization does correctly cause all streams to wait 375 # for a recorded event, but does not guarantee synchronization with the current 376 # state of the stream that recorded the event. 377 378 self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 379 self.handler._handle_event_record(event_id(1), stream_id(1)) 380 self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)]) 381 382 self.handler._handle_event_synchronization(event_id(1)) 383 self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)]) 384 self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(2)]) 385 386 387class TestMessages(TestCase): 388 def setUp(self): 389 self.handler = csan.EventHandler() 390 391 def test_ensure_exists(self): 392 ARG = 0 393 for func, out in [ 394 ( 395 self.handler._handle_event_deletion, 396 f"Found Event with id: {ARG}, but no matching event " 397 "creation in the trace. Backfilling the trace now. " 398 "Perhaps the sanitizer was enabled after some torch operations?", 399 ), 400 ( 401 self.handler._handle_memory_deallocation, 402 f"Found tensor with pointer: {ARG}, but no matching tensor " 403 "allocation in the trace. Backfilling the trace now. " 404 "Perhaps the sanitizer was enabled after some torch operations?", 405 ), 406 ]: 407 with self.subTest(func=func, out=out): 408 with self.assertLogs() as captured: 409 func(ARG) 410 self.assertEqual(captured.records[0].getMessage(), out) 411 412 def test_ensure_does_not_exist(self): 413 ARG = 0 414 self.handler._handle_event_creation(ARG) 415 self.handler._handle_stream_creation(ARG) 416 for func, out in [ 417 ( 418 self.handler._handle_event_creation, 419 "Found duplicate event creation in the trace for event with " 420 f"id: {ARG}. Assuming the trace for event deletion wasn't caught " 421 "and backfilling it now. " 422 "Perhaps the sanitizer was enabled after some torch operations?", 423 ), 424 ( 425 self.handler._handle_stream_creation, 426 "Found duplicate Stream creation in the trace for Stream with " 427 f"id: {ARG}. PyTorch Streams are only created once, so this " 428 "trace entry is ignored.", 429 ), 430 ]: 431 with self.subTest(func=func, out=out): 432 with self.assertLogs() as captured: 433 func(ARG) 434 self.assertEqual(captured.records[0].getMessage(), out) 435 436 def test_error_message(self): 437 current_access = csan.Access( 438 type=csan.AccessType.WRITE, 439 seq_num=1, 440 stream=stream_id(1), 441 operator="schema", 442 aliases=["b"], 443 is_output=True, 444 stack_trace=traceback.StackSummary.from_list( 445 [("file", 0, "name", "trace a")] 446 ), 447 ) 448 previous_access = csan.Access( 449 type=csan.AccessType.READ, 450 seq_num=2, 451 stream=stream_id(0), 452 operator="schema", 453 aliases=["a"], 454 is_output=False, 455 stack_trace=traceback.StackSummary.from_list( 456 [("file", 0, "name", "trace b")] 457 ), 458 ) 459 error = csan.UnsynchronizedAccessError( 460 data_ptr=tensor_id(1), 461 allocation_stack_trace=traceback.StackSummary.from_list( 462 [("file", 0, "name", "alloc")] 463 ), 464 current_access=current_access, 465 previous_access=previous_access, 466 ) 467 self.assertEqual( 468 str(error), 469 textwrap.dedent( 470 """\ 471 ============================ 472 CSAN detected a possible data race on tensor with data pointer 1 473 Access by stream 1001 during kernel: 474 schema 475 writing to argument(s) b, and to the output 476 With stack trace: 477 File "file", line 0, in name 478 trace a 479 480 Previous access by stream 1000 during kernel: 481 schema 482 reading from argument(s) a 483 With stack trace: 484 File "file", line 0, in name 485 trace b 486 487 Tensor was allocated with stack trace: 488 File "file", line 0, in name 489 alloc 490 """ 491 ), 492 ) 493 494 495if __name__ == "__main__": 496 run_tests() 497