xref: /aosp_15_r20/external/pytorch/torch/_inductor/cudagraph_trees.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables,
3which share the same memory pool.  Sharing a memory pool is an extremely
4important optimization when chaining multiple CUDA graphs together, as it
5prevents you from needing to copy intermediate tensors from one graph to the
6next, and reduces overall memory usage by allowing dead memory from the first
7pool to be reused in the second.
8
9The standard graph/make_graph_callables support sharing memory pool, but
10with a lot of caveats.  CUDA graph trees remove these restrictions:
11
12* Previously, if you recorded graphs A, B, you had to replay A, B in that
13  order.  With CUDA graph trees, after replaying A, you can change your
14  mind and record/replay a different graph B'; we will support efficient
15  execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')).  In
16  other words: we support arbitrary trees of CUDA graph operations, not just
17  sequences (this is why this feature is called CUDA graph trees.)
18
19* Previously, if you executed graph A, some non-CUDA graph code, and then
20  graph B, after executing graph B, it was not safe to retain any references
21  to intermediates produced by A.  With CUDA graph trees, we track if any
22outputs of graph A are still live by the time graph B is run, and make
23  sure graph B doesn't clobber there memory when reusing the CUDA graphs
24  pool.  You'll get a separate recording of B depending on what tensors
25  stay live or dead.
26
27CUDA graph trees are flexible enough to be used in Dynamo across graph breaks,
28which is their primary use case.
29
30The ability to switch from replay to record is fairly nontrivial: remember that
31when you replay a CUDA graph, you only replay CUDA operations; no CPU side state
32is updated.  In particular, the CPU-side book-keeping for the allocator is not
33reconstructed.  However, to record a new child CUDA graph, we must restore this
34book-keeping.  This is what checkpoint pool state is used for.
35"""
36
37from __future__ import annotations
38
39import contextlib
40import dataclasses
41import functools
42import gc
43import itertools
44import operator
45import sys
46import threading
47import traceback
48import warnings
49import weakref
50from collections import defaultdict
51from enum import auto, Enum
52from typing import (
53    Any,
54    Callable,
55    cast,
56    ContextManager,
57    Dict,
58    Generator,
59    Iterator,
60    List,
61    Optional,
62    Sequence,
63    Set,
64    Tuple,
65    Type,
66    TYPE_CHECKING,
67    TypeVar,
68    Union,
69)
70
71import torch.fx
72from torch import Tensor
73from torch._dynamo.mutation_guard import GenerationTracker
74from torch._dynamo.utils import counters, preserve_rng_state
75from torch._inductor.compile_fx import (
76    align_inputs_from_check_idxs,
77    copy_misaligned_inputs,
78    get_expanded_dims,
79    get_input_idxs_to_check,
80    index_expanded_dims,
81    remove_unaligned_input_idxs,
82    static_input,
83)
84from torch._inductor.cudagraph_utils import (
85    check_for_mutation,
86    CheckInvariantStatus,
87    FunctionID,
88    log_cudagraph_skip_and_bump_counter,
89    log_data_ptr_mismatch,
90    maybe_warning_due_to_dynamic_shape,
91    ModelType,
92    OutputType,
93    PlaceholderInfo,
94    WrappedFunction,
95)
96from torch.multiprocessing.reductions import StorageWeakRef
97from torch.storage import UntypedStorage
98from torch.utils import _pytree as pytree
99from torch.utils.weak import TensorWeakRef
100
101
102if TYPE_CHECKING:
103    from torch._inductor.utils import InputType
104    from torch.types import _bool
105
106StorageWeakRefPointer = int
107StorageDataPtr = int
108NBytes = int
109S = TypeVar("S", bound="StorageWeakRefWrapper")
110
111
112if torch.backends.cuda.is_built():
113    from torch._C import (
114        _cuda_CUDAAllocator_AllocatorState as AllocatorState,
115        _set_cached_tensors_enabled as _set_cached_tensors_enabled,
116    )
117else:
118
119    class AllocatorState:  # type: ignore[no-redef]
120        pass
121
122    def _set_cached_tensors_enabled(enabled: _bool) -> None:
123        pass
124
125
126log = torch._logging.getArtifactLogger(__name__, "cudagraphs")
127
128
129from . import config
130
131
132@dataclasses.dataclass(frozen=True)
133class GraphID:
134    "Unique counter of a cuda graph recording"
135    id: int
136
137
138def clear_cublass_cache() -> None:
139    """
140    Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for
141    doing warmup within a CUDAGraph private pool because we do not want persistent allocations from
142    one one run to the next. When we begin a new run of a cudagraphs path (generation), all tensors
143    from the previous generation are freed. This frees them the memory pool, but not elsewhere.
144    A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated
145    in the next run. The memory would be in use in two places.
146
147    To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required
148    it will be allocated to the cudagraph private pool and accounted for in the allocator for the duration of the
149    program. There is no overhead to this on replay since cudagraphs removes allocation overhead.
150    """
151    torch._C._cuda_clearCublasWorkspaces()
152
153
154@contextlib.contextmanager
155def clear_cublas_manager() -> Generator[None, None, None]:
156    "Context manager around clearing cublas caches that will clear on enter and exit"
157    clear_cublass_cache()
158    try:
159        yield
160    finally:
161        clear_cublass_cache()
162
163
164@contextlib.contextmanager
165def disable_conv_cache_emptying() -> Generator[None, None, None]:
166    prev = torch._C._cuda_get_conv_benchmark_empty_cache()
167    torch._C._cudnn_set_conv_benchmark_empty_cache(False)
168    try:
169        yield
170    finally:
171        torch._C._cudnn_set_conv_benchmark_empty_cache(prev)
172
173
174@contextlib.contextmanager
175def enable_history_recording() -> Generator[None, None, None]:
176    "Turns on history recording in the CUDA Caching Allocator"
177    enabled = torch._C._cuda_isHistoryEnabled()
178    try:
179        if not enabled:
180            torch.cuda.memory._record_memory_history()
181        yield
182    finally:
183        if not enabled:
184            torch.cuda.memory._record_memory_history(None)
185
186
187def get_history_recording() -> ContextManager[None]:
188    # TODO - remove, prevents cleanup
189    if not config.triton.cudagraph_trees_history_recording:
190        return contextlib.nullcontext()
191    return enable_history_recording()
192
193
194class TreeManagerContainer:
195    """
196    Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator,
197    the tree and its corresponding memory pool should be kept alive as long as any outstanding
198    graph or tensor which is an output of a graph remains alive.
199
200    There is a single tree manager container per device.
201
202    The lifecycle of a tree_manager is:
203    -  Is constructed, no graph, no fns, no tensors
204    -  Tree manager is fetched, resulting in tree manager being allocated
205    -  We generate a bunch of functions, calling add_strong_reference
206    -  These functions die, calling finalize_reference
207    -  When all the functions die, we finalize_tree_manager.
208
209    TODO: in the future, we would like to do the following once storage weak refs land
210    -  We look for all the live storages and add references to THOSE
211    -  We count as storages die
212    -  All the storages are dead, we deallocate the tree manager
213    """
214
215    def __init__(self, device_index: int) -> None:
216        # This class keeps a strong reference to tree_manager,
217        # but upon all other strong references to the tree_manager will reset it to None.
218        # We need a strong reference so that we can still access its attributes upon cleanup.
219        self.tree_manager: Optional[CUDAGraphTreeManager] = None
220
221        # Number of outstanding references to the current tree manager
222        self.live_cudagraphify_fns = 0
223
224        self.device_index = device_index
225
226        # Following two objects are only set in the case that Tensor outputs outlive
227        # the cudagraphify_fns. Reference to the Graph is needed to keep the private pool from
228        # deallocation.
229        self.live_storages_count = 0
230        self.graph: Optional[torch.cuda.CUDAGraph] = None
231
232        self.lock = threading.Lock()
233
234    def _finalize_tensor(self) -> None:
235        with self.lock:
236            self.live_storages_count -= 1
237            if self.live_storages_count == 0:
238                self.graph = None
239
240                # manager was used again after existing cleanup,
241                # we shouldnt set it to None
242                if self.live_cudagraphify_fns == 0:
243                    self.tree_manager = None
244
245    def finalize_cudagraphify_fn(self) -> None:
246        with self.lock:
247            self.live_cudagraphify_fns -= 1
248            if self.live_cudagraphify_fns == 0:
249                self._finalize_tree_manager()
250
251    def _finalize_tree_manager(self) -> None:
252        assert self.lock.locked()
253        self.tree_manager = None
254
255        # TODO - when issue #91395 is landed, we can set a weakref on
256        # storages and trigger a deallocation when all outputs of the
257        # cudagraph are dead.
258
259        # live_storages = list(
260        #     tree_manager.live_cudagraph_pool_storages_in_curr_execution()
261        # )
262
263        # # Maintain reference to graph to keep tensors alive
264        # assert len(tree_manager.roots) > 0, "expected at least one use"
265        # root = next(tree_manager.get_roots())
266        # self.graph = root.graph
267        # seen_storages = set()
268        # for stor in live_storages:
269        #     if stor in seen_storages:
270        #         continue
271        #     seen_storages.add(stor)
272        #     self.live_storages_count += 1
273        # .   weakref.finalize(stor, self._finalize_tensor)
274
275    def add_strong_reference(self, fn: Callable[..., Any]) -> None:
276        with self.lock:
277            self.live_cudagraphify_fns += 1
278
279        weakref.finalize(fn, self.finalize_cudagraphify_fn)
280
281    def get_tree_manager(self) -> CUDAGraphTreeManager:
282        with self.lock:
283            if self.tree_manager is None:
284                self.tree_manager = CUDAGraphTreeManager(self.device_index)
285            return self.tree_manager
286
287
288local = threading.local()
289
290# one tree manager per device
291local.tree_manager_containers = {}
292local.tree_manager_locks = defaultdict(threading.Lock)
293
294
295# only incremented by user call of mark_step_begin
296class MarkStepBox:
297    mark_step_counter = 0
298
299
300# We need to register this as an object that will be copied over as TLS when new
301# threads are created in autograd
302torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_containers)
303torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks)
304
305
306def mark_step_begin() -> None:
307    "Indicates that a new iteration of inference or training is about to begin."
308
309    # iterate down to distinguish from GenerationTracking counter
310    MarkStepBox.mark_step_counter -= 1
311
312
313def reset_cudagraph_trees() -> None:
314    "Clear all cudagraph trees"
315    # see shutdown below for why this is necessary
316    container_dict = get_obj(local, "tree_manager_containers")
317    locks_dict = get_obj(local, "tree_manager_locks")
318    for device, lock in locks_dict.items():
319        with lock:
320            container = container_dict.get(device)
321            if not container or not container.tree_manager:
322                continue
323
324            container.tree_manager.shutdown()
325
326    _set_cached_tensors_enabled(False)
327    container_dict.clear()
328
329    MarkStepBox.mark_step_counter = 0
330
331
332def get_obj(local: Any, attr_name: str) -> Any:
333    if hasattr(local, attr_name):
334        return getattr(local, attr_name)
335    else:
336        assert torch._C._is_key_in_tls(attr_name)
337        return torch._C._get_obj_in_tls(attr_name)
338
339
340def get_container(device_index: int) -> TreeManagerContainer:
341    container_dict = get_obj(local, "tree_manager_containers")
342    lock = get_obj(local, "tree_manager_locks")[device_index]
343
344    with lock:
345        if device_index not in container_dict:
346            container_dict[device_index] = TreeManagerContainer(device_index)
347
348        return container_dict[device_index]
349
350
351def get_manager(
352    device_index: int, create_if_none_exists: bool = True
353) -> Optional[CUDAGraphTreeManager]:
354    if create_if_none_exists:
355        return get_container(device_index).get_tree_manager()
356    return get_container(device_index).tree_manager
357
358
359def cudagraphify_impl(
360    model: ModelType,
361    inputs: List[InputType],
362    static_input_idxs: Sequence[int],
363    *args: Any,
364    **kwargs: Any,
365) -> ModelType:
366    fn_cache: Dict[Tuple[int, ...], Callable[..., Any]] = {}
367
368    # Detect int inputs: we need to index on these
369    int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)]
370    get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None
371
372    has_warn = False
373
374    del inputs
375
376    def deferred_cudagraphify(inputs: List[InputType]) -> OutputType:
377        nonlocal has_warn
378
379        int_key = get_ints(inputs)
380        fn = fn_cache.get(int_key)
381        if fn is not None:
382            return fn(inputs)
383
384        if int_key is None:
385            log.info("recording cudagraph tree for graph without symints")
386        else:
387            log.info("recording cudagraph tree for symint key %s", int_key)
388
389        if not has_warn:
390            has_warn = maybe_warning_due_to_dynamic_shape(fn_cache, int_key)
391
392        # first get indices we need to check to align, then update our static inputs,
393        # and finally copy
394        check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
395        new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
396        copy_misaligned_inputs(inputs, check_input_idxs)
397
398        fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
399        fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs)
400        fn_cache[int_key] = fn
401
402        return out
403
404    return deferred_cudagraphify
405
406
407def cudagraphify(
408    model: ModelType,
409    inputs: List[InputType],
410    static_input_idxs: Sequence[int] = (),
411    *,
412    device_index: int,
413    is_backward: bool,
414    is_inference: bool,
415    stack_traces: Optional[StackTraces] = None,
416    constants: Tuple[torch.Tensor, ...] = (),
417    placeholders: Tuple[PlaceholderInfo, ...] = (),
418    mutated_input_idxs: Tuple[int, ...] = (),
419) -> Tuple[ModelType, OutputType]:
420    manager = get_container(device_index).get_tree_manager()
421    assert not (is_backward and is_inference)
422    mode = (
423        CompilationMode.BACKWARD
424        if is_backward
425        else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD)
426    )
427
428    return manager.add_function(
429        model,
430        inputs,
431        static_input_idxs,
432        stack_traces,
433        mode,
434        constants,
435        placeholders,
436        mutated_input_idxs,
437    )
438
439
440class StorageWeakRefWrapper:
441    """
442    Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked.
443    """
444
445    __slots__ = ["ref", "_data_ptr", "extra_ref_check"]
446
447    storage_ref: Optional[StorageWeakRef]
448
449    def __init__(
450        self,
451        inp: Union[Tensor, UntypedStorage],
452        extra_ref_check: Optional[Callable[[], bool]] = None,
453    ) -> None:
454        """
455        extra_ref_check is an additional check we need to run to check if the
456        weak ref has expired. in checking storage use count we assume extra_ref_check
457        will hold an additional reference to the storage.
458        """
459        if isinstance(inp, Tensor):
460            stor = inp.untyped_storage()
461        else:
462            assert isinstance(inp, UntypedStorage)
463            stor = inp
464        self.ref = StorageWeakRef(stor)
465        self._data_ptr = stor.data_ptr()
466        self.extra_ref_check = extra_ref_check
467
468    @classmethod
469    def from_weakref_and_data_ptr(
470        cls: Type[S],
471        cdata: Any,
472        data_ptr: int,
473        extra_ref_check: Optional[Callable[[], bool]] = None,
474    ) -> StorageWeakRefWrapper:
475        instance = cls.__new__(cls)
476        instance._data_ptr = data_ptr
477        instance.ref = StorageWeakRef.from_weakref(cdata)
478        instance.extra_ref_check = extra_ref_check
479        return instance
480
481    def __call__(self) -> Optional[StorageWeakRefPointer]:
482        if self.expired():
483            return None
484
485        return self.ref.cdata
486
487    def swap_weakref(self, cdata: Any) -> None:
488        self.ref.__del__()
489        self.ref.cdata = cdata
490
491    def data_ptr(self) -> int:
492        "NB: returns the data ptr even if the storage has expired"
493        return self._data_ptr
494
495    def remove_extra_reference(self) -> None:
496        self.extra_ref_check = None
497
498    def expired(self) -> bool:
499        if self.extra_ref_check is not None and not self.extra_ref_check():
500            return False
501
502        # if extra_ref_check is not None we expect an additional reference
503        stor_count = torch._C._storage_Use_Count(self.ref.cdata)
504        return (stor_count - (self.extra_ref_check is not None)) == 0
505
506    def __repr__(self) -> str:
507        if self.ref is None or self.ref.expired():
508            return f"StorageWeakRefWrapper to {self.data_ptr()}; dead"
509        else:
510            return f"StorageWeakRefWrapper to {self.data_ptr()}; alive"
511
512
513def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool:
514    return maybe_deref(weak_ref) is not None
515
516
517def maybe_deref(
518    weak_ref: Optional[StorageWeakRefWrapper],
519) -> Optional[Tuple[StorageWeakRefPointer, int]]:
520    if weak_ref is None:
521        return None
522    r = weak_ref()
523    if r is None:
524        return None
525    # NB: r.data_ptr() does not necessarily equal weak_ref.data_ptr()
526    return r, weak_ref.data_ptr()
527
528
529@contextlib.contextmanager
530def _use_cuda_memory_pool_manager(
531    device: int, mem_pool: Tuple[int, int], stream: torch.cuda.Stream
532) -> Generator[None, None, None]:
533    """
534    Context manager to use cuda graph pool for new allocations. If you use this manager
535    all cudagraph tensors in use should be reflected in the allocator or they will be overwritten.
536    existing_graph should already have been used in a capture, and the mem_pool must already exist,
537    because this manager will not preserve a reference to the pool which keeps it alive.
538    """
539    torch.cuda.synchronize()
540    stream.wait_stream(torch.cuda.current_stream())
541
542    with torch.cuda.stream(stream), torch.device(device):
543        torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
544        try:
545            yield
546        finally:
547            torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool)
548            torch._C._cuda_releasePool(device, mem_pool)
549
550    torch.cuda.current_stream().wait_stream(stream)
551
552
553def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
554    if not isinstance(t, torch.Tensor):
555        assert t is None
556        return None
557    return StorageWeakRefWrapper(t)
558
559
560# A path index of (depth, offset) indices into a graph that is `depth`` number of nodes from the root
561# at graph output offset
562PathOutputIndex = Tuple[int, int]
563
564# For each node in the path, for each output, is the output alive
565PathLiveness = List[List[bool]]
566
567StackTraces = List[Optional[str]]
568
569
570class CUDAWarmupNode:
571    """
572    Simplified Wrapper around A CUDA Model that wraps outputs in storage refs and exposes
573    apis to get the live storages in the current chain of warmup.
574
575    A CUDAWarmupNode may have either CUDAGraphNode or CUDAWarmupNode as a parent, but may only have
576    CUDAWarmupNode as children, because we cannot record or execute with tensors which do not have stable
577    memory addresses.
578
579    CUDAWarmupNode and CUDAGraphNode have a number of differences that make it easier to use separate classes.
580    - Much of the CUDAGraphNode logic & initialization is based on the tensor properties of first recording. In the
581    first instance of warmup, these are not finalized yet.
582    - All Inputs to the RecordedFunction must be copied over to the cuda graph memory pool, this is unnecessary in warmup.
583    - CUDAWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler.
584
585    NB: this class and CUDAGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and
586    `self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility.
587    """
588
589    def __init__(
590        self,
591        wrapped_function: WrappedFunction,
592        parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]],
593        cuda_graphs_pool: Tuple[int, int],
594        existing_cuda_graph: Optional[torch.cuda.CUDAGraph],
595        device_index: int,
596        stack_traces: Optional[StackTraces],
597        stream: torch.cuda.Stream,
598        already_warm: bool,
599        id: GraphID,
600    ) -> None:
601        self.wrapped_function = wrapped_function
602        self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent
603        self.cuda_graphs_pool = cuda_graphs_pool
604        self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = []
605        self.tensor_weakrefs: List[Optional[TensorWeakRef]] = []
606        self.existing_cuda_graph = existing_cuda_graph
607        self.has_run = False
608        self.device_index = device_index
609        self.stack_traces = stack_traces
610        self.stream = stream
611        self.already_warm = already_warm
612        self.id = id
613
614    def run(self, new_inputs: Any) -> OutputType:
615        assert not self.has_run, "Wrapped function should never be run twice"
616
617        # See: output_is_alias_of_persistent_static_inputs below. We should only be returning freshly created
618        # storages in path_live_weakrefs.
619        existing_path_data_ptrs = {
620            t.data_ptr() for t in self.path_live_weakrefs() if t()
621        }
622
623        def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]:
624            non_cudagraph_inps = []
625            for t in itertools.chain(new_inputs, self.wrapped_function.constants):
626                if (
627                    isinstance(t, torch.Tensor)
628                    and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
629                ):
630                    non_cudagraph_inps.append(weakref.ref(t.untyped_storage()))
631            return non_cudagraph_inps
632
633        non_cudagraph_inps_storages = get_non_cudagraph_inps()
634
635        if config.triton.slow_path_cudagraph_asserts and not self.already_warm:
636            refs = list(self.path_live_weakrefs())
637            check_memory_pool(self.device_index, self.cuda_graphs_pool, refs)
638
639        with torch.cuda.device(
640            self.device_index
641        ), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager(
642            self.device_index, self.cuda_graphs_pool, self.stream
643        ), get_history_recording():
644            out = self.wrapped_function.model(new_inputs)
645
646        # We need to know which outputs are allocated within the cudagraph pool
647        # so that we can deallocate them at the beginning of the next cudagraph step,
648        # and set their access to error.
649        # We use a weakref to the inputs storage, in case a block which was previously
650        # allocated to the general caching allocator pool gets reallocated to a private pool.
651
652        non_cudagraph_inps_storage_ptrs = set()
653        for storage in non_cudagraph_inps_storages:
654            s = storage()
655            if s is not None:
656                non_cudagraph_inps_storage_ptrs.add(s._cdata)
657
658        assert len(new_inputs) == 0
659
660        # sdpa returns cpu tensors when not recording cuda graph
661        def add_ref(o: Any) -> bool:
662            return (
663                isinstance(o, torch.Tensor)
664                and o.is_cuda
665                and o.untyped_storage()._cdata not in non_cudagraph_inps_storage_ptrs
666                and o.untyped_storage().data_ptr() != 0
667            )
668
669        self.outputs_weakrefs.extend(
670            [map_to_ref(o) if add_ref(o) else None for o in out]
671        )
672        self.tensor_weakrefs.extend(
673            [TensorWeakRef(o) if add_ref(o) else None for o in out]
674        )
675
676        if config.triton.slow_path_cudagraph_asserts and not self.already_warm:
677            out_refs = list(self.path_live_weakrefs())
678            check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs)
679
680        return out
681
682    @property
683    def _path_from_root(
684        self,
685    ) -> Generator[Union[CUDAGraphNode, CUDAWarmupNode], None, None]:
686        nodes = []
687        node: Union[CUDAGraphNode, CUDAWarmupNode] = self
688        while node:
689            nodes.append(node)
690            node = node.parent  # type: ignore[assignment]
691
692        yield from reversed(nodes)
693
694    def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]:
695        "Returns all live storages weakrefs that created by nodes in this path"
696        for node in self._path_from_root:
697            for output in node.outputs_weakrefs:
698                if is_live(output):
699                    yield output  # type: ignore[misc]
700
701    def all_outputs_are_dead(self) -> bool:
702        return not list(self.path_live_weakrefs())
703
704    def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool:
705        for storage_weak_ref in self.path_live_weakrefs():
706            if t.untyped_storage().data_ptr() == storage_weak_ref.data_ptr():
707                return True
708        return False
709
710
711# Aliases for List that say what the indices denote
712InputList = List  # input indexes
713OutputList = List  # output indexes
714LevelList = List  # levels (distance from root of tree)
715
716
717class OutputAliasInfo:
718    pass
719
720
721class _UnaliasedStorage(OutputAliasInfo):
722    "Singleton to mark that the graph output constructs a new alias or is None"
723
724
725UnaliasedStorage = _UnaliasedStorage()
726
727
728class AliasesPriorGraphOutput(OutputAliasInfo):
729    "Marks that the graph output aliases an output of a prior graph"
730    __slots__ = ["index"]
731
732    index: PathOutputIndex
733
734    def __init__(self, index: PathOutputIndex) -> None:
735        assert isinstance(index, tuple)
736        self.index = index
737
738
739class AliasesNewOutput(OutputAliasInfo):
740    "Marks that the graph output aliases an index in the new, returned outputs"
741
742    __slots__ = ["index"]
743
744    index: int
745
746    def __init__(self, index: int) -> None:
747        assert isinstance(index, int)
748        self.index = index
749
750
751class CUDAGraphNode:
752    """
753    A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool
754    and are structured into a tree, where there is a single recording that can precede it (parent) and multiple
755    subsequent recordings that may follow (children). A node will have no parent if it is the first recording
756    in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which
757    would force a dependency.
758
759    On first recording, all of the live tensors in the current CUDA Graph Node path will be
760    reflected in the corresponding private pool. On subsequent executions, the caching allocator
761    is unaffected when the graph is replayed.
762
763    In order to support recording a subsequent cuda graph recording after execution of this graph,
764    we checkpoint the state of the memory pool so that it may later be resumed.
765
766    WrappedFunction should have already been warmed up prior to invocation.
767
768    See [setCheckpointPoolState] for further explanation, as well as
769    https://user-images.githubusercontent.com/13564/222815509-374f3400-f83d-4f7d-8fa6-4a092b3250bb.png
770    """
771
772    def __init__(
773        self,
774        wrapped_function: WrappedFunction,
775        id: GraphID,
776        parent: Optional[CUDAGraphNode],
777        inputs: List[InputType],
778        cuda_graphs_pool: Tuple[int, int],
779        device_index: int,
780        stack_traces: Optional[StackTraces],
781        stream: torch.cuda.Stream,
782    ) -> None:
783        assert isinstance(inputs, (list, tuple))
784
785        self.wrapped_function = wrapped_function
786        self.id = id
787        self.device = device_index
788        self.stack_traces = stack_traces
789        self.stream = stream
790
791        # Enable re-record a cudagraph when static tensor address changed.
792        # if not we should error when it changed.
793        self.rerecord_if_static_inputs_change = (
794            torch._dynamo.config.inline_inbuilt_nn_modules
795            or torch._inductor.config.triton.cudagraph_support_input_mutation
796        )
797
798        # if this is a root parent will be None. use weakref to prevent reference cycle
799        self._parent = weakref.ref(parent) if parent is not None else None
800        # reference to the shared memory pool for the entire cuda graphs tree
801        self.cuda_graphs_pool = cuda_graphs_pool
802
803        # A single wrapped function may be recorded multiple times if memory patterns or
804        # invariants change from one execution to the next
805        self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
806
807        # StorageWeakRef maintains whether the Storage C++ object remains allocated,
808        # not whether the corresponding memory has been deallocated. In order
809        # to use them to track memory deallocations we must maintain a single StorageWeakRef
810        # for all Storages that reference that memory (even if we are constructing Storages
811        # that do not have a deallocator function). We maintain one single storage_cache
812        # as we execute any tree path. When we retrieve a storage from the cache we
813        # check that it is still alive, and we hash based on observed recording data ptr
814        # and storage cdata.
815
816        # we preserve a single reference to executed outputs that is then referenced
817        # in children to avoid children having to chase parent pointers in the hot path
818        # DO NOT reassign output_weakrefs, only call `clear()`
819        # Path is a series of nodes from root to the current node
820        self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = []
821        self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [
822            node.outputs_weakrefs for node in self._path_from_root
823        ]
824        self.path_stacktraces: LevelList[Optional[StackTraces]] = [
825            node.stack_traces for node in self._path_from_root
826        ]
827        self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = []
828
829        # tensors which are outputs of previous graphs in the tree
830        self.cudagraph_managed_idxs: List[int] = [
831            idx
832            for idx, t in enumerate(inputs)
833            if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
834        ]
835
836        self.static_input_idxs: List[int] = list(
837            set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
838        )
839
840        self.non_static_input_idx: LevelList[int] = [
841            i for i in range(len(inputs)) if i not in self.static_input_idxs
842        ]
843
844        counters["inductor"]["cudagraph_recorded_non_static_inputs"] += len(
845            self.non_static_input_idx
846        )
847
848        self.non_managed_static_input_idxs: LevelList[int] = [
849            i
850            for i in wrapped_function.static_input_idxs
851            if i not in self.cudagraph_managed_idxs
852        ]
853
854        def maybe_get_static_data_ptr(
855            idx: int,
856            inputs: List[Union[torch.Tensor, int]],
857            static_input_idxs: List[int],
858        ) -> Optional[int]:
859            inp = inputs[idx]
860            if isinstance(inp, torch.Tensor) and idx in static_input_idxs:
861                return inp.data_ptr()
862            return None
863
864        self.static_input_data_ptrs: InputList[Optional[int]] = [
865            maybe_get_static_data_ptr(i, inputs, self.static_input_idxs)
866            for i in range(len(inputs))
867        ]
868
869        # When we checkpoint, and free generations, we will be manually freeing the outputs
870        # of CUDAGraphNodes. We should not be freeing parameters, not do we need to account for
871        # their liveness (they are static), so we need to compute which outputs are aliases of
872        # parameters. Some static inputs are saved tensors from the forward that die in the backward.
873        # Their locations are static but lifetimes are not. We only include the persistent static
874        # data ptrs below because the non persistent data ptrs may be outputs of this record and
875        # fresh allocations.
876
877        # precompute expanded dims to avoid computing in the hot path
878        self.expanded_dims: List[List[int]] = [
879            get_expanded_dims(x)
880            if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs
881            else []
882            for idx, x in enumerate(inputs)
883        ]
884
885        # For each node in path, which outputs were observed to be live
886        # before invoking graph recording, and after graph recording
887        self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = []
888        self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = []
889
890        # List of Tuples of (depth, output_index) that index into node at depth
891        # number of nodes from root and output_index of outputs. Will index into
892        # path_weakrefs.
893        self.expected_dead_indices_before_graph: List[PathOutputIndex] = []
894        self.expected_dead_indices_after_graph: List[PathOutputIndex] = []
895
896        # all live indices after graph recording
897        self.live_indices_after_graph: List[PathOutputIndex] = []
898
899        if self.parent is not None:
900            previous_liveness = self.parent.recorded_liveness_after_graph
901            curr_liveness = self._get_liveness(self.path_weakrefs)
902
903            different_indices = self._get_different_indices(
904                previous_liveness, curr_liveness
905            )
906
907            self.recorded_liveness_before_graph = curr_liveness
908            self.expected_dead_indices_before_graph = different_indices
909
910        recording_inputs = self._allocate_and_copy_recording_inputs(inputs)
911        # recording inputs will copy over memory, so we can free non recording inputs
912        inputs.clear()
913        del inputs
914
915        # graph used for recording model invocation
916        self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
917
918        # we allocate non-static inputs within the same memory pool as the CUDAGraph
919        # which we will record the model with. For memory efficiency, it is important
920        # to reclaim the input memory when the inputs are no longer live. To accomplish this,
921        # we reconstruct tensors at the correct data pointers of our inputs which are
922        # non owning and do not prevent deallocation. On subsequent executions, input values
923        # will be copied over to these tensors.
924        self.reconstructed_inputs: List[InputType] = [
925            self._reconstruct_from_tensor_metadata(self._tensor_metadata(x))
926            if isinstance(x, torch.Tensor)
927            else x
928            for x in recording_inputs
929        ]
930
931        # DO THE RECORDING!!!
932        # We record the CUDA graph in the constructor of CUDAGraphNode, which
933        # gives you what the CPU side compute of the function would do.  We
934        # don't throw the recording outputs away: their memory is
935        # correctly accounted for in the CUDAGraphs caching allocator.  This
936        # means on the very FIRST run of the CUDA graph node, we can directly
937        # do more recording, because we have a valid caching allocator state.
938        # NB: This relies on run() being called immediately after the
939        # constructor, otherwise this optimization would not be valid.
940
941        # initialized below in _record
942
943        self.checkpointed_caching_state: Optional[AllocatorState] = None
944
945        # Output Storage Alias information, can be:
946        # - A new, unaliased storage, or the output is None
947        # - An alias of an output of a prior graph
948        # - An alias of an output already created in the reconstructed outputs
949        # This is None if the output in question is an int
950        self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = []
951
952        # is the output Storage unaliased in subsequent outputs, of all subsequent paths
953        # if it is, we cached the output tensor and adjust storage liveness tracking to also
954        # check if the output tensor does not have an additional python reference.
955        # If a descendent node discovers it has an alias of a prior output, then the output
956        # will no longer be cached in the ancestor.
957        # The large majority of tensors are unaliased, and preserving aliased output tensors would add
958        # significant additional complexity with marginal gains
959        # The cached tensor outputs are added on the first execution, and cleared whenever we need
960        # to do subsequent recording
961        self.unaliased_in_all_paths: OutputList[bool] = []
962        self.cached_tensor_outputs: OutputList[Optional[Tensor]] = []
963
964        # if an output aliases a static, persistent input then the corresponding Tensor will
965        # be set here. These are different than cached tensors, because they are tensors that
966        # are aliases of parameters that are always live.
967        self.static_output_tensors: OutputList[Optional[Tensor]] = []
968
969        # Cleared after recording
970        self.recording_outputs: Optional[OutputType] = self._record(
971            wrapped_function.model, recording_inputs
972        )
973        self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = []
974
975        # As with inputs, we do not want to keep the outputs permanently alive because that would prevent
976        # their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata
977        # needed to reconstruct instead.
978        assert self.recording_outputs is not None
979        for out in self.recording_outputs:
980            if isinstance(out, torch.Tensor):
981                self.outputs_metadata.append(
982                    self._tensor_metadata(out, ignore_storage_offset=False)
983                )
984            else:
985                assert isinstance(out, (int, type(None))), type(out)
986                self.outputs_metadata.append(out)
987
988        self.graph.replay()
989
990    def _copy_inputs_and_remove_from_src(
991        self, dsts: List[InputType], srcs: List[InputType]
992    ) -> None:
993        dst_tensors = []
994        src_tensors = []
995        for idx in self.non_static_input_idx:
996            if not isinstance(srcs[idx], torch.Tensor):
997                continue
998            expanded_dims = self.expanded_dims[idx]
999            dst_tensors.append(index_expanded_dims(dsts[idx], expanded_dims))  # type: ignore[arg-type]
1000            src_tensors.append(index_expanded_dims(srcs[idx], expanded_dims))  # type: ignore[arg-type]
1001            srcs[idx] = None  # type: ignore[call-overload]
1002        # Fails on empty lists
1003        if dst_tensors:
1004            torch._foreach_copy_(dst_tensors, src_tensors)
1005
1006    def check_static_inputs_are_stable(self, new_inputs: List[InputType]) -> None:
1007        # avoid checking managed tensor static points since we already checked those in check_invariants
1008        if (
1009            not self.rerecord_if_static_inputs_change
1010            and not torch._C._tensors_data_ptrs_at_indices_equal(
1011                new_inputs,  # type: ignore[arg-type]
1012                self.static_input_data_ptrs,
1013                self.non_managed_static_input_idxs,
1014            )
1015        ):
1016            # this should error
1017            error_msg = log_data_ptr_mismatch(
1018                self.wrapped_function.placeholders,
1019                new_inputs,
1020                self.static_input_data_ptrs,
1021                self.non_managed_static_input_idxs,
1022                CheckInvariantStatus.StaticInputIdxMismatch,
1023            )
1024            torch._check(False, lambda: error_msg)
1025
1026    def run_first_inputs(self, new_inputs: List[InputType]) -> OutputType:
1027        if config.triton.fast_path_cudagraph_asserts:
1028            self.debug_check_invariants_before_invocation()
1029
1030        # graph is already invoked in the __init__
1031        # inputs are copied over in _allocate_recording_inputs and subsequently cleared
1032        assert len(new_inputs) == 0
1033        outputs = self.recording_outputs
1034        self.recording_outputs = None
1035        assert outputs is not None
1036        return outputs
1037
1038    def run(self, new_inputs: List[InputType]) -> OutputType:
1039        self.check_static_inputs_are_stable(new_inputs)
1040
1041        self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
1042        new_inputs.clear()
1043
1044        self.run_graph()
1045
1046        outputs = self.reconstruct_outputs()
1047
1048        if config.triton.fast_path_cudagraph_asserts:
1049            self.debug_check_invariants_after_invocation()
1050
1051        if config.triton.force_cudagraph_sync:
1052            torch.cuda.synchronize()
1053
1054        # Reset this to run the check in the future
1055        self.static_inputs_stable = False
1056
1057        return outputs
1058
1059    def reconstruct_outputs(self) -> OutputType:
1060        "Reconstruct output tensors according to their saved metadata and alias information"
1061
1062        # Cached tensors will not yet be set on the first execution
1063        # They are also cleared in checkpointing, so if we checkpoint this node
1064        # and then execute it again we will need to repopulate cached tensors
1065        if not self.cached_tensor_outputs:
1066            self._initialize_cached_tensors()
1067
1068        outputs: OutputType = []
1069
1070        for i, (storage_info, metadata) in enumerate(
1071            zip(self.output_storage_alias, self.outputs_metadata)
1072        ):
1073            if not isinstance(metadata, dict):  # tensor metadata
1074                assert isinstance(metadata, (int, type(None)))
1075                outputs.append(metadata)
1076                continue
1077
1078            cached_t = self.cached_tensor_outputs[i]
1079            if cached_t is not None:
1080                # this output represents a fresh allocated tensor.
1081                # We return the same TensorImpl from run to run to avoid overhead.
1082                # autograd.Function will reset the Autograd meta of output tensors
1083                # as part of aot_autograd, but _backward_hooks are stored on tensors separately,
1084                # so we need to manually reset hooks.
1085                if cached_t._backward_hooks is not None:
1086                    cached_t._backward_hooks = None
1087
1088                # No need to update weakrefs, already correctly initialized
1089                outputs.append(cached_t)
1090                continue
1091
1092            static_t = self.static_output_tensors[i]
1093            if static_t is not None:
1094                assert self.outputs_weakrefs[i] is None
1095                outputs.append(static_t)
1096                continue
1097
1098            storage = self.prepare_alias_info_for_tensor_construction(
1099                storage_info, metadata
1100            )
1101
1102            if isinstance(storage, UntypedStorage) or storage is None:
1103                out = self._reconstruct_from_tensor_metadata(metadata, storage)
1104            else:
1105                assert isinstance(storage, int)
1106                out = self._reconstruct_from_tensor_metadata(
1107                    metadata, cast(torch.Tensor, outputs[storage]).untyped_storage()
1108                )
1109
1110            outputs.append(out)
1111            w = self.outputs_weakrefs[i]
1112            assert w is not None
1113            w.swap_weakref(out.untyped_storage()._weak_ref())
1114
1115        return outputs
1116
1117    def prepare_alias_info_for_tensor_construction(
1118        self,
1119        out_alias_info: Optional[OutputAliasInfo],
1120        metadata: Union[Dict[str, Any], int, None],
1121    ) -> Union[UntypedStorage, None, int]:
1122        if (
1123            isinstance(metadata, (int, type(None)))
1124            or out_alias_info is UnaliasedStorage
1125        ):
1126            return None
1127
1128        if isinstance(out_alias_info, AliasesPriorGraphOutput):
1129            depth, existing_output_index = out_alias_info.index
1130            ref = self.path_weakrefs[depth][existing_output_index]
1131            assert ref is not None
1132            return torch.UntypedStorage._new_with_weak_ptr(ref())
1133
1134        assert isinstance(out_alias_info, AliasesNewOutput)
1135        return out_alias_info.index
1136
1137    def prepare_storages_for_construction(
1138        self,
1139    ) -> List[Union[UntypedStorage, None, int]]:
1140        output_storages = []
1141        for output_storage_alias, metadata in zip(
1142            self.output_storage_alias, self.outputs_metadata
1143        ):
1144            output_storages.append(
1145                self.prepare_alias_info_for_tensor_construction(
1146                    output_storage_alias, metadata
1147                )
1148            )
1149
1150        return output_storages
1151
1152    def run_graph(self) -> None:
1153        assert self.graph is not None
1154        self.graph.replay()
1155
1156    def all_outputs_are_dead(self) -> bool:
1157        "All outputs of the path from this node to its root are dead"
1158        for depth, output_index in self.live_indices_after_graph:
1159            if is_live(self.path_weakrefs[depth][output_index]):
1160                return False
1161        return True
1162
1163    def _record(self, model: ModelType, inputs: List[InputType]) -> OutputType:
1164        "Record the model"
1165
1166        def static_input_iter() -> Generator[torch.Tensor, None, None]:
1167            for i in self.wrapped_function.static_input_idxs:
1168                _inp = inputs[i]
1169                if isinstance(
1170                    _inp, torch.Tensor
1171                ) and not self._is_cuda_graph_recorded_tensor(_inp):
1172                    yield _inp
1173
1174        # see: output_is_alias_of_persistent_static_inputs above
1175        static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = {
1176            inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp)
1177            for inp in itertools.chain(
1178                static_input_iter(), self.wrapped_function.constants
1179            )
1180        }
1181
1182        if config.triton.slow_path_cudagraph_asserts:
1183            # need to use parent live weakrefs because live_indices isnt set yet
1184            memory = (
1185                [] if self.parent is None else list(self.parent.path_live_weakrefs())
1186            )
1187            memory += [
1188                StorageWeakRefWrapper(elem)
1189                for i, elem in enumerate(inputs)
1190                if isinstance(elem, torch.Tensor)
1191                and i not in self.wrapped_function.static_input_idxs
1192                and elem.untyped_storage().data_ptr() != 0
1193            ]
1194            check_memory_pool(self.device, self.cuda_graphs_pool, memory)
1195
1196        with preserve_rng_state(), torch.cuda.device(
1197            self.device
1198        ), clear_cublas_manager(), torch.cuda.graph(
1199            self.graph,
1200            stream=self.stream,
1201            pool=self.cuda_graphs_pool,
1202            capture_error_mode="thread_local",
1203        ), get_history_recording():
1204            static_outputs = model(inputs)
1205
1206        # running model should reclaim memory
1207        assert len(inputs) == 0
1208
1209        if not isinstance(static_outputs, (list, tuple)):
1210            static_outputs = (static_outputs,)
1211
1212        self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs)
1213
1214        return static_outputs
1215
1216    def _add_first_outputs(
1217        self,
1218        outputs: OutputType,
1219        static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper],
1220    ) -> None:
1221        "Add the outputs from the first invocation of the node and set up metadata"
1222
1223        # getting liveness before we have added the outputs to path, so the length
1224        # of the two lists is equal
1225        prev_liveness = self.recorded_liveness_before_graph
1226        curr_liveness = self._get_liveness(self.path_weakrefs)
1227
1228        delta = self._get_different_indices(prev_liveness, curr_liveness)
1229        self.expected_dead_indices_after_graph = delta
1230
1231        assert len(self.outputs_weakrefs) == 0
1232        # index from data pointer to index in outputs
1233        output_new_storages_index: Dict[StorageDataPtr, int] = {}
1234
1235        self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
1236        self.static_output_tensors = [None for _ in range(len(outputs))]
1237
1238        for i, o in enumerate(outputs):
1239            if o is None or not isinstance(o, torch.Tensor):
1240                self.output_storage_alias.append(UnaliasedStorage)
1241                continue
1242
1243            torch._check(
1244                o.is_cuda or o.untyped_storage().data_ptr() == 0,
1245                lambda: (
1246                    "Expected all cuda outputs in cuda graph recording. Non cuda output "
1247                    f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
1248                ),
1249            ),
1250
1251            ref = static_input_persistent_storage_ptrs.get(
1252                o.untyped_storage().data_ptr(), None
1253            )
1254            # also treat empty storages as static outputs because we do not need to manage their lifetime
1255            # and they should not participate in checkpointing
1256            is_empty_storage = o.untyped_storage().data_ptr() == 0
1257            if (ref and ref() is not None) or is_empty_storage:
1258                self.output_storage_alias.append(None)
1259                self.static_output_tensors[i] = o
1260                continue
1261
1262            path_ref = self._is_alias_of_live_recorded_tensor(o)
1263            if path_ref is not None:
1264                self._mark_prior_graph_output_as_aliased(path_ref)
1265                self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
1266                continue
1267
1268            if o.untyped_storage().data_ptr() in output_new_storages_index:
1269                index = output_new_storages_index[o.untyped_storage().data_ptr()]
1270                self.unaliased_in_all_paths[index] = False
1271                self.output_storage_alias.append(AliasesNewOutput(index))
1272                continue
1273
1274            output_new_storages_index[o.untyped_storage().data_ptr()] = i
1275            self.output_storage_alias.append(UnaliasedStorage)
1276            self.unaliased_in_all_paths[i] = True
1277
1278        if self.stack_traces is None:
1279            self.stack_traces = [None for _ in range(len(outputs))]
1280        else:
1281            assert len(self.stack_traces) == len(
1282                outputs
1283            ), "Wrong number of stack traces passed in"
1284
1285        assert not self.outputs_weakrefs
1286        for out, static_output_tensor in zip(outputs, self.static_output_tensors):
1287            if not isinstance(out, torch.Tensor) or static_output_tensor is not None:
1288                self.outputs_weakrefs.append(None)
1289                self.tensor_weakrefs.append(None)
1290            else:
1291                self.outputs_weakrefs.append(StorageWeakRefWrapper(out))
1292                self.tensor_weakrefs.append(TensorWeakRef(out))
1293
1294        self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs)
1295        self.checkpointed_caching_state = torch._C._cuda_getCheckpointState(
1296            self.device, self.cuda_graphs_pool
1297        )
1298
1299        # now, get liveness with outputs added
1300        for depth in range(len(self.path_weakrefs)):
1301            for output_index in range(len(self.path_weakrefs[depth])):
1302                if is_live(self.path_weakrefs[depth][output_index]):
1303                    self.live_indices_after_graph.append((depth, output_index))
1304
1305        self.debug_check_invariants_after_invocation()
1306        if config.triton.slow_path_cudagraph_asserts:
1307            check_memory_pool(
1308                self.device, self.cuda_graphs_pool, list(self.path_live_weakrefs())
1309            )
1310
1311    def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex) -> None:
1312        "Remove a graph output from the unaliased, cached tensors in an ancestor node"
1313        depth, output_index = index
1314        node = list(self._path_from_root)[depth]
1315        node.unaliased_in_all_paths[output_index] = False
1316        x = self.path_weakrefs[depth][output_index]
1317        assert x is not None
1318        x.remove_extra_reference()
1319
1320    def _initialize_cached_tensors(self) -> None:
1321        # we should not be clearing output_weakrefs, and they should be set in the first
1322        # record run
1323        assert len(self.outputs_weakrefs) == len(self.outputs_metadata)
1324
1325        for i, (storage_info, metadata, make_cached) in enumerate(
1326            zip(
1327                self.output_storage_alias,
1328                self.outputs_metadata,
1329                self.unaliased_in_all_paths,
1330            )
1331        ):
1332            if not make_cached:
1333                self.cached_tensor_outputs.append(None)
1334                continue
1335
1336            assert storage_info is UnaliasedStorage
1337            assert isinstance(metadata, dict)
1338            s = self.create_storage(metadata)
1339            out = self._reconstruct_from_tensor_metadata(metadata, storage=s)  # type: ignore[arg-type]
1340
1341            # XXX: let autograd know that there will be an additional reference to the tensor
1342            # that can be ignored when deciding whether to do gradient buffer inplacing.
1343            # Otherwise, inplacing could differ between tracing and subsequent execution.
1344            # For some models we tested this led to inputs no longer being in cudagraph pools,
1345            # leading to spurious re-recordings.
1346            # It also tells AMP cache that even though the tensor impls cannot be cached
1347            # in dtype conversions.
1348
1349            torch._C._add_cached_tensor(out)
1350
1351            self_ref = weakref.ref(self)
1352
1353            # one reference in our array, and calling sys.getrefcount bumps the refcount by one
1354            def check_refcount(i: int) -> bool:
1355                self_loc = self_ref()
1356                if self_loc is None:
1357                    return False
1358                return self_loc.get_output_refcount(i) == 2
1359
1360            check = functools.partial(check_refcount, i=i)
1361
1362            self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check)
1363            self.cached_tensor_outputs.append(out)
1364
1365    def get_output_refcount(self, index: int) -> int:
1366        return sys.getrefcount(self.cached_tensor_outputs[index])
1367
1368    @property
1369    def parent(self) -> Optional[CUDAGraphNode]:
1370        "unwraps the weakref to _parent"
1371        return self._parent() if self._parent is not None else None
1372
1373    @property
1374    def _path_to_root(self) -> Generator[CUDAGraphNode, None, None]:
1375        "Returns all nodes in the path starting at self and ending at root"
1376        node = self
1377        while node:
1378            yield node
1379            node = node.parent  # type: ignore[assignment]
1380
1381    @property
1382    def _path_from_root(self) -> Generator[CUDAGraphNode, None, None]:
1383        "Returns all nodes in the path starting at the root and ending at self"
1384        nodes = reversed(list(self._path_to_root))
1385        yield from nodes
1386
1387    def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool:
1388        "Is this tensor an output of a node in this path"
1389        for output_refs in self.path_weakrefs:
1390            for storage_weak_ref in output_refs:
1391                if storage_weak_ref is None:
1392                    continue
1393                # don't need to check liveness of storage since the cuda graph managed
1394                # memory is never released.
1395                data_ptr = storage_weak_ref.data_ptr()
1396                if t.untyped_storage().data_ptr() == data_ptr:
1397                    return True
1398
1399        return False
1400
1401    def _is_alias_of_live_recorded_tensor(
1402        self, t: torch.Tensor
1403    ) -> Optional[PathOutputIndex]:
1404        for depth, output_refs in enumerate(self.path_weakrefs):
1405            for output_index, storage_ref in enumerate(output_refs):
1406                if (storage_and_ptr := maybe_deref(storage_ref)) is not None:
1407                    storage, ptr = storage_and_ptr
1408                    if ptr == t.untyped_storage().data_ptr():
1409                        return (depth, output_index)
1410
1411        return None
1412
1413    @staticmethod
1414    def _check_liveness(
1415        indices: List[PathOutputIndex],
1416        output_refs: List[List[Optional[StorageWeakRefWrapper]]],
1417    ) -> bool:
1418        "Check that all of the indices specified are dead references"
1419        for depth, output_index in indices:
1420            w = output_refs[depth][output_index]
1421            assert w is not None
1422            if w() is not None:
1423                return False
1424        return True
1425
1426    def add_child(self, function_id: FunctionID, node: CUDAGraphNode) -> None:
1427        "Adds node as a a child of self"
1428        self.children[function_id].append(node)
1429
1430    @staticmethod
1431    def _get_different_indices(
1432        prev: List[List[bool]], curr: List[List[bool]]
1433    ) -> List[PathOutputIndex]:
1434        "Find indices where the two lists differ."
1435        dead_indices = []
1436        assert len(prev) <= len(curr)
1437        for i, (outputs1, outputs2) in enumerate(zip(prev, curr)):
1438            assert len(outputs1) == len(outputs2)
1439            for j, (output1, output2) in enumerate(zip(outputs1, outputs2)):
1440                if output1 != output2:
1441                    dead_indices.append((i, j))
1442
1443        return dead_indices
1444
1445    @staticmethod
1446    def _get_liveness(
1447        weakrefs: List[List[Optional[StorageWeakRefWrapper]]],
1448    ) -> List[List[bool]]:
1449        "Maps weakrefs to true if the reference is alive and false otherwise"
1450        if len(weakrefs) == 0:
1451            return []
1452
1453        return [pytree.tree_map(is_live, outputs) for outputs in weakrefs]
1454
1455    def debug_assert_invariants(
1456        self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex]
1457    ) -> None:
1458        if not config.triton.fast_path_cudagraph_asserts:
1459            return
1460
1461        for i, node in enumerate(self._path_from_root):
1462            assert self.path_weakrefs[i] is node.outputs_weakrefs
1463
1464        nodes = list(self._path_from_root)
1465
1466        live_blocks = get_block_addrs(self.cuda_graphs_pool)
1467
1468        live_storage_data_ptrs = set()
1469        live_storage_weak_ptrs = set()
1470
1471        for depth, outputs_liveness in enumerate(expected_liveness):
1472            for output_idx, output_liveness in enumerate(outputs_liveness):
1473                # tensor can die early, but it can't be alive when it should be dead
1474                w = self.path_weakrefs[depth][output_idx]
1475                if (stor_weak_ptr_and_data_ptr := maybe_deref(w)) is not None:
1476                    assert output_liveness
1477                    stor_weak_ptr, stor_data_ptr = stor_weak_ptr_and_data_ptr
1478                    assert (stor_data_ptr in live_storage_data_ptrs) == (
1479                        stor_weak_ptr in live_storage_weak_ptrs
1480                    )
1481                    live_storage_data_ptrs.add(stor_data_ptr)
1482                    live_storage_weak_ptrs.add(stor_weak_ptr)
1483
1484                    is_persistent_alias = (
1485                        nodes[depth].static_output_tensors[output_idx] is not None
1486                    )
1487
1488                    if is_persistent_alias:
1489                        assert stor_data_ptr not in live_blocks
1490
1491        for depth, output_index in newly_dead:
1492            assert not is_live(self.path_weakrefs[depth][output_index])
1493
1494    def debug_check_invariants_before_invocation(self) -> None:
1495        self.debug_assert_invariants(
1496            self.recorded_liveness_before_graph, self.expected_dead_indices_before_graph
1497        )
1498
1499    def debug_check_invariants_after_invocation(self) -> None:
1500        self.debug_assert_invariants(
1501            self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph
1502        )
1503
1504    def data_ptrs_dead_since_invocation(self) -> List[int]:
1505        """
1506        Since this node was invoked, return data ptrs of all tensor outputs that have died
1507        in the current executing tree path.
1508        """
1509        curr_liveness = self._get_liveness(self.path_weakrefs)
1510        _get_different_indices = self._get_different_indices(
1511            self.recorded_liveness_after_graph, curr_liveness
1512        )
1513
1514        path = list(self._path_from_root)
1515        ptrs_to_deallocate = []
1516        for depth, output_index in _get_different_indices:
1517            ptrs_to_deallocate.append(
1518                path[depth].outputs_metadata[output_index]["data_ptr"]  # type: ignore[index]
1519            )
1520
1521        return ptrs_to_deallocate
1522
1523    def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]:
1524        for i, j in self.live_indices_after_graph:
1525            out = self.path_weakrefs[i][j]
1526            if out is not None and is_live(out):
1527                yield out
1528
1529    def remove_node_cached_tensors(self) -> None:
1530        for t in self.cached_tensor_outputs:
1531            if t is not None:
1532                torch._C._remove_cached_tensor(t)
1533        self.cached_tensor_outputs.clear()
1534
1535        for i, unaliased in enumerate(self.unaliased_in_all_paths):
1536            if unaliased:
1537                n = self.outputs_weakrefs[i]
1538                assert n is not None
1539                n.remove_extra_reference()
1540
1541    def remove_path_cached_tensors(self) -> None:
1542        for node in self._path_from_root:
1543            node.remove_node_cached_tensors()
1544
1545    def clear_path_state(self) -> None:
1546        "Clear the path state in this current executing node"
1547        # this doesnt actually do anything right now, leaving it as placeholder
1548
1549    @staticmethod
1550    def _tensor_metadata(
1551        x: torch.Tensor, ignore_storage_offset: bool = True
1552    ) -> Dict[str, Any]:
1553        assert isinstance(x, torch.Tensor)
1554        # We ignore the storage offset for inputs, but not for outputs
1555        # TODO: - should we make the storage resizable ?
1556        return {
1557            "nbytes": x.untyped_storage().nbytes(),
1558            "data_ptr": x.untyped_storage().data_ptr(),
1559            "size": x.shape,
1560            "stride": x.stride(),
1561            "dtype": x.dtype,
1562            "device": x.device,
1563            "storage_offset": x.storage_offset() if not ignore_storage_offset else 0,
1564        }
1565
1566    def _reconstruct_from_tensor_metadata(
1567        self, metadata: Dict[str, Any], storage: Optional[UntypedStorage] = None
1568    ) -> Tensor:
1569        s = self.create_storage(metadata) if storage is None else storage
1570        return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s)  # type: ignore[arg-type]
1571
1572    def create_storage(self, metadata: Dict[str, Any]) -> torch.types.Storage:
1573        return torch._C._construct_storage_from_data_pointer(
1574            metadata["data_ptr"], metadata["device"], metadata["nbytes"]
1575        )
1576
1577    def _allocate_and_copy_recording_inputs(
1578        self, inputs: List[InputType]
1579    ) -> List[Union[torch.Tensor, int]]:
1580        """
1581        Allocate inputs for non static, non cudagraph managed tensors in the memory pool
1582        and copy over the tensor values.
1583        """
1584
1585        torch.cuda.synchronize()
1586        self.stream.wait_stream(torch.cuda.current_stream())
1587        recording_inputs: List[InputType] = []
1588
1589        with warnings.catch_warnings(record=True), torch.cuda.device(
1590            self.device
1591        ), _use_cuda_memory_pool_manager(
1592            self.device,
1593            mem_pool=self.cuda_graphs_pool,
1594            stream=self.stream,
1595        ):
1596            for i, inp in enumerate(inputs):
1597                if not isinstance(inp, torch.Tensor):
1598                    assert isinstance(inp, int)
1599                    recording_inputs.append(inp)
1600                elif i not in self.static_input_idxs:
1601                    # static_input does an allocation!
1602                    recording_inputs.append(static_input(inp))
1603                else:
1604                    recording_inputs.append(inp)
1605
1606            self._copy_inputs_and_remove_from_src(recording_inputs, inputs)
1607
1608        return recording_inputs
1609
1610    def check_invariants(
1611        self, inputs: List[InputType]
1612    ) -> Tuple[CheckInvariantStatus, Callable[..., str]]:
1613        """
1614        Checks if this node can be run. The same pattern of tensor liveness, static inputs,
1615        and tensors managed in the cudagraph private pool must remain stable.
1616        """
1617
1618        _logger = functools.partial(
1619            log_data_ptr_mismatch,
1620            self.wrapped_function.placeholders,
1621            inputs,
1622            self.static_input_data_ptrs,
1623        )
1624
1625        # previously managed data pointers remain stable
1626        # this is on the hot path so moved to C++. equivalent to:
1627        # return all(t.data_ptr() == data_ptr for (t, data_ptr) in zip(tensors, data_ptrs))
1628        if not torch._C._tensors_data_ptrs_at_indices_equal(
1629            inputs,  # type: ignore[arg-type]
1630            self.static_input_data_ptrs,
1631            self.cudagraph_managed_idxs,
1632        ):
1633            status = CheckInvariantStatus.CudagraphManagedIdxMismatch
1634            _logger = functools.partial(
1635                _logger,
1636                self.cudagraph_managed_idxs,
1637                status,
1638            )
1639            return status, _logger
1640
1641        if not self._check_liveness(
1642            self.expected_dead_indices_before_graph, self.path_weakrefs
1643        ):
1644            status = CheckInvariantStatus.ExpectedDeadIndicesBeforeGraphMismatch
1645            return status, lambda: f"{status}"
1646
1647        # static input data pointers should remain stable
1648        # if we are inlining builtin nn modules we re-record in this case
1649        # if we are not inlining builtin nn modules, we check this in check_static_inputs_are_stable
1650        # and error if they are not stable
1651        if (
1652            self.rerecord_if_static_inputs_change
1653            and not torch._C._tensors_data_ptrs_at_indices_equal(
1654                inputs,  # type: ignore[arg-type]
1655                self.static_input_data_ptrs,
1656                self.static_input_idxs,
1657            )
1658        ):
1659            status = CheckInvariantStatus.StaticInputIdxMismatch
1660            _logger = functools.partial(
1661                _logger,
1662                self.static_input_idxs,
1663                status,
1664            )
1665            return status, _logger
1666
1667        # the cudagraph managed tensors which died upon recording must also die upon
1668        # this invocation. it is too late to check after we've replayed the graph,
1669        # because we would have already written over their memory.
1670        for idx in self.cudagraph_managed_idxs:
1671            inputs[idx] = None  # type: ignore[call-overload]
1672
1673        torch._check(
1674            self._check_liveness(
1675                self.expected_dead_indices_after_graph, self.path_weakrefs
1676            ),
1677            lambda: "TODO: graph recording observed an input tensor deallocate during graph "
1678            " recording that did not occur during replay. Please file an issue.",
1679        )
1680        return CheckInvariantStatus.SUCCESS, lambda: f"{CheckInvariantStatus.SUCCESS}"
1681
1682    def num_descendants(self) -> int:
1683        "Total number of descendents of this node"
1684        num_desc = 0
1685        for children in self.children.values():
1686            for child in children:
1687                num_desc += 1
1688                num_desc += child.num_descendants()
1689        return num_desc
1690
1691
1692def get_cudagraph_segments(pool_id: Tuple[int, int]) -> Any:
1693    segments = torch.cuda.memory_snapshot()
1694    return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
1695
1696
1697def get_block_addrs(pool_id: Tuple[int, int], live_only: bool = True) -> List[int]:
1698    blocks = []
1699
1700    for segment in get_cudagraph_segments(pool_id):
1701        addr = segment["address"]
1702        for block in segment["blocks"]:
1703            if block["state"] == "active_allocated" or not live_only:
1704                blocks.append(addr)
1705
1706            addr += block["size"]
1707
1708    return blocks
1709
1710
1711def format_tb(frames: List[Any]) -> str:
1712    formatted_traceback = []
1713
1714    for entry in frames:
1715        formatted_traceback.append(
1716            traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
1717        )
1718
1719    return "".join(traceback.format_list(formatted_traceback))
1720
1721
1722def check_memory_pool(
1723    device: int,
1724    pool_id: Tuple[int, int],
1725    live_storages_ptrs: List[StorageWeakRefWrapper],
1726) -> None:
1727    assert all(
1728        isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
1729    )  # noqa: C419
1730    unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()}
1731
1732    # check if there is a divergence first, then do the expensive snapshot call after
1733    # we know it will error
1734    if torch._C._cuda_checkPoolLiveAllocations(device, pool_id, unique_storages):
1735        return
1736
1737    # at this point we are past the fast-path. we have seen rare cases where a dead tensor is dead,
1738    # but hasn't been gc'd yet, and gives false positive for allocated_not_in_live_storages
1739    gc.collect()
1740
1741    segments = get_cudagraph_segments(pool_id)
1742
1743    allocated_not_in_live_storages = {}
1744
1745    for segment in segments:
1746        addr = segment["address"]
1747        for block in segment["blocks"]:
1748            if block["state"] == "active_allocated":
1749                if addr not in unique_storages:
1750                    allocated_not_in_live_storages[addr] = block
1751                else:
1752                    unique_storages.remove(addr)
1753
1754            addr += block["size"]
1755
1756    torch._check(
1757        len(unique_storages) == 0,
1758        lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
1759    )
1760
1761    if len(allocated_not_in_live_storages) != 0:
1762        formatted = []
1763        for dp, block in allocated_not_in_live_storages.items():
1764            trace = format_tb(block.get("frames", []))
1765            formatted.append(f"Data Pointer: {dp}, history: \n{trace}")
1766        formatted_s = "\n".join(formatted)
1767        msg = (
1768            f"These live storage data ptrs are in the cudagraph pool but not "
1769            f"accounted for as an output of cudagraph trees: \n\n{formatted_s}"
1770        )
1771        raise RuntimeError(msg)
1772
1773
1774class ExecutionState(Enum):
1775    """
1776    Represents the state of the CUDAGraph Tree. Will be None if there is no live current memory allocated
1777    in the cuda graph pool. Otherwise will reflect the state of the most recently executed node.
1778    """
1779
1780    NONE = auto()
1781    WARMUP = auto()
1782    RECORDING = auto()
1783    EXECUTION = auto()
1784
1785
1786class CompilationMode(Enum):
1787    FORWARD = auto()
1788    BACKWARD = auto()
1789    INFERENCE = auto()
1790
1791
1792class CUDAGraphTreeManager:
1793    """
1794    Groups individual recordings or executions of cuda graphs into a tree of recordings,
1795    and checks required invariants, and manages warmups of graphs.
1796
1797    When graphs are recorded in the same tree, it enforces subsequent execution
1798    to follow the same order and have the same output tensor livespans. To remove
1799    unnecessary coupling of cuda graphs (and additional imposed invariants),
1800    the tree manager will end a currently recording tree whenever it is valid - when
1801    the memory pool no longer has any live allocations.
1802
1803    We ignore outputs from a previous generation that correspond to prior model outputs.
1804    Currently this is hardcoded `GenerationTracker.generation` tracked in torch dynamo.
1805    # TODO: make generation increment configurable, warn on overwrite.
1806
1807    We run graph warmups in the cudagraph memory pool and return the result on the first invocation
1808    of a function. For many models it is important to reclaim activations as you run the backward.
1809    If we were to warm up the model and keep an extra copy of the inputs around to subsequently
1810    use for recording, we would incur a memory penalty. Additionally, if we are part way through training
1811    your model and need to recompile, memory will be allocated to the cuda graph pool, so we run this
1812    warmup run in the cuda graph memory pool. As for recording, warm up needs the state of live tensors
1813    to be accurately reflected so we checkpoint the allocator state if we need to warm up following graph
1814    replay.
1815    """
1816
1817    def __init__(self, device_index: int) -> None:
1818        # roots are functions which have no dependencies on an other node. I.e.,
1819        # when they are first invoked, none of their inputs are outputs are outputs
1820        # of another node, nor are there any live outputs of another node whose
1821        # liveness would create a dependency.
1822        self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
1823
1824        # mapping from function id to wrapped function
1825        self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {}
1826
1827        self.ids_to_stack_traces: Dict[FunctionID, Optional[StackTraces]] = {}
1828
1829        self.warmed_up_functions: Set[FunctionID] = set()
1830        # if we fail to increment generation, and are stuck warming up,
1831        # only warn on each function once
1832        self.warned_functions: Set[FunctionID] = set()
1833        torch._C._set_cached_tensors_enabled(True)
1834
1835        # warn only once if a function mutates inputs
1836        self.warned_mutation: Set[FunctionID] = set()
1837
1838        # NB: cuda caching allocator will remember the stream a segment is allocated to
1839        # and only allocate that segment to the same stream. we need to use a single stream
1840        # for all allocations to the memory pool, otherwise the allocations to separate streams
1841        # will not be reused; separate recordings would have use the same memory pool, but not
1842        # the same memory.
1843
1844        with torch.cuda.device(device_index):
1845            torch.cuda.synchronize()
1846            self.stream = torch.cuda.Stream()
1847            self.stream.wait_stream(torch.cuda.current_stream())
1848
1849            # Keeps Memory Pool Alive
1850            self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
1851            self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle()
1852
1853            with warnings.catch_warnings(record=True), torch.cuda.graph(
1854                self.graph,
1855                pool=self.cuda_graphs_thread_pool,
1856                stream=self.stream,
1857                capture_error_mode="thread_local",
1858            ):
1859                pass
1860
1861        self.graph_counter = itertools.count(0)
1862        self.func_counter = itertools.count(0)
1863
1864        # mapping from graph_id to (function id to mutation type hint) since we are
1865        # specializing on a particular combination of Parent Node -> Function ID.
1866        self.non_cudagraph_managed_mutation_hint: Dict[
1867            Optional[GraphID], Dict[FunctionID, bool]
1868        ] = defaultdict(dict)
1869        self.warmup_node_counter = itertools.count(start=-1, step=-1)
1870
1871        # mapping from graph_id to (function id to re-record count). We fall back to
1872        # eager function if a function is re-recorded frequently on a node.
1873        self.num_rerecord: Dict[Optional[GraphID], Dict[FunctionID, int]] = defaultdict(
1874            lambda: defaultdict(lambda: 0)
1875        )
1876
1877        # whether we the current node is in a state of warmup, recording, execution. If
1878        # there is no current node the state will be ExecutionState.None.
1879        self.path_state = ExecutionState.NONE
1880        self.device_index = device_index
1881
1882        # the most recently invoked cudagraph wrapping of a function. Will be None
1883        # when there is no output from a previous recording or execution whose memory
1884        # we need to respect in the cuda caching allocation. If you incremented generation,
1885        # this will also be none, as ignore those allocations.
1886        self.current_node: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = None
1887
1888        # current generation of cudagraph invocations. when torch.compile is run
1889        # we increment the current generation. are willing to ignore live outputs
1890        # of a previous generation in checking liveness.
1891        self.current_gen: int = -1
1892
1893        # number of instances we are in execution and failed to match to an
1894        # existing child
1895        self.debug_fail_counter = 0
1896        # number of instances we had to checkpoint the function
1897        self.debug_checkpointing_counter = 0
1898
1899        self.id_to_mode: Dict[FunctionID, CompilationMode] = {}
1900
1901        # Note: [Backward Generation Handling]
1902        # We generally perform a sequence of forward executions followed by backward executions.
1903        # If multiple torch.compile wrapped forwards are executed with their backwards pending,
1904        # we should not disregard the outputs from a prior torch.compile since the entire training
1905        # loop hasn't completed.  Occasionally, a backward pass corresponding to a forward pass may
1906        # not be executed, so we cannot wait for all pending forward pass backward completions, so
1907        # we cannot wait for all backwards to have been invoked. Instead we wait for a single backward
1908        # invocation. Triggering a backward pass typically doesn't lead to another torch.compile
1909        # invocation, making it less likely for the generation to increase between multiple
1910        # backward calls. The following use case is covered by this approach:
1911        # mod1 = torch.compile(...)
1912        # mod2 = torch.compile(...)
1913        # mod2(mod1(x)).sum().backward()
1914
1915        self.running_forwards_with_pending_backwards = False
1916
1917    def run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType:
1918        assert self.graph is not None, "Running CUDAGraph after shutdown"
1919        out = self._run(new_inputs, function_id)
1920
1921        # The forwards are only pending following invocation, not before
1922        mode = self.id_to_mode[function_id]
1923        if mode == CompilationMode.FORWARD:
1924            self.running_forwards_with_pending_backwards = True
1925        elif mode == CompilationMode.BACKWARD:
1926            self.running_forwards_with_pending_backwards = False
1927
1928        return out
1929
1930    def set_to_running_backward(self) -> None:
1931        self.running_forwards_with_pending_backwards = False
1932
1933    def _get_cuda_graph_recorded_tensor_checker(self) -> Callable[[Tensor], bool]:
1934        return (
1935            self.current_node._is_cuda_graph_recorded_tensor
1936            if isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode))
1937            else lambda _: False
1938        )
1939
1940    def new_warmup_node_id(self) -> GraphID:
1941        return GraphID(next(self.warmup_node_counter))
1942
1943    def _update_non_cudagraph_managed_mutation(
1944        self, function_id: FunctionID, inputs: List[InputType]
1945    ) -> None:
1946        node_id = self._get_node_id()
1947        if maybe_mutation_str := check_for_mutation(
1948            self.ids_to_funcs[function_id],
1949            inputs,
1950            self._get_cuda_graph_recorded_tensor_checker(),
1951        ):
1952            self.non_cudagraph_managed_mutation_hint[node_id][function_id] = True
1953            # warn once per function_id
1954            if function_id in self.warned_mutation:
1955                return
1956            self.warned_mutation.add(function_id)
1957            log_cudagraph_skip_and_bump_counter(maybe_mutation_str)
1958        else:
1959            self.non_cudagraph_managed_mutation_hint[node_id][function_id] = False
1960
1961    def _get_node_id(self) -> Optional[GraphID]:
1962        if self.current_node is None:
1963            return None
1964        elif isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode)):
1965            return self.current_node.id
1966        else:
1967            raise RuntimeError(f"Unknown node type {type(self.current_node)}")
1968
1969    def exceed_rerecord_limit(
1970        self, node_id: Optional[GraphID], function_id: FunctionID
1971    ) -> bool:
1972        if torch._dynamo.config.inline_inbuilt_nn_modules:
1973            return False
1974
1975        return (
1976            self.num_rerecord[node_id][function_id]
1977            > torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit
1978        )
1979
1980    def _run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType:
1981        # we will try to end the current execution lazily, since
1982        # we dont want to do unnecessary checking of the existing outputs
1983        # on the hot path, but both recording and warmup only happen once
1984        # so we check up front
1985        if self.in_recording:
1986            self.try_end_curr_recording(function_id)
1987
1988        if self.in_warmup:
1989            self.try_end_curr_warmup(function_id)
1990
1991        node_id = self._get_node_id()
1992        if function_id not in self.non_cudagraph_managed_mutation_hint[node_id]:
1993            self._update_non_cudagraph_managed_mutation(function_id, new_inputs)
1994
1995        # Early exit if the function mutates inputs which are neither parameters/buffers nor
1996        # cudagraph recorded tensors. This check should happen after `try_end_curr_recording`
1997        # and `try_end_curr_warmup` which may change self.current_node.
1998        if self.non_cudagraph_managed_mutation_hint[node_id][
1999            function_id
2000        ] or self.exceed_rerecord_limit(node_id, function_id):
2001            return self.ids_to_funcs[function_id].model(new_inputs)
2002
2003        # warming up a function and subsequentally recording may use different memory addresses
2004        # because both depend on the state of the caching allocator. if we warm up graph A,
2005        # then warm up graph B and make more allocations, the subsequent recording of A will not
2006        # necessarily use the same addresses as in the warm up. Thus any warm up of a node can only
2007        # be followed by warm up runs.
2008        if (
2009            (
2010                not (
2011                    function_id in self.warmed_up_functions
2012                    or config.triton.skip_cudagraph_warmup
2013                )
2014            )
2015            or self.in_warmup
2016            or config.triton.force_cudagraphs_warmup
2017        ):
2018            # If we are in the middle of executing cuda graphs, then we need to checkpoint memory state.
2019            # Both Recording and Warmup will be reflected in the allocator and dont need changes
2020            if self.path_state == ExecutionState.EXECUTION:
2021                self.apply_checkpoint_execution_state_in_allocator()
2022
2023            return self.run_eager(new_inputs, function_id)
2024
2025        assert not isinstance(self.current_node, CUDAWarmupNode)
2026        child_nodes = (
2027            self.roots if self.current_node is None else self.current_node.children
2028        )
2029
2030        if not self.in_recording:
2031            unexpected_rerecord, unexpected_rerecord_reason = False, lambda: ""
2032            for child in child_nodes[function_id]:
2033                # here we are checking memory consistency between recording and execution,
2034                # as well as things like stability of tensor locations, etc
2035                # and other
2036                status, status_logger = child.check_invariants(new_inputs)
2037                if status == CheckInvariantStatus.SUCCESS:
2038                    return self.execute_node(child, new_inputs)
2039
2040                if (
2041                    status == CheckInvariantStatus.StaticInputIdxMismatch
2042                    or status == CheckInvariantStatus.CudagraphManagedIdxMismatch
2043                ):
2044                    unexpected_rerecord = True
2045                    unexpected_rerecord_reason = status_logger
2046
2047            # now that we know the new function can't be run as a child of the
2048            # current node, if it is a root, try to end the current execution.
2049            # as noted above, we want to do this lazily to avoid having to
2050            # check all existing outputs
2051            if self.current_node is not None and function_id in self.roots:
2052                self.try_end_curr_execution()
2053
2054                # run again to hit the root matching case which must succeed
2055                if self.current_node is None:
2056                    return self.run(new_inputs, function_id)
2057
2058            if len(self.ids_to_funcs[function_id].mutated_input_idxs) > 0:
2059                self._update_non_cudagraph_managed_mutation(function_id, new_inputs)
2060                if self.non_cudagraph_managed_mutation_hint[self._get_node_id()][
2061                    function_id
2062                ]:
2063                    return self.ids_to_funcs[function_id].model(new_inputs)
2064
2065            # nb: run before checkpointing because checkpointing is slow, and we will
2066            # be using the eager caching allocator pool which does not require live
2067            # accounting of tensors in cudagraph allocator
2068            if unexpected_rerecord:
2069                curr_node_id = self._get_node_id()
2070                self.num_rerecord[curr_node_id][function_id] += 1
2071                if self.exceed_rerecord_limit(curr_node_id, function_id):
2072                    _id = curr_node_id.id if curr_node_id else None
2073                    log_cudagraph_skip_and_bump_counter(
2074                        f"skipping cudagraph due to function {function_id.id} exceeding max "
2075                        f"re-recording limit "
2076                        f"(={torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit}) "
2077                        f"on cudagraph node {_id} due to {unexpected_rerecord_reason()}."
2078                    )
2079                    return self.ids_to_funcs[function_id].model(new_inputs)
2080
2081            # at this point, we necessarily will do a new recording
2082            self.debug_fail_counter += 1
2083
2084            self.try_end_curr_execution()
2085            if self.current_node is not None:
2086                self.apply_checkpoint_execution_state_in_allocator()
2087
2088        # now, we are in a recording state !
2089        return self.record_function(new_inputs, function_id)
2090
2091    def shutdown(self) -> None:
2092        """
2093        Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn
2094        might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown
2095        to avoid a reference cycle.
2096        """
2097        nodes = []
2098        for roots in self.roots.values():
2099            nodes.extend(roots)
2100
2101        while nodes:
2102            node = nodes.pop()
2103            for children in node.children.values():
2104                nodes.extend(children)
2105            node.remove_node_cached_tensors()
2106            node.graph = None
2107
2108        self.graph = None
2109        self.roots = None  # type: ignore[assignment]
2110        self.current_node = None
2111
2112    def record_function(
2113        self, new_inputs: List[InputType], function_id: FunctionID
2114    ) -> OutputType:
2115        assert not isinstance(self.current_node, CUDAWarmupNode)
2116        graph_id = self.new_graph_id()
2117        log.debug(
2118            "Recording function %d of graph recording id %d",
2119            function_id.id,
2120            graph_id.id,
2121        )
2122        torch.cuda.synchronize()
2123        node = CUDAGraphNode(
2124            self.ids_to_funcs[function_id],
2125            graph_id,
2126            self.current_node,
2127            new_inputs,
2128            self.cuda_graphs_thread_pool,
2129            self.device_index,
2130            self.ids_to_stack_traces[function_id],
2131            self.stream,
2132        )
2133        if self.current_node is None:
2134            self.roots[function_id].append(node)
2135        else:
2136            self.current_node.add_child(function_id, node)
2137        self.current_node = node
2138        self.path_state = ExecutionState.RECORDING
2139        self.update_generation()
2140        torch.cuda.synchronize()
2141        return node.run_first_inputs(new_inputs)
2142
2143    def execute_node(
2144        self, node: CUDAGraphNode, new_inputs: List[InputType]
2145    ) -> OutputType:
2146        self.current_node = node
2147        self.path_state = ExecutionState.EXECUTION
2148        self.update_generation()
2149        return node.run(new_inputs)
2150
2151    def run_eager(
2152        self, new_inputs: List[InputType], function_id: FunctionID
2153    ) -> OutputType:
2154        # this is only stored on current node, because when we start a new path,
2155        # we will deallocate it
2156        already_warm = function_id in self.warmed_up_functions
2157        if not already_warm:
2158            log.debug("Running warmup of function %d", function_id.id)
2159        else:
2160            log.debug(
2161                "Running eager of function %d because ancestor needed to warm up",
2162                function_id.id,
2163            )
2164        self.warmed_up_functions.add(function_id)
2165        node = CUDAWarmupNode(
2166            self.ids_to_funcs[function_id],
2167            self.current_node,
2168            self.cuda_graphs_thread_pool,
2169            self.graph,
2170            self.device_index,
2171            self.ids_to_stack_traces[function_id],
2172            self.stream,
2173            already_warm,
2174            self.new_warmup_node_id(),
2175        )
2176        self.current_node = node
2177        self.path_state = ExecutionState.WARMUP
2178        self.update_generation()
2179        return node.run(new_inputs)
2180
2181    def new_graph_id(self) -> GraphID:
2182        return GraphID(next(self.graph_counter))
2183
2184    def new_func_id(self) -> FunctionID:
2185        return FunctionID(next(self.func_counter))
2186
2187    def add_function(
2188        self,
2189        model: ModelType,
2190        inputs: List[InputType],
2191        static_input_idxs: Sequence[int],
2192        stack_traces: Optional[StackTraces],
2193        mode: CompilationMode,
2194        constants: Tuple[torch.Tensor, ...],
2195        placeholders: Tuple[PlaceholderInfo, ...],
2196        mutated_input_idxs: Tuple[int, ...],
2197    ) -> Tuple[ModelType, OutputType,]:
2198        id = self.new_func_id()
2199        self.ids_to_stack_traces[id] = stack_traces
2200        self.ids_to_funcs[id] = WrappedFunction(
2201            model,
2202            list(static_input_idxs),
2203            id,
2204            tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda),
2205            placeholders,
2206            mutated_input_idxs,
2207        )
2208        self.id_to_mode[id] = mode
2209        fn = functools.partial(self.run, function_id=id)
2210
2211        # container needs to set clean up when fn dies
2212        get_container(self.device_index).add_strong_reference(fn)
2213        return fn, fn(inputs)
2214
2215    @property
2216    def in_recording(self) -> bool:
2217        return self.path_state == ExecutionState.RECORDING
2218
2219    @property
2220    def in_warmup(self) -> bool:
2221        return self.path_state == ExecutionState.WARMUP
2222
2223    def get_roots(self) -> Iterator[CUDAGraphNode]:
2224        for nodes in self.roots.values():
2225            yield from nodes
2226
2227    @property
2228    def current_node(self) -> Optional[Union[CUDAGraphNode, CUDAWarmupNode]]:
2229        return self._current_node
2230
2231    @current_node.setter
2232    def current_node(
2233        self, value: Optional[Union[CUDAGraphNode, CUDAWarmupNode]]
2234    ) -> None:
2235        self._current_node = value
2236        if value is None:
2237            self.path_state = ExecutionState.NONE
2238
2239    def update_generation(self) -> None:
2240        self.current_gen = self.get_curr_generation()
2241
2242    @staticmethod
2243    def get_curr_generation() -> int:
2244        if MarkStepBox.mark_step_counter != 0:
2245            return MarkStepBox.mark_step_counter
2246
2247        return GenerationTracker.generation
2248
2249    @staticmethod
2250    def user_invoked_mark_step() -> bool:
2251        return MarkStepBox.mark_step_counter != 0
2252
2253    def can_start_new_generation(self) -> bool:
2254        if not self.in_new_torch_compile_invocation():
2255            return False
2256
2257        if self.user_invoked_mark_step():
2258            return True
2259
2260        return not self.running_forwards_with_pending_backwards
2261
2262    def in_new_torch_compile_invocation(self) -> bool:
2263        return self.current_gen != self.get_curr_generation()
2264
2265    def try_end_curr_recording(self, function_id: FunctionID) -> None:
2266        """
2267        Check if the current recording can be terminated, either because all outputs of the
2268        previously recorded node are dead or because it was executed in a different
2269        generation. Will set current_node to None and in_recording to False if successful.
2270        """
2271        assert self.in_recording
2272        assert self.current_node is not None
2273
2274        # multiple invocations, allow overwriting the previous generation
2275        if self.can_start_new_generation():
2276            self.dealloc_current_path_weakrefs()
2277            self.clear_current_path_state_and_set_to_none()
2278            return
2279
2280        if self.current_node.all_outputs_are_dead():
2281            self.clear_current_path_state_and_set_to_none()
2282            return
2283
2284        self.check_warn_on_unable_to_start_executing(function_id)
2285
2286    def try_end_curr_execution(self) -> None:
2287        """
2288        Check if the current executing node can be terminated, either because all outputs of the
2289        previously executed node are dead or because it was executed in a different generation.
2290        Will set current_node to None if successful.
2291        """
2292
2293        assert not self.in_recording
2294        if self.current_node is None:
2295            return
2296
2297        if self.can_start_new_generation():
2298            self.clear_current_path_state_and_set_to_none()
2299            return
2300
2301        if self.current_node.all_outputs_are_dead():
2302            self.clear_current_path_state_and_set_to_none()
2303
2304    def try_end_curr_warmup(self, function_id: FunctionID) -> None:
2305        if self.can_start_new_generation():
2306            self.dealloc_current_path_weakrefs()
2307            self.current_node = None
2308            return
2309
2310        assert self.current_node is not None
2311        if self.current_node.all_outputs_are_dead():
2312            self.current_node = None
2313            return
2314
2315        self.check_warn_on_unable_to_start_executing(function_id)
2316
2317    def check_warn_on_unable_to_start_executing(self, function_id: FunctionID) -> None:
2318        "Warn if we in a potential loop where we are unable to hit fast path"
2319        if (
2320            function_id in self.warned_functions
2321            or not self.in_new_torch_compile_invocation()
2322        ):
2323            return
2324
2325        assert self.current_node is not None
2326        existing_nodes = [
2327            node
2328            for node in self.current_node._path_from_root
2329            if node.wrapped_function.id == function_id
2330        ]
2331
2332        if len(existing_nodes) <= 1:
2333            return
2334
2335        # repeated same pattern
2336        parents = {
2337            n.parent.wrapped_function.id
2338            for n in itertools.chain(existing_nodes, (self.current_node,))
2339            if n.parent is not None
2340        }
2341        if len(parents) == len(existing_nodes):
2342            return
2343
2344        self.warned_functions.add(function_id)
2345        warnings.warn(
2346            "Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. "
2347            "Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() "
2348            "before each model invocation"
2349        )
2350
2351    def dealloc_current_path_weakrefs(self) -> None:
2352        assert self.current_node is not None
2353        # TODO: we could also allow the these weak refs to continue to be allocated,
2354        # but that adds some complications.
2355        for node in self.current_node._path_from_root:
2356            assert node.stack_traces is not None
2357            assert len(node.tensor_weakrefs) == len(node.stack_traces)
2358            for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces):
2359                ten = None if t is None else t()
2360                if ten is None:
2361                    continue
2362
2363                stack_trace = (
2364                    stack_trace.strip()
2365                    if stack_trace
2366                    else "[Could not find stack trace]"
2367                )
2368                msg = (
2369                    "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. "
2370                    f"Stack trace: {stack_trace}. "
2371                    "To prevent overwriting, clone the tensor outside of torch.compile() "
2372                    "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation."
2373                )
2374                torch._C._set_storage_access_error_msg(ten, msg)
2375
2376        deleted = set()
2377        for storage_ref in self.current_node.path_live_weakrefs():
2378            _storage_deref = storage_ref()
2379            if _storage_deref and storage_ref.data_ptr() not in deleted:
2380                deleted.add(storage_ref.data_ptr())
2381                torch._C._free_And_Remove_DeleterFn(_storage_deref)
2382
2383    def clear_current_path_state_and_set_to_none(self) -> None:
2384        assert isinstance(self.current_node, CUDAGraphNode)
2385        self.current_node.clear_path_state()
2386        self.current_node = None
2387
2388    def apply_checkpoint_execution_state_in_allocator(self) -> None:
2389        """
2390        Checkpoint the current execution state in the caching allocator so that
2391        additional cudagraph recordings can be made respecting existent live storages.
2392        """
2393        assert isinstance(self.current_node, CUDAGraphNode)
2394        self.debug_checkpointing_counter += 1
2395        log.debug(
2396            "Checkpointing cuda caching allocator state. Number of checkpoints %d",
2397            self.debug_checkpointing_counter,
2398        )
2399
2400        state = self.current_node.checkpointed_caching_state
2401        device = self.current_node.device
2402        assert state is not None and device is not None
2403
2404        # currently we deallocate on instead of allowing stale recordings
2405        stale_storages: List[int] = []
2406
2407        # remove cached tensors, otherwise they would prevent memory from being
2408        # reclaimed in subsequent recordings
2409        self.current_node.remove_path_cached_tensors()
2410        live_storages_wrappers = list(self.current_node.path_live_weakrefs())
2411
2412        # path_live_weakrefs guarantees that t() will not be None
2413        live_storages_weak_refs: list[int] = [t() for t in live_storages_wrappers]  # type: ignore[misc]
2414        ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation()
2415        torch._C._cuda_setCheckpointPoolState(
2416            device, state, stale_storages, live_storages_weak_refs
2417        )
2418
2419        # NB: deduplicate aliased outputs
2420        for ptr in set(ptrs_to_deallocate):
2421            torch._C._cuda_cudaCachingAllocator_raw_delete(ptr)
2422
2423        # Now the live blocks should be exactly equal to the live storages in private pool
2424        if config.triton.slow_path_cudagraph_asserts:
2425            check_memory_pool(
2426                self.device_index, self.cuda_graphs_thread_pool, live_storages_wrappers
2427            )
2428            for wrapper in live_storages_wrappers:
2429                storage_ptr = wrapper()
2430                assert storage_ptr is not None
2431                assert torch._C._has_Standard_Deleter(storage_ptr)
2432                assert wrapper.data_ptr() not in ptrs_to_deallocate
2433
2434    def live_cudagraph_pool_storages_in_curr_execution(
2435        self,
2436    ) -> List[StorageWeakRefPointer]:
2437        if self.current_node is None:
2438            return []
2439        # explicitly ignoring previous recorded outputs from past path
2440        # path_live_weakrefs() guarantees that t() will not be None
2441        return [t() for t in self.current_node.path_live_weakrefs()]  # type: ignore[misc]
2442