# Owner(s): ["module: cuda"] import sys import textwrap import traceback from typing import List import torch import torch.cuda._sanitizer as csan from torch.cuda._sanitizer import DataPtr, EventId, StreamId from torch.testing._internal.common_utils import NoTest, run_tests, TEST_CUDA, TestCase if not TEST_CUDA: print("CUDA not available, skipping tests", file=sys.stderr) TestCase = NoTest # noqa: F811 class TestArgumentHandler(TestCase): def test_add(self): add_func = torch.ops.aten.add.Tensor a = torch.ones(5, 3, device="cuda") b = torch.randn(5, 3, device="cuda") argument_handler = csan.ArgumentHandler() argument_handler.parse_inputs(add_func._schema, (a, b), {}) c = torch.add(a, b) argument_handler.parse_outputs(c) self.assertEqual({a.data_ptr(), b.data_ptr()}, argument_handler.dataptrs_read) self.assertEqual({c.data_ptr()}, argument_handler.dataptrs_written) def test_cat(self): cat_func = torch.ops.aten.cat.default a = torch.ones(2, 4, 5, device="cuda") b = torch.zeros(2, 1, 5, device="cuda") c = torch.rand(2, 7, 5, device="cuda") argument_handler = csan.ArgumentHandler() argument_handler.parse_inputs(cat_func._schema, ([a, b, c], 1), {}) d = torch.cat((a, b, c), dim=1) argument_handler.parse_outputs(d) self.assertEqual( {a.data_ptr(), b.data_ptr(), c.data_ptr()}, argument_handler.dataptrs_read ) self.assertEqual({d.data_ptr()}, argument_handler.dataptrs_written) def test_split(self): split_func = torch.ops.aten.split.Tensor a = torch.arange(10, device="cuda").reshape(5, 2) argument_handler = csan.ArgumentHandler() argument_handler.parse_inputs(split_func._schema, (a, 2), {}) out = torch.split(a, 2) argument_handler.parse_outputs(out) outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()} self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) self.assertEqual(outputs, argument_handler.dataptrs_written) def test_inplace(self): add_inplace_func = torch.ops.aten.add_.Tensor a = torch.rand(4, 2, device="cuda") argument_handler = csan.ArgumentHandler() argument_handler.parse_inputs(add_inplace_func._schema, (a, 5), {}) a.add_(5) argument_handler.parse_outputs(a) self.assertEqual(set(), argument_handler.dataptrs_read) self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_written) def test_out(self): mul_out_func = torch.ops.aten.mul.out a = torch.arange(8, device="cuda") b = torch.empty(8, device="cuda") argument_handler = csan.ArgumentHandler() argument_handler.parse_inputs(mul_out_func._schema, (a, 3), {"out": b}) torch.mul(a, 3, out=b) argument_handler.parse_outputs(b) self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) self.assertEqual({b.data_ptr()}, argument_handler.dataptrs_written) def test_nonzero(self): nonzero_func = torch.ops.aten.nonzero.default a = torch.ones(5, 3, 2, device="cuda") argument_handler = csan.ArgumentHandler() argument_handler.parse_inputs(nonzero_func._schema, (a,), {"as_tuple": True}) out = torch.nonzero(a, as_tuple=True) argument_handler.parse_outputs(out) outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()} self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) self.assertEqual(outputs, argument_handler.dataptrs_written) def test_tensor_names(self): addr_func = torch.ops.aten.addr.default vec = torch.arange(1, 4, device="cuda") M = torch.zeros(3, 3, device="cuda") argument_handler = csan.ArgumentHandler() argument_handler.parse_inputs(addr_func._schema, (M, vec, vec), {}) out = torch.addr(M, vec, vec) argument_handler.parse_outputs(out) self.assertEqual( argument_handler.tensor_aliases, { M.data_ptr(): ["self"], vec.data_ptr(): ["vec1", "vec2"], out.data_ptr(): [], }, ) self.assertEqual({out.data_ptr()}, argument_handler.outputs) def tensor_id(i: int) -> DataPtr: return i def stream_id(i: int) -> StreamId: return 1000 + i def event_id(i: int) -> EventId: return 2000 + i class TestEventHandler(TestCase): def setUp(self): self.handler = csan.EventHandler() def kernel_launch( self, stream: StreamId, read_only: List[DataPtr] = None, read_write: List[DataPtr] = None, ) -> List[csan.SynchronizationError]: if read_only is None: read_only = [] if read_write is None: read_write = [] return self.handler._handle_kernel_launch( stream, read_only, read_write, {}, "", {k: [""] for k in read_only + read_write}, ) def assert_good_kernel_launch( self, stream: StreamId, read_only: List[DataPtr] = None, read_write: List[DataPtr] = None, ) -> None: self.assertEqual(self.kernel_launch(stream, read_only, read_write), []) def assert_bad_kernel_launch( self, number_of_errors: int, stream: StreamId, read_only: List[DataPtr] = None, read_write: List[DataPtr] = None, ) -> None: errors = self.kernel_launch(stream, read_only, read_write) self.assertEqual(len(errors), number_of_errors) def test_empty_kernel_launch(self): self.assert_good_kernel_launch(stream_id(0)) def test_simple_passing(self): self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) def test_simple_error(self): self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) def test_simple_sync(self): self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) self.handler._handle_event_record(event_id(0), stream_id(1)) self.handler._handle_event_wait(event_id(0), stream_id(2)) self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)]) def test_reads_check_last_write(self): # Tests that not only the first read operation checks if it is in conflict # with the last write operation, but all read operations do. self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) self.handler._handle_event_record(event_id(0), stream_id(1)) self.handler._handle_event_wait(event_id(0), stream_id(2)) self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) self.assert_bad_kernel_launch(1, stream_id(3), read_only=[tensor_id(1)]) def test_branch_sync(self): # Tests that two streams can read after both waiting for a third, but they # cannot write without further synchronization. self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) self.handler._handle_event_record(event_id(0), stream_id(1)) self.handler._handle_event_wait(event_id(0), stream_id(2)) self.handler._handle_event_wait(event_id(0), stream_id(3)) self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)]) self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) def test_chain_sync(self): iterations = 10 self.assert_good_kernel_launch(stream_id(0), read_only=[tensor_id(1)]) for i in range(iterations): self.handler._handle_event_record(event_id(i), stream_id(i)) self.handler._handle_event_wait(event_id(i), stream_id(i + 1)) self.assert_good_kernel_launch(stream_id(iterations), read_write=[tensor_id(1)]) def test_expired_record(self): self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) self.handler._handle_event_record(event_id(0), stream_id(1)) self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) self.handler._handle_event_wait(event_id(0), stream_id(2)) self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) def test_deleted_record(self): for should_delete, should_create in [ (True, True), (True, False), (False, True), ]: self.setUp() with self.subTest(should_delete=should_delete, should_create=should_create): self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) self.handler._handle_event_record(event_id(0), stream_id(1)) if should_delete: self.handler._handle_event_deletion(event_id(0)) if should_create: self.handler._handle_event_creation(event_id(0)) self.handler._handle_event_wait(event_id(0), stream_id(2)) self.assert_bad_kernel_launch( 1, stream_id(2), read_write=[tensor_id(1)] ) def test_all_reads_checked_failing(self): iterations = 10 for i in range(1, iterations): self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)]) self.handler._handle_event_record(event_id(i), stream_id(i)) for i in range(1, iterations): self.handler._handle_event_wait(event_id(i), stream_id(0)) self.assert_good_kernel_launch(stream_id(iterations), read_only=[tensor_id(1)]) self.handler._handle_event_record(event_id(iterations), stream_id(i)) # Does not synchronize with the last read. self.assert_bad_kernel_launch(1, stream_id(0), read_write=[tensor_id(1)]) def test_all_reads_checked_passing(self): iterations = 10 for i in range(1, iterations): self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)]) self.handler._handle_event_record(event_id(i), stream_id(i)) for i in range(1, iterations): self.handler._handle_event_wait(event_id(i), stream_id(0)) self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)]) def test_multiple_errors(self): iterations = 10 self.assert_good_kernel_launch( stream_id(0), read_write=[tensor_id(i) for i in range(iterations)] ) self.assert_bad_kernel_launch( iterations, stream_id(1), read_write=[tensor_id(i) for i in range(iterations)], ) def test_correct_state_merging(self): # Tests that after waiting for an event, a stream's state is indeed set # to the pointwise maximum of its old state and the recorded state. self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)]) self.handler._handle_event_record(event_id(1), stream_id(1)) self.handler._handle_event_record(event_id(2), stream_id(2)) self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)]) self.handler._handle_event_wait(event_id(1), stream_id(2)) self.handler._handle_event_wait(event_id(2), stream_id(1)) self.handler._handle_event_record(event_id(3), stream_id(2)) self.handler._handle_event_wait(event_id(3), stream_id(1)) self.assert_good_kernel_launch( stream_id(1), read_write=[tensor_id(1), tensor_id(2)] ) def test_record_override(self): self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(2)]) self.handler._handle_event_record(event_id(1), stream_id(1)) self.handler._handle_event_record(event_id(1), stream_id(2)) self.handler._handle_event_wait(event_id(1), stream_id(3)) self.assert_bad_kernel_launch(1, stream_id(3), read_write=[tensor_id(1)]) def test_multiple_wait(self): # Tests that a wait operation can be performed multiple times on the same event # by different streams. self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) self.handler._handle_event_record(event_id(1), stream_id(1)) self.handler._handle_event_wait(event_id(1), stream_id(2)) self.handler._handle_event_wait(event_id(1), stream_id(3)) self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)]) def test_device_synchronize(self): # Tests that a device synchronization does correctly cause all streams # to synchronize with each other. iterations = 10 for i in range(1, iterations): self.assert_good_kernel_launch(stream_id(i), read_write=[tensor_id(i)]) self.handler._handle_device_synchronization() self.assert_good_kernel_launch( stream_id(0), read_write=[tensor_id(i) for i in range(1, iterations)] ) def test_device_synchronization_expired(self): # Tests that a device synchronization is a one-time synchronization. self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) self.handler._handle_device_synchronization() self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) def test_new_stream_is_synchronized(self): # Tests that after synchronizing operations with the host, any newly created # stream is guaranteed to be synchronized with them as well. self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) self.handler._handle_device_synchronization() self.handler._handle_stream_creation(stream_id(2)) self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)]) def test_stream_synchronize(self): # Tests that a stream synchronization does correctly cause all streams to wait # for one specific stream, but does not synchronize all streams with each other. self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)]) self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)]) self.handler._handle_stream_synchronization(stream_id(0)) self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)]) self.assert_bad_kernel_launch(1, stream_id(4), read_only=[tensor_id(2)]) def test_event_synchronize(self): # Tests that an event synchronization does correctly cause all streams to wait # for a recorded event, but does not guarantee synchronization with the current # state of the stream that recorded the event. self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) self.handler._handle_event_record(event_id(1), stream_id(1)) self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)]) self.handler._handle_event_synchronization(event_id(1)) self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)]) self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(2)]) class TestMessages(TestCase): def setUp(self): self.handler = csan.EventHandler() def test_ensure_exists(self): ARG = 0 for func, out in [ ( self.handler._handle_event_deletion, f"Found Event with id: {ARG}, but no matching event " "creation in the trace. Backfilling the trace now. " "Perhaps the sanitizer was enabled after some torch operations?", ), ( self.handler._handle_memory_deallocation, f"Found tensor with pointer: {ARG}, but no matching tensor " "allocation in the trace. Backfilling the trace now. " "Perhaps the sanitizer was enabled after some torch operations?", ), ]: with self.subTest(func=func, out=out): with self.assertLogs() as captured: func(ARG) self.assertEqual(captured.records[0].getMessage(), out) def test_ensure_does_not_exist(self): ARG = 0 self.handler._handle_event_creation(ARG) self.handler._handle_stream_creation(ARG) for func, out in [ ( self.handler._handle_event_creation, "Found duplicate event creation in the trace for event with " f"id: {ARG}. Assuming the trace for event deletion wasn't caught " "and backfilling it now. " "Perhaps the sanitizer was enabled after some torch operations?", ), ( self.handler._handle_stream_creation, "Found duplicate Stream creation in the trace for Stream with " f"id: {ARG}. PyTorch Streams are only created once, so this " "trace entry is ignored.", ), ]: with self.subTest(func=func, out=out): with self.assertLogs() as captured: func(ARG) self.assertEqual(captured.records[0].getMessage(), out) def test_error_message(self): current_access = csan.Access( type=csan.AccessType.WRITE, seq_num=1, stream=stream_id(1), operator="schema", aliases=["b"], is_output=True, stack_trace=traceback.StackSummary.from_list( [("file", 0, "name", "trace a")] ), ) previous_access = csan.Access( type=csan.AccessType.READ, seq_num=2, stream=stream_id(0), operator="schema", aliases=["a"], is_output=False, stack_trace=traceback.StackSummary.from_list( [("file", 0, "name", "trace b")] ), ) error = csan.UnsynchronizedAccessError( data_ptr=tensor_id(1), allocation_stack_trace=traceback.StackSummary.from_list( [("file", 0, "name", "alloc")] ), current_access=current_access, previous_access=previous_access, ) self.assertEqual( str(error), textwrap.dedent( """\ ============================ CSAN detected a possible data race on tensor with data pointer 1 Access by stream 1001 during kernel: schema writing to argument(s) b, and to the output With stack trace: File "file", line 0, in name trace a Previous access by stream 1000 during kernel: schema reading from argument(s) a With stack trace: File "file", line 0, in name trace b Tensor was allocated with stack trace: File "file", line 0, in name alloc """ ), ) if __name__ == "__main__": run_tests()