xref: /aosp_15_r20/external/pytorch/torch/cuda/_sanitizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""
3This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams.
4
5It stores information on accesses to tensors to determine if they are synchronized
6or not. When enabled in a python program and a possible data race is detected, a
7detailed warning will be printed and the program will exit.
8
9It can be enabled either by importing this module and calling
10:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER``
11environment variable.
12"""
13
14import enum
15import functools
16import inspect
17import io
18import logging
19import sys
20import textwrap
21import traceback
22from dataclasses import dataclass, field
23from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
24
25import torch
26import torch.cuda._gpu_trace as gpu_trace
27from torch.utils import _pytree as pytree
28from torch.utils._python_dispatch import TorchDispatchMode
29
30
31DEFAULT_STREAM_ID = 0
32
33TK = TypeVar("TK")
34TVa = TypeVar("TVa")
35TVb = TypeVar("TVb")
36
37DataPtr = int
38StreamId = int
39EventId = int
40SeqNum = int
41
42logger = logging.getLogger(__name__)
43
44
45class AccessType(enum.Enum):
46    READ = enum.auto()
47    WRITE = enum.auto()
48
49    def __str__(self):
50        return "reading from" if self is AccessType.READ else "writing to"
51
52
53@dataclass
54class Access:
55    r"""Stores information about a single access to a tensor by a kernel.
56
57    Args:
58        type: either AccessType.READ or AccessType.Write.
59        seq_num: the sequential number of the kernel performing the access.
60        stream: the stream id of the stream executing the kernel.
61        operator: the schema of the launched kernel, which lists the
62            arguments and return type.
63        aliases: the arguments in the schema this access corresponds to.
64        is_output: Whether the tensor was an output of the kernel.
65        stack_trace: the stack summary object captured during access.
66    """
67
68    type: AccessType
69    seq_num: SeqNum
70    stream: StreamId
71    operator: str
72    aliases: List[str]
73    is_output: bool
74    stack_trace: traceback.StackSummary
75
76
77class SynchronizationError(Exception):
78    """Base class for errors detected by CUDA Sanitizer."""
79
80
81class UnsynchronizedAccessError(SynchronizationError):
82    """Stores information about two unsynchronized accesses to one data pointer."""
83
84    def __init__(
85        self,
86        data_ptr: DataPtr,
87        allocation_stack_trace: Optional[traceback.StackSummary],
88        current_access: Access,
89        previous_access: Access,
90    ):
91        self.data_ptr = data_ptr
92        self.allocation_stack_trace = allocation_stack_trace
93        self.current_access = current_access
94        self.previous_access = previous_access
95
96    def __str__(self):
97        def format_access(access: Access):
98            message.write(f"{access.operator}\n{access.type}")
99            if access.aliases:
100                message.write(" argument(s) " + ", ".join(access.aliases))
101                if access.is_output:
102                    message.write(", and to")
103            if access.is_output:
104                message.write(" the output")
105            message.write(
106                f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n"
107            )
108
109        with io.StringIO() as message:
110            message.write(
111                textwrap.dedent(
112                    f"""\
113                    ============================
114                    CSAN detected a possible data race on tensor with data pointer {self.data_ptr}
115                    Access by stream {self.current_access.stream} during kernel:
116                    """
117                )
118            )
119            format_access(self.current_access)
120
121            message.write(
122                f"Previous access by stream {self.previous_access.stream} during kernel:\n"
123            )
124            format_access(self.previous_access)
125
126            if self.allocation_stack_trace:
127                message.write(
128                    "Tensor was allocated with stack trace:\n"
129                    f"{''.join(self.allocation_stack_trace.format())}"
130                )
131            else:
132                message.write("Trace for tensor allocation not found.")
133            return message.getvalue()
134
135
136class CUDASanitizerErrors(Exception):
137    """Wrapper class for errors reported by CUDA Sanitizer."""
138
139    def __init__(self, errors: List[SynchronizationError]):
140        self.errors = errors
141
142    def __str__(self):
143        return f"detected {len(self.errors)} errors"
144
145
146@dataclass
147class TensorInfo:
148    r"""Stores information about a single tensor and recent accesses to it.
149
150    Args:
151        allocation_stack_trace: the stack summary object captured during tensor
152            allocation. Can be ``None`` if the allocation wasn't caught by CSAN.
153        reads: list of read accesses to the tensor that were performed since
154            the last write.
155        write: the last write access to the tensor.
156    """
157
158    allocation_stack_trace: Optional[traceback.StackSummary]
159    reads: List[Access] = field(default_factory=list)
160    write: Optional[Access] = None
161
162
163class _TensorsAccessed:
164    def __init__(self) -> None:
165        self.accesses: Dict[DataPtr, TensorInfo] = {}
166
167    def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
168        if data_ptr not in self.accesses:
169            logger.info(
170                "Found tensor with pointer: %s, but no matching tensor "
171                "allocation in the trace. Backfilling the trace now. "
172                "Perhaps the sanitizer was enabled after some torch operations?",
173                data_ptr,
174            )
175            self.create_tensor(data_ptr, None)
176
177    def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None:
178        if data_ptr in self.accesses:
179            logger.info(
180                "Found duplicate tensor allocation in the trace for tensor with "
181                "pointer: %s. Assuming the trace for tensor deallocation "
182                "wasn't caught and backfilling it now. "
183                "Perhaps the sanitizer was enabled after some torch operations?",
184                data_ptr,
185            )
186            self.delete_tensor(data_ptr)
187
188    def create_tensor(
189        self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary]
190    ) -> None:
191        self.accesses[data_ptr] = TensorInfo(stack_trace)
192
193    def delete_tensor(self, data_ptr: DataPtr) -> None:
194        del self.accesses[data_ptr]
195
196    def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
197        return True if self.accesses[data_ptr].reads else False
198
199    def get_allocation_stack_trace(
200        self, data_ptr: DataPtr
201    ) -> Optional[traceback.StackSummary]:
202        return self.accesses[data_ptr].allocation_stack_trace
203
204    def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
205        return self.accesses[data_ptr].write
206
207    def get_reads(self, data_ptr: DataPtr) -> List[Access]:
208        return self.accesses[data_ptr].reads
209
210    def add_read(self, data_ptr: DataPtr, access: Access) -> None:
211        self.accesses[data_ptr].reads.append(access)
212
213    def set_write(self, data_ptr: DataPtr, access: Access) -> None:
214        self.accesses[data_ptr].write = access
215        self.accesses[data_ptr].reads = []
216
217
218class StreamSynchronizations:
219    def __init__(self) -> None:
220        self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
221        self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
222        self.host_sync_state: Dict[StreamId, SeqNum] = {}
223        self.create_stream(DEFAULT_STREAM_ID)
224
225    def _ensure_stream_exists(self, stream: StreamId) -> None:
226        if stream not in self.current_sync_states:
227            logger.info(
228                "Found Stream with id: %s, but no matching stream "
229                "creation in the trace. Backfilling the trace now. "
230                "Perhaps the sanitizer was enabled after some torch operations?",
231                stream,
232            )
233            self.create_stream(stream)
234
235    def _ensure_event_exists(self, event: EventId) -> None:
236        if event not in self.recorded_sync_states:
237            logger.info(
238                "Found Event with id: %s, but no matching event "
239                "creation in the trace. Backfilling the trace now. "
240                "Perhaps the sanitizer was enabled after some torch operations?",
241                event,
242            )
243            self.create_event(event)
244
245    def _ensure_event_does_not_exist(self, event: EventId) -> None:
246        if event in self.recorded_sync_states:
247            logger.info(
248                "Found duplicate event creation in the trace for event with "
249                "id: %s. Assuming the trace for event deletion wasn't caught "
250                "and backfilling it now. "
251                "Perhaps the sanitizer was enabled after some torch operations?",
252                event,
253            )
254            self.delete_event(event)
255
256    def create_stream(self, stream: StreamId) -> None:
257        if stream in self.current_sync_states:
258            logger.info(
259                "Found duplicate Stream creation in the trace for Stream with "
260                "id: %s. PyTorch Streams are only created once, so this "
261                "trace entry is ignored.",
262                stream,
263            )
264        else:
265            self.host_sync_state[stream] = 0
266            self.current_sync_states[stream] = self.host_sync_state.copy()
267
268    def create_event(self, event: EventId) -> None:
269        self._ensure_event_does_not_exist(event)
270        self.recorded_sync_states[event] = {}
271
272    def delete_event(self, event: EventId) -> None:
273        self._ensure_event_exists(event)
274        del self.recorded_sync_states[event]
275
276    def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None:
277        self._ensure_stream_exists(stream)
278        self.current_sync_states[stream][stream] = seq_num
279
280    def record_state(self, event: EventId, stream: StreamId) -> None:
281        self._ensure_event_exists(event)
282        self._ensure_stream_exists(stream)
283        self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
284
285    def _state_wait_for_other(
286        self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
287    ) -> None:
288        for stream, seq_num in other.items():
289            state[stream] = max(state.get(stream, -1), seq_num)
290
291    def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None:
292        self._ensure_stream_exists(stream)
293        self._ensure_event_exists(event)
294        self._state_wait_for_other(
295            self.current_sync_states[stream], self.recorded_sync_states[event]
296        )
297
298    def all_streams_wait_for_event(self, event: EventId) -> None:
299        self._ensure_event_exists(event)
300        for stream in self.current_sync_states.keys():
301            self.stream_wait_for_event(stream, event)
302
303        self._state_wait_for_other(
304            self.host_sync_state, self.recorded_sync_states[event]
305        )
306
307    def all_streams_wait_for_stream(self, stream: StreamId) -> None:
308        self._ensure_stream_exists(stream)
309        for state in self.current_sync_states.values():
310            self._state_wait_for_other(state, self.current_sync_states[stream])
311
312        self._state_wait_for_other(
313            self.host_sync_state, self.current_sync_states[stream]
314        )
315
316    def sync_all_streams(self) -> None:
317        for stream, state in self.current_sync_states.items():
318            self.host_sync_state[stream] = state[stream]
319
320        for state in self.current_sync_states.values():
321            self._state_wait_for_other(state, self.host_sync_state)
322
323    def is_ordered_after(
324        self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId
325    ) -> bool:
326        self._ensure_stream_exists(current_stream)
327        self._ensure_stream_exists(other_stream)
328        return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1)
329
330
331class EventHandler:
332    """Analyzes CSAN trace for synchronization errors.
333
334    Stores information on each stream's synchronizations with other streams as well
335    as tensor accesses to determine whether a given kernel launch might cause a
336    data race.
337    """
338
339    def __init__(self) -> None:
340        self.tensors_accessed = _TensorsAccessed()
341        self.syncs = StreamSynchronizations()
342        self.seq_num: SeqNum = 0
343
344    def _handle_kernel_launch(
345        self,
346        stream: StreamId,
347        read_only: Set[DataPtr],
348        read_write: Set[DataPtr],
349        outputs: Set[DataPtr],
350        operator: str,
351        tensor_aliases: Dict[int, List[str]],
352    ) -> List[SynchronizationError]:
353        def check_conflict(
354            data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
355        ) -> None:
356            if previous_access is None:
357                return
358            if not self.syncs.is_ordered_after(
359                current_access.stream, previous_access.seq_num, previous_access.stream
360            ):
361                error_list.append(
362                    UnsynchronizedAccessError(
363                        data_ptr,
364                        self.tensors_accessed.get_allocation_stack_trace(data_ptr),
365                        current_access,
366                        previous_access,
367                    )
368                )
369
370        error_list: List[SynchronizationError] = []
371        self.seq_num += 1
372        self.syncs.update_seq_num(stream, self.seq_num)
373        stack_trace = traceback.StackSummary.extract(
374            traceback.walk_stack(inspect.currentframe()), lookup_lines=False
375        )
376        # The stack trace generated in this way is in the inverse order, so it must be
377        # reversed.
378        stack_trace.reverse()
379
380        for data_ptr in read_only:
381            self.tensors_accessed.ensure_tensor_exists(data_ptr)
382            current_access = Access(
383                AccessType.READ,
384                self.seq_num,
385                stream,
386                operator,
387                tensor_aliases[data_ptr],
388                data_ptr in outputs,
389                stack_trace,
390            )
391            check_conflict(
392                data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
393            )
394            self.tensors_accessed.add_read(data_ptr, current_access)
395
396        for data_ptr in read_write:
397            self.tensors_accessed.ensure_tensor_exists(data_ptr)
398            current_access = Access(
399                AccessType.WRITE,
400                self.seq_num,
401                stream,
402                operator,
403                tensor_aliases[data_ptr],
404                data_ptr in outputs,
405                stack_trace,
406            )
407            if self.tensors_accessed.were_there_reads_since_last_write(data_ptr):
408                for previous_access in self.tensors_accessed.get_reads(data_ptr):
409                    check_conflict(data_ptr, current_access, previous_access)
410            else:
411                check_conflict(
412                    data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
413                )
414            self.tensors_accessed.set_write(data_ptr, current_access)
415
416        return error_list
417
418    def _handle_event_creation(self, event: EventId) -> None:
419        self.syncs.create_event(event)
420
421    def _handle_event_deletion(self, event: EventId) -> None:
422        self.syncs.delete_event(event)
423
424    def _handle_event_record(self, event: EventId, stream: StreamId) -> None:
425        self.syncs.record_state(event, stream)
426
427    def _handle_event_wait(self, event: EventId, stream: StreamId) -> None:
428        self.syncs.stream_wait_for_event(stream, event)
429
430    def _handle_memory_allocation(self, data_ptr: DataPtr) -> None:
431        self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr)
432        stack_trace = traceback.StackSummary.extract(
433            traceback.walk_stack(inspect.currentframe()), lookup_lines=False
434        )
435        # The stack trace generated in this way is in the inverse order, so it must be
436        # reversed.
437        stack_trace.reverse()
438        self.tensors_accessed.create_tensor(
439            data_ptr,
440            stack_trace,
441        )
442
443    def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None:
444        self.tensors_accessed.ensure_tensor_exists(data_ptr)
445        self.tensors_accessed.delete_tensor(data_ptr)
446
447    def _handle_stream_creation(self, stream: StreamId) -> None:
448        self.syncs.create_stream(stream)
449
450    def _handle_device_synchronization(self) -> None:
451        self.syncs.sync_all_streams()
452
453    def _handle_stream_synchronization(self, stream: StreamId) -> None:
454        self.syncs.all_streams_wait_for_stream(stream)
455
456    def _handle_event_synchronization(self, event: EventId) -> None:
457        self.syncs.all_streams_wait_for_event(event)
458
459
460def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
461    for arg, value in a.items():
462        if arg in b:
463            yield arg, value, b[arg]
464
465
466def zip_arguments(
467    schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
468) -> Iterator[Tuple[torch.Argument, Any]]:
469    schema_args = schema.arguments[: len(args)]
470    schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
471
472    yield from zip(schema_args, args)
473
474    for _, argument, value in zip_by_key(schema_kwargs, kwargs):
475        yield (argument, value)
476
477
478class ArgumentHandler:
479    def __init__(self) -> None:
480        self.dataptrs_read: Set[DataPtr] = set()
481        self.dataptrs_written: Set[DataPtr] = set()
482        self.tensor_aliases: Dict[DataPtr, List[str]] = {}
483        self.outputs: Set[DataPtr] = set()
484
485    def _handle_argument(
486        self,
487        value: Any,
488        is_write: bool,
489        name: Optional[str] = None,
490        is_output: bool = False,
491    ) -> None:
492        if isinstance(value, torch.Tensor) and value.is_cuda:
493            data_ptr = value.data_ptr()
494            if is_write:
495                self.dataptrs_written.add(data_ptr)
496            else:
497                self.dataptrs_read.add(data_ptr)
498
499            self.tensor_aliases.setdefault(data_ptr, [])
500            if name is not None:
501                self.tensor_aliases[data_ptr].append(name)
502            if is_output:
503                self.outputs.add(data_ptr)
504
505    def parse_inputs(
506        self,
507        schema: torch.FunctionSchema,
508        args: Tuple[Any, ...],
509        kwargs: Dict[str, Any],
510    ) -> None:
511        for argument, value in zip_arguments(schema, args, kwargs):
512            is_write = argument.alias_info is not None and argument.alias_info.is_write
513            pytree.tree_map_(
514                functools.partial(
515                    self._handle_argument, is_write=is_write, name=argument.name
516                ),
517                value,
518            )
519
520    def parse_outputs(self, outputs: Any) -> None:
521        pytree.tree_map_(
522            functools.partial(self._handle_argument, is_write=True, is_output=True),
523            outputs,
524        )
525
526
527class CUDASanitizerDispatchMode(TorchDispatchMode):
528    def __init__(self) -> None:
529        self.event_handler = EventHandler()
530        torch._C._activate_gpu_trace()
531        gpu_trace.register_callback_for_event_creation(
532            self.event_handler._handle_event_creation
533        )
534        gpu_trace.register_callback_for_event_deletion(
535            self.event_handler._handle_event_deletion
536        )
537        gpu_trace.register_callback_for_event_record(
538            self.event_handler._handle_event_record
539        )
540        gpu_trace.register_callback_for_event_wait(
541            self.event_handler._handle_event_wait
542        )
543        gpu_trace.register_callback_for_memory_allocation(
544            self.event_handler._handle_memory_allocation
545        )
546        gpu_trace.register_callback_for_memory_deallocation(
547            self.event_handler._handle_memory_deallocation
548        )
549        gpu_trace.register_callback_for_stream_creation(
550            self.event_handler._handle_stream_creation
551        )
552        gpu_trace.register_callback_for_device_synchronization(
553            self.event_handler._handle_device_synchronization
554        )
555        gpu_trace.register_callback_for_stream_synchronization(
556            self.event_handler._handle_stream_synchronization
557        )
558        gpu_trace.register_callback_for_event_synchronization(
559            self.event_handler._handle_event_synchronization
560        )
561
562    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
563        if kwargs is None:
564            kwargs = {}
565
566        argument_handler = ArgumentHandler()
567        argument_handler.parse_inputs(func._schema, args, kwargs)
568
569        outputs = func(*args, **kwargs)
570
571        argument_handler.parse_outputs(outputs)
572        errors = self.event_handler._handle_kernel_launch(
573            torch.cuda.current_stream().cuda_stream,
574            argument_handler.dataptrs_read - argument_handler.dataptrs_written,
575            argument_handler.dataptrs_written,
576            argument_handler.outputs,
577            func._schema,
578            argument_handler.tensor_aliases,
579        )
580        if errors:
581            for error in errors:
582                print(error, file=sys.stderr)
583            raise CUDASanitizerErrors(errors)
584
585        return outputs
586
587
588class CUDASanitizer:
589    """Manages the lifetime of a CUDASanitizer dispatch mode object.
590
591    The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode
592    context manager in the enable function/destructor, respectively. This is to
593    explicitly set the lifetime of the dispatch mode object to that of the application.
594    This approach was deemed more elegant than using the atexit module.
595    """
596
597    def __init__(self) -> None:
598        self.dispatch = CUDASanitizerDispatchMode()
599        self.enabled = False
600
601    def enable(self):
602        self.dispatch.__enter__()
603        self.enabled = True
604
605    def __del__(self):
606        if self.enabled:
607            self.dispatch.__exit__(None, None, None)
608
609
610def enable_cuda_sanitizer():
611    """Enable CUDA Sanitizer.
612
613    The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions
614    for synchronization errors. All data races found will be printed to the standard
615    error output along with stack traces of suspected causes. For best results, the
616    sanitizer should be enabled at the very beginning of the program.
617    """
618    cuda_sanitizer.enable()
619
620
621cuda_sanitizer = CUDASanitizer()
622