xref: /aosp_15_r20/external/pytorch/test/test_cuda_sanitizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: cuda"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport sys
4*da0073e9SAndroid Build Coastguard Workerimport textwrap
5*da0073e9SAndroid Build Coastguard Workerimport traceback
6*da0073e9SAndroid Build Coastguard Workerfrom typing import List
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerimport torch.cuda._sanitizer as csan
10*da0073e9SAndroid Build Coastguard Workerfrom torch.cuda._sanitizer import DataPtr, EventId, StreamId
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import NoTest, run_tests, TEST_CUDA, TestCase
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerif not TEST_CUDA:
15*da0073e9SAndroid Build Coastguard Worker    print("CUDA not available, skipping tests", file=sys.stderr)
16*da0073e9SAndroid Build Coastguard Worker    TestCase = NoTest  # noqa: F811
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerclass TestArgumentHandler(TestCase):
20*da0073e9SAndroid Build Coastguard Worker    def test_add(self):
21*da0073e9SAndroid Build Coastguard Worker        add_func = torch.ops.aten.add.Tensor
22*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(5, 3, device="cuda")
23*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 3, device="cuda")
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker        argument_handler = csan.ArgumentHandler()
26*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_inputs(add_func._schema, (a, b), {})
27*da0073e9SAndroid Build Coastguard Worker        c = torch.add(a, b)
28*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_outputs(c)
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker        self.assertEqual({a.data_ptr(), b.data_ptr()}, argument_handler.dataptrs_read)
31*da0073e9SAndroid Build Coastguard Worker        self.assertEqual({c.data_ptr()}, argument_handler.dataptrs_written)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    def test_cat(self):
34*da0073e9SAndroid Build Coastguard Worker        cat_func = torch.ops.aten.cat.default
35*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, 4, 5, device="cuda")
36*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(2, 1, 5, device="cuda")
37*da0073e9SAndroid Build Coastguard Worker        c = torch.rand(2, 7, 5, device="cuda")
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker        argument_handler = csan.ArgumentHandler()
40*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_inputs(cat_func._schema, ([a, b, c], 1), {})
41*da0073e9SAndroid Build Coastguard Worker        d = torch.cat((a, b, c), dim=1)
42*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_outputs(d)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
45*da0073e9SAndroid Build Coastguard Worker            {a.data_ptr(), b.data_ptr(), c.data_ptr()}, argument_handler.dataptrs_read
46*da0073e9SAndroid Build Coastguard Worker        )
47*da0073e9SAndroid Build Coastguard Worker        self.assertEqual({d.data_ptr()}, argument_handler.dataptrs_written)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    def test_split(self):
50*da0073e9SAndroid Build Coastguard Worker        split_func = torch.ops.aten.split.Tensor
51*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(10, device="cuda").reshape(5, 2)
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker        argument_handler = csan.ArgumentHandler()
54*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_inputs(split_func._schema, (a, 2), {})
55*da0073e9SAndroid Build Coastguard Worker        out = torch.split(a, 2)
56*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_outputs(out)
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker        outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
59*da0073e9SAndroid Build Coastguard Worker        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
60*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outputs, argument_handler.dataptrs_written)
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def test_inplace(self):
63*da0073e9SAndroid Build Coastguard Worker        add_inplace_func = torch.ops.aten.add_.Tensor
64*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4, 2, device="cuda")
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker        argument_handler = csan.ArgumentHandler()
67*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_inputs(add_inplace_func._schema, (a, 5), {})
68*da0073e9SAndroid Build Coastguard Worker        a.add_(5)
69*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_outputs(a)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(set(), argument_handler.dataptrs_read)
72*da0073e9SAndroid Build Coastguard Worker        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_written)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def test_out(self):
75*da0073e9SAndroid Build Coastguard Worker        mul_out_func = torch.ops.aten.mul.out
76*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(8, device="cuda")
77*da0073e9SAndroid Build Coastguard Worker        b = torch.empty(8, device="cuda")
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        argument_handler = csan.ArgumentHandler()
80*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_inputs(mul_out_func._schema, (a, 3), {"out": b})
81*da0073e9SAndroid Build Coastguard Worker        torch.mul(a, 3, out=b)
82*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_outputs(b)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
85*da0073e9SAndroid Build Coastguard Worker        self.assertEqual({b.data_ptr()}, argument_handler.dataptrs_written)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    def test_nonzero(self):
88*da0073e9SAndroid Build Coastguard Worker        nonzero_func = torch.ops.aten.nonzero.default
89*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(5, 3, 2, device="cuda")
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        argument_handler = csan.ArgumentHandler()
92*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_inputs(nonzero_func._schema, (a,), {"as_tuple": True})
93*da0073e9SAndroid Build Coastguard Worker        out = torch.nonzero(a, as_tuple=True)
94*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_outputs(out)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
97*da0073e9SAndroid Build Coastguard Worker        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
98*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outputs, argument_handler.dataptrs_written)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    def test_tensor_names(self):
101*da0073e9SAndroid Build Coastguard Worker        addr_func = torch.ops.aten.addr.default
102*da0073e9SAndroid Build Coastguard Worker        vec = torch.arange(1, 4, device="cuda")
103*da0073e9SAndroid Build Coastguard Worker        M = torch.zeros(3, 3, device="cuda")
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        argument_handler = csan.ArgumentHandler()
106*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_inputs(addr_func._schema, (M, vec, vec), {})
107*da0073e9SAndroid Build Coastguard Worker        out = torch.addr(M, vec, vec)
108*da0073e9SAndroid Build Coastguard Worker        argument_handler.parse_outputs(out)
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
111*da0073e9SAndroid Build Coastguard Worker            argument_handler.tensor_aliases,
112*da0073e9SAndroid Build Coastguard Worker            {
113*da0073e9SAndroid Build Coastguard Worker                M.data_ptr(): ["self"],
114*da0073e9SAndroid Build Coastguard Worker                vec.data_ptr(): ["vec1", "vec2"],
115*da0073e9SAndroid Build Coastguard Worker                out.data_ptr(): [],
116*da0073e9SAndroid Build Coastguard Worker            },
117*da0073e9SAndroid Build Coastguard Worker        )
118*da0073e9SAndroid Build Coastguard Worker        self.assertEqual({out.data_ptr()}, argument_handler.outputs)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Workerdef tensor_id(i: int) -> DataPtr:
122*da0073e9SAndroid Build Coastguard Worker    return i
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Workerdef stream_id(i: int) -> StreamId:
126*da0073e9SAndroid Build Coastguard Worker    return 1000 + i
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Workerdef event_id(i: int) -> EventId:
130*da0073e9SAndroid Build Coastguard Worker    return 2000 + i
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Workerclass TestEventHandler(TestCase):
134*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
135*da0073e9SAndroid Build Coastguard Worker        self.handler = csan.EventHandler()
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    def kernel_launch(
138*da0073e9SAndroid Build Coastguard Worker        self,
139*da0073e9SAndroid Build Coastguard Worker        stream: StreamId,
140*da0073e9SAndroid Build Coastguard Worker        read_only: List[DataPtr] = None,
141*da0073e9SAndroid Build Coastguard Worker        read_write: List[DataPtr] = None,
142*da0073e9SAndroid Build Coastguard Worker    ) -> List[csan.SynchronizationError]:
143*da0073e9SAndroid Build Coastguard Worker        if read_only is None:
144*da0073e9SAndroid Build Coastguard Worker            read_only = []
145*da0073e9SAndroid Build Coastguard Worker        if read_write is None:
146*da0073e9SAndroid Build Coastguard Worker            read_write = []
147*da0073e9SAndroid Build Coastguard Worker        return self.handler._handle_kernel_launch(
148*da0073e9SAndroid Build Coastguard Worker            stream,
149*da0073e9SAndroid Build Coastguard Worker            read_only,
150*da0073e9SAndroid Build Coastguard Worker            read_write,
151*da0073e9SAndroid Build Coastguard Worker            {},
152*da0073e9SAndroid Build Coastguard Worker            "",
153*da0073e9SAndroid Build Coastguard Worker            {k: [""] for k in read_only + read_write},
154*da0073e9SAndroid Build Coastguard Worker        )
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    def assert_good_kernel_launch(
157*da0073e9SAndroid Build Coastguard Worker        self,
158*da0073e9SAndroid Build Coastguard Worker        stream: StreamId,
159*da0073e9SAndroid Build Coastguard Worker        read_only: List[DataPtr] = None,
160*da0073e9SAndroid Build Coastguard Worker        read_write: List[DataPtr] = None,
161*da0073e9SAndroid Build Coastguard Worker    ) -> None:
162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.kernel_launch(stream, read_only, read_write), [])
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker    def assert_bad_kernel_launch(
165*da0073e9SAndroid Build Coastguard Worker        self,
166*da0073e9SAndroid Build Coastguard Worker        number_of_errors: int,
167*da0073e9SAndroid Build Coastguard Worker        stream: StreamId,
168*da0073e9SAndroid Build Coastguard Worker        read_only: List[DataPtr] = None,
169*da0073e9SAndroid Build Coastguard Worker        read_write: List[DataPtr] = None,
170*da0073e9SAndroid Build Coastguard Worker    ) -> None:
171*da0073e9SAndroid Build Coastguard Worker        errors = self.kernel_launch(stream, read_only, read_write)
172*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(errors), number_of_errors)
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    def test_empty_kernel_launch(self):
175*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(0))
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    def test_simple_passing(self):
178*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
179*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker    def test_simple_error(self):
182*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
183*da0073e9SAndroid Build Coastguard Worker        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker    def test_simple_sync(self):
186*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
187*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(0), stream_id(1))
188*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_wait(event_id(0), stream_id(2))
189*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker    def test_reads_check_last_write(self):
192*da0073e9SAndroid Build Coastguard Worker        # Tests that not only the first read operation checks if it is in conflict
193*da0073e9SAndroid Build Coastguard Worker        # with the last write operation, but all read operations do.
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
196*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(0), stream_id(1))
197*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_wait(event_id(0), stream_id(2))
198*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        self.assert_bad_kernel_launch(1, stream_id(3), read_only=[tensor_id(1)])
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    def test_branch_sync(self):
203*da0073e9SAndroid Build Coastguard Worker        # Tests that two streams can read after both waiting for a third, but they
204*da0073e9SAndroid Build Coastguard Worker        # cannot write without further synchronization.
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
207*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(0), stream_id(1))
208*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_wait(event_id(0), stream_id(2))
209*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_wait(event_id(0), stream_id(3))
210*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
211*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker    def test_chain_sync(self):
216*da0073e9SAndroid Build Coastguard Worker        iterations = 10
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(0), read_only=[tensor_id(1)])
219*da0073e9SAndroid Build Coastguard Worker        for i in range(iterations):
220*da0073e9SAndroid Build Coastguard Worker            self.handler._handle_event_record(event_id(i), stream_id(i))
221*da0073e9SAndroid Build Coastguard Worker            self.handler._handle_event_wait(event_id(i), stream_id(i + 1))
222*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(iterations), read_write=[tensor_id(1)])
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    def test_expired_record(self):
225*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
226*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(0), stream_id(1))
227*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
228*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_wait(event_id(0), stream_id(2))
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker    def test_deleted_record(self):
233*da0073e9SAndroid Build Coastguard Worker        for should_delete, should_create in [
234*da0073e9SAndroid Build Coastguard Worker            (True, True),
235*da0073e9SAndroid Build Coastguard Worker            (True, False),
236*da0073e9SAndroid Build Coastguard Worker            (False, True),
237*da0073e9SAndroid Build Coastguard Worker        ]:
238*da0073e9SAndroid Build Coastguard Worker            self.setUp()
239*da0073e9SAndroid Build Coastguard Worker            with self.subTest(should_delete=should_delete, should_create=should_create):
240*da0073e9SAndroid Build Coastguard Worker                self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
241*da0073e9SAndroid Build Coastguard Worker                self.handler._handle_event_record(event_id(0), stream_id(1))
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker                if should_delete:
244*da0073e9SAndroid Build Coastguard Worker                    self.handler._handle_event_deletion(event_id(0))
245*da0073e9SAndroid Build Coastguard Worker                if should_create:
246*da0073e9SAndroid Build Coastguard Worker                    self.handler._handle_event_creation(event_id(0))
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker                self.handler._handle_event_wait(event_id(0), stream_id(2))
249*da0073e9SAndroid Build Coastguard Worker                self.assert_bad_kernel_launch(
250*da0073e9SAndroid Build Coastguard Worker                    1, stream_id(2), read_write=[tensor_id(1)]
251*da0073e9SAndroid Build Coastguard Worker                )
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker    def test_all_reads_checked_failing(self):
254*da0073e9SAndroid Build Coastguard Worker        iterations = 10
255*da0073e9SAndroid Build Coastguard Worker        for i in range(1, iterations):
256*da0073e9SAndroid Build Coastguard Worker            self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
257*da0073e9SAndroid Build Coastguard Worker            self.handler._handle_event_record(event_id(i), stream_id(i))
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        for i in range(1, iterations):
260*da0073e9SAndroid Build Coastguard Worker            self.handler._handle_event_wait(event_id(i), stream_id(0))
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(iterations), read_only=[tensor_id(1)])
263*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(iterations), stream_id(i))
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        # Does not synchronize with the last read.
266*da0073e9SAndroid Build Coastguard Worker        self.assert_bad_kernel_launch(1, stream_id(0), read_write=[tensor_id(1)])
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker    def test_all_reads_checked_passing(self):
269*da0073e9SAndroid Build Coastguard Worker        iterations = 10
270*da0073e9SAndroid Build Coastguard Worker        for i in range(1, iterations):
271*da0073e9SAndroid Build Coastguard Worker            self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
272*da0073e9SAndroid Build Coastguard Worker            self.handler._handle_event_record(event_id(i), stream_id(i))
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        for i in range(1, iterations):
275*da0073e9SAndroid Build Coastguard Worker            self.handler._handle_event_wait(event_id(i), stream_id(0))
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    def test_multiple_errors(self):
280*da0073e9SAndroid Build Coastguard Worker        iterations = 10
281*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(
282*da0073e9SAndroid Build Coastguard Worker            stream_id(0), read_write=[tensor_id(i) for i in range(iterations)]
283*da0073e9SAndroid Build Coastguard Worker        )
284*da0073e9SAndroid Build Coastguard Worker        self.assert_bad_kernel_launch(
285*da0073e9SAndroid Build Coastguard Worker            iterations,
286*da0073e9SAndroid Build Coastguard Worker            stream_id(1),
287*da0073e9SAndroid Build Coastguard Worker            read_write=[tensor_id(i) for i in range(iterations)],
288*da0073e9SAndroid Build Coastguard Worker        )
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker    def test_correct_state_merging(self):
291*da0073e9SAndroid Build Coastguard Worker        # Tests that after waiting for an event, a stream's state is indeed set
292*da0073e9SAndroid Build Coastguard Worker        # to the pointwise maximum of its old state and the recorded state.
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
295*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
296*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(1), stream_id(1))
297*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(2), stream_id(2))
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
300*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
301*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_wait(event_id(1), stream_id(2))
302*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_wait(event_id(2), stream_id(1))
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(3), stream_id(2))
305*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_wait(event_id(3), stream_id(1))
306*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(
307*da0073e9SAndroid Build Coastguard Worker            stream_id(1), read_write=[tensor_id(1), tensor_id(2)]
308*da0073e9SAndroid Build Coastguard Worker        )
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker    def test_record_override(self):
311*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
312*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(2)])
313*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(1), stream_id(1))
314*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(1), stream_id(2))
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_wait(event_id(1), stream_id(3))
317*da0073e9SAndroid Build Coastguard Worker        self.assert_bad_kernel_launch(1, stream_id(3), read_write=[tensor_id(1)])
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker    def test_multiple_wait(self):
320*da0073e9SAndroid Build Coastguard Worker        # Tests that a wait operation can be performed multiple times on the same event
321*da0073e9SAndroid Build Coastguard Worker        # by different streams.
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
324*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(1), stream_id(1))
325*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_wait(event_id(1), stream_id(2))
326*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_wait(event_id(1), stream_id(3))
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
329*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker    def test_device_synchronize(self):
332*da0073e9SAndroid Build Coastguard Worker        # Tests that a device synchronization does correctly cause all streams
333*da0073e9SAndroid Build Coastguard Worker        # to synchronize with each other.
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        iterations = 10
336*da0073e9SAndroid Build Coastguard Worker        for i in range(1, iterations):
337*da0073e9SAndroid Build Coastguard Worker            self.assert_good_kernel_launch(stream_id(i), read_write=[tensor_id(i)])
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_device_synchronization()
340*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(
341*da0073e9SAndroid Build Coastguard Worker            stream_id(0), read_write=[tensor_id(i) for i in range(1, iterations)]
342*da0073e9SAndroid Build Coastguard Worker        )
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker    def test_device_synchronization_expired(self):
345*da0073e9SAndroid Build Coastguard Worker        # Tests that a device synchronization is a one-time synchronization.
346*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
347*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_device_synchronization()
348*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    def test_new_stream_is_synchronized(self):
353*da0073e9SAndroid Build Coastguard Worker        # Tests that after synchronizing operations with the host, any newly created
354*da0073e9SAndroid Build Coastguard Worker        # stream is guaranteed to be synchronized with them as well.
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
357*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_device_synchronization()
358*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_stream_creation(stream_id(2))
359*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    def test_stream_synchronize(self):
362*da0073e9SAndroid Build Coastguard Worker        # Tests that a stream synchronization does correctly cause all streams to wait
363*da0073e9SAndroid Build Coastguard Worker        # for one specific stream, but does not synchronize all streams with each other.
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
366*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
367*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_stream_synchronization(stream_id(0))
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
370*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
371*da0073e9SAndroid Build Coastguard Worker        self.assert_bad_kernel_launch(1, stream_id(4), read_only=[tensor_id(2)])
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    def test_event_synchronize(self):
374*da0073e9SAndroid Build Coastguard Worker        # Tests that an event synchronization does correctly cause all streams to wait
375*da0073e9SAndroid Build Coastguard Worker        # for a recorded event, but does not guarantee synchronization with the current
376*da0073e9SAndroid Build Coastguard Worker        # state of the stream that recorded the event.
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
379*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_record(event_id(1), stream_id(1))
380*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_synchronization(event_id(1))
383*da0073e9SAndroid Build Coastguard Worker        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
384*da0073e9SAndroid Build Coastguard Worker        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(2)])
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Workerclass TestMessages(TestCase):
388*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
389*da0073e9SAndroid Build Coastguard Worker        self.handler = csan.EventHandler()
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker    def test_ensure_exists(self):
392*da0073e9SAndroid Build Coastguard Worker        ARG = 0
393*da0073e9SAndroid Build Coastguard Worker        for func, out in [
394*da0073e9SAndroid Build Coastguard Worker            (
395*da0073e9SAndroid Build Coastguard Worker                self.handler._handle_event_deletion,
396*da0073e9SAndroid Build Coastguard Worker                f"Found Event with id: {ARG}, but no matching event "
397*da0073e9SAndroid Build Coastguard Worker                "creation in the trace. Backfilling the trace now. "
398*da0073e9SAndroid Build Coastguard Worker                "Perhaps the sanitizer was enabled after some torch operations?",
399*da0073e9SAndroid Build Coastguard Worker            ),
400*da0073e9SAndroid Build Coastguard Worker            (
401*da0073e9SAndroid Build Coastguard Worker                self.handler._handle_memory_deallocation,
402*da0073e9SAndroid Build Coastguard Worker                f"Found tensor with pointer: {ARG}, but no matching tensor "
403*da0073e9SAndroid Build Coastguard Worker                "allocation in the trace. Backfilling the trace now. "
404*da0073e9SAndroid Build Coastguard Worker                "Perhaps the sanitizer was enabled after some torch operations?",
405*da0073e9SAndroid Build Coastguard Worker            ),
406*da0073e9SAndroid Build Coastguard Worker        ]:
407*da0073e9SAndroid Build Coastguard Worker            with self.subTest(func=func, out=out):
408*da0073e9SAndroid Build Coastguard Worker                with self.assertLogs() as captured:
409*da0073e9SAndroid Build Coastguard Worker                    func(ARG)
410*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(captured.records[0].getMessage(), out)
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker    def test_ensure_does_not_exist(self):
413*da0073e9SAndroid Build Coastguard Worker        ARG = 0
414*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_event_creation(ARG)
415*da0073e9SAndroid Build Coastguard Worker        self.handler._handle_stream_creation(ARG)
416*da0073e9SAndroid Build Coastguard Worker        for func, out in [
417*da0073e9SAndroid Build Coastguard Worker            (
418*da0073e9SAndroid Build Coastguard Worker                self.handler._handle_event_creation,
419*da0073e9SAndroid Build Coastguard Worker                "Found duplicate event creation in the trace for event with "
420*da0073e9SAndroid Build Coastguard Worker                f"id: {ARG}. Assuming the trace for event deletion wasn't caught "
421*da0073e9SAndroid Build Coastguard Worker                "and backfilling it now. "
422*da0073e9SAndroid Build Coastguard Worker                "Perhaps the sanitizer was enabled after some torch operations?",
423*da0073e9SAndroid Build Coastguard Worker            ),
424*da0073e9SAndroid Build Coastguard Worker            (
425*da0073e9SAndroid Build Coastguard Worker                self.handler._handle_stream_creation,
426*da0073e9SAndroid Build Coastguard Worker                "Found duplicate Stream creation in the trace for Stream with "
427*da0073e9SAndroid Build Coastguard Worker                f"id: {ARG}. PyTorch Streams are only created once, so this "
428*da0073e9SAndroid Build Coastguard Worker                "trace entry is ignored.",
429*da0073e9SAndroid Build Coastguard Worker            ),
430*da0073e9SAndroid Build Coastguard Worker        ]:
431*da0073e9SAndroid Build Coastguard Worker            with self.subTest(func=func, out=out):
432*da0073e9SAndroid Build Coastguard Worker                with self.assertLogs() as captured:
433*da0073e9SAndroid Build Coastguard Worker                    func(ARG)
434*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(captured.records[0].getMessage(), out)
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker    def test_error_message(self):
437*da0073e9SAndroid Build Coastguard Worker        current_access = csan.Access(
438*da0073e9SAndroid Build Coastguard Worker            type=csan.AccessType.WRITE,
439*da0073e9SAndroid Build Coastguard Worker            seq_num=1,
440*da0073e9SAndroid Build Coastguard Worker            stream=stream_id(1),
441*da0073e9SAndroid Build Coastguard Worker            operator="schema",
442*da0073e9SAndroid Build Coastguard Worker            aliases=["b"],
443*da0073e9SAndroid Build Coastguard Worker            is_output=True,
444*da0073e9SAndroid Build Coastguard Worker            stack_trace=traceback.StackSummary.from_list(
445*da0073e9SAndroid Build Coastguard Worker                [("file", 0, "name", "trace a")]
446*da0073e9SAndroid Build Coastguard Worker            ),
447*da0073e9SAndroid Build Coastguard Worker        )
448*da0073e9SAndroid Build Coastguard Worker        previous_access = csan.Access(
449*da0073e9SAndroid Build Coastguard Worker            type=csan.AccessType.READ,
450*da0073e9SAndroid Build Coastguard Worker            seq_num=2,
451*da0073e9SAndroid Build Coastguard Worker            stream=stream_id(0),
452*da0073e9SAndroid Build Coastguard Worker            operator="schema",
453*da0073e9SAndroid Build Coastguard Worker            aliases=["a"],
454*da0073e9SAndroid Build Coastguard Worker            is_output=False,
455*da0073e9SAndroid Build Coastguard Worker            stack_trace=traceback.StackSummary.from_list(
456*da0073e9SAndroid Build Coastguard Worker                [("file", 0, "name", "trace b")]
457*da0073e9SAndroid Build Coastguard Worker            ),
458*da0073e9SAndroid Build Coastguard Worker        )
459*da0073e9SAndroid Build Coastguard Worker        error = csan.UnsynchronizedAccessError(
460*da0073e9SAndroid Build Coastguard Worker            data_ptr=tensor_id(1),
461*da0073e9SAndroid Build Coastguard Worker            allocation_stack_trace=traceback.StackSummary.from_list(
462*da0073e9SAndroid Build Coastguard Worker                [("file", 0, "name", "alloc")]
463*da0073e9SAndroid Build Coastguard Worker            ),
464*da0073e9SAndroid Build Coastguard Worker            current_access=current_access,
465*da0073e9SAndroid Build Coastguard Worker            previous_access=previous_access,
466*da0073e9SAndroid Build Coastguard Worker        )
467*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
468*da0073e9SAndroid Build Coastguard Worker            str(error),
469*da0073e9SAndroid Build Coastguard Worker            textwrap.dedent(
470*da0073e9SAndroid Build Coastguard Worker                """\
471*da0073e9SAndroid Build Coastguard Worker                ============================
472*da0073e9SAndroid Build Coastguard Worker                CSAN detected a possible data race on tensor with data pointer 1
473*da0073e9SAndroid Build Coastguard Worker                Access by stream 1001 during kernel:
474*da0073e9SAndroid Build Coastguard Worker                schema
475*da0073e9SAndroid Build Coastguard Worker                writing to argument(s) b, and to the output
476*da0073e9SAndroid Build Coastguard Worker                With stack trace:
477*da0073e9SAndroid Build Coastguard Worker                  File "file", line 0, in name
478*da0073e9SAndroid Build Coastguard Worker                    trace a
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker                Previous access by stream 1000 during kernel:
481*da0073e9SAndroid Build Coastguard Worker                schema
482*da0073e9SAndroid Build Coastguard Worker                reading from argument(s) a
483*da0073e9SAndroid Build Coastguard Worker                With stack trace:
484*da0073e9SAndroid Build Coastguard Worker                  File "file", line 0, in name
485*da0073e9SAndroid Build Coastguard Worker                    trace b
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker                Tensor was allocated with stack trace:
488*da0073e9SAndroid Build Coastguard Worker                  File "file", line 0, in name
489*da0073e9SAndroid Build Coastguard Worker                    alloc
490*da0073e9SAndroid Build Coastguard Worker                """
491*da0073e9SAndroid Build Coastguard Worker            ),
492*da0073e9SAndroid Build Coastguard Worker        )
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
496*da0073e9SAndroid Build Coastguard Worker    run_tests()
497