1""" 2CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables, 3which share the same memory pool. Sharing a memory pool is an extremely 4important optimization when chaining multiple CUDA graphs together, as it 5prevents you from needing to copy intermediate tensors from one graph to the 6next, and reduces overall memory usage by allowing dead memory from the first 7pool to be reused in the second. 8 9The standard graph/make_graph_callables support sharing memory pool, but 10with a lot of caveats. CUDA graph trees remove these restrictions: 11 12* Previously, if you recorded graphs A, B, you had to replay A, B in that 13 order. With CUDA graph trees, after replaying A, you can change your 14 mind and record/replay a different graph B'; we will support efficient 15 execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')). In 16 other words: we support arbitrary trees of CUDA graph operations, not just 17 sequences (this is why this feature is called CUDA graph trees.) 18 19* Previously, if you executed graph A, some non-CUDA graph code, and then 20 graph B, after executing graph B, it was not safe to retain any references 21 to intermediates produced by A. With CUDA graph trees, we track if any 22outputs of graph A are still live by the time graph B is run, and make 23 sure graph B doesn't clobber there memory when reusing the CUDA graphs 24 pool. You'll get a separate recording of B depending on what tensors 25 stay live or dead. 26 27CUDA graph trees are flexible enough to be used in Dynamo across graph breaks, 28which is their primary use case. 29 30The ability to switch from replay to record is fairly nontrivial: remember that 31when you replay a CUDA graph, you only replay CUDA operations; no CPU side state 32is updated. In particular, the CPU-side book-keeping for the allocator is not 33reconstructed. However, to record a new child CUDA graph, we must restore this 34book-keeping. This is what checkpoint pool state is used for. 35""" 36 37from __future__ import annotations 38 39import contextlib 40import dataclasses 41import functools 42import gc 43import itertools 44import operator 45import sys 46import threading 47import traceback 48import warnings 49import weakref 50from collections import defaultdict 51from enum import auto, Enum 52from typing import ( 53 Any, 54 Callable, 55 cast, 56 ContextManager, 57 Dict, 58 Generator, 59 Iterator, 60 List, 61 Optional, 62 Sequence, 63 Set, 64 Tuple, 65 Type, 66 TYPE_CHECKING, 67 TypeVar, 68 Union, 69) 70 71import torch.fx 72from torch import Tensor 73from torch._dynamo.mutation_guard import GenerationTracker 74from torch._dynamo.utils import counters, preserve_rng_state 75from torch._inductor.compile_fx import ( 76 align_inputs_from_check_idxs, 77 copy_misaligned_inputs, 78 get_expanded_dims, 79 get_input_idxs_to_check, 80 index_expanded_dims, 81 remove_unaligned_input_idxs, 82 static_input, 83) 84from torch._inductor.cudagraph_utils import ( 85 check_for_mutation, 86 CheckInvariantStatus, 87 FunctionID, 88 log_cudagraph_skip_and_bump_counter, 89 log_data_ptr_mismatch, 90 maybe_warning_due_to_dynamic_shape, 91 ModelType, 92 OutputType, 93 PlaceholderInfo, 94 WrappedFunction, 95) 96from torch.multiprocessing.reductions import StorageWeakRef 97from torch.storage import UntypedStorage 98from torch.utils import _pytree as pytree 99from torch.utils.weak import TensorWeakRef 100 101 102if TYPE_CHECKING: 103 from torch._inductor.utils import InputType 104 from torch.types import _bool 105 106StorageWeakRefPointer = int 107StorageDataPtr = int 108NBytes = int 109S = TypeVar("S", bound="StorageWeakRefWrapper") 110 111 112if torch.backends.cuda.is_built(): 113 from torch._C import ( 114 _cuda_CUDAAllocator_AllocatorState as AllocatorState, 115 _set_cached_tensors_enabled as _set_cached_tensors_enabled, 116 ) 117else: 118 119 class AllocatorState: # type: ignore[no-redef] 120 pass 121 122 def _set_cached_tensors_enabled(enabled: _bool) -> None: 123 pass 124 125 126log = torch._logging.getArtifactLogger(__name__, "cudagraphs") 127 128 129from . import config 130 131 132@dataclasses.dataclass(frozen=True) 133class GraphID: 134 "Unique counter of a cuda graph recording" 135 id: int 136 137 138def clear_cublass_cache() -> None: 139 """ 140 Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for 141 doing warmup within a CUDAGraph private pool because we do not want persistent allocations from 142 one one run to the next. When we begin a new run of a cudagraphs path (generation), all tensors 143 from the previous generation are freed. This frees them the memory pool, but not elsewhere. 144 A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated 145 in the next run. The memory would be in use in two places. 146 147 To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required 148 it will be allocated to the cudagraph private pool and accounted for in the allocator for the duration of the 149 program. There is no overhead to this on replay since cudagraphs removes allocation overhead. 150 """ 151 torch._C._cuda_clearCublasWorkspaces() 152 153 154@contextlib.contextmanager 155def clear_cublas_manager() -> Generator[None, None, None]: 156 "Context manager around clearing cublas caches that will clear on enter and exit" 157 clear_cublass_cache() 158 try: 159 yield 160 finally: 161 clear_cublass_cache() 162 163 164@contextlib.contextmanager 165def disable_conv_cache_emptying() -> Generator[None, None, None]: 166 prev = torch._C._cuda_get_conv_benchmark_empty_cache() 167 torch._C._cudnn_set_conv_benchmark_empty_cache(False) 168 try: 169 yield 170 finally: 171 torch._C._cudnn_set_conv_benchmark_empty_cache(prev) 172 173 174@contextlib.contextmanager 175def enable_history_recording() -> Generator[None, None, None]: 176 "Turns on history recording in the CUDA Caching Allocator" 177 enabled = torch._C._cuda_isHistoryEnabled() 178 try: 179 if not enabled: 180 torch.cuda.memory._record_memory_history() 181 yield 182 finally: 183 if not enabled: 184 torch.cuda.memory._record_memory_history(None) 185 186 187def get_history_recording() -> ContextManager[None]: 188 # TODO - remove, prevents cleanup 189 if not config.triton.cudagraph_trees_history_recording: 190 return contextlib.nullcontext() 191 return enable_history_recording() 192 193 194class TreeManagerContainer: 195 """ 196 Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator, 197 the tree and its corresponding memory pool should be kept alive as long as any outstanding 198 graph or tensor which is an output of a graph remains alive. 199 200 There is a single tree manager container per device. 201 202 The lifecycle of a tree_manager is: 203 - Is constructed, no graph, no fns, no tensors 204 - Tree manager is fetched, resulting in tree manager being allocated 205 - We generate a bunch of functions, calling add_strong_reference 206 - These functions die, calling finalize_reference 207 - When all the functions die, we finalize_tree_manager. 208 209 TODO: in the future, we would like to do the following once storage weak refs land 210 - We look for all the live storages and add references to THOSE 211 - We count as storages die 212 - All the storages are dead, we deallocate the tree manager 213 """ 214 215 def __init__(self, device_index: int) -> None: 216 # This class keeps a strong reference to tree_manager, 217 # but upon all other strong references to the tree_manager will reset it to None. 218 # We need a strong reference so that we can still access its attributes upon cleanup. 219 self.tree_manager: Optional[CUDAGraphTreeManager] = None 220 221 # Number of outstanding references to the current tree manager 222 self.live_cudagraphify_fns = 0 223 224 self.device_index = device_index 225 226 # Following two objects are only set in the case that Tensor outputs outlive 227 # the cudagraphify_fns. Reference to the Graph is needed to keep the private pool from 228 # deallocation. 229 self.live_storages_count = 0 230 self.graph: Optional[torch.cuda.CUDAGraph] = None 231 232 self.lock = threading.Lock() 233 234 def _finalize_tensor(self) -> None: 235 with self.lock: 236 self.live_storages_count -= 1 237 if self.live_storages_count == 0: 238 self.graph = None 239 240 # manager was used again after existing cleanup, 241 # we shouldnt set it to None 242 if self.live_cudagraphify_fns == 0: 243 self.tree_manager = None 244 245 def finalize_cudagraphify_fn(self) -> None: 246 with self.lock: 247 self.live_cudagraphify_fns -= 1 248 if self.live_cudagraphify_fns == 0: 249 self._finalize_tree_manager() 250 251 def _finalize_tree_manager(self) -> None: 252 assert self.lock.locked() 253 self.tree_manager = None 254 255 # TODO - when issue #91395 is landed, we can set a weakref on 256 # storages and trigger a deallocation when all outputs of the 257 # cudagraph are dead. 258 259 # live_storages = list( 260 # tree_manager.live_cudagraph_pool_storages_in_curr_execution() 261 # ) 262 263 # # Maintain reference to graph to keep tensors alive 264 # assert len(tree_manager.roots) > 0, "expected at least one use" 265 # root = next(tree_manager.get_roots()) 266 # self.graph = root.graph 267 # seen_storages = set() 268 # for stor in live_storages: 269 # if stor in seen_storages: 270 # continue 271 # seen_storages.add(stor) 272 # self.live_storages_count += 1 273 # . weakref.finalize(stor, self._finalize_tensor) 274 275 def add_strong_reference(self, fn: Callable[..., Any]) -> None: 276 with self.lock: 277 self.live_cudagraphify_fns += 1 278 279 weakref.finalize(fn, self.finalize_cudagraphify_fn) 280 281 def get_tree_manager(self) -> CUDAGraphTreeManager: 282 with self.lock: 283 if self.tree_manager is None: 284 self.tree_manager = CUDAGraphTreeManager(self.device_index) 285 return self.tree_manager 286 287 288local = threading.local() 289 290# one tree manager per device 291local.tree_manager_containers = {} 292local.tree_manager_locks = defaultdict(threading.Lock) 293 294 295# only incremented by user call of mark_step_begin 296class MarkStepBox: 297 mark_step_counter = 0 298 299 300# We need to register this as an object that will be copied over as TLS when new 301# threads are created in autograd 302torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_containers) 303torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks) 304 305 306def mark_step_begin() -> None: 307 "Indicates that a new iteration of inference or training is about to begin." 308 309 # iterate down to distinguish from GenerationTracking counter 310 MarkStepBox.mark_step_counter -= 1 311 312 313def reset_cudagraph_trees() -> None: 314 "Clear all cudagraph trees" 315 # see shutdown below for why this is necessary 316 container_dict = get_obj(local, "tree_manager_containers") 317 locks_dict = get_obj(local, "tree_manager_locks") 318 for device, lock in locks_dict.items(): 319 with lock: 320 container = container_dict.get(device) 321 if not container or not container.tree_manager: 322 continue 323 324 container.tree_manager.shutdown() 325 326 _set_cached_tensors_enabled(False) 327 container_dict.clear() 328 329 MarkStepBox.mark_step_counter = 0 330 331 332def get_obj(local: Any, attr_name: str) -> Any: 333 if hasattr(local, attr_name): 334 return getattr(local, attr_name) 335 else: 336 assert torch._C._is_key_in_tls(attr_name) 337 return torch._C._get_obj_in_tls(attr_name) 338 339 340def get_container(device_index: int) -> TreeManagerContainer: 341 container_dict = get_obj(local, "tree_manager_containers") 342 lock = get_obj(local, "tree_manager_locks")[device_index] 343 344 with lock: 345 if device_index not in container_dict: 346 container_dict[device_index] = TreeManagerContainer(device_index) 347 348 return container_dict[device_index] 349 350 351def get_manager( 352 device_index: int, create_if_none_exists: bool = True 353) -> Optional[CUDAGraphTreeManager]: 354 if create_if_none_exists: 355 return get_container(device_index).get_tree_manager() 356 return get_container(device_index).tree_manager 357 358 359def cudagraphify_impl( 360 model: ModelType, 361 inputs: List[InputType], 362 static_input_idxs: Sequence[int], 363 *args: Any, 364 **kwargs: Any, 365) -> ModelType: 366 fn_cache: Dict[Tuple[int, ...], Callable[..., Any]] = {} 367 368 # Detect int inputs: we need to index on these 369 int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)] 370 get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None 371 372 has_warn = False 373 374 del inputs 375 376 def deferred_cudagraphify(inputs: List[InputType]) -> OutputType: 377 nonlocal has_warn 378 379 int_key = get_ints(inputs) 380 fn = fn_cache.get(int_key) 381 if fn is not None: 382 return fn(inputs) 383 384 if int_key is None: 385 log.info("recording cudagraph tree for graph without symints") 386 else: 387 log.info("recording cudagraph tree for symint key %s", int_key) 388 389 if not has_warn: 390 has_warn = maybe_warning_due_to_dynamic_shape(fn_cache, int_key) 391 392 # first get indices we need to check to align, then update our static inputs, 393 # and finally copy 394 check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) 395 new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) 396 copy_misaligned_inputs(inputs, check_input_idxs) 397 398 fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs) 399 fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs) 400 fn_cache[int_key] = fn 401 402 return out 403 404 return deferred_cudagraphify 405 406 407def cudagraphify( 408 model: ModelType, 409 inputs: List[InputType], 410 static_input_idxs: Sequence[int] = (), 411 *, 412 device_index: int, 413 is_backward: bool, 414 is_inference: bool, 415 stack_traces: Optional[StackTraces] = None, 416 constants: Tuple[torch.Tensor, ...] = (), 417 placeholders: Tuple[PlaceholderInfo, ...] = (), 418 mutated_input_idxs: Tuple[int, ...] = (), 419) -> Tuple[ModelType, OutputType]: 420 manager = get_container(device_index).get_tree_manager() 421 assert not (is_backward and is_inference) 422 mode = ( 423 CompilationMode.BACKWARD 424 if is_backward 425 else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD) 426 ) 427 428 return manager.add_function( 429 model, 430 inputs, 431 static_input_idxs, 432 stack_traces, 433 mode, 434 constants, 435 placeholders, 436 mutated_input_idxs, 437 ) 438 439 440class StorageWeakRefWrapper: 441 """ 442 Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked. 443 """ 444 445 __slots__ = ["ref", "_data_ptr", "extra_ref_check"] 446 447 storage_ref: Optional[StorageWeakRef] 448 449 def __init__( 450 self, 451 inp: Union[Tensor, UntypedStorage], 452 extra_ref_check: Optional[Callable[[], bool]] = None, 453 ) -> None: 454 """ 455 extra_ref_check is an additional check we need to run to check if the 456 weak ref has expired. in checking storage use count we assume extra_ref_check 457 will hold an additional reference to the storage. 458 """ 459 if isinstance(inp, Tensor): 460 stor = inp.untyped_storage() 461 else: 462 assert isinstance(inp, UntypedStorage) 463 stor = inp 464 self.ref = StorageWeakRef(stor) 465 self._data_ptr = stor.data_ptr() 466 self.extra_ref_check = extra_ref_check 467 468 @classmethod 469 def from_weakref_and_data_ptr( 470 cls: Type[S], 471 cdata: Any, 472 data_ptr: int, 473 extra_ref_check: Optional[Callable[[], bool]] = None, 474 ) -> StorageWeakRefWrapper: 475 instance = cls.__new__(cls) 476 instance._data_ptr = data_ptr 477 instance.ref = StorageWeakRef.from_weakref(cdata) 478 instance.extra_ref_check = extra_ref_check 479 return instance 480 481 def __call__(self) -> Optional[StorageWeakRefPointer]: 482 if self.expired(): 483 return None 484 485 return self.ref.cdata 486 487 def swap_weakref(self, cdata: Any) -> None: 488 self.ref.__del__() 489 self.ref.cdata = cdata 490 491 def data_ptr(self) -> int: 492 "NB: returns the data ptr even if the storage has expired" 493 return self._data_ptr 494 495 def remove_extra_reference(self) -> None: 496 self.extra_ref_check = None 497 498 def expired(self) -> bool: 499 if self.extra_ref_check is not None and not self.extra_ref_check(): 500 return False 501 502 # if extra_ref_check is not None we expect an additional reference 503 stor_count = torch._C._storage_Use_Count(self.ref.cdata) 504 return (stor_count - (self.extra_ref_check is not None)) == 0 505 506 def __repr__(self) -> str: 507 if self.ref is None or self.ref.expired(): 508 return f"StorageWeakRefWrapper to {self.data_ptr()}; dead" 509 else: 510 return f"StorageWeakRefWrapper to {self.data_ptr()}; alive" 511 512 513def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool: 514 return maybe_deref(weak_ref) is not None 515 516 517def maybe_deref( 518 weak_ref: Optional[StorageWeakRefWrapper], 519) -> Optional[Tuple[StorageWeakRefPointer, int]]: 520 if weak_ref is None: 521 return None 522 r = weak_ref() 523 if r is None: 524 return None 525 # NB: r.data_ptr() does not necessarily equal weak_ref.data_ptr() 526 return r, weak_ref.data_ptr() 527 528 529@contextlib.contextmanager 530def _use_cuda_memory_pool_manager( 531 device: int, mem_pool: Tuple[int, int], stream: torch.cuda.Stream 532) -> Generator[None, None, None]: 533 """ 534 Context manager to use cuda graph pool for new allocations. If you use this manager 535 all cudagraph tensors in use should be reflected in the allocator or they will be overwritten. 536 existing_graph should already have been used in a capture, and the mem_pool must already exist, 537 because this manager will not preserve a reference to the pool which keeps it alive. 538 """ 539 torch.cuda.synchronize() 540 stream.wait_stream(torch.cuda.current_stream()) 541 542 with torch.cuda.stream(stream), torch.device(device): 543 torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool) 544 try: 545 yield 546 finally: 547 torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool) 548 torch._C._cuda_releasePool(device, mem_pool) 549 550 torch.cuda.current_stream().wait_stream(stream) 551 552 553def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]: 554 if not isinstance(t, torch.Tensor): 555 assert t is None 556 return None 557 return StorageWeakRefWrapper(t) 558 559 560# A path index of (depth, offset) indices into a graph that is `depth`` number of nodes from the root 561# at graph output offset 562PathOutputIndex = Tuple[int, int] 563 564# For each node in the path, for each output, is the output alive 565PathLiveness = List[List[bool]] 566 567StackTraces = List[Optional[str]] 568 569 570class CUDAWarmupNode: 571 """ 572 Simplified Wrapper around A CUDA Model that wraps outputs in storage refs and exposes 573 apis to get the live storages in the current chain of warmup. 574 575 A CUDAWarmupNode may have either CUDAGraphNode or CUDAWarmupNode as a parent, but may only have 576 CUDAWarmupNode as children, because we cannot record or execute with tensors which do not have stable 577 memory addresses. 578 579 CUDAWarmupNode and CUDAGraphNode have a number of differences that make it easier to use separate classes. 580 - Much of the CUDAGraphNode logic & initialization is based on the tensor properties of first recording. In the 581 first instance of warmup, these are not finalized yet. 582 - All Inputs to the RecordedFunction must be copied over to the cuda graph memory pool, this is unnecessary in warmup. 583 - CUDAWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler. 584 585 NB: this class and CUDAGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and 586 `self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility. 587 """ 588 589 def __init__( 590 self, 591 wrapped_function: WrappedFunction, 592 parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]], 593 cuda_graphs_pool: Tuple[int, int], 594 existing_cuda_graph: Optional[torch.cuda.CUDAGraph], 595 device_index: int, 596 stack_traces: Optional[StackTraces], 597 stream: torch.cuda.Stream, 598 already_warm: bool, 599 id: GraphID, 600 ) -> None: 601 self.wrapped_function = wrapped_function 602 self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent 603 self.cuda_graphs_pool = cuda_graphs_pool 604 self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = [] 605 self.tensor_weakrefs: List[Optional[TensorWeakRef]] = [] 606 self.existing_cuda_graph = existing_cuda_graph 607 self.has_run = False 608 self.device_index = device_index 609 self.stack_traces = stack_traces 610 self.stream = stream 611 self.already_warm = already_warm 612 self.id = id 613 614 def run(self, new_inputs: Any) -> OutputType: 615 assert not self.has_run, "Wrapped function should never be run twice" 616 617 # See: output_is_alias_of_persistent_static_inputs below. We should only be returning freshly created 618 # storages in path_live_weakrefs. 619 existing_path_data_ptrs = { 620 t.data_ptr() for t in self.path_live_weakrefs() if t() 621 } 622 623 def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]: 624 non_cudagraph_inps = [] 625 for t in itertools.chain(new_inputs, self.wrapped_function.constants): 626 if ( 627 isinstance(t, torch.Tensor) 628 and t.untyped_storage().data_ptr() not in existing_path_data_ptrs 629 ): 630 non_cudagraph_inps.append(weakref.ref(t.untyped_storage())) 631 return non_cudagraph_inps 632 633 non_cudagraph_inps_storages = get_non_cudagraph_inps() 634 635 if config.triton.slow_path_cudagraph_asserts and not self.already_warm: 636 refs = list(self.path_live_weakrefs()) 637 check_memory_pool(self.device_index, self.cuda_graphs_pool, refs) 638 639 with torch.cuda.device( 640 self.device_index 641 ), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager( 642 self.device_index, self.cuda_graphs_pool, self.stream 643 ), get_history_recording(): 644 out = self.wrapped_function.model(new_inputs) 645 646 # We need to know which outputs are allocated within the cudagraph pool 647 # so that we can deallocate them at the beginning of the next cudagraph step, 648 # and set their access to error. 649 # We use a weakref to the inputs storage, in case a block which was previously 650 # allocated to the general caching allocator pool gets reallocated to a private pool. 651 652 non_cudagraph_inps_storage_ptrs = set() 653 for storage in non_cudagraph_inps_storages: 654 s = storage() 655 if s is not None: 656 non_cudagraph_inps_storage_ptrs.add(s._cdata) 657 658 assert len(new_inputs) == 0 659 660 # sdpa returns cpu tensors when not recording cuda graph 661 def add_ref(o: Any) -> bool: 662 return ( 663 isinstance(o, torch.Tensor) 664 and o.is_cuda 665 and o.untyped_storage()._cdata not in non_cudagraph_inps_storage_ptrs 666 and o.untyped_storage().data_ptr() != 0 667 ) 668 669 self.outputs_weakrefs.extend( 670 [map_to_ref(o) if add_ref(o) else None for o in out] 671 ) 672 self.tensor_weakrefs.extend( 673 [TensorWeakRef(o) if add_ref(o) else None for o in out] 674 ) 675 676 if config.triton.slow_path_cudagraph_asserts and not self.already_warm: 677 out_refs = list(self.path_live_weakrefs()) 678 check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs) 679 680 return out 681 682 @property 683 def _path_from_root( 684 self, 685 ) -> Generator[Union[CUDAGraphNode, CUDAWarmupNode], None, None]: 686 nodes = [] 687 node: Union[CUDAGraphNode, CUDAWarmupNode] = self 688 while node: 689 nodes.append(node) 690 node = node.parent # type: ignore[assignment] 691 692 yield from reversed(nodes) 693 694 def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: 695 "Returns all live storages weakrefs that created by nodes in this path" 696 for node in self._path_from_root: 697 for output in node.outputs_weakrefs: 698 if is_live(output): 699 yield output # type: ignore[misc] 700 701 def all_outputs_are_dead(self) -> bool: 702 return not list(self.path_live_weakrefs()) 703 704 def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool: 705 for storage_weak_ref in self.path_live_weakrefs(): 706 if t.untyped_storage().data_ptr() == storage_weak_ref.data_ptr(): 707 return True 708 return False 709 710 711# Aliases for List that say what the indices denote 712InputList = List # input indexes 713OutputList = List # output indexes 714LevelList = List # levels (distance from root of tree) 715 716 717class OutputAliasInfo: 718 pass 719 720 721class _UnaliasedStorage(OutputAliasInfo): 722 "Singleton to mark that the graph output constructs a new alias or is None" 723 724 725UnaliasedStorage = _UnaliasedStorage() 726 727 728class AliasesPriorGraphOutput(OutputAliasInfo): 729 "Marks that the graph output aliases an output of a prior graph" 730 __slots__ = ["index"] 731 732 index: PathOutputIndex 733 734 def __init__(self, index: PathOutputIndex) -> None: 735 assert isinstance(index, tuple) 736 self.index = index 737 738 739class AliasesNewOutput(OutputAliasInfo): 740 "Marks that the graph output aliases an index in the new, returned outputs" 741 742 __slots__ = ["index"] 743 744 index: int 745 746 def __init__(self, index: int) -> None: 747 assert isinstance(index, int) 748 self.index = index 749 750 751class CUDAGraphNode: 752 """ 753 A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool 754 and are structured into a tree, where there is a single recording that can precede it (parent) and multiple 755 subsequent recordings that may follow (children). A node will have no parent if it is the first recording 756 in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which 757 would force a dependency. 758 759 On first recording, all of the live tensors in the current CUDA Graph Node path will be 760 reflected in the corresponding private pool. On subsequent executions, the caching allocator 761 is unaffected when the graph is replayed. 762 763 In order to support recording a subsequent cuda graph recording after execution of this graph, 764 we checkpoint the state of the memory pool so that it may later be resumed. 765 766 WrappedFunction should have already been warmed up prior to invocation. 767 768 See [setCheckpointPoolState] for further explanation, as well as 769 https://user-images.githubusercontent.com/13564/222815509-374f3400-f83d-4f7d-8fa6-4a092b3250bb.png 770 """ 771 772 def __init__( 773 self, 774 wrapped_function: WrappedFunction, 775 id: GraphID, 776 parent: Optional[CUDAGraphNode], 777 inputs: List[InputType], 778 cuda_graphs_pool: Tuple[int, int], 779 device_index: int, 780 stack_traces: Optional[StackTraces], 781 stream: torch.cuda.Stream, 782 ) -> None: 783 assert isinstance(inputs, (list, tuple)) 784 785 self.wrapped_function = wrapped_function 786 self.id = id 787 self.device = device_index 788 self.stack_traces = stack_traces 789 self.stream = stream 790 791 # Enable re-record a cudagraph when static tensor address changed. 792 # if not we should error when it changed. 793 self.rerecord_if_static_inputs_change = ( 794 torch._dynamo.config.inline_inbuilt_nn_modules 795 or torch._inductor.config.triton.cudagraph_support_input_mutation 796 ) 797 798 # if this is a root parent will be None. use weakref to prevent reference cycle 799 self._parent = weakref.ref(parent) if parent is not None else None 800 # reference to the shared memory pool for the entire cuda graphs tree 801 self.cuda_graphs_pool = cuda_graphs_pool 802 803 # A single wrapped function may be recorded multiple times if memory patterns or 804 # invariants change from one execution to the next 805 self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) 806 807 # StorageWeakRef maintains whether the Storage C++ object remains allocated, 808 # not whether the corresponding memory has been deallocated. In order 809 # to use them to track memory deallocations we must maintain a single StorageWeakRef 810 # for all Storages that reference that memory (even if we are constructing Storages 811 # that do not have a deallocator function). We maintain one single storage_cache 812 # as we execute any tree path. When we retrieve a storage from the cache we 813 # check that it is still alive, and we hash based on observed recording data ptr 814 # and storage cdata. 815 816 # we preserve a single reference to executed outputs that is then referenced 817 # in children to avoid children having to chase parent pointers in the hot path 818 # DO NOT reassign output_weakrefs, only call `clear()` 819 # Path is a series of nodes from root to the current node 820 self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = [] 821 self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [ 822 node.outputs_weakrefs for node in self._path_from_root 823 ] 824 self.path_stacktraces: LevelList[Optional[StackTraces]] = [ 825 node.stack_traces for node in self._path_from_root 826 ] 827 self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = [] 828 829 # tensors which are outputs of previous graphs in the tree 830 self.cudagraph_managed_idxs: List[int] = [ 831 idx 832 for idx, t in enumerate(inputs) 833 if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t) 834 ] 835 836 self.static_input_idxs: List[int] = list( 837 set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs) 838 ) 839 840 self.non_static_input_idx: LevelList[int] = [ 841 i for i in range(len(inputs)) if i not in self.static_input_idxs 842 ] 843 844 counters["inductor"]["cudagraph_recorded_non_static_inputs"] += len( 845 self.non_static_input_idx 846 ) 847 848 self.non_managed_static_input_idxs: LevelList[int] = [ 849 i 850 for i in wrapped_function.static_input_idxs 851 if i not in self.cudagraph_managed_idxs 852 ] 853 854 def maybe_get_static_data_ptr( 855 idx: int, 856 inputs: List[Union[torch.Tensor, int]], 857 static_input_idxs: List[int], 858 ) -> Optional[int]: 859 inp = inputs[idx] 860 if isinstance(inp, torch.Tensor) and idx in static_input_idxs: 861 return inp.data_ptr() 862 return None 863 864 self.static_input_data_ptrs: InputList[Optional[int]] = [ 865 maybe_get_static_data_ptr(i, inputs, self.static_input_idxs) 866 for i in range(len(inputs)) 867 ] 868 869 # When we checkpoint, and free generations, we will be manually freeing the outputs 870 # of CUDAGraphNodes. We should not be freeing parameters, not do we need to account for 871 # their liveness (they are static), so we need to compute which outputs are aliases of 872 # parameters. Some static inputs are saved tensors from the forward that die in the backward. 873 # Their locations are static but lifetimes are not. We only include the persistent static 874 # data ptrs below because the non persistent data ptrs may be outputs of this record and 875 # fresh allocations. 876 877 # precompute expanded dims to avoid computing in the hot path 878 self.expanded_dims: List[List[int]] = [ 879 get_expanded_dims(x) 880 if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs 881 else [] 882 for idx, x in enumerate(inputs) 883 ] 884 885 # For each node in path, which outputs were observed to be live 886 # before invoking graph recording, and after graph recording 887 self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = [] 888 self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = [] 889 890 # List of Tuples of (depth, output_index) that index into node at depth 891 # number of nodes from root and output_index of outputs. Will index into 892 # path_weakrefs. 893 self.expected_dead_indices_before_graph: List[PathOutputIndex] = [] 894 self.expected_dead_indices_after_graph: List[PathOutputIndex] = [] 895 896 # all live indices after graph recording 897 self.live_indices_after_graph: List[PathOutputIndex] = [] 898 899 if self.parent is not None: 900 previous_liveness = self.parent.recorded_liveness_after_graph 901 curr_liveness = self._get_liveness(self.path_weakrefs) 902 903 different_indices = self._get_different_indices( 904 previous_liveness, curr_liveness 905 ) 906 907 self.recorded_liveness_before_graph = curr_liveness 908 self.expected_dead_indices_before_graph = different_indices 909 910 recording_inputs = self._allocate_and_copy_recording_inputs(inputs) 911 # recording inputs will copy over memory, so we can free non recording inputs 912 inputs.clear() 913 del inputs 914 915 # graph used for recording model invocation 916 self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() 917 918 # we allocate non-static inputs within the same memory pool as the CUDAGraph 919 # which we will record the model with. For memory efficiency, it is important 920 # to reclaim the input memory when the inputs are no longer live. To accomplish this, 921 # we reconstruct tensors at the correct data pointers of our inputs which are 922 # non owning and do not prevent deallocation. On subsequent executions, input values 923 # will be copied over to these tensors. 924 self.reconstructed_inputs: List[InputType] = [ 925 self._reconstruct_from_tensor_metadata(self._tensor_metadata(x)) 926 if isinstance(x, torch.Tensor) 927 else x 928 for x in recording_inputs 929 ] 930 931 # DO THE RECORDING!!! 932 # We record the CUDA graph in the constructor of CUDAGraphNode, which 933 # gives you what the CPU side compute of the function would do. We 934 # don't throw the recording outputs away: their memory is 935 # correctly accounted for in the CUDAGraphs caching allocator. This 936 # means on the very FIRST run of the CUDA graph node, we can directly 937 # do more recording, because we have a valid caching allocator state. 938 # NB: This relies on run() being called immediately after the 939 # constructor, otherwise this optimization would not be valid. 940 941 # initialized below in _record 942 943 self.checkpointed_caching_state: Optional[AllocatorState] = None 944 945 # Output Storage Alias information, can be: 946 # - A new, unaliased storage, or the output is None 947 # - An alias of an output of a prior graph 948 # - An alias of an output already created in the reconstructed outputs 949 # This is None if the output in question is an int 950 self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = [] 951 952 # is the output Storage unaliased in subsequent outputs, of all subsequent paths 953 # if it is, we cached the output tensor and adjust storage liveness tracking to also 954 # check if the output tensor does not have an additional python reference. 955 # If a descendent node discovers it has an alias of a prior output, then the output 956 # will no longer be cached in the ancestor. 957 # The large majority of tensors are unaliased, and preserving aliased output tensors would add 958 # significant additional complexity with marginal gains 959 # The cached tensor outputs are added on the first execution, and cleared whenever we need 960 # to do subsequent recording 961 self.unaliased_in_all_paths: OutputList[bool] = [] 962 self.cached_tensor_outputs: OutputList[Optional[Tensor]] = [] 963 964 # if an output aliases a static, persistent input then the corresponding Tensor will 965 # be set here. These are different than cached tensors, because they are tensors that 966 # are aliases of parameters that are always live. 967 self.static_output_tensors: OutputList[Optional[Tensor]] = [] 968 969 # Cleared after recording 970 self.recording_outputs: Optional[OutputType] = self._record( 971 wrapped_function.model, recording_inputs 972 ) 973 self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = [] 974 975 # As with inputs, we do not want to keep the outputs permanently alive because that would prevent 976 # their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata 977 # needed to reconstruct instead. 978 assert self.recording_outputs is not None 979 for out in self.recording_outputs: 980 if isinstance(out, torch.Tensor): 981 self.outputs_metadata.append( 982 self._tensor_metadata(out, ignore_storage_offset=False) 983 ) 984 else: 985 assert isinstance(out, (int, type(None))), type(out) 986 self.outputs_metadata.append(out) 987 988 self.graph.replay() 989 990 def _copy_inputs_and_remove_from_src( 991 self, dsts: List[InputType], srcs: List[InputType] 992 ) -> None: 993 dst_tensors = [] 994 src_tensors = [] 995 for idx in self.non_static_input_idx: 996 if not isinstance(srcs[idx], torch.Tensor): 997 continue 998 expanded_dims = self.expanded_dims[idx] 999 dst_tensors.append(index_expanded_dims(dsts[idx], expanded_dims)) # type: ignore[arg-type] 1000 src_tensors.append(index_expanded_dims(srcs[idx], expanded_dims)) # type: ignore[arg-type] 1001 srcs[idx] = None # type: ignore[call-overload] 1002 # Fails on empty lists 1003 if dst_tensors: 1004 torch._foreach_copy_(dst_tensors, src_tensors) 1005 1006 def check_static_inputs_are_stable(self, new_inputs: List[InputType]) -> None: 1007 # avoid checking managed tensor static points since we already checked those in check_invariants 1008 if ( 1009 not self.rerecord_if_static_inputs_change 1010 and not torch._C._tensors_data_ptrs_at_indices_equal( 1011 new_inputs, # type: ignore[arg-type] 1012 self.static_input_data_ptrs, 1013 self.non_managed_static_input_idxs, 1014 ) 1015 ): 1016 # this should error 1017 error_msg = log_data_ptr_mismatch( 1018 self.wrapped_function.placeholders, 1019 new_inputs, 1020 self.static_input_data_ptrs, 1021 self.non_managed_static_input_idxs, 1022 CheckInvariantStatus.StaticInputIdxMismatch, 1023 ) 1024 torch._check(False, lambda: error_msg) 1025 1026 def run_first_inputs(self, new_inputs: List[InputType]) -> OutputType: 1027 if config.triton.fast_path_cudagraph_asserts: 1028 self.debug_check_invariants_before_invocation() 1029 1030 # graph is already invoked in the __init__ 1031 # inputs are copied over in _allocate_recording_inputs and subsequently cleared 1032 assert len(new_inputs) == 0 1033 outputs = self.recording_outputs 1034 self.recording_outputs = None 1035 assert outputs is not None 1036 return outputs 1037 1038 def run(self, new_inputs: List[InputType]) -> OutputType: 1039 self.check_static_inputs_are_stable(new_inputs) 1040 1041 self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs) 1042 new_inputs.clear() 1043 1044 self.run_graph() 1045 1046 outputs = self.reconstruct_outputs() 1047 1048 if config.triton.fast_path_cudagraph_asserts: 1049 self.debug_check_invariants_after_invocation() 1050 1051 if config.triton.force_cudagraph_sync: 1052 torch.cuda.synchronize() 1053 1054 # Reset this to run the check in the future 1055 self.static_inputs_stable = False 1056 1057 return outputs 1058 1059 def reconstruct_outputs(self) -> OutputType: 1060 "Reconstruct output tensors according to their saved metadata and alias information" 1061 1062 # Cached tensors will not yet be set on the first execution 1063 # They are also cleared in checkpointing, so if we checkpoint this node 1064 # and then execute it again we will need to repopulate cached tensors 1065 if not self.cached_tensor_outputs: 1066 self._initialize_cached_tensors() 1067 1068 outputs: OutputType = [] 1069 1070 for i, (storage_info, metadata) in enumerate( 1071 zip(self.output_storage_alias, self.outputs_metadata) 1072 ): 1073 if not isinstance(metadata, dict): # tensor metadata 1074 assert isinstance(metadata, (int, type(None))) 1075 outputs.append(metadata) 1076 continue 1077 1078 cached_t = self.cached_tensor_outputs[i] 1079 if cached_t is not None: 1080 # this output represents a fresh allocated tensor. 1081 # We return the same TensorImpl from run to run to avoid overhead. 1082 # autograd.Function will reset the Autograd meta of output tensors 1083 # as part of aot_autograd, but _backward_hooks are stored on tensors separately, 1084 # so we need to manually reset hooks. 1085 if cached_t._backward_hooks is not None: 1086 cached_t._backward_hooks = None 1087 1088 # No need to update weakrefs, already correctly initialized 1089 outputs.append(cached_t) 1090 continue 1091 1092 static_t = self.static_output_tensors[i] 1093 if static_t is not None: 1094 assert self.outputs_weakrefs[i] is None 1095 outputs.append(static_t) 1096 continue 1097 1098 storage = self.prepare_alias_info_for_tensor_construction( 1099 storage_info, metadata 1100 ) 1101 1102 if isinstance(storage, UntypedStorage) or storage is None: 1103 out = self._reconstruct_from_tensor_metadata(metadata, storage) 1104 else: 1105 assert isinstance(storage, int) 1106 out = self._reconstruct_from_tensor_metadata( 1107 metadata, cast(torch.Tensor, outputs[storage]).untyped_storage() 1108 ) 1109 1110 outputs.append(out) 1111 w = self.outputs_weakrefs[i] 1112 assert w is not None 1113 w.swap_weakref(out.untyped_storage()._weak_ref()) 1114 1115 return outputs 1116 1117 def prepare_alias_info_for_tensor_construction( 1118 self, 1119 out_alias_info: Optional[OutputAliasInfo], 1120 metadata: Union[Dict[str, Any], int, None], 1121 ) -> Union[UntypedStorage, None, int]: 1122 if ( 1123 isinstance(metadata, (int, type(None))) 1124 or out_alias_info is UnaliasedStorage 1125 ): 1126 return None 1127 1128 if isinstance(out_alias_info, AliasesPriorGraphOutput): 1129 depth, existing_output_index = out_alias_info.index 1130 ref = self.path_weakrefs[depth][existing_output_index] 1131 assert ref is not None 1132 return torch.UntypedStorage._new_with_weak_ptr(ref()) 1133 1134 assert isinstance(out_alias_info, AliasesNewOutput) 1135 return out_alias_info.index 1136 1137 def prepare_storages_for_construction( 1138 self, 1139 ) -> List[Union[UntypedStorage, None, int]]: 1140 output_storages = [] 1141 for output_storage_alias, metadata in zip( 1142 self.output_storage_alias, self.outputs_metadata 1143 ): 1144 output_storages.append( 1145 self.prepare_alias_info_for_tensor_construction( 1146 output_storage_alias, metadata 1147 ) 1148 ) 1149 1150 return output_storages 1151 1152 def run_graph(self) -> None: 1153 assert self.graph is not None 1154 self.graph.replay() 1155 1156 def all_outputs_are_dead(self) -> bool: 1157 "All outputs of the path from this node to its root are dead" 1158 for depth, output_index in self.live_indices_after_graph: 1159 if is_live(self.path_weakrefs[depth][output_index]): 1160 return False 1161 return True 1162 1163 def _record(self, model: ModelType, inputs: List[InputType]) -> OutputType: 1164 "Record the model" 1165 1166 def static_input_iter() -> Generator[torch.Tensor, None, None]: 1167 for i in self.wrapped_function.static_input_idxs: 1168 _inp = inputs[i] 1169 if isinstance( 1170 _inp, torch.Tensor 1171 ) and not self._is_cuda_graph_recorded_tensor(_inp): 1172 yield _inp 1173 1174 # see: output_is_alias_of_persistent_static_inputs above 1175 static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = { 1176 inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp) 1177 for inp in itertools.chain( 1178 static_input_iter(), self.wrapped_function.constants 1179 ) 1180 } 1181 1182 if config.triton.slow_path_cudagraph_asserts: 1183 # need to use parent live weakrefs because live_indices isnt set yet 1184 memory = ( 1185 [] if self.parent is None else list(self.parent.path_live_weakrefs()) 1186 ) 1187 memory += [ 1188 StorageWeakRefWrapper(elem) 1189 for i, elem in enumerate(inputs) 1190 if isinstance(elem, torch.Tensor) 1191 and i not in self.wrapped_function.static_input_idxs 1192 and elem.untyped_storage().data_ptr() != 0 1193 ] 1194 check_memory_pool(self.device, self.cuda_graphs_pool, memory) 1195 1196 with preserve_rng_state(), torch.cuda.device( 1197 self.device 1198 ), clear_cublas_manager(), torch.cuda.graph( 1199 self.graph, 1200 stream=self.stream, 1201 pool=self.cuda_graphs_pool, 1202 capture_error_mode="thread_local", 1203 ), get_history_recording(): 1204 static_outputs = model(inputs) 1205 1206 # running model should reclaim memory 1207 assert len(inputs) == 0 1208 1209 if not isinstance(static_outputs, (list, tuple)): 1210 static_outputs = (static_outputs,) 1211 1212 self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs) 1213 1214 return static_outputs 1215 1216 def _add_first_outputs( 1217 self, 1218 outputs: OutputType, 1219 static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper], 1220 ) -> None: 1221 "Add the outputs from the first invocation of the node and set up metadata" 1222 1223 # getting liveness before we have added the outputs to path, so the length 1224 # of the two lists is equal 1225 prev_liveness = self.recorded_liveness_before_graph 1226 curr_liveness = self._get_liveness(self.path_weakrefs) 1227 1228 delta = self._get_different_indices(prev_liveness, curr_liveness) 1229 self.expected_dead_indices_after_graph = delta 1230 1231 assert len(self.outputs_weakrefs) == 0 1232 # index from data pointer to index in outputs 1233 output_new_storages_index: Dict[StorageDataPtr, int] = {} 1234 1235 self.unaliased_in_all_paths = [False for _ in range(len(outputs))] 1236 self.static_output_tensors = [None for _ in range(len(outputs))] 1237 1238 for i, o in enumerate(outputs): 1239 if o is None or not isinstance(o, torch.Tensor): 1240 self.output_storage_alias.append(UnaliasedStorage) 1241 continue 1242 1243 torch._check( 1244 o.is_cuda or o.untyped_storage().data_ptr() == 0, 1245 lambda: ( 1246 "Expected all cuda outputs in cuda graph recording. Non cuda output " 1247 f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" 1248 ), 1249 ), 1250 1251 ref = static_input_persistent_storage_ptrs.get( 1252 o.untyped_storage().data_ptr(), None 1253 ) 1254 # also treat empty storages as static outputs because we do not need to manage their lifetime 1255 # and they should not participate in checkpointing 1256 is_empty_storage = o.untyped_storage().data_ptr() == 0 1257 if (ref and ref() is not None) or is_empty_storage: 1258 self.output_storage_alias.append(None) 1259 self.static_output_tensors[i] = o 1260 continue 1261 1262 path_ref = self._is_alias_of_live_recorded_tensor(o) 1263 if path_ref is not None: 1264 self._mark_prior_graph_output_as_aliased(path_ref) 1265 self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref)) 1266 continue 1267 1268 if o.untyped_storage().data_ptr() in output_new_storages_index: 1269 index = output_new_storages_index[o.untyped_storage().data_ptr()] 1270 self.unaliased_in_all_paths[index] = False 1271 self.output_storage_alias.append(AliasesNewOutput(index)) 1272 continue 1273 1274 output_new_storages_index[o.untyped_storage().data_ptr()] = i 1275 self.output_storage_alias.append(UnaliasedStorage) 1276 self.unaliased_in_all_paths[i] = True 1277 1278 if self.stack_traces is None: 1279 self.stack_traces = [None for _ in range(len(outputs))] 1280 else: 1281 assert len(self.stack_traces) == len( 1282 outputs 1283 ), "Wrong number of stack traces passed in" 1284 1285 assert not self.outputs_weakrefs 1286 for out, static_output_tensor in zip(outputs, self.static_output_tensors): 1287 if not isinstance(out, torch.Tensor) or static_output_tensor is not None: 1288 self.outputs_weakrefs.append(None) 1289 self.tensor_weakrefs.append(None) 1290 else: 1291 self.outputs_weakrefs.append(StorageWeakRefWrapper(out)) 1292 self.tensor_weakrefs.append(TensorWeakRef(out)) 1293 1294 self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs) 1295 self.checkpointed_caching_state = torch._C._cuda_getCheckpointState( 1296 self.device, self.cuda_graphs_pool 1297 ) 1298 1299 # now, get liveness with outputs added 1300 for depth in range(len(self.path_weakrefs)): 1301 for output_index in range(len(self.path_weakrefs[depth])): 1302 if is_live(self.path_weakrefs[depth][output_index]): 1303 self.live_indices_after_graph.append((depth, output_index)) 1304 1305 self.debug_check_invariants_after_invocation() 1306 if config.triton.slow_path_cudagraph_asserts: 1307 check_memory_pool( 1308 self.device, self.cuda_graphs_pool, list(self.path_live_weakrefs()) 1309 ) 1310 1311 def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex) -> None: 1312 "Remove a graph output from the unaliased, cached tensors in an ancestor node" 1313 depth, output_index = index 1314 node = list(self._path_from_root)[depth] 1315 node.unaliased_in_all_paths[output_index] = False 1316 x = self.path_weakrefs[depth][output_index] 1317 assert x is not None 1318 x.remove_extra_reference() 1319 1320 def _initialize_cached_tensors(self) -> None: 1321 # we should not be clearing output_weakrefs, and they should be set in the first 1322 # record run 1323 assert len(self.outputs_weakrefs) == len(self.outputs_metadata) 1324 1325 for i, (storage_info, metadata, make_cached) in enumerate( 1326 zip( 1327 self.output_storage_alias, 1328 self.outputs_metadata, 1329 self.unaliased_in_all_paths, 1330 ) 1331 ): 1332 if not make_cached: 1333 self.cached_tensor_outputs.append(None) 1334 continue 1335 1336 assert storage_info is UnaliasedStorage 1337 assert isinstance(metadata, dict) 1338 s = self.create_storage(metadata) 1339 out = self._reconstruct_from_tensor_metadata(metadata, storage=s) # type: ignore[arg-type] 1340 1341 # XXX: let autograd know that there will be an additional reference to the tensor 1342 # that can be ignored when deciding whether to do gradient buffer inplacing. 1343 # Otherwise, inplacing could differ between tracing and subsequent execution. 1344 # For some models we tested this led to inputs no longer being in cudagraph pools, 1345 # leading to spurious re-recordings. 1346 # It also tells AMP cache that even though the tensor impls cannot be cached 1347 # in dtype conversions. 1348 1349 torch._C._add_cached_tensor(out) 1350 1351 self_ref = weakref.ref(self) 1352 1353 # one reference in our array, and calling sys.getrefcount bumps the refcount by one 1354 def check_refcount(i: int) -> bool: 1355 self_loc = self_ref() 1356 if self_loc is None: 1357 return False 1358 return self_loc.get_output_refcount(i) == 2 1359 1360 check = functools.partial(check_refcount, i=i) 1361 1362 self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check) 1363 self.cached_tensor_outputs.append(out) 1364 1365 def get_output_refcount(self, index: int) -> int: 1366 return sys.getrefcount(self.cached_tensor_outputs[index]) 1367 1368 @property 1369 def parent(self) -> Optional[CUDAGraphNode]: 1370 "unwraps the weakref to _parent" 1371 return self._parent() if self._parent is not None else None 1372 1373 @property 1374 def _path_to_root(self) -> Generator[CUDAGraphNode, None, None]: 1375 "Returns all nodes in the path starting at self and ending at root" 1376 node = self 1377 while node: 1378 yield node 1379 node = node.parent # type: ignore[assignment] 1380 1381 @property 1382 def _path_from_root(self) -> Generator[CUDAGraphNode, None, None]: 1383 "Returns all nodes in the path starting at the root and ending at self" 1384 nodes = reversed(list(self._path_to_root)) 1385 yield from nodes 1386 1387 def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool: 1388 "Is this tensor an output of a node in this path" 1389 for output_refs in self.path_weakrefs: 1390 for storage_weak_ref in output_refs: 1391 if storage_weak_ref is None: 1392 continue 1393 # don't need to check liveness of storage since the cuda graph managed 1394 # memory is never released. 1395 data_ptr = storage_weak_ref.data_ptr() 1396 if t.untyped_storage().data_ptr() == data_ptr: 1397 return True 1398 1399 return False 1400 1401 def _is_alias_of_live_recorded_tensor( 1402 self, t: torch.Tensor 1403 ) -> Optional[PathOutputIndex]: 1404 for depth, output_refs in enumerate(self.path_weakrefs): 1405 for output_index, storage_ref in enumerate(output_refs): 1406 if (storage_and_ptr := maybe_deref(storage_ref)) is not None: 1407 storage, ptr = storage_and_ptr 1408 if ptr == t.untyped_storage().data_ptr(): 1409 return (depth, output_index) 1410 1411 return None 1412 1413 @staticmethod 1414 def _check_liveness( 1415 indices: List[PathOutputIndex], 1416 output_refs: List[List[Optional[StorageWeakRefWrapper]]], 1417 ) -> bool: 1418 "Check that all of the indices specified are dead references" 1419 for depth, output_index in indices: 1420 w = output_refs[depth][output_index] 1421 assert w is not None 1422 if w() is not None: 1423 return False 1424 return True 1425 1426 def add_child(self, function_id: FunctionID, node: CUDAGraphNode) -> None: 1427 "Adds node as a a child of self" 1428 self.children[function_id].append(node) 1429 1430 @staticmethod 1431 def _get_different_indices( 1432 prev: List[List[bool]], curr: List[List[bool]] 1433 ) -> List[PathOutputIndex]: 1434 "Find indices where the two lists differ." 1435 dead_indices = [] 1436 assert len(prev) <= len(curr) 1437 for i, (outputs1, outputs2) in enumerate(zip(prev, curr)): 1438 assert len(outputs1) == len(outputs2) 1439 for j, (output1, output2) in enumerate(zip(outputs1, outputs2)): 1440 if output1 != output2: 1441 dead_indices.append((i, j)) 1442 1443 return dead_indices 1444 1445 @staticmethod 1446 def _get_liveness( 1447 weakrefs: List[List[Optional[StorageWeakRefWrapper]]], 1448 ) -> List[List[bool]]: 1449 "Maps weakrefs to true if the reference is alive and false otherwise" 1450 if len(weakrefs) == 0: 1451 return [] 1452 1453 return [pytree.tree_map(is_live, outputs) for outputs in weakrefs] 1454 1455 def debug_assert_invariants( 1456 self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex] 1457 ) -> None: 1458 if not config.triton.fast_path_cudagraph_asserts: 1459 return 1460 1461 for i, node in enumerate(self._path_from_root): 1462 assert self.path_weakrefs[i] is node.outputs_weakrefs 1463 1464 nodes = list(self._path_from_root) 1465 1466 live_blocks = get_block_addrs(self.cuda_graphs_pool) 1467 1468 live_storage_data_ptrs = set() 1469 live_storage_weak_ptrs = set() 1470 1471 for depth, outputs_liveness in enumerate(expected_liveness): 1472 for output_idx, output_liveness in enumerate(outputs_liveness): 1473 # tensor can die early, but it can't be alive when it should be dead 1474 w = self.path_weakrefs[depth][output_idx] 1475 if (stor_weak_ptr_and_data_ptr := maybe_deref(w)) is not None: 1476 assert output_liveness 1477 stor_weak_ptr, stor_data_ptr = stor_weak_ptr_and_data_ptr 1478 assert (stor_data_ptr in live_storage_data_ptrs) == ( 1479 stor_weak_ptr in live_storage_weak_ptrs 1480 ) 1481 live_storage_data_ptrs.add(stor_data_ptr) 1482 live_storage_weak_ptrs.add(stor_weak_ptr) 1483 1484 is_persistent_alias = ( 1485 nodes[depth].static_output_tensors[output_idx] is not None 1486 ) 1487 1488 if is_persistent_alias: 1489 assert stor_data_ptr not in live_blocks 1490 1491 for depth, output_index in newly_dead: 1492 assert not is_live(self.path_weakrefs[depth][output_index]) 1493 1494 def debug_check_invariants_before_invocation(self) -> None: 1495 self.debug_assert_invariants( 1496 self.recorded_liveness_before_graph, self.expected_dead_indices_before_graph 1497 ) 1498 1499 def debug_check_invariants_after_invocation(self) -> None: 1500 self.debug_assert_invariants( 1501 self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph 1502 ) 1503 1504 def data_ptrs_dead_since_invocation(self) -> List[int]: 1505 """ 1506 Since this node was invoked, return data ptrs of all tensor outputs that have died 1507 in the current executing tree path. 1508 """ 1509 curr_liveness = self._get_liveness(self.path_weakrefs) 1510 _get_different_indices = self._get_different_indices( 1511 self.recorded_liveness_after_graph, curr_liveness 1512 ) 1513 1514 path = list(self._path_from_root) 1515 ptrs_to_deallocate = [] 1516 for depth, output_index in _get_different_indices: 1517 ptrs_to_deallocate.append( 1518 path[depth].outputs_metadata[output_index]["data_ptr"] # type: ignore[index] 1519 ) 1520 1521 return ptrs_to_deallocate 1522 1523 def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: 1524 for i, j in self.live_indices_after_graph: 1525 out = self.path_weakrefs[i][j] 1526 if out is not None and is_live(out): 1527 yield out 1528 1529 def remove_node_cached_tensors(self) -> None: 1530 for t in self.cached_tensor_outputs: 1531 if t is not None: 1532 torch._C._remove_cached_tensor(t) 1533 self.cached_tensor_outputs.clear() 1534 1535 for i, unaliased in enumerate(self.unaliased_in_all_paths): 1536 if unaliased: 1537 n = self.outputs_weakrefs[i] 1538 assert n is not None 1539 n.remove_extra_reference() 1540 1541 def remove_path_cached_tensors(self) -> None: 1542 for node in self._path_from_root: 1543 node.remove_node_cached_tensors() 1544 1545 def clear_path_state(self) -> None: 1546 "Clear the path state in this current executing node" 1547 # this doesnt actually do anything right now, leaving it as placeholder 1548 1549 @staticmethod 1550 def _tensor_metadata( 1551 x: torch.Tensor, ignore_storage_offset: bool = True 1552 ) -> Dict[str, Any]: 1553 assert isinstance(x, torch.Tensor) 1554 # We ignore the storage offset for inputs, but not for outputs 1555 # TODO: - should we make the storage resizable ? 1556 return { 1557 "nbytes": x.untyped_storage().nbytes(), 1558 "data_ptr": x.untyped_storage().data_ptr(), 1559 "size": x.shape, 1560 "stride": x.stride(), 1561 "dtype": x.dtype, 1562 "device": x.device, 1563 "storage_offset": x.storage_offset() if not ignore_storage_offset else 0, 1564 } 1565 1566 def _reconstruct_from_tensor_metadata( 1567 self, metadata: Dict[str, Any], storage: Optional[UntypedStorage] = None 1568 ) -> Tensor: 1569 s = self.create_storage(metadata) if storage is None else storage 1570 return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) # type: ignore[arg-type] 1571 1572 def create_storage(self, metadata: Dict[str, Any]) -> torch.types.Storage: 1573 return torch._C._construct_storage_from_data_pointer( 1574 metadata["data_ptr"], metadata["device"], metadata["nbytes"] 1575 ) 1576 1577 def _allocate_and_copy_recording_inputs( 1578 self, inputs: List[InputType] 1579 ) -> List[Union[torch.Tensor, int]]: 1580 """ 1581 Allocate inputs for non static, non cudagraph managed tensors in the memory pool 1582 and copy over the tensor values. 1583 """ 1584 1585 torch.cuda.synchronize() 1586 self.stream.wait_stream(torch.cuda.current_stream()) 1587 recording_inputs: List[InputType] = [] 1588 1589 with warnings.catch_warnings(record=True), torch.cuda.device( 1590 self.device 1591 ), _use_cuda_memory_pool_manager( 1592 self.device, 1593 mem_pool=self.cuda_graphs_pool, 1594 stream=self.stream, 1595 ): 1596 for i, inp in enumerate(inputs): 1597 if not isinstance(inp, torch.Tensor): 1598 assert isinstance(inp, int) 1599 recording_inputs.append(inp) 1600 elif i not in self.static_input_idxs: 1601 # static_input does an allocation! 1602 recording_inputs.append(static_input(inp)) 1603 else: 1604 recording_inputs.append(inp) 1605 1606 self._copy_inputs_and_remove_from_src(recording_inputs, inputs) 1607 1608 return recording_inputs 1609 1610 def check_invariants( 1611 self, inputs: List[InputType] 1612 ) -> Tuple[CheckInvariantStatus, Callable[..., str]]: 1613 """ 1614 Checks if this node can be run. The same pattern of tensor liveness, static inputs, 1615 and tensors managed in the cudagraph private pool must remain stable. 1616 """ 1617 1618 _logger = functools.partial( 1619 log_data_ptr_mismatch, 1620 self.wrapped_function.placeholders, 1621 inputs, 1622 self.static_input_data_ptrs, 1623 ) 1624 1625 # previously managed data pointers remain stable 1626 # this is on the hot path so moved to C++. equivalent to: 1627 # return all(t.data_ptr() == data_ptr for (t, data_ptr) in zip(tensors, data_ptrs)) 1628 if not torch._C._tensors_data_ptrs_at_indices_equal( 1629 inputs, # type: ignore[arg-type] 1630 self.static_input_data_ptrs, 1631 self.cudagraph_managed_idxs, 1632 ): 1633 status = CheckInvariantStatus.CudagraphManagedIdxMismatch 1634 _logger = functools.partial( 1635 _logger, 1636 self.cudagraph_managed_idxs, 1637 status, 1638 ) 1639 return status, _logger 1640 1641 if not self._check_liveness( 1642 self.expected_dead_indices_before_graph, self.path_weakrefs 1643 ): 1644 status = CheckInvariantStatus.ExpectedDeadIndicesBeforeGraphMismatch 1645 return status, lambda: f"{status}" 1646 1647 # static input data pointers should remain stable 1648 # if we are inlining builtin nn modules we re-record in this case 1649 # if we are not inlining builtin nn modules, we check this in check_static_inputs_are_stable 1650 # and error if they are not stable 1651 if ( 1652 self.rerecord_if_static_inputs_change 1653 and not torch._C._tensors_data_ptrs_at_indices_equal( 1654 inputs, # type: ignore[arg-type] 1655 self.static_input_data_ptrs, 1656 self.static_input_idxs, 1657 ) 1658 ): 1659 status = CheckInvariantStatus.StaticInputIdxMismatch 1660 _logger = functools.partial( 1661 _logger, 1662 self.static_input_idxs, 1663 status, 1664 ) 1665 return status, _logger 1666 1667 # the cudagraph managed tensors which died upon recording must also die upon 1668 # this invocation. it is too late to check after we've replayed the graph, 1669 # because we would have already written over their memory. 1670 for idx in self.cudagraph_managed_idxs: 1671 inputs[idx] = None # type: ignore[call-overload] 1672 1673 torch._check( 1674 self._check_liveness( 1675 self.expected_dead_indices_after_graph, self.path_weakrefs 1676 ), 1677 lambda: "TODO: graph recording observed an input tensor deallocate during graph " 1678 " recording that did not occur during replay. Please file an issue.", 1679 ) 1680 return CheckInvariantStatus.SUCCESS, lambda: f"{CheckInvariantStatus.SUCCESS}" 1681 1682 def num_descendants(self) -> int: 1683 "Total number of descendents of this node" 1684 num_desc = 0 1685 for children in self.children.values(): 1686 for child in children: 1687 num_desc += 1 1688 num_desc += child.num_descendants() 1689 return num_desc 1690 1691 1692def get_cudagraph_segments(pool_id: Tuple[int, int]) -> Any: 1693 segments = torch.cuda.memory_snapshot() 1694 return [segment for segment in segments if segment["segment_pool_id"] == pool_id] 1695 1696 1697def get_block_addrs(pool_id: Tuple[int, int], live_only: bool = True) -> List[int]: 1698 blocks = [] 1699 1700 for segment in get_cudagraph_segments(pool_id): 1701 addr = segment["address"] 1702 for block in segment["blocks"]: 1703 if block["state"] == "active_allocated" or not live_only: 1704 blocks.append(addr) 1705 1706 addr += block["size"] 1707 1708 return blocks 1709 1710 1711def format_tb(frames: List[Any]) -> str: 1712 formatted_traceback = [] 1713 1714 for entry in frames: 1715 formatted_traceback.append( 1716 traceback.FrameSummary(entry["filename"], entry["line"], entry["name"]) 1717 ) 1718 1719 return "".join(traceback.format_list(formatted_traceback)) 1720 1721 1722def check_memory_pool( 1723 device: int, 1724 pool_id: Tuple[int, int], 1725 live_storages_ptrs: List[StorageWeakRefWrapper], 1726) -> None: 1727 assert all( 1728 isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs 1729 ) # noqa: C419 1730 unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} 1731 1732 # check if there is a divergence first, then do the expensive snapshot call after 1733 # we know it will error 1734 if torch._C._cuda_checkPoolLiveAllocations(device, pool_id, unique_storages): 1735 return 1736 1737 # at this point we are past the fast-path. we have seen rare cases where a dead tensor is dead, 1738 # but hasn't been gc'd yet, and gives false positive for allocated_not_in_live_storages 1739 gc.collect() 1740 1741 segments = get_cudagraph_segments(pool_id) 1742 1743 allocated_not_in_live_storages = {} 1744 1745 for segment in segments: 1746 addr = segment["address"] 1747 for block in segment["blocks"]: 1748 if block["state"] == "active_allocated": 1749 if addr not in unique_storages: 1750 allocated_not_in_live_storages[addr] = block 1751 else: 1752 unique_storages.remove(addr) 1753 1754 addr += block["size"] 1755 1756 torch._check( 1757 len(unique_storages) == 0, 1758 lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}", 1759 ) 1760 1761 if len(allocated_not_in_live_storages) != 0: 1762 formatted = [] 1763 for dp, block in allocated_not_in_live_storages.items(): 1764 trace = format_tb(block.get("frames", [])) 1765 formatted.append(f"Data Pointer: {dp}, history: \n{trace}") 1766 formatted_s = "\n".join(formatted) 1767 msg = ( 1768 f"These live storage data ptrs are in the cudagraph pool but not " 1769 f"accounted for as an output of cudagraph trees: \n\n{formatted_s}" 1770 ) 1771 raise RuntimeError(msg) 1772 1773 1774class ExecutionState(Enum): 1775 """ 1776 Represents the state of the CUDAGraph Tree. Will be None if there is no live current memory allocated 1777 in the cuda graph pool. Otherwise will reflect the state of the most recently executed node. 1778 """ 1779 1780 NONE = auto() 1781 WARMUP = auto() 1782 RECORDING = auto() 1783 EXECUTION = auto() 1784 1785 1786class CompilationMode(Enum): 1787 FORWARD = auto() 1788 BACKWARD = auto() 1789 INFERENCE = auto() 1790 1791 1792class CUDAGraphTreeManager: 1793 """ 1794 Groups individual recordings or executions of cuda graphs into a tree of recordings, 1795 and checks required invariants, and manages warmups of graphs. 1796 1797 When graphs are recorded in the same tree, it enforces subsequent execution 1798 to follow the same order and have the same output tensor livespans. To remove 1799 unnecessary coupling of cuda graphs (and additional imposed invariants), 1800 the tree manager will end a currently recording tree whenever it is valid - when 1801 the memory pool no longer has any live allocations. 1802 1803 We ignore outputs from a previous generation that correspond to prior model outputs. 1804 Currently this is hardcoded `GenerationTracker.generation` tracked in torch dynamo. 1805 # TODO: make generation increment configurable, warn on overwrite. 1806 1807 We run graph warmups in the cudagraph memory pool and return the result on the first invocation 1808 of a function. For many models it is important to reclaim activations as you run the backward. 1809 If we were to warm up the model and keep an extra copy of the inputs around to subsequently 1810 use for recording, we would incur a memory penalty. Additionally, if we are part way through training 1811 your model and need to recompile, memory will be allocated to the cuda graph pool, so we run this 1812 warmup run in the cuda graph memory pool. As for recording, warm up needs the state of live tensors 1813 to be accurately reflected so we checkpoint the allocator state if we need to warm up following graph 1814 replay. 1815 """ 1816 1817 def __init__(self, device_index: int) -> None: 1818 # roots are functions which have no dependencies on an other node. I.e., 1819 # when they are first invoked, none of their inputs are outputs are outputs 1820 # of another node, nor are there any live outputs of another node whose 1821 # liveness would create a dependency. 1822 self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) 1823 1824 # mapping from function id to wrapped function 1825 self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {} 1826 1827 self.ids_to_stack_traces: Dict[FunctionID, Optional[StackTraces]] = {} 1828 1829 self.warmed_up_functions: Set[FunctionID] = set() 1830 # if we fail to increment generation, and are stuck warming up, 1831 # only warn on each function once 1832 self.warned_functions: Set[FunctionID] = set() 1833 torch._C._set_cached_tensors_enabled(True) 1834 1835 # warn only once if a function mutates inputs 1836 self.warned_mutation: Set[FunctionID] = set() 1837 1838 # NB: cuda caching allocator will remember the stream a segment is allocated to 1839 # and only allocate that segment to the same stream. we need to use a single stream 1840 # for all allocations to the memory pool, otherwise the allocations to separate streams 1841 # will not be reused; separate recordings would have use the same memory pool, but not 1842 # the same memory. 1843 1844 with torch.cuda.device(device_index): 1845 torch.cuda.synchronize() 1846 self.stream = torch.cuda.Stream() 1847 self.stream.wait_stream(torch.cuda.current_stream()) 1848 1849 # Keeps Memory Pool Alive 1850 self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() 1851 self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle() 1852 1853 with warnings.catch_warnings(record=True), torch.cuda.graph( 1854 self.graph, 1855 pool=self.cuda_graphs_thread_pool, 1856 stream=self.stream, 1857 capture_error_mode="thread_local", 1858 ): 1859 pass 1860 1861 self.graph_counter = itertools.count(0) 1862 self.func_counter = itertools.count(0) 1863 1864 # mapping from graph_id to (function id to mutation type hint) since we are 1865 # specializing on a particular combination of Parent Node -> Function ID. 1866 self.non_cudagraph_managed_mutation_hint: Dict[ 1867 Optional[GraphID], Dict[FunctionID, bool] 1868 ] = defaultdict(dict) 1869 self.warmup_node_counter = itertools.count(start=-1, step=-1) 1870 1871 # mapping from graph_id to (function id to re-record count). We fall back to 1872 # eager function if a function is re-recorded frequently on a node. 1873 self.num_rerecord: Dict[Optional[GraphID], Dict[FunctionID, int]] = defaultdict( 1874 lambda: defaultdict(lambda: 0) 1875 ) 1876 1877 # whether we the current node is in a state of warmup, recording, execution. If 1878 # there is no current node the state will be ExecutionState.None. 1879 self.path_state = ExecutionState.NONE 1880 self.device_index = device_index 1881 1882 # the most recently invoked cudagraph wrapping of a function. Will be None 1883 # when there is no output from a previous recording or execution whose memory 1884 # we need to respect in the cuda caching allocation. If you incremented generation, 1885 # this will also be none, as ignore those allocations. 1886 self.current_node: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = None 1887 1888 # current generation of cudagraph invocations. when torch.compile is run 1889 # we increment the current generation. are willing to ignore live outputs 1890 # of a previous generation in checking liveness. 1891 self.current_gen: int = -1 1892 1893 # number of instances we are in execution and failed to match to an 1894 # existing child 1895 self.debug_fail_counter = 0 1896 # number of instances we had to checkpoint the function 1897 self.debug_checkpointing_counter = 0 1898 1899 self.id_to_mode: Dict[FunctionID, CompilationMode] = {} 1900 1901 # Note: [Backward Generation Handling] 1902 # We generally perform a sequence of forward executions followed by backward executions. 1903 # If multiple torch.compile wrapped forwards are executed with their backwards pending, 1904 # we should not disregard the outputs from a prior torch.compile since the entire training 1905 # loop hasn't completed. Occasionally, a backward pass corresponding to a forward pass may 1906 # not be executed, so we cannot wait for all pending forward pass backward completions, so 1907 # we cannot wait for all backwards to have been invoked. Instead we wait for a single backward 1908 # invocation. Triggering a backward pass typically doesn't lead to another torch.compile 1909 # invocation, making it less likely for the generation to increase between multiple 1910 # backward calls. The following use case is covered by this approach: 1911 # mod1 = torch.compile(...) 1912 # mod2 = torch.compile(...) 1913 # mod2(mod1(x)).sum().backward() 1914 1915 self.running_forwards_with_pending_backwards = False 1916 1917 def run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType: 1918 assert self.graph is not None, "Running CUDAGraph after shutdown" 1919 out = self._run(new_inputs, function_id) 1920 1921 # The forwards are only pending following invocation, not before 1922 mode = self.id_to_mode[function_id] 1923 if mode == CompilationMode.FORWARD: 1924 self.running_forwards_with_pending_backwards = True 1925 elif mode == CompilationMode.BACKWARD: 1926 self.running_forwards_with_pending_backwards = False 1927 1928 return out 1929 1930 def set_to_running_backward(self) -> None: 1931 self.running_forwards_with_pending_backwards = False 1932 1933 def _get_cuda_graph_recorded_tensor_checker(self) -> Callable[[Tensor], bool]: 1934 return ( 1935 self.current_node._is_cuda_graph_recorded_tensor 1936 if isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode)) 1937 else lambda _: False 1938 ) 1939 1940 def new_warmup_node_id(self) -> GraphID: 1941 return GraphID(next(self.warmup_node_counter)) 1942 1943 def _update_non_cudagraph_managed_mutation( 1944 self, function_id: FunctionID, inputs: List[InputType] 1945 ) -> None: 1946 node_id = self._get_node_id() 1947 if maybe_mutation_str := check_for_mutation( 1948 self.ids_to_funcs[function_id], 1949 inputs, 1950 self._get_cuda_graph_recorded_tensor_checker(), 1951 ): 1952 self.non_cudagraph_managed_mutation_hint[node_id][function_id] = True 1953 # warn once per function_id 1954 if function_id in self.warned_mutation: 1955 return 1956 self.warned_mutation.add(function_id) 1957 log_cudagraph_skip_and_bump_counter(maybe_mutation_str) 1958 else: 1959 self.non_cudagraph_managed_mutation_hint[node_id][function_id] = False 1960 1961 def _get_node_id(self) -> Optional[GraphID]: 1962 if self.current_node is None: 1963 return None 1964 elif isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode)): 1965 return self.current_node.id 1966 else: 1967 raise RuntimeError(f"Unknown node type {type(self.current_node)}") 1968 1969 def exceed_rerecord_limit( 1970 self, node_id: Optional[GraphID], function_id: FunctionID 1971 ) -> bool: 1972 if torch._dynamo.config.inline_inbuilt_nn_modules: 1973 return False 1974 1975 return ( 1976 self.num_rerecord[node_id][function_id] 1977 > torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit 1978 ) 1979 1980 def _run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType: 1981 # we will try to end the current execution lazily, since 1982 # we dont want to do unnecessary checking of the existing outputs 1983 # on the hot path, but both recording and warmup only happen once 1984 # so we check up front 1985 if self.in_recording: 1986 self.try_end_curr_recording(function_id) 1987 1988 if self.in_warmup: 1989 self.try_end_curr_warmup(function_id) 1990 1991 node_id = self._get_node_id() 1992 if function_id not in self.non_cudagraph_managed_mutation_hint[node_id]: 1993 self._update_non_cudagraph_managed_mutation(function_id, new_inputs) 1994 1995 # Early exit if the function mutates inputs which are neither parameters/buffers nor 1996 # cudagraph recorded tensors. This check should happen after `try_end_curr_recording` 1997 # and `try_end_curr_warmup` which may change self.current_node. 1998 if self.non_cudagraph_managed_mutation_hint[node_id][ 1999 function_id 2000 ] or self.exceed_rerecord_limit(node_id, function_id): 2001 return self.ids_to_funcs[function_id].model(new_inputs) 2002 2003 # warming up a function and subsequentally recording may use different memory addresses 2004 # because both depend on the state of the caching allocator. if we warm up graph A, 2005 # then warm up graph B and make more allocations, the subsequent recording of A will not 2006 # necessarily use the same addresses as in the warm up. Thus any warm up of a node can only 2007 # be followed by warm up runs. 2008 if ( 2009 ( 2010 not ( 2011 function_id in self.warmed_up_functions 2012 or config.triton.skip_cudagraph_warmup 2013 ) 2014 ) 2015 or self.in_warmup 2016 or config.triton.force_cudagraphs_warmup 2017 ): 2018 # If we are in the middle of executing cuda graphs, then we need to checkpoint memory state. 2019 # Both Recording and Warmup will be reflected in the allocator and dont need changes 2020 if self.path_state == ExecutionState.EXECUTION: 2021 self.apply_checkpoint_execution_state_in_allocator() 2022 2023 return self.run_eager(new_inputs, function_id) 2024 2025 assert not isinstance(self.current_node, CUDAWarmupNode) 2026 child_nodes = ( 2027 self.roots if self.current_node is None else self.current_node.children 2028 ) 2029 2030 if not self.in_recording: 2031 unexpected_rerecord, unexpected_rerecord_reason = False, lambda: "" 2032 for child in child_nodes[function_id]: 2033 # here we are checking memory consistency between recording and execution, 2034 # as well as things like stability of tensor locations, etc 2035 # and other 2036 status, status_logger = child.check_invariants(new_inputs) 2037 if status == CheckInvariantStatus.SUCCESS: 2038 return self.execute_node(child, new_inputs) 2039 2040 if ( 2041 status == CheckInvariantStatus.StaticInputIdxMismatch 2042 or status == CheckInvariantStatus.CudagraphManagedIdxMismatch 2043 ): 2044 unexpected_rerecord = True 2045 unexpected_rerecord_reason = status_logger 2046 2047 # now that we know the new function can't be run as a child of the 2048 # current node, if it is a root, try to end the current execution. 2049 # as noted above, we want to do this lazily to avoid having to 2050 # check all existing outputs 2051 if self.current_node is not None and function_id in self.roots: 2052 self.try_end_curr_execution() 2053 2054 # run again to hit the root matching case which must succeed 2055 if self.current_node is None: 2056 return self.run(new_inputs, function_id) 2057 2058 if len(self.ids_to_funcs[function_id].mutated_input_idxs) > 0: 2059 self._update_non_cudagraph_managed_mutation(function_id, new_inputs) 2060 if self.non_cudagraph_managed_mutation_hint[self._get_node_id()][ 2061 function_id 2062 ]: 2063 return self.ids_to_funcs[function_id].model(new_inputs) 2064 2065 # nb: run before checkpointing because checkpointing is slow, and we will 2066 # be using the eager caching allocator pool which does not require live 2067 # accounting of tensors in cudagraph allocator 2068 if unexpected_rerecord: 2069 curr_node_id = self._get_node_id() 2070 self.num_rerecord[curr_node_id][function_id] += 1 2071 if self.exceed_rerecord_limit(curr_node_id, function_id): 2072 _id = curr_node_id.id if curr_node_id else None 2073 log_cudagraph_skip_and_bump_counter( 2074 f"skipping cudagraph due to function {function_id.id} exceeding max " 2075 f"re-recording limit " 2076 f"(={torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit}) " 2077 f"on cudagraph node {_id} due to {unexpected_rerecord_reason()}." 2078 ) 2079 return self.ids_to_funcs[function_id].model(new_inputs) 2080 2081 # at this point, we necessarily will do a new recording 2082 self.debug_fail_counter += 1 2083 2084 self.try_end_curr_execution() 2085 if self.current_node is not None: 2086 self.apply_checkpoint_execution_state_in_allocator() 2087 2088 # now, we are in a recording state ! 2089 return self.record_function(new_inputs, function_id) 2090 2091 def shutdown(self) -> None: 2092 """ 2093 Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn 2094 might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown 2095 to avoid a reference cycle. 2096 """ 2097 nodes = [] 2098 for roots in self.roots.values(): 2099 nodes.extend(roots) 2100 2101 while nodes: 2102 node = nodes.pop() 2103 for children in node.children.values(): 2104 nodes.extend(children) 2105 node.remove_node_cached_tensors() 2106 node.graph = None 2107 2108 self.graph = None 2109 self.roots = None # type: ignore[assignment] 2110 self.current_node = None 2111 2112 def record_function( 2113 self, new_inputs: List[InputType], function_id: FunctionID 2114 ) -> OutputType: 2115 assert not isinstance(self.current_node, CUDAWarmupNode) 2116 graph_id = self.new_graph_id() 2117 log.debug( 2118 "Recording function %d of graph recording id %d", 2119 function_id.id, 2120 graph_id.id, 2121 ) 2122 torch.cuda.synchronize() 2123 node = CUDAGraphNode( 2124 self.ids_to_funcs[function_id], 2125 graph_id, 2126 self.current_node, 2127 new_inputs, 2128 self.cuda_graphs_thread_pool, 2129 self.device_index, 2130 self.ids_to_stack_traces[function_id], 2131 self.stream, 2132 ) 2133 if self.current_node is None: 2134 self.roots[function_id].append(node) 2135 else: 2136 self.current_node.add_child(function_id, node) 2137 self.current_node = node 2138 self.path_state = ExecutionState.RECORDING 2139 self.update_generation() 2140 torch.cuda.synchronize() 2141 return node.run_first_inputs(new_inputs) 2142 2143 def execute_node( 2144 self, node: CUDAGraphNode, new_inputs: List[InputType] 2145 ) -> OutputType: 2146 self.current_node = node 2147 self.path_state = ExecutionState.EXECUTION 2148 self.update_generation() 2149 return node.run(new_inputs) 2150 2151 def run_eager( 2152 self, new_inputs: List[InputType], function_id: FunctionID 2153 ) -> OutputType: 2154 # this is only stored on current node, because when we start a new path, 2155 # we will deallocate it 2156 already_warm = function_id in self.warmed_up_functions 2157 if not already_warm: 2158 log.debug("Running warmup of function %d", function_id.id) 2159 else: 2160 log.debug( 2161 "Running eager of function %d because ancestor needed to warm up", 2162 function_id.id, 2163 ) 2164 self.warmed_up_functions.add(function_id) 2165 node = CUDAWarmupNode( 2166 self.ids_to_funcs[function_id], 2167 self.current_node, 2168 self.cuda_graphs_thread_pool, 2169 self.graph, 2170 self.device_index, 2171 self.ids_to_stack_traces[function_id], 2172 self.stream, 2173 already_warm, 2174 self.new_warmup_node_id(), 2175 ) 2176 self.current_node = node 2177 self.path_state = ExecutionState.WARMUP 2178 self.update_generation() 2179 return node.run(new_inputs) 2180 2181 def new_graph_id(self) -> GraphID: 2182 return GraphID(next(self.graph_counter)) 2183 2184 def new_func_id(self) -> FunctionID: 2185 return FunctionID(next(self.func_counter)) 2186 2187 def add_function( 2188 self, 2189 model: ModelType, 2190 inputs: List[InputType], 2191 static_input_idxs: Sequence[int], 2192 stack_traces: Optional[StackTraces], 2193 mode: CompilationMode, 2194 constants: Tuple[torch.Tensor, ...], 2195 placeholders: Tuple[PlaceholderInfo, ...], 2196 mutated_input_idxs: Tuple[int, ...], 2197 ) -> Tuple[ModelType, OutputType,]: 2198 id = self.new_func_id() 2199 self.ids_to_stack_traces[id] = stack_traces 2200 self.ids_to_funcs[id] = WrappedFunction( 2201 model, 2202 list(static_input_idxs), 2203 id, 2204 tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda), 2205 placeholders, 2206 mutated_input_idxs, 2207 ) 2208 self.id_to_mode[id] = mode 2209 fn = functools.partial(self.run, function_id=id) 2210 2211 # container needs to set clean up when fn dies 2212 get_container(self.device_index).add_strong_reference(fn) 2213 return fn, fn(inputs) 2214 2215 @property 2216 def in_recording(self) -> bool: 2217 return self.path_state == ExecutionState.RECORDING 2218 2219 @property 2220 def in_warmup(self) -> bool: 2221 return self.path_state == ExecutionState.WARMUP 2222 2223 def get_roots(self) -> Iterator[CUDAGraphNode]: 2224 for nodes in self.roots.values(): 2225 yield from nodes 2226 2227 @property 2228 def current_node(self) -> Optional[Union[CUDAGraphNode, CUDAWarmupNode]]: 2229 return self._current_node 2230 2231 @current_node.setter 2232 def current_node( 2233 self, value: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] 2234 ) -> None: 2235 self._current_node = value 2236 if value is None: 2237 self.path_state = ExecutionState.NONE 2238 2239 def update_generation(self) -> None: 2240 self.current_gen = self.get_curr_generation() 2241 2242 @staticmethod 2243 def get_curr_generation() -> int: 2244 if MarkStepBox.mark_step_counter != 0: 2245 return MarkStepBox.mark_step_counter 2246 2247 return GenerationTracker.generation 2248 2249 @staticmethod 2250 def user_invoked_mark_step() -> bool: 2251 return MarkStepBox.mark_step_counter != 0 2252 2253 def can_start_new_generation(self) -> bool: 2254 if not self.in_new_torch_compile_invocation(): 2255 return False 2256 2257 if self.user_invoked_mark_step(): 2258 return True 2259 2260 return not self.running_forwards_with_pending_backwards 2261 2262 def in_new_torch_compile_invocation(self) -> bool: 2263 return self.current_gen != self.get_curr_generation() 2264 2265 def try_end_curr_recording(self, function_id: FunctionID) -> None: 2266 """ 2267 Check if the current recording can be terminated, either because all outputs of the 2268 previously recorded node are dead or because it was executed in a different 2269 generation. Will set current_node to None and in_recording to False if successful. 2270 """ 2271 assert self.in_recording 2272 assert self.current_node is not None 2273 2274 # multiple invocations, allow overwriting the previous generation 2275 if self.can_start_new_generation(): 2276 self.dealloc_current_path_weakrefs() 2277 self.clear_current_path_state_and_set_to_none() 2278 return 2279 2280 if self.current_node.all_outputs_are_dead(): 2281 self.clear_current_path_state_and_set_to_none() 2282 return 2283 2284 self.check_warn_on_unable_to_start_executing(function_id) 2285 2286 def try_end_curr_execution(self) -> None: 2287 """ 2288 Check if the current executing node can be terminated, either because all outputs of the 2289 previously executed node are dead or because it was executed in a different generation. 2290 Will set current_node to None if successful. 2291 """ 2292 2293 assert not self.in_recording 2294 if self.current_node is None: 2295 return 2296 2297 if self.can_start_new_generation(): 2298 self.clear_current_path_state_and_set_to_none() 2299 return 2300 2301 if self.current_node.all_outputs_are_dead(): 2302 self.clear_current_path_state_and_set_to_none() 2303 2304 def try_end_curr_warmup(self, function_id: FunctionID) -> None: 2305 if self.can_start_new_generation(): 2306 self.dealloc_current_path_weakrefs() 2307 self.current_node = None 2308 return 2309 2310 assert self.current_node is not None 2311 if self.current_node.all_outputs_are_dead(): 2312 self.current_node = None 2313 return 2314 2315 self.check_warn_on_unable_to_start_executing(function_id) 2316 2317 def check_warn_on_unable_to_start_executing(self, function_id: FunctionID) -> None: 2318 "Warn if we in a potential loop where we are unable to hit fast path" 2319 if ( 2320 function_id in self.warned_functions 2321 or not self.in_new_torch_compile_invocation() 2322 ): 2323 return 2324 2325 assert self.current_node is not None 2326 existing_nodes = [ 2327 node 2328 for node in self.current_node._path_from_root 2329 if node.wrapped_function.id == function_id 2330 ] 2331 2332 if len(existing_nodes) <= 1: 2333 return 2334 2335 # repeated same pattern 2336 parents = { 2337 n.parent.wrapped_function.id 2338 for n in itertools.chain(existing_nodes, (self.current_node,)) 2339 if n.parent is not None 2340 } 2341 if len(parents) == len(existing_nodes): 2342 return 2343 2344 self.warned_functions.add(function_id) 2345 warnings.warn( 2346 "Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. " 2347 "Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() " 2348 "before each model invocation" 2349 ) 2350 2351 def dealloc_current_path_weakrefs(self) -> None: 2352 assert self.current_node is not None 2353 # TODO: we could also allow the these weak refs to continue to be allocated, 2354 # but that adds some complications. 2355 for node in self.current_node._path_from_root: 2356 assert node.stack_traces is not None 2357 assert len(node.tensor_weakrefs) == len(node.stack_traces) 2358 for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces): 2359 ten = None if t is None else t() 2360 if ten is None: 2361 continue 2362 2363 stack_trace = ( 2364 stack_trace.strip() 2365 if stack_trace 2366 else "[Could not find stack trace]" 2367 ) 2368 msg = ( 2369 "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. " 2370 f"Stack trace: {stack_trace}. " 2371 "To prevent overwriting, clone the tensor outside of torch.compile() " 2372 "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation." 2373 ) 2374 torch._C._set_storage_access_error_msg(ten, msg) 2375 2376 deleted = set() 2377 for storage_ref in self.current_node.path_live_weakrefs(): 2378 _storage_deref = storage_ref() 2379 if _storage_deref and storage_ref.data_ptr() not in deleted: 2380 deleted.add(storage_ref.data_ptr()) 2381 torch._C._free_And_Remove_DeleterFn(_storage_deref) 2382 2383 def clear_current_path_state_and_set_to_none(self) -> None: 2384 assert isinstance(self.current_node, CUDAGraphNode) 2385 self.current_node.clear_path_state() 2386 self.current_node = None 2387 2388 def apply_checkpoint_execution_state_in_allocator(self) -> None: 2389 """ 2390 Checkpoint the current execution state in the caching allocator so that 2391 additional cudagraph recordings can be made respecting existent live storages. 2392 """ 2393 assert isinstance(self.current_node, CUDAGraphNode) 2394 self.debug_checkpointing_counter += 1 2395 log.debug( 2396 "Checkpointing cuda caching allocator state. Number of checkpoints %d", 2397 self.debug_checkpointing_counter, 2398 ) 2399 2400 state = self.current_node.checkpointed_caching_state 2401 device = self.current_node.device 2402 assert state is not None and device is not None 2403 2404 # currently we deallocate on instead of allowing stale recordings 2405 stale_storages: List[int] = [] 2406 2407 # remove cached tensors, otherwise they would prevent memory from being 2408 # reclaimed in subsequent recordings 2409 self.current_node.remove_path_cached_tensors() 2410 live_storages_wrappers = list(self.current_node.path_live_weakrefs()) 2411 2412 # path_live_weakrefs guarantees that t() will not be None 2413 live_storages_weak_refs: list[int] = [t() for t in live_storages_wrappers] # type: ignore[misc] 2414 ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation() 2415 torch._C._cuda_setCheckpointPoolState( 2416 device, state, stale_storages, live_storages_weak_refs 2417 ) 2418 2419 # NB: deduplicate aliased outputs 2420 for ptr in set(ptrs_to_deallocate): 2421 torch._C._cuda_cudaCachingAllocator_raw_delete(ptr) 2422 2423 # Now the live blocks should be exactly equal to the live storages in private pool 2424 if config.triton.slow_path_cudagraph_asserts: 2425 check_memory_pool( 2426 self.device_index, self.cuda_graphs_thread_pool, live_storages_wrappers 2427 ) 2428 for wrapper in live_storages_wrappers: 2429 storage_ptr = wrapper() 2430 assert storage_ptr is not None 2431 assert torch._C._has_Standard_Deleter(storage_ptr) 2432 assert wrapper.data_ptr() not in ptrs_to_deallocate 2433 2434 def live_cudagraph_pool_storages_in_curr_execution( 2435 self, 2436 ) -> List[StorageWeakRefPointer]: 2437 if self.current_node is None: 2438 return [] 2439 # explicitly ignoring previous recorded outputs from past path 2440 # path_live_weakrefs() guarantees that t() will not be None 2441 return [t() for t in self.current_node.path_live_weakrefs()] # type: ignore[misc] 2442