1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: cuda"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport sys 4*da0073e9SAndroid Build Coastguard Workerimport textwrap 5*da0073e9SAndroid Build Coastguard Workerimport traceback 6*da0073e9SAndroid Build Coastguard Workerfrom typing import List 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerimport torch.cuda._sanitizer as csan 10*da0073e9SAndroid Build Coastguard Workerfrom torch.cuda._sanitizer import DataPtr, EventId, StreamId 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import NoTest, run_tests, TEST_CUDA, TestCase 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerif not TEST_CUDA: 15*da0073e9SAndroid Build Coastguard Worker print("CUDA not available, skipping tests", file=sys.stderr) 16*da0073e9SAndroid Build Coastguard Worker TestCase = NoTest # noqa: F811 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerclass TestArgumentHandler(TestCase): 20*da0073e9SAndroid Build Coastguard Worker def test_add(self): 21*da0073e9SAndroid Build Coastguard Worker add_func = torch.ops.aten.add.Tensor 22*da0073e9SAndroid Build Coastguard Worker a = torch.ones(5, 3, device="cuda") 23*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 3, device="cuda") 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker argument_handler = csan.ArgumentHandler() 26*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_inputs(add_func._schema, (a, b), {}) 27*da0073e9SAndroid Build Coastguard Worker c = torch.add(a, b) 28*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_outputs(c) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker self.assertEqual({a.data_ptr(), b.data_ptr()}, argument_handler.dataptrs_read) 31*da0073e9SAndroid Build Coastguard Worker self.assertEqual({c.data_ptr()}, argument_handler.dataptrs_written) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker def test_cat(self): 34*da0073e9SAndroid Build Coastguard Worker cat_func = torch.ops.aten.cat.default 35*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, 4, 5, device="cuda") 36*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(2, 1, 5, device="cuda") 37*da0073e9SAndroid Build Coastguard Worker c = torch.rand(2, 7, 5, device="cuda") 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker argument_handler = csan.ArgumentHandler() 40*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_inputs(cat_func._schema, ([a, b, c], 1), {}) 41*da0073e9SAndroid Build Coastguard Worker d = torch.cat((a, b, c), dim=1) 42*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_outputs(d) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 45*da0073e9SAndroid Build Coastguard Worker {a.data_ptr(), b.data_ptr(), c.data_ptr()}, argument_handler.dataptrs_read 46*da0073e9SAndroid Build Coastguard Worker ) 47*da0073e9SAndroid Build Coastguard Worker self.assertEqual({d.data_ptr()}, argument_handler.dataptrs_written) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker def test_split(self): 50*da0073e9SAndroid Build Coastguard Worker split_func = torch.ops.aten.split.Tensor 51*da0073e9SAndroid Build Coastguard Worker a = torch.arange(10, device="cuda").reshape(5, 2) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker argument_handler = csan.ArgumentHandler() 54*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_inputs(split_func._schema, (a, 2), {}) 55*da0073e9SAndroid Build Coastguard Worker out = torch.split(a, 2) 56*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_outputs(out) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()} 59*da0073e9SAndroid Build Coastguard Worker self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) 60*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, argument_handler.dataptrs_written) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def test_inplace(self): 63*da0073e9SAndroid Build Coastguard Worker add_inplace_func = torch.ops.aten.add_.Tensor 64*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4, 2, device="cuda") 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker argument_handler = csan.ArgumentHandler() 67*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_inputs(add_inplace_func._schema, (a, 5), {}) 68*da0073e9SAndroid Build Coastguard Worker a.add_(5) 69*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_outputs(a) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker self.assertEqual(set(), argument_handler.dataptrs_read) 72*da0073e9SAndroid Build Coastguard Worker self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_written) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def test_out(self): 75*da0073e9SAndroid Build Coastguard Worker mul_out_func = torch.ops.aten.mul.out 76*da0073e9SAndroid Build Coastguard Worker a = torch.arange(8, device="cuda") 77*da0073e9SAndroid Build Coastguard Worker b = torch.empty(8, device="cuda") 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker argument_handler = csan.ArgumentHandler() 80*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_inputs(mul_out_func._schema, (a, 3), {"out": b}) 81*da0073e9SAndroid Build Coastguard Worker torch.mul(a, 3, out=b) 82*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_outputs(b) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) 85*da0073e9SAndroid Build Coastguard Worker self.assertEqual({b.data_ptr()}, argument_handler.dataptrs_written) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def test_nonzero(self): 88*da0073e9SAndroid Build Coastguard Worker nonzero_func = torch.ops.aten.nonzero.default 89*da0073e9SAndroid Build Coastguard Worker a = torch.ones(5, 3, 2, device="cuda") 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker argument_handler = csan.ArgumentHandler() 92*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_inputs(nonzero_func._schema, (a,), {"as_tuple": True}) 93*da0073e9SAndroid Build Coastguard Worker out = torch.nonzero(a, as_tuple=True) 94*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_outputs(out) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()} 97*da0073e9SAndroid Build Coastguard Worker self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) 98*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, argument_handler.dataptrs_written) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker def test_tensor_names(self): 101*da0073e9SAndroid Build Coastguard Worker addr_func = torch.ops.aten.addr.default 102*da0073e9SAndroid Build Coastguard Worker vec = torch.arange(1, 4, device="cuda") 103*da0073e9SAndroid Build Coastguard Worker M = torch.zeros(3, 3, device="cuda") 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker argument_handler = csan.ArgumentHandler() 106*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_inputs(addr_func._schema, (M, vec, vec), {}) 107*da0073e9SAndroid Build Coastguard Worker out = torch.addr(M, vec, vec) 108*da0073e9SAndroid Build Coastguard Worker argument_handler.parse_outputs(out) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 111*da0073e9SAndroid Build Coastguard Worker argument_handler.tensor_aliases, 112*da0073e9SAndroid Build Coastguard Worker { 113*da0073e9SAndroid Build Coastguard Worker M.data_ptr(): ["self"], 114*da0073e9SAndroid Build Coastguard Worker vec.data_ptr(): ["vec1", "vec2"], 115*da0073e9SAndroid Build Coastguard Worker out.data_ptr(): [], 116*da0073e9SAndroid Build Coastguard Worker }, 117*da0073e9SAndroid Build Coastguard Worker ) 118*da0073e9SAndroid Build Coastguard Worker self.assertEqual({out.data_ptr()}, argument_handler.outputs) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Workerdef tensor_id(i: int) -> DataPtr: 122*da0073e9SAndroid Build Coastguard Worker return i 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Workerdef stream_id(i: int) -> StreamId: 126*da0073e9SAndroid Build Coastguard Worker return 1000 + i 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Workerdef event_id(i: int) -> EventId: 130*da0073e9SAndroid Build Coastguard Worker return 2000 + i 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Workerclass TestEventHandler(TestCase): 134*da0073e9SAndroid Build Coastguard Worker def setUp(self): 135*da0073e9SAndroid Build Coastguard Worker self.handler = csan.EventHandler() 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker def kernel_launch( 138*da0073e9SAndroid Build Coastguard Worker self, 139*da0073e9SAndroid Build Coastguard Worker stream: StreamId, 140*da0073e9SAndroid Build Coastguard Worker read_only: List[DataPtr] = None, 141*da0073e9SAndroid Build Coastguard Worker read_write: List[DataPtr] = None, 142*da0073e9SAndroid Build Coastguard Worker ) -> List[csan.SynchronizationError]: 143*da0073e9SAndroid Build Coastguard Worker if read_only is None: 144*da0073e9SAndroid Build Coastguard Worker read_only = [] 145*da0073e9SAndroid Build Coastguard Worker if read_write is None: 146*da0073e9SAndroid Build Coastguard Worker read_write = [] 147*da0073e9SAndroid Build Coastguard Worker return self.handler._handle_kernel_launch( 148*da0073e9SAndroid Build Coastguard Worker stream, 149*da0073e9SAndroid Build Coastguard Worker read_only, 150*da0073e9SAndroid Build Coastguard Worker read_write, 151*da0073e9SAndroid Build Coastguard Worker {}, 152*da0073e9SAndroid Build Coastguard Worker "", 153*da0073e9SAndroid Build Coastguard Worker {k: [""] for k in read_only + read_write}, 154*da0073e9SAndroid Build Coastguard Worker ) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker def assert_good_kernel_launch( 157*da0073e9SAndroid Build Coastguard Worker self, 158*da0073e9SAndroid Build Coastguard Worker stream: StreamId, 159*da0073e9SAndroid Build Coastguard Worker read_only: List[DataPtr] = None, 160*da0073e9SAndroid Build Coastguard Worker read_write: List[DataPtr] = None, 161*da0073e9SAndroid Build Coastguard Worker ) -> None: 162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.kernel_launch(stream, read_only, read_write), []) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker def assert_bad_kernel_launch( 165*da0073e9SAndroid Build Coastguard Worker self, 166*da0073e9SAndroid Build Coastguard Worker number_of_errors: int, 167*da0073e9SAndroid Build Coastguard Worker stream: StreamId, 168*da0073e9SAndroid Build Coastguard Worker read_only: List[DataPtr] = None, 169*da0073e9SAndroid Build Coastguard Worker read_write: List[DataPtr] = None, 170*da0073e9SAndroid Build Coastguard Worker ) -> None: 171*da0073e9SAndroid Build Coastguard Worker errors = self.kernel_launch(stream, read_only, read_write) 172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(errors), number_of_errors) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker def test_empty_kernel_launch(self): 175*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(0)) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker def test_simple_passing(self): 178*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 179*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker def test_simple_error(self): 182*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 183*da0073e9SAndroid Build Coastguard Worker self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker def test_simple_sync(self): 186*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 187*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(0), stream_id(1)) 188*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(0), stream_id(2)) 189*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)]) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker def test_reads_check_last_write(self): 192*da0073e9SAndroid Build Coastguard Worker # Tests that not only the first read operation checks if it is in conflict 193*da0073e9SAndroid Build Coastguard Worker # with the last write operation, but all read operations do. 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 196*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(0), stream_id(1)) 197*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(0), stream_id(2)) 198*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker self.assert_bad_kernel_launch(1, stream_id(3), read_only=[tensor_id(1)]) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker def test_branch_sync(self): 203*da0073e9SAndroid Build Coastguard Worker # Tests that two streams can read after both waiting for a third, but they 204*da0073e9SAndroid Build Coastguard Worker # cannot write without further synchronization. 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 207*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(0), stream_id(1)) 208*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(0), stream_id(2)) 209*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(0), stream_id(3)) 210*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) 211*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)]) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker def test_chain_sync(self): 216*da0073e9SAndroid Build Coastguard Worker iterations = 10 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(0), read_only=[tensor_id(1)]) 219*da0073e9SAndroid Build Coastguard Worker for i in range(iterations): 220*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(i), stream_id(i)) 221*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(i), stream_id(i + 1)) 222*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(iterations), read_write=[tensor_id(1)]) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker def test_expired_record(self): 225*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 226*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(0), stream_id(1)) 227*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 228*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(0), stream_id(2)) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def test_deleted_record(self): 233*da0073e9SAndroid Build Coastguard Worker for should_delete, should_create in [ 234*da0073e9SAndroid Build Coastguard Worker (True, True), 235*da0073e9SAndroid Build Coastguard Worker (True, False), 236*da0073e9SAndroid Build Coastguard Worker (False, True), 237*da0073e9SAndroid Build Coastguard Worker ]: 238*da0073e9SAndroid Build Coastguard Worker self.setUp() 239*da0073e9SAndroid Build Coastguard Worker with self.subTest(should_delete=should_delete, should_create=should_create): 240*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 241*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(0), stream_id(1)) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker if should_delete: 244*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_deletion(event_id(0)) 245*da0073e9SAndroid Build Coastguard Worker if should_create: 246*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_creation(event_id(0)) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(0), stream_id(2)) 249*da0073e9SAndroid Build Coastguard Worker self.assert_bad_kernel_launch( 250*da0073e9SAndroid Build Coastguard Worker 1, stream_id(2), read_write=[tensor_id(1)] 251*da0073e9SAndroid Build Coastguard Worker ) 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker def test_all_reads_checked_failing(self): 254*da0073e9SAndroid Build Coastguard Worker iterations = 10 255*da0073e9SAndroid Build Coastguard Worker for i in range(1, iterations): 256*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)]) 257*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(i), stream_id(i)) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker for i in range(1, iterations): 260*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(i), stream_id(0)) 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(iterations), read_only=[tensor_id(1)]) 263*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(iterations), stream_id(i)) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker # Does not synchronize with the last read. 266*da0073e9SAndroid Build Coastguard Worker self.assert_bad_kernel_launch(1, stream_id(0), read_write=[tensor_id(1)]) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker def test_all_reads_checked_passing(self): 269*da0073e9SAndroid Build Coastguard Worker iterations = 10 270*da0073e9SAndroid Build Coastguard Worker for i in range(1, iterations): 271*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)]) 272*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(i), stream_id(i)) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker for i in range(1, iterations): 275*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(i), stream_id(0)) 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)]) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker def test_multiple_errors(self): 280*da0073e9SAndroid Build Coastguard Worker iterations = 10 281*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch( 282*da0073e9SAndroid Build Coastguard Worker stream_id(0), read_write=[tensor_id(i) for i in range(iterations)] 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker self.assert_bad_kernel_launch( 285*da0073e9SAndroid Build Coastguard Worker iterations, 286*da0073e9SAndroid Build Coastguard Worker stream_id(1), 287*da0073e9SAndroid Build Coastguard Worker read_write=[tensor_id(i) for i in range(iterations)], 288*da0073e9SAndroid Build Coastguard Worker ) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker def test_correct_state_merging(self): 291*da0073e9SAndroid Build Coastguard Worker # Tests that after waiting for an event, a stream's state is indeed set 292*da0073e9SAndroid Build Coastguard Worker # to the pointwise maximum of its old state and the recorded state. 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 295*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)]) 296*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(1), stream_id(1)) 297*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(2), stream_id(2)) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 300*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)]) 301*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(1), stream_id(2)) 302*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(2), stream_id(1)) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(3), stream_id(2)) 305*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(3), stream_id(1)) 306*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch( 307*da0073e9SAndroid Build Coastguard Worker stream_id(1), read_write=[tensor_id(1), tensor_id(2)] 308*da0073e9SAndroid Build Coastguard Worker ) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker def test_record_override(self): 311*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)]) 312*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(2)]) 313*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(1), stream_id(1)) 314*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(1), stream_id(2)) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(1), stream_id(3)) 317*da0073e9SAndroid Build Coastguard Worker self.assert_bad_kernel_launch(1, stream_id(3), read_write=[tensor_id(1)]) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker def test_multiple_wait(self): 320*da0073e9SAndroid Build Coastguard Worker # Tests that a wait operation can be performed multiple times on the same event 321*da0073e9SAndroid Build Coastguard Worker # by different streams. 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 324*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(1), stream_id(1)) 325*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(1), stream_id(2)) 326*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_wait(event_id(1), stream_id(3)) 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) 329*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)]) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker def test_device_synchronize(self): 332*da0073e9SAndroid Build Coastguard Worker # Tests that a device synchronization does correctly cause all streams 333*da0073e9SAndroid Build Coastguard Worker # to synchronize with each other. 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker iterations = 10 336*da0073e9SAndroid Build Coastguard Worker for i in range(1, iterations): 337*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(i), read_write=[tensor_id(i)]) 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker self.handler._handle_device_synchronization() 340*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch( 341*da0073e9SAndroid Build Coastguard Worker stream_id(0), read_write=[tensor_id(i) for i in range(1, iterations)] 342*da0073e9SAndroid Build Coastguard Worker ) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker def test_device_synchronization_expired(self): 345*da0073e9SAndroid Build Coastguard Worker # Tests that a device synchronization is a one-time synchronization. 346*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 347*da0073e9SAndroid Build Coastguard Worker self.handler._handle_device_synchronization() 348*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)]) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker def test_new_stream_is_synchronized(self): 353*da0073e9SAndroid Build Coastguard Worker # Tests that after synchronizing operations with the host, any newly created 354*da0073e9SAndroid Build Coastguard Worker # stream is guaranteed to be synchronized with them as well. 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 357*da0073e9SAndroid Build Coastguard Worker self.handler._handle_device_synchronization() 358*da0073e9SAndroid Build Coastguard Worker self.handler._handle_stream_creation(stream_id(2)) 359*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)]) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker def test_stream_synchronize(self): 362*da0073e9SAndroid Build Coastguard Worker # Tests that a stream synchronization does correctly cause all streams to wait 363*da0073e9SAndroid Build Coastguard Worker # for one specific stream, but does not synchronize all streams with each other. 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)]) 366*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)]) 367*da0073e9SAndroid Build Coastguard Worker self.handler._handle_stream_synchronization(stream_id(0)) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)]) 370*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)]) 371*da0073e9SAndroid Build Coastguard Worker self.assert_bad_kernel_launch(1, stream_id(4), read_only=[tensor_id(2)]) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker def test_event_synchronize(self): 374*da0073e9SAndroid Build Coastguard Worker # Tests that an event synchronization does correctly cause all streams to wait 375*da0073e9SAndroid Build Coastguard Worker # for a recorded event, but does not guarantee synchronization with the current 376*da0073e9SAndroid Build Coastguard Worker # state of the stream that recorded the event. 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)]) 379*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_record(event_id(1), stream_id(1)) 380*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)]) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_synchronization(event_id(1)) 383*da0073e9SAndroid Build Coastguard Worker self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)]) 384*da0073e9SAndroid Build Coastguard Worker self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(2)]) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Workerclass TestMessages(TestCase): 388*da0073e9SAndroid Build Coastguard Worker def setUp(self): 389*da0073e9SAndroid Build Coastguard Worker self.handler = csan.EventHandler() 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker def test_ensure_exists(self): 392*da0073e9SAndroid Build Coastguard Worker ARG = 0 393*da0073e9SAndroid Build Coastguard Worker for func, out in [ 394*da0073e9SAndroid Build Coastguard Worker ( 395*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_deletion, 396*da0073e9SAndroid Build Coastguard Worker f"Found Event with id: {ARG}, but no matching event " 397*da0073e9SAndroid Build Coastguard Worker "creation in the trace. Backfilling the trace now. " 398*da0073e9SAndroid Build Coastguard Worker "Perhaps the sanitizer was enabled after some torch operations?", 399*da0073e9SAndroid Build Coastguard Worker ), 400*da0073e9SAndroid Build Coastguard Worker ( 401*da0073e9SAndroid Build Coastguard Worker self.handler._handle_memory_deallocation, 402*da0073e9SAndroid Build Coastguard Worker f"Found tensor with pointer: {ARG}, but no matching tensor " 403*da0073e9SAndroid Build Coastguard Worker "allocation in the trace. Backfilling the trace now. " 404*da0073e9SAndroid Build Coastguard Worker "Perhaps the sanitizer was enabled after some torch operations?", 405*da0073e9SAndroid Build Coastguard Worker ), 406*da0073e9SAndroid Build Coastguard Worker ]: 407*da0073e9SAndroid Build Coastguard Worker with self.subTest(func=func, out=out): 408*da0073e9SAndroid Build Coastguard Worker with self.assertLogs() as captured: 409*da0073e9SAndroid Build Coastguard Worker func(ARG) 410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(captured.records[0].getMessage(), out) 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker def test_ensure_does_not_exist(self): 413*da0073e9SAndroid Build Coastguard Worker ARG = 0 414*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_creation(ARG) 415*da0073e9SAndroid Build Coastguard Worker self.handler._handle_stream_creation(ARG) 416*da0073e9SAndroid Build Coastguard Worker for func, out in [ 417*da0073e9SAndroid Build Coastguard Worker ( 418*da0073e9SAndroid Build Coastguard Worker self.handler._handle_event_creation, 419*da0073e9SAndroid Build Coastguard Worker "Found duplicate event creation in the trace for event with " 420*da0073e9SAndroid Build Coastguard Worker f"id: {ARG}. Assuming the trace for event deletion wasn't caught " 421*da0073e9SAndroid Build Coastguard Worker "and backfilling it now. " 422*da0073e9SAndroid Build Coastguard Worker "Perhaps the sanitizer was enabled after some torch operations?", 423*da0073e9SAndroid Build Coastguard Worker ), 424*da0073e9SAndroid Build Coastguard Worker ( 425*da0073e9SAndroid Build Coastguard Worker self.handler._handle_stream_creation, 426*da0073e9SAndroid Build Coastguard Worker "Found duplicate Stream creation in the trace for Stream with " 427*da0073e9SAndroid Build Coastguard Worker f"id: {ARG}. PyTorch Streams are only created once, so this " 428*da0073e9SAndroid Build Coastguard Worker "trace entry is ignored.", 429*da0073e9SAndroid Build Coastguard Worker ), 430*da0073e9SAndroid Build Coastguard Worker ]: 431*da0073e9SAndroid Build Coastguard Worker with self.subTest(func=func, out=out): 432*da0073e9SAndroid Build Coastguard Worker with self.assertLogs() as captured: 433*da0073e9SAndroid Build Coastguard Worker func(ARG) 434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(captured.records[0].getMessage(), out) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker def test_error_message(self): 437*da0073e9SAndroid Build Coastguard Worker current_access = csan.Access( 438*da0073e9SAndroid Build Coastguard Worker type=csan.AccessType.WRITE, 439*da0073e9SAndroid Build Coastguard Worker seq_num=1, 440*da0073e9SAndroid Build Coastguard Worker stream=stream_id(1), 441*da0073e9SAndroid Build Coastguard Worker operator="schema", 442*da0073e9SAndroid Build Coastguard Worker aliases=["b"], 443*da0073e9SAndroid Build Coastguard Worker is_output=True, 444*da0073e9SAndroid Build Coastguard Worker stack_trace=traceback.StackSummary.from_list( 445*da0073e9SAndroid Build Coastguard Worker [("file", 0, "name", "trace a")] 446*da0073e9SAndroid Build Coastguard Worker ), 447*da0073e9SAndroid Build Coastguard Worker ) 448*da0073e9SAndroid Build Coastguard Worker previous_access = csan.Access( 449*da0073e9SAndroid Build Coastguard Worker type=csan.AccessType.READ, 450*da0073e9SAndroid Build Coastguard Worker seq_num=2, 451*da0073e9SAndroid Build Coastguard Worker stream=stream_id(0), 452*da0073e9SAndroid Build Coastguard Worker operator="schema", 453*da0073e9SAndroid Build Coastguard Worker aliases=["a"], 454*da0073e9SAndroid Build Coastguard Worker is_output=False, 455*da0073e9SAndroid Build Coastguard Worker stack_trace=traceback.StackSummary.from_list( 456*da0073e9SAndroid Build Coastguard Worker [("file", 0, "name", "trace b")] 457*da0073e9SAndroid Build Coastguard Worker ), 458*da0073e9SAndroid Build Coastguard Worker ) 459*da0073e9SAndroid Build Coastguard Worker error = csan.UnsynchronizedAccessError( 460*da0073e9SAndroid Build Coastguard Worker data_ptr=tensor_id(1), 461*da0073e9SAndroid Build Coastguard Worker allocation_stack_trace=traceback.StackSummary.from_list( 462*da0073e9SAndroid Build Coastguard Worker [("file", 0, "name", "alloc")] 463*da0073e9SAndroid Build Coastguard Worker ), 464*da0073e9SAndroid Build Coastguard Worker current_access=current_access, 465*da0073e9SAndroid Build Coastguard Worker previous_access=previous_access, 466*da0073e9SAndroid Build Coastguard Worker ) 467*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 468*da0073e9SAndroid Build Coastguard Worker str(error), 469*da0073e9SAndroid Build Coastguard Worker textwrap.dedent( 470*da0073e9SAndroid Build Coastguard Worker """\ 471*da0073e9SAndroid Build Coastguard Worker ============================ 472*da0073e9SAndroid Build Coastguard Worker CSAN detected a possible data race on tensor with data pointer 1 473*da0073e9SAndroid Build Coastguard Worker Access by stream 1001 during kernel: 474*da0073e9SAndroid Build Coastguard Worker schema 475*da0073e9SAndroid Build Coastguard Worker writing to argument(s) b, and to the output 476*da0073e9SAndroid Build Coastguard Worker With stack trace: 477*da0073e9SAndroid Build Coastguard Worker File "file", line 0, in name 478*da0073e9SAndroid Build Coastguard Worker trace a 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker Previous access by stream 1000 during kernel: 481*da0073e9SAndroid Build Coastguard Worker schema 482*da0073e9SAndroid Build Coastguard Worker reading from argument(s) a 483*da0073e9SAndroid Build Coastguard Worker With stack trace: 484*da0073e9SAndroid Build Coastguard Worker File "file", line 0, in name 485*da0073e9SAndroid Build Coastguard Worker trace b 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker Tensor was allocated with stack trace: 488*da0073e9SAndroid Build Coastguard Worker File "file", line 0, in name 489*da0073e9SAndroid Build Coastguard Worker alloc 490*da0073e9SAndroid Build Coastguard Worker """ 491*da0073e9SAndroid Build Coastguard Worker ), 492*da0073e9SAndroid Build Coastguard Worker ) 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 496*da0073e9SAndroid Build Coastguard Worker run_tests() 497