xref: /aosp_15_r20/external/pytorch/test/test_cuda_sanitizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cuda"]
2
3import sys
4import textwrap
5import traceback
6from typing import List
7
8import torch
9import torch.cuda._sanitizer as csan
10from torch.cuda._sanitizer import DataPtr, EventId, StreamId
11from torch.testing._internal.common_utils import NoTest, run_tests, TEST_CUDA, TestCase
12
13
14if not TEST_CUDA:
15    print("CUDA not available, skipping tests", file=sys.stderr)
16    TestCase = NoTest  # noqa: F811
17
18
19class TestArgumentHandler(TestCase):
20    def test_add(self):
21        add_func = torch.ops.aten.add.Tensor
22        a = torch.ones(5, 3, device="cuda")
23        b = torch.randn(5, 3, device="cuda")
24
25        argument_handler = csan.ArgumentHandler()
26        argument_handler.parse_inputs(add_func._schema, (a, b), {})
27        c = torch.add(a, b)
28        argument_handler.parse_outputs(c)
29
30        self.assertEqual({a.data_ptr(), b.data_ptr()}, argument_handler.dataptrs_read)
31        self.assertEqual({c.data_ptr()}, argument_handler.dataptrs_written)
32
33    def test_cat(self):
34        cat_func = torch.ops.aten.cat.default
35        a = torch.ones(2, 4, 5, device="cuda")
36        b = torch.zeros(2, 1, 5, device="cuda")
37        c = torch.rand(2, 7, 5, device="cuda")
38
39        argument_handler = csan.ArgumentHandler()
40        argument_handler.parse_inputs(cat_func._schema, ([a, b, c], 1), {})
41        d = torch.cat((a, b, c), dim=1)
42        argument_handler.parse_outputs(d)
43
44        self.assertEqual(
45            {a.data_ptr(), b.data_ptr(), c.data_ptr()}, argument_handler.dataptrs_read
46        )
47        self.assertEqual({d.data_ptr()}, argument_handler.dataptrs_written)
48
49    def test_split(self):
50        split_func = torch.ops.aten.split.Tensor
51        a = torch.arange(10, device="cuda").reshape(5, 2)
52
53        argument_handler = csan.ArgumentHandler()
54        argument_handler.parse_inputs(split_func._schema, (a, 2), {})
55        out = torch.split(a, 2)
56        argument_handler.parse_outputs(out)
57
58        outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
59        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
60        self.assertEqual(outputs, argument_handler.dataptrs_written)
61
62    def test_inplace(self):
63        add_inplace_func = torch.ops.aten.add_.Tensor
64        a = torch.rand(4, 2, device="cuda")
65
66        argument_handler = csan.ArgumentHandler()
67        argument_handler.parse_inputs(add_inplace_func._schema, (a, 5), {})
68        a.add_(5)
69        argument_handler.parse_outputs(a)
70
71        self.assertEqual(set(), argument_handler.dataptrs_read)
72        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_written)
73
74    def test_out(self):
75        mul_out_func = torch.ops.aten.mul.out
76        a = torch.arange(8, device="cuda")
77        b = torch.empty(8, device="cuda")
78
79        argument_handler = csan.ArgumentHandler()
80        argument_handler.parse_inputs(mul_out_func._schema, (a, 3), {"out": b})
81        torch.mul(a, 3, out=b)
82        argument_handler.parse_outputs(b)
83
84        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
85        self.assertEqual({b.data_ptr()}, argument_handler.dataptrs_written)
86
87    def test_nonzero(self):
88        nonzero_func = torch.ops.aten.nonzero.default
89        a = torch.ones(5, 3, 2, device="cuda")
90
91        argument_handler = csan.ArgumentHandler()
92        argument_handler.parse_inputs(nonzero_func._schema, (a,), {"as_tuple": True})
93        out = torch.nonzero(a, as_tuple=True)
94        argument_handler.parse_outputs(out)
95
96        outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
97        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
98        self.assertEqual(outputs, argument_handler.dataptrs_written)
99
100    def test_tensor_names(self):
101        addr_func = torch.ops.aten.addr.default
102        vec = torch.arange(1, 4, device="cuda")
103        M = torch.zeros(3, 3, device="cuda")
104
105        argument_handler = csan.ArgumentHandler()
106        argument_handler.parse_inputs(addr_func._schema, (M, vec, vec), {})
107        out = torch.addr(M, vec, vec)
108        argument_handler.parse_outputs(out)
109
110        self.assertEqual(
111            argument_handler.tensor_aliases,
112            {
113                M.data_ptr(): ["self"],
114                vec.data_ptr(): ["vec1", "vec2"],
115                out.data_ptr(): [],
116            },
117        )
118        self.assertEqual({out.data_ptr()}, argument_handler.outputs)
119
120
121def tensor_id(i: int) -> DataPtr:
122    return i
123
124
125def stream_id(i: int) -> StreamId:
126    return 1000 + i
127
128
129def event_id(i: int) -> EventId:
130    return 2000 + i
131
132
133class TestEventHandler(TestCase):
134    def setUp(self):
135        self.handler = csan.EventHandler()
136
137    def kernel_launch(
138        self,
139        stream: StreamId,
140        read_only: List[DataPtr] = None,
141        read_write: List[DataPtr] = None,
142    ) -> List[csan.SynchronizationError]:
143        if read_only is None:
144            read_only = []
145        if read_write is None:
146            read_write = []
147        return self.handler._handle_kernel_launch(
148            stream,
149            read_only,
150            read_write,
151            {},
152            "",
153            {k: [""] for k in read_only + read_write},
154        )
155
156    def assert_good_kernel_launch(
157        self,
158        stream: StreamId,
159        read_only: List[DataPtr] = None,
160        read_write: List[DataPtr] = None,
161    ) -> None:
162        self.assertEqual(self.kernel_launch(stream, read_only, read_write), [])
163
164    def assert_bad_kernel_launch(
165        self,
166        number_of_errors: int,
167        stream: StreamId,
168        read_only: List[DataPtr] = None,
169        read_write: List[DataPtr] = None,
170    ) -> None:
171        errors = self.kernel_launch(stream, read_only, read_write)
172        self.assertEqual(len(errors), number_of_errors)
173
174    def test_empty_kernel_launch(self):
175        self.assert_good_kernel_launch(stream_id(0))
176
177    def test_simple_passing(self):
178        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
179        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
180
181    def test_simple_error(self):
182        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
183        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
184
185    def test_simple_sync(self):
186        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
187        self.handler._handle_event_record(event_id(0), stream_id(1))
188        self.handler._handle_event_wait(event_id(0), stream_id(2))
189        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
190
191    def test_reads_check_last_write(self):
192        # Tests that not only the first read operation checks if it is in conflict
193        # with the last write operation, but all read operations do.
194
195        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
196        self.handler._handle_event_record(event_id(0), stream_id(1))
197        self.handler._handle_event_wait(event_id(0), stream_id(2))
198        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
199
200        self.assert_bad_kernel_launch(1, stream_id(3), read_only=[tensor_id(1)])
201
202    def test_branch_sync(self):
203        # Tests that two streams can read after both waiting for a third, but they
204        # cannot write without further synchronization.
205
206        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
207        self.handler._handle_event_record(event_id(0), stream_id(1))
208        self.handler._handle_event_wait(event_id(0), stream_id(2))
209        self.handler._handle_event_wait(event_id(0), stream_id(3))
210        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
211        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
212
213        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
214
215    def test_chain_sync(self):
216        iterations = 10
217
218        self.assert_good_kernel_launch(stream_id(0), read_only=[tensor_id(1)])
219        for i in range(iterations):
220            self.handler._handle_event_record(event_id(i), stream_id(i))
221            self.handler._handle_event_wait(event_id(i), stream_id(i + 1))
222        self.assert_good_kernel_launch(stream_id(iterations), read_write=[tensor_id(1)])
223
224    def test_expired_record(self):
225        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
226        self.handler._handle_event_record(event_id(0), stream_id(1))
227        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
228        self.handler._handle_event_wait(event_id(0), stream_id(2))
229
230        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
231
232    def test_deleted_record(self):
233        for should_delete, should_create in [
234            (True, True),
235            (True, False),
236            (False, True),
237        ]:
238            self.setUp()
239            with self.subTest(should_delete=should_delete, should_create=should_create):
240                self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
241                self.handler._handle_event_record(event_id(0), stream_id(1))
242
243                if should_delete:
244                    self.handler._handle_event_deletion(event_id(0))
245                if should_create:
246                    self.handler._handle_event_creation(event_id(0))
247
248                self.handler._handle_event_wait(event_id(0), stream_id(2))
249                self.assert_bad_kernel_launch(
250                    1, stream_id(2), read_write=[tensor_id(1)]
251                )
252
253    def test_all_reads_checked_failing(self):
254        iterations = 10
255        for i in range(1, iterations):
256            self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
257            self.handler._handle_event_record(event_id(i), stream_id(i))
258
259        for i in range(1, iterations):
260            self.handler._handle_event_wait(event_id(i), stream_id(0))
261
262        self.assert_good_kernel_launch(stream_id(iterations), read_only=[tensor_id(1)])
263        self.handler._handle_event_record(event_id(iterations), stream_id(i))
264
265        # Does not synchronize with the last read.
266        self.assert_bad_kernel_launch(1, stream_id(0), read_write=[tensor_id(1)])
267
268    def test_all_reads_checked_passing(self):
269        iterations = 10
270        for i in range(1, iterations):
271            self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
272            self.handler._handle_event_record(event_id(i), stream_id(i))
273
274        for i in range(1, iterations):
275            self.handler._handle_event_wait(event_id(i), stream_id(0))
276
277        self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
278
279    def test_multiple_errors(self):
280        iterations = 10
281        self.assert_good_kernel_launch(
282            stream_id(0), read_write=[tensor_id(i) for i in range(iterations)]
283        )
284        self.assert_bad_kernel_launch(
285            iterations,
286            stream_id(1),
287            read_write=[tensor_id(i) for i in range(iterations)],
288        )
289
290    def test_correct_state_merging(self):
291        # Tests that after waiting for an event, a stream's state is indeed set
292        # to the pointwise maximum of its old state and the recorded state.
293
294        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
295        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
296        self.handler._handle_event_record(event_id(1), stream_id(1))
297        self.handler._handle_event_record(event_id(2), stream_id(2))
298
299        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
300        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
301        self.handler._handle_event_wait(event_id(1), stream_id(2))
302        self.handler._handle_event_wait(event_id(2), stream_id(1))
303
304        self.handler._handle_event_record(event_id(3), stream_id(2))
305        self.handler._handle_event_wait(event_id(3), stream_id(1))
306        self.assert_good_kernel_launch(
307            stream_id(1), read_write=[tensor_id(1), tensor_id(2)]
308        )
309
310    def test_record_override(self):
311        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
312        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(2)])
313        self.handler._handle_event_record(event_id(1), stream_id(1))
314        self.handler._handle_event_record(event_id(1), stream_id(2))
315
316        self.handler._handle_event_wait(event_id(1), stream_id(3))
317        self.assert_bad_kernel_launch(1, stream_id(3), read_write=[tensor_id(1)])
318
319    def test_multiple_wait(self):
320        # Tests that a wait operation can be performed multiple times on the same event
321        # by different streams.
322
323        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
324        self.handler._handle_event_record(event_id(1), stream_id(1))
325        self.handler._handle_event_wait(event_id(1), stream_id(2))
326        self.handler._handle_event_wait(event_id(1), stream_id(3))
327
328        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
329        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
330
331    def test_device_synchronize(self):
332        # Tests that a device synchronization does correctly cause all streams
333        # to synchronize with each other.
334
335        iterations = 10
336        for i in range(1, iterations):
337            self.assert_good_kernel_launch(stream_id(i), read_write=[tensor_id(i)])
338
339        self.handler._handle_device_synchronization()
340        self.assert_good_kernel_launch(
341            stream_id(0), read_write=[tensor_id(i) for i in range(1, iterations)]
342        )
343
344    def test_device_synchronization_expired(self):
345        # Tests that a device synchronization is a one-time synchronization.
346        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
347        self.handler._handle_device_synchronization()
348        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
349
350        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
351
352    def test_new_stream_is_synchronized(self):
353        # Tests that after synchronizing operations with the host, any newly created
354        # stream is guaranteed to be synchronized with them as well.
355
356        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
357        self.handler._handle_device_synchronization()
358        self.handler._handle_stream_creation(stream_id(2))
359        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
360
361    def test_stream_synchronize(self):
362        # Tests that a stream synchronization does correctly cause all streams to wait
363        # for one specific stream, but does not synchronize all streams with each other.
364
365        self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
366        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
367        self.handler._handle_stream_synchronization(stream_id(0))
368
369        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
370        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
371        self.assert_bad_kernel_launch(1, stream_id(4), read_only=[tensor_id(2)])
372
373    def test_event_synchronize(self):
374        # Tests that an event synchronization does correctly cause all streams to wait
375        # for a recorded event, but does not guarantee synchronization with the current
376        # state of the stream that recorded the event.
377
378        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
379        self.handler._handle_event_record(event_id(1), stream_id(1))
380        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
381
382        self.handler._handle_event_synchronization(event_id(1))
383        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
384        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(2)])
385
386
387class TestMessages(TestCase):
388    def setUp(self):
389        self.handler = csan.EventHandler()
390
391    def test_ensure_exists(self):
392        ARG = 0
393        for func, out in [
394            (
395                self.handler._handle_event_deletion,
396                f"Found Event with id: {ARG}, but no matching event "
397                "creation in the trace. Backfilling the trace now. "
398                "Perhaps the sanitizer was enabled after some torch operations?",
399            ),
400            (
401                self.handler._handle_memory_deallocation,
402                f"Found tensor with pointer: {ARG}, but no matching tensor "
403                "allocation in the trace. Backfilling the trace now. "
404                "Perhaps the sanitizer was enabled after some torch operations?",
405            ),
406        ]:
407            with self.subTest(func=func, out=out):
408                with self.assertLogs() as captured:
409                    func(ARG)
410                self.assertEqual(captured.records[0].getMessage(), out)
411
412    def test_ensure_does_not_exist(self):
413        ARG = 0
414        self.handler._handle_event_creation(ARG)
415        self.handler._handle_stream_creation(ARG)
416        for func, out in [
417            (
418                self.handler._handle_event_creation,
419                "Found duplicate event creation in the trace for event with "
420                f"id: {ARG}. Assuming the trace for event deletion wasn't caught "
421                "and backfilling it now. "
422                "Perhaps the sanitizer was enabled after some torch operations?",
423            ),
424            (
425                self.handler._handle_stream_creation,
426                "Found duplicate Stream creation in the trace for Stream with "
427                f"id: {ARG}. PyTorch Streams are only created once, so this "
428                "trace entry is ignored.",
429            ),
430        ]:
431            with self.subTest(func=func, out=out):
432                with self.assertLogs() as captured:
433                    func(ARG)
434                self.assertEqual(captured.records[0].getMessage(), out)
435
436    def test_error_message(self):
437        current_access = csan.Access(
438            type=csan.AccessType.WRITE,
439            seq_num=1,
440            stream=stream_id(1),
441            operator="schema",
442            aliases=["b"],
443            is_output=True,
444            stack_trace=traceback.StackSummary.from_list(
445                [("file", 0, "name", "trace a")]
446            ),
447        )
448        previous_access = csan.Access(
449            type=csan.AccessType.READ,
450            seq_num=2,
451            stream=stream_id(0),
452            operator="schema",
453            aliases=["a"],
454            is_output=False,
455            stack_trace=traceback.StackSummary.from_list(
456                [("file", 0, "name", "trace b")]
457            ),
458        )
459        error = csan.UnsynchronizedAccessError(
460            data_ptr=tensor_id(1),
461            allocation_stack_trace=traceback.StackSummary.from_list(
462                [("file", 0, "name", "alloc")]
463            ),
464            current_access=current_access,
465            previous_access=previous_access,
466        )
467        self.assertEqual(
468            str(error),
469            textwrap.dedent(
470                """\
471                ============================
472                CSAN detected a possible data race on tensor with data pointer 1
473                Access by stream 1001 during kernel:
474                schema
475                writing to argument(s) b, and to the output
476                With stack trace:
477                  File "file", line 0, in name
478                    trace a
479
480                Previous access by stream 1000 during kernel:
481                schema
482                reading from argument(s) a
483                With stack trace:
484                  File "file", line 0, in name
485                    trace b
486
487                Tensor was allocated with stack trace:
488                  File "file", line 0, in name
489                    alloc
490                """
491            ),
492        )
493
494
495if __name__ == "__main__":
496    run_tests()
497