1# mypy: allow-untyped-defs 2r""" 3This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams. 4 5It stores information on accesses to tensors to determine if they are synchronized 6or not. When enabled in a python program and a possible data race is detected, a 7detailed warning will be printed and the program will exit. 8 9It can be enabled either by importing this module and calling 10:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER`` 11environment variable. 12""" 13 14import enum 15import functools 16import inspect 17import io 18import logging 19import sys 20import textwrap 21import traceback 22from dataclasses import dataclass, field 23from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar 24 25import torch 26import torch.cuda._gpu_trace as gpu_trace 27from torch.utils import _pytree as pytree 28from torch.utils._python_dispatch import TorchDispatchMode 29 30 31DEFAULT_STREAM_ID = 0 32 33TK = TypeVar("TK") 34TVa = TypeVar("TVa") 35TVb = TypeVar("TVb") 36 37DataPtr = int 38StreamId = int 39EventId = int 40SeqNum = int 41 42logger = logging.getLogger(__name__) 43 44 45class AccessType(enum.Enum): 46 READ = enum.auto() 47 WRITE = enum.auto() 48 49 def __str__(self): 50 return "reading from" if self is AccessType.READ else "writing to" 51 52 53@dataclass 54class Access: 55 r"""Stores information about a single access to a tensor by a kernel. 56 57 Args: 58 type: either AccessType.READ or AccessType.Write. 59 seq_num: the sequential number of the kernel performing the access. 60 stream: the stream id of the stream executing the kernel. 61 operator: the schema of the launched kernel, which lists the 62 arguments and return type. 63 aliases: the arguments in the schema this access corresponds to. 64 is_output: Whether the tensor was an output of the kernel. 65 stack_trace: the stack summary object captured during access. 66 """ 67 68 type: AccessType 69 seq_num: SeqNum 70 stream: StreamId 71 operator: str 72 aliases: List[str] 73 is_output: bool 74 stack_trace: traceback.StackSummary 75 76 77class SynchronizationError(Exception): 78 """Base class for errors detected by CUDA Sanitizer.""" 79 80 81class UnsynchronizedAccessError(SynchronizationError): 82 """Stores information about two unsynchronized accesses to one data pointer.""" 83 84 def __init__( 85 self, 86 data_ptr: DataPtr, 87 allocation_stack_trace: Optional[traceback.StackSummary], 88 current_access: Access, 89 previous_access: Access, 90 ): 91 self.data_ptr = data_ptr 92 self.allocation_stack_trace = allocation_stack_trace 93 self.current_access = current_access 94 self.previous_access = previous_access 95 96 def __str__(self): 97 def format_access(access: Access): 98 message.write(f"{access.operator}\n{access.type}") 99 if access.aliases: 100 message.write(" argument(s) " + ", ".join(access.aliases)) 101 if access.is_output: 102 message.write(", and to") 103 if access.is_output: 104 message.write(" the output") 105 message.write( 106 f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n" 107 ) 108 109 with io.StringIO() as message: 110 message.write( 111 textwrap.dedent( 112 f"""\ 113 ============================ 114 CSAN detected a possible data race on tensor with data pointer {self.data_ptr} 115 Access by stream {self.current_access.stream} during kernel: 116 """ 117 ) 118 ) 119 format_access(self.current_access) 120 121 message.write( 122 f"Previous access by stream {self.previous_access.stream} during kernel:\n" 123 ) 124 format_access(self.previous_access) 125 126 if self.allocation_stack_trace: 127 message.write( 128 "Tensor was allocated with stack trace:\n" 129 f"{''.join(self.allocation_stack_trace.format())}" 130 ) 131 else: 132 message.write("Trace for tensor allocation not found.") 133 return message.getvalue() 134 135 136class CUDASanitizerErrors(Exception): 137 """Wrapper class for errors reported by CUDA Sanitizer.""" 138 139 def __init__(self, errors: List[SynchronizationError]): 140 self.errors = errors 141 142 def __str__(self): 143 return f"detected {len(self.errors)} errors" 144 145 146@dataclass 147class TensorInfo: 148 r"""Stores information about a single tensor and recent accesses to it. 149 150 Args: 151 allocation_stack_trace: the stack summary object captured during tensor 152 allocation. Can be ``None`` if the allocation wasn't caught by CSAN. 153 reads: list of read accesses to the tensor that were performed since 154 the last write. 155 write: the last write access to the tensor. 156 """ 157 158 allocation_stack_trace: Optional[traceback.StackSummary] 159 reads: List[Access] = field(default_factory=list) 160 write: Optional[Access] = None 161 162 163class _TensorsAccessed: 164 def __init__(self) -> None: 165 self.accesses: Dict[DataPtr, TensorInfo] = {} 166 167 def ensure_tensor_exists(self, data_ptr: DataPtr) -> None: 168 if data_ptr not in self.accesses: 169 logger.info( 170 "Found tensor with pointer: %s, but no matching tensor " 171 "allocation in the trace. Backfilling the trace now. " 172 "Perhaps the sanitizer was enabled after some torch operations?", 173 data_ptr, 174 ) 175 self.create_tensor(data_ptr, None) 176 177 def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None: 178 if data_ptr in self.accesses: 179 logger.info( 180 "Found duplicate tensor allocation in the trace for tensor with " 181 "pointer: %s. Assuming the trace for tensor deallocation " 182 "wasn't caught and backfilling it now. " 183 "Perhaps the sanitizer was enabled after some torch operations?", 184 data_ptr, 185 ) 186 self.delete_tensor(data_ptr) 187 188 def create_tensor( 189 self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary] 190 ) -> None: 191 self.accesses[data_ptr] = TensorInfo(stack_trace) 192 193 def delete_tensor(self, data_ptr: DataPtr) -> None: 194 del self.accesses[data_ptr] 195 196 def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool: 197 return True if self.accesses[data_ptr].reads else False 198 199 def get_allocation_stack_trace( 200 self, data_ptr: DataPtr 201 ) -> Optional[traceback.StackSummary]: 202 return self.accesses[data_ptr].allocation_stack_trace 203 204 def get_write(self, data_ptr: DataPtr) -> Optional[Access]: 205 return self.accesses[data_ptr].write 206 207 def get_reads(self, data_ptr: DataPtr) -> List[Access]: 208 return self.accesses[data_ptr].reads 209 210 def add_read(self, data_ptr: DataPtr, access: Access) -> None: 211 self.accesses[data_ptr].reads.append(access) 212 213 def set_write(self, data_ptr: DataPtr, access: Access) -> None: 214 self.accesses[data_ptr].write = access 215 self.accesses[data_ptr].reads = [] 216 217 218class StreamSynchronizations: 219 def __init__(self) -> None: 220 self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {} 221 self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {} 222 self.host_sync_state: Dict[StreamId, SeqNum] = {} 223 self.create_stream(DEFAULT_STREAM_ID) 224 225 def _ensure_stream_exists(self, stream: StreamId) -> None: 226 if stream not in self.current_sync_states: 227 logger.info( 228 "Found Stream with id: %s, but no matching stream " 229 "creation in the trace. Backfilling the trace now. " 230 "Perhaps the sanitizer was enabled after some torch operations?", 231 stream, 232 ) 233 self.create_stream(stream) 234 235 def _ensure_event_exists(self, event: EventId) -> None: 236 if event not in self.recorded_sync_states: 237 logger.info( 238 "Found Event with id: %s, but no matching event " 239 "creation in the trace. Backfilling the trace now. " 240 "Perhaps the sanitizer was enabled after some torch operations?", 241 event, 242 ) 243 self.create_event(event) 244 245 def _ensure_event_does_not_exist(self, event: EventId) -> None: 246 if event in self.recorded_sync_states: 247 logger.info( 248 "Found duplicate event creation in the trace for event with " 249 "id: %s. Assuming the trace for event deletion wasn't caught " 250 "and backfilling it now. " 251 "Perhaps the sanitizer was enabled after some torch operations?", 252 event, 253 ) 254 self.delete_event(event) 255 256 def create_stream(self, stream: StreamId) -> None: 257 if stream in self.current_sync_states: 258 logger.info( 259 "Found duplicate Stream creation in the trace for Stream with " 260 "id: %s. PyTorch Streams are only created once, so this " 261 "trace entry is ignored.", 262 stream, 263 ) 264 else: 265 self.host_sync_state[stream] = 0 266 self.current_sync_states[stream] = self.host_sync_state.copy() 267 268 def create_event(self, event: EventId) -> None: 269 self._ensure_event_does_not_exist(event) 270 self.recorded_sync_states[event] = {} 271 272 def delete_event(self, event: EventId) -> None: 273 self._ensure_event_exists(event) 274 del self.recorded_sync_states[event] 275 276 def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None: 277 self._ensure_stream_exists(stream) 278 self.current_sync_states[stream][stream] = seq_num 279 280 def record_state(self, event: EventId, stream: StreamId) -> None: 281 self._ensure_event_exists(event) 282 self._ensure_stream_exists(stream) 283 self.recorded_sync_states[event] = self.current_sync_states[stream].copy() 284 285 def _state_wait_for_other( 286 self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum] 287 ) -> None: 288 for stream, seq_num in other.items(): 289 state[stream] = max(state.get(stream, -1), seq_num) 290 291 def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None: 292 self._ensure_stream_exists(stream) 293 self._ensure_event_exists(event) 294 self._state_wait_for_other( 295 self.current_sync_states[stream], self.recorded_sync_states[event] 296 ) 297 298 def all_streams_wait_for_event(self, event: EventId) -> None: 299 self._ensure_event_exists(event) 300 for stream in self.current_sync_states.keys(): 301 self.stream_wait_for_event(stream, event) 302 303 self._state_wait_for_other( 304 self.host_sync_state, self.recorded_sync_states[event] 305 ) 306 307 def all_streams_wait_for_stream(self, stream: StreamId) -> None: 308 self._ensure_stream_exists(stream) 309 for state in self.current_sync_states.values(): 310 self._state_wait_for_other(state, self.current_sync_states[stream]) 311 312 self._state_wait_for_other( 313 self.host_sync_state, self.current_sync_states[stream] 314 ) 315 316 def sync_all_streams(self) -> None: 317 for stream, state in self.current_sync_states.items(): 318 self.host_sync_state[stream] = state[stream] 319 320 for state in self.current_sync_states.values(): 321 self._state_wait_for_other(state, self.host_sync_state) 322 323 def is_ordered_after( 324 self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId 325 ) -> bool: 326 self._ensure_stream_exists(current_stream) 327 self._ensure_stream_exists(other_stream) 328 return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1) 329 330 331class EventHandler: 332 """Analyzes CSAN trace for synchronization errors. 333 334 Stores information on each stream's synchronizations with other streams as well 335 as tensor accesses to determine whether a given kernel launch might cause a 336 data race. 337 """ 338 339 def __init__(self) -> None: 340 self.tensors_accessed = _TensorsAccessed() 341 self.syncs = StreamSynchronizations() 342 self.seq_num: SeqNum = 0 343 344 def _handle_kernel_launch( 345 self, 346 stream: StreamId, 347 read_only: Set[DataPtr], 348 read_write: Set[DataPtr], 349 outputs: Set[DataPtr], 350 operator: str, 351 tensor_aliases: Dict[int, List[str]], 352 ) -> List[SynchronizationError]: 353 def check_conflict( 354 data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access] 355 ) -> None: 356 if previous_access is None: 357 return 358 if not self.syncs.is_ordered_after( 359 current_access.stream, previous_access.seq_num, previous_access.stream 360 ): 361 error_list.append( 362 UnsynchronizedAccessError( 363 data_ptr, 364 self.tensors_accessed.get_allocation_stack_trace(data_ptr), 365 current_access, 366 previous_access, 367 ) 368 ) 369 370 error_list: List[SynchronizationError] = [] 371 self.seq_num += 1 372 self.syncs.update_seq_num(stream, self.seq_num) 373 stack_trace = traceback.StackSummary.extract( 374 traceback.walk_stack(inspect.currentframe()), lookup_lines=False 375 ) 376 # The stack trace generated in this way is in the inverse order, so it must be 377 # reversed. 378 stack_trace.reverse() 379 380 for data_ptr in read_only: 381 self.tensors_accessed.ensure_tensor_exists(data_ptr) 382 current_access = Access( 383 AccessType.READ, 384 self.seq_num, 385 stream, 386 operator, 387 tensor_aliases[data_ptr], 388 data_ptr in outputs, 389 stack_trace, 390 ) 391 check_conflict( 392 data_ptr, current_access, self.tensors_accessed.get_write(data_ptr) 393 ) 394 self.tensors_accessed.add_read(data_ptr, current_access) 395 396 for data_ptr in read_write: 397 self.tensors_accessed.ensure_tensor_exists(data_ptr) 398 current_access = Access( 399 AccessType.WRITE, 400 self.seq_num, 401 stream, 402 operator, 403 tensor_aliases[data_ptr], 404 data_ptr in outputs, 405 stack_trace, 406 ) 407 if self.tensors_accessed.were_there_reads_since_last_write(data_ptr): 408 for previous_access in self.tensors_accessed.get_reads(data_ptr): 409 check_conflict(data_ptr, current_access, previous_access) 410 else: 411 check_conflict( 412 data_ptr, current_access, self.tensors_accessed.get_write(data_ptr) 413 ) 414 self.tensors_accessed.set_write(data_ptr, current_access) 415 416 return error_list 417 418 def _handle_event_creation(self, event: EventId) -> None: 419 self.syncs.create_event(event) 420 421 def _handle_event_deletion(self, event: EventId) -> None: 422 self.syncs.delete_event(event) 423 424 def _handle_event_record(self, event: EventId, stream: StreamId) -> None: 425 self.syncs.record_state(event, stream) 426 427 def _handle_event_wait(self, event: EventId, stream: StreamId) -> None: 428 self.syncs.stream_wait_for_event(stream, event) 429 430 def _handle_memory_allocation(self, data_ptr: DataPtr) -> None: 431 self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr) 432 stack_trace = traceback.StackSummary.extract( 433 traceback.walk_stack(inspect.currentframe()), lookup_lines=False 434 ) 435 # The stack trace generated in this way is in the inverse order, so it must be 436 # reversed. 437 stack_trace.reverse() 438 self.tensors_accessed.create_tensor( 439 data_ptr, 440 stack_trace, 441 ) 442 443 def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None: 444 self.tensors_accessed.ensure_tensor_exists(data_ptr) 445 self.tensors_accessed.delete_tensor(data_ptr) 446 447 def _handle_stream_creation(self, stream: StreamId) -> None: 448 self.syncs.create_stream(stream) 449 450 def _handle_device_synchronization(self) -> None: 451 self.syncs.sync_all_streams() 452 453 def _handle_stream_synchronization(self, stream: StreamId) -> None: 454 self.syncs.all_streams_wait_for_stream(stream) 455 456 def _handle_event_synchronization(self, event: EventId) -> None: 457 self.syncs.all_streams_wait_for_event(event) 458 459 460def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]: 461 for arg, value in a.items(): 462 if arg in b: 463 yield arg, value, b[arg] 464 465 466def zip_arguments( 467 schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any] 468) -> Iterator[Tuple[torch.Argument, Any]]: 469 schema_args = schema.arguments[: len(args)] 470 schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]} 471 472 yield from zip(schema_args, args) 473 474 for _, argument, value in zip_by_key(schema_kwargs, kwargs): 475 yield (argument, value) 476 477 478class ArgumentHandler: 479 def __init__(self) -> None: 480 self.dataptrs_read: Set[DataPtr] = set() 481 self.dataptrs_written: Set[DataPtr] = set() 482 self.tensor_aliases: Dict[DataPtr, List[str]] = {} 483 self.outputs: Set[DataPtr] = set() 484 485 def _handle_argument( 486 self, 487 value: Any, 488 is_write: bool, 489 name: Optional[str] = None, 490 is_output: bool = False, 491 ) -> None: 492 if isinstance(value, torch.Tensor) and value.is_cuda: 493 data_ptr = value.data_ptr() 494 if is_write: 495 self.dataptrs_written.add(data_ptr) 496 else: 497 self.dataptrs_read.add(data_ptr) 498 499 self.tensor_aliases.setdefault(data_ptr, []) 500 if name is not None: 501 self.tensor_aliases[data_ptr].append(name) 502 if is_output: 503 self.outputs.add(data_ptr) 504 505 def parse_inputs( 506 self, 507 schema: torch.FunctionSchema, 508 args: Tuple[Any, ...], 509 kwargs: Dict[str, Any], 510 ) -> None: 511 for argument, value in zip_arguments(schema, args, kwargs): 512 is_write = argument.alias_info is not None and argument.alias_info.is_write 513 pytree.tree_map_( 514 functools.partial( 515 self._handle_argument, is_write=is_write, name=argument.name 516 ), 517 value, 518 ) 519 520 def parse_outputs(self, outputs: Any) -> None: 521 pytree.tree_map_( 522 functools.partial(self._handle_argument, is_write=True, is_output=True), 523 outputs, 524 ) 525 526 527class CUDASanitizerDispatchMode(TorchDispatchMode): 528 def __init__(self) -> None: 529 self.event_handler = EventHandler() 530 torch._C._activate_gpu_trace() 531 gpu_trace.register_callback_for_event_creation( 532 self.event_handler._handle_event_creation 533 ) 534 gpu_trace.register_callback_for_event_deletion( 535 self.event_handler._handle_event_deletion 536 ) 537 gpu_trace.register_callback_for_event_record( 538 self.event_handler._handle_event_record 539 ) 540 gpu_trace.register_callback_for_event_wait( 541 self.event_handler._handle_event_wait 542 ) 543 gpu_trace.register_callback_for_memory_allocation( 544 self.event_handler._handle_memory_allocation 545 ) 546 gpu_trace.register_callback_for_memory_deallocation( 547 self.event_handler._handle_memory_deallocation 548 ) 549 gpu_trace.register_callback_for_stream_creation( 550 self.event_handler._handle_stream_creation 551 ) 552 gpu_trace.register_callback_for_device_synchronization( 553 self.event_handler._handle_device_synchronization 554 ) 555 gpu_trace.register_callback_for_stream_synchronization( 556 self.event_handler._handle_stream_synchronization 557 ) 558 gpu_trace.register_callback_for_event_synchronization( 559 self.event_handler._handle_event_synchronization 560 ) 561 562 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 563 if kwargs is None: 564 kwargs = {} 565 566 argument_handler = ArgumentHandler() 567 argument_handler.parse_inputs(func._schema, args, kwargs) 568 569 outputs = func(*args, **kwargs) 570 571 argument_handler.parse_outputs(outputs) 572 errors = self.event_handler._handle_kernel_launch( 573 torch.cuda.current_stream().cuda_stream, 574 argument_handler.dataptrs_read - argument_handler.dataptrs_written, 575 argument_handler.dataptrs_written, 576 argument_handler.outputs, 577 func._schema, 578 argument_handler.tensor_aliases, 579 ) 580 if errors: 581 for error in errors: 582 print(error, file=sys.stderr) 583 raise CUDASanitizerErrors(errors) 584 585 return outputs 586 587 588class CUDASanitizer: 589 """Manages the lifetime of a CUDASanitizer dispatch mode object. 590 591 The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode 592 context manager in the enable function/destructor, respectively. This is to 593 explicitly set the lifetime of the dispatch mode object to that of the application. 594 This approach was deemed more elegant than using the atexit module. 595 """ 596 597 def __init__(self) -> None: 598 self.dispatch = CUDASanitizerDispatchMode() 599 self.enabled = False 600 601 def enable(self): 602 self.dispatch.__enter__() 603 self.enabled = True 604 605 def __del__(self): 606 if self.enabled: 607 self.dispatch.__exit__(None, None, None) 608 609 610def enable_cuda_sanitizer(): 611 """Enable CUDA Sanitizer. 612 613 The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions 614 for synchronization errors. All data races found will be printed to the standard 615 error output along with stack traces of suspected causes. For best results, the 616 sanitizer should be enabled at the very beginning of the program. 617 """ 618 cuda_sanitizer.enable() 619 620 621cuda_sanitizer = CUDASanitizer() 622