xref: /aosp_15_r20/external/pytorch/torch/cuda/graphs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport gc
3*da0073e9SAndroid Build Coastguard Workerimport typing
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom .._utils import _dummy_type
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerif not hasattr(torch._C, "_CudaStreamBase"):
11*da0073e9SAndroid Build Coastguard Worker    # Define dummy base classes
12*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
13*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
14*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type(
15*da0073e9SAndroid Build Coastguard Worker        "_cuda_isCurrentStreamCapturing"
16*da0073e9SAndroid Build Coastguard Worker    )
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerfrom torch._C import (  # noqa: F401
19*da0073e9SAndroid Build Coastguard Worker    _cuda_isCurrentStreamCapturing,
20*da0073e9SAndroid Build Coastguard Worker    _CUDAGraph,
21*da0073e9SAndroid Build Coastguard Worker    _graph_pool_handle,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerdef is_current_stream_capturing():
26*da0073e9SAndroid Build Coastguard Worker    r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    If a CUDA context does not exist on the current device, returns False without initializing the context.
29*da0073e9SAndroid Build Coastguard Worker    """
30*da0073e9SAndroid Build Coastguard Worker    return _cuda_isCurrentStreamCapturing()
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker# Python shim helps Sphinx process docstrings more reliably.
34*da0073e9SAndroid Build Coastguard Workerdef graph_pool_handle():
35*da0073e9SAndroid Build Coastguard Worker    r"""Return an opaque token representing the id of a graph memory pool.
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    See :ref:`Graph memory management<graph-memory-management>`.
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    .. warning::
40*da0073e9SAndroid Build Coastguard Worker        This API is in beta and may change in future releases.
41*da0073e9SAndroid Build Coastguard Worker    """
42*da0073e9SAndroid Build Coastguard Worker    return _graph_pool_handle()
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker# Python shim helps Sphinx process docstrings more reliably.
46*da0073e9SAndroid Build Coastguard Workerclass CUDAGraph(torch._C._CUDAGraph):
47*da0073e9SAndroid Build Coastguard Worker    r"""Wrapper around a CUDA graph.
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    .. warning::
50*da0073e9SAndroid Build Coastguard Worker        This API is in beta and may change in future releases.
51*da0073e9SAndroid Build Coastguard Worker    """
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    def __new__(cls):
54*da0073e9SAndroid Build Coastguard Worker        return super().__new__(cls)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    def capture_begin(self, pool=None, capture_error_mode="global"):
57*da0073e9SAndroid Build Coastguard Worker        r"""Begin capturing CUDA work on the current stream.
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        Typically, you shouldn't call ``capture_begin`` yourself.
60*da0073e9SAndroid Build Coastguard Worker        Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
61*da0073e9SAndroid Build Coastguard Worker        which call ``capture_begin`` internally.
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        Arguments:
64*da0073e9SAndroid Build Coastguard Worker            pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
65*da0073e9SAndroid Build Coastguard Worker                :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
66*da0073e9SAndroid Build Coastguard Worker                with the indicated pool.  See :ref:`Graph memory management<graph-memory-management>`.
67*da0073e9SAndroid Build Coastguard Worker            capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
68*da0073e9SAndroid Build Coastguard Worker                Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
69*da0073e9SAndroid Build Coastguard Worker                may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
70*da0073e9SAndroid Build Coastguard Worker                actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
71*da0073e9SAndroid Build Coastguard Worker                unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
72*da0073e9SAndroid Build Coastguard Worker        """  # noqa: B950
73*da0073e9SAndroid Build Coastguard Worker        super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def capture_end(self):
76*da0073e9SAndroid Build Coastguard Worker        r"""End CUDA graph capture on the current stream.
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        After ``capture_end``, ``replay`` may be called on this instance.
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        Typically, you shouldn't call ``capture_end`` yourself.
81*da0073e9SAndroid Build Coastguard Worker        Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
82*da0073e9SAndroid Build Coastguard Worker        which call ``capture_end`` internally.
83*da0073e9SAndroid Build Coastguard Worker        """
84*da0073e9SAndroid Build Coastguard Worker        super().capture_end()
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def replay(self):
87*da0073e9SAndroid Build Coastguard Worker        r"""Replay the CUDA work captured by this graph."""
88*da0073e9SAndroid Build Coastguard Worker        super().replay()
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    def reset(self):
91*da0073e9SAndroid Build Coastguard Worker        r"""Delete the graph currently held by this instance."""
92*da0073e9SAndroid Build Coastguard Worker        super().reset()
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    def pool(self):
95*da0073e9SAndroid Build Coastguard Worker        r"""Return an opaque token representing the id of this graph's memory pool.
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        This id can optionally be passed to another graph's ``capture_begin``,
98*da0073e9SAndroid Build Coastguard Worker        which hints the other graph may share the same memory pool.
99*da0073e9SAndroid Build Coastguard Worker        """
100*da0073e9SAndroid Build Coastguard Worker        return super().pool()
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    def enable_debug_mode(self):
103*da0073e9SAndroid Build Coastguard Worker        r"""Enable debugging mode for CUDAGraph.debug_dump."""
104*da0073e9SAndroid Build Coastguard Worker        return super().enable_debug_mode()
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    def debug_dump(self, debug_path):
107*da0073e9SAndroid Build Coastguard Worker        r"""
108*da0073e9SAndroid Build Coastguard Worker        Arguments:
109*da0073e9SAndroid Build Coastguard Worker            debug_path (required): Path to dump the graph to.
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker        Calls a debugging function to dump the graph if the debugging is
112*da0073e9SAndroid Build Coastguard Worker        enabled via CUDAGraph.enable_debug_mode()
113*da0073e9SAndroid Build Coastguard Worker        """
114*da0073e9SAndroid Build Coastguard Worker        return super().debug_dump(debug_path)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Workerclass graph:
118*da0073e9SAndroid Build Coastguard Worker    r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay.
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
121*da0073e9SAndroid Build Coastguard Worker    detailed use, and constraints.
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker    Arguments:
124*da0073e9SAndroid Build Coastguard Worker        cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
125*da0073e9SAndroid Build Coastguard Worker        pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
126*da0073e9SAndroid Build Coastguard Worker            :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
127*da0073e9SAndroid Build Coastguard Worker            may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
128*da0073e9SAndroid Build Coastguard Worker        stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
129*da0073e9SAndroid Build Coastguard Worker            If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
130*da0073e9SAndroid Build Coastguard Worker        capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
131*da0073e9SAndroid Build Coastguard Worker            Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
132*da0073e9SAndroid Build Coastguard Worker            may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
133*da0073e9SAndroid Build Coastguard Worker            actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
134*da0073e9SAndroid Build Coastguard Worker            unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker    .. note::
137*da0073e9SAndroid Build Coastguard Worker        For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
138*da0073e9SAndroid Build Coastguard Worker        used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker    .. warning::
141*da0073e9SAndroid Build Coastguard Worker        This API is in beta and may change in future releases.
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker    .. _cudaStreamCaptureMode:
144*da0073e9SAndroid Build Coastguard Worker        https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
145*da0073e9SAndroid Build Coastguard Worker    """  # noqa: B950
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker    default_capture_stream: typing.Optional["torch.cuda.Stream"] = None
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    def __init__(
150*da0073e9SAndroid Build Coastguard Worker        self,
151*da0073e9SAndroid Build Coastguard Worker        cuda_graph,
152*da0073e9SAndroid Build Coastguard Worker        pool=None,
153*da0073e9SAndroid Build Coastguard Worker        stream=None,
154*da0073e9SAndroid Build Coastguard Worker        capture_error_mode: str = "global",
155*da0073e9SAndroid Build Coastguard Worker    ):
156*da0073e9SAndroid Build Coastguard Worker        # Lazy-init of default_capture_stream helps avoid circular-import errors.
157*da0073e9SAndroid Build Coastguard Worker        # Not thread safe, but graphs already have the general (explicitly documented)
158*da0073e9SAndroid Build Coastguard Worker        # restriction that only one capture may be underway at a time in the process.
159*da0073e9SAndroid Build Coastguard Worker        if self.__class__.default_capture_stream is None:
160*da0073e9SAndroid Build Coastguard Worker            self.__class__.default_capture_stream = torch.cuda.Stream()
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker        self.pool = () if pool is None else (pool,)
163*da0073e9SAndroid Build Coastguard Worker        self.capture_stream = (
164*da0073e9SAndroid Build Coastguard Worker            stream if stream is not None else self.__class__.default_capture_stream
165*da0073e9SAndroid Build Coastguard Worker        )
166*da0073e9SAndroid Build Coastguard Worker        assert self.capture_stream is not None
167*da0073e9SAndroid Build Coastguard Worker        self.stream_ctx = torch.cuda.stream(self.capture_stream)
168*da0073e9SAndroid Build Coastguard Worker        self.cuda_graph = cuda_graph
169*da0073e9SAndroid Build Coastguard Worker        self.capture_error_mode = capture_error_mode
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
172*da0073e9SAndroid Build Coastguard Worker        # Free as much memory as we can for the graph
173*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
174*da0073e9SAndroid Build Coastguard Worker        gc.collect()
175*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        # Stackoverflow seems comfortable with this pattern
178*da0073e9SAndroid Build Coastguard Worker        # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
179*da0073e9SAndroid Build Coastguard Worker        self.stream_ctx.__enter__()
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker        self.cuda_graph.capture_begin(
182*da0073e9SAndroid Build Coastguard Worker            *self.pool, capture_error_mode=self.capture_error_mode
183*da0073e9SAndroid Build Coastguard Worker        )
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, exc_type, exc_value, traceback):
186*da0073e9SAndroid Build Coastguard Worker        self.cuda_graph.capture_end()
187*da0073e9SAndroid Build Coastguard Worker        self.stream_ctx.__exit__(exc_type, exc_value, traceback)
188*da0073e9SAndroid Build Coastguard Worker        # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Workerdef make_graphed_callables(
192*da0073e9SAndroid Build Coastguard Worker    callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
193*da0073e9SAndroid Build Coastguard Worker):
194*da0073e9SAndroid Build Coastguard Worker    r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker    Each graphed callable's forward pass runs its source callable's
197*da0073e9SAndroid Build Coastguard Worker    forward CUDA work as a CUDA graph inside a single autograd node.
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker    The graphed callable's forward pass also appends
200*da0073e9SAndroid Build Coastguard Worker    a backward node to the autograd graph. During backward, this node runs the
201*da0073e9SAndroid Build Coastguard Worker    callable's backward work as a CUDA graph.
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker    Therefore, each graphed callable should be a drop-in replacement for its source callable
204*da0073e9SAndroid Build Coastguard Worker    in an autograd-enabled training loop.
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker    See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    If you pass a tuple of several callables, their captures will use the same memory pool.
209*da0073e9SAndroid Build Coastguard Worker    See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    Arguments:
212*da0073e9SAndroid Build Coastguard Worker        callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
213*da0073e9SAndroid Build Coastguard Worker            See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
214*da0073e9SAndroid Build Coastguard Worker            is appropriate.  If you pass a tuple of callables, their order in the tuple must be the same order
215*da0073e9SAndroid Build Coastguard Worker            they'll run in the live workload.
216*da0073e9SAndroid Build Coastguard Worker        sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
217*da0073e9SAndroid Build Coastguard Worker            If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
218*da0073e9SAndroid Build Coastguard Worker            If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
219*da0073e9SAndroid Build Coastguard Worker        num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
220*da0073e9SAndroid Build Coastguard Worker            11 iterations for warm up. Default: ``3``.
221*da0073e9SAndroid Build Coastguard Worker        allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
222*da0073e9SAndroid Build Coastguard Worker            (and therefore their grad is always zero) is an error. Defaults to False.
223*da0073e9SAndroid Build Coastguard Worker        pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
224*da0073e9SAndroid Build Coastguard Worker            :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
225*da0073e9SAndroid Build Coastguard Worker            with the indicated pool.  See :ref:`Graph memory management<graph-memory-management>`.
226*da0073e9SAndroid Build Coastguard Worker    .. note::
227*da0073e9SAndroid Build Coastguard Worker        The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
228*da0073e9SAndroid Build Coastguard Worker        that's expected for the corresponding real input in the training loop.
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    .. warning::
231*da0073e9SAndroid Build Coastguard Worker        This API is in beta and may change in future releases.
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker    .. warning::
234*da0073e9SAndroid Build Coastguard Worker        ``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker    .. warning::
237*da0073e9SAndroid Build Coastguard Worker        Returned callables do not support higher order differentiation (e.g., double backward).
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    .. warning::
240*da0073e9SAndroid Build Coastguard Worker        In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
241*da0073e9SAndroid Build Coastguard Worker        may be trainable. Buffers must have ``requires_grad=False``.
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    .. warning::
244*da0073e9SAndroid Build Coastguard Worker        After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
245*da0073e9SAndroid Build Coastguard Worker        you may not add or remove any of that Module's parameters or buffers.
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker    .. warning::
248*da0073e9SAndroid Build Coastguard Worker        :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
249*da0073e9SAndroid Build Coastguard Worker        registered on them at the time they are passed. However, registering hooks on modules *after* passing them
250*da0073e9SAndroid Build Coastguard Worker        through :func:`~torch.cuda.make_graphed_callables` is allowed.
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker    .. warning::
253*da0073e9SAndroid Build Coastguard Worker        When running a graphed callable, you must pass its arguments in the same order and format
254*da0073e9SAndroid Build Coastguard Worker        they appeared in that callable's ``sample_args``.
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker    .. warning::
257*da0073e9SAndroid Build Coastguard Worker        The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
258*da0073e9SAndroid Build Coastguard Worker        caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
259*da0073e9SAndroid Build Coastguard Worker    """
260*da0073e9SAndroid Build Coastguard Worker    if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
261*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
262*da0073e9SAndroid Build Coastguard Worker            "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
263*da0073e9SAndroid Build Coastguard Worker        )
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    just_one_callable = False
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    if not isinstance(callables, tuple):
268*da0073e9SAndroid Build Coastguard Worker        just_one_callable = True
269*da0073e9SAndroid Build Coastguard Worker        callables = (callables,)
270*da0073e9SAndroid Build Coastguard Worker        sample_args = (sample_args,)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker    flatten_sample_args = []
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker    for c, args in zip(callables, sample_args):
275*da0073e9SAndroid Build Coastguard Worker        if isinstance(c, torch.nn.Module):
276*da0073e9SAndroid Build Coastguard Worker            assert (
277*da0073e9SAndroid Build Coastguard Worker                len(c._backward_hooks) == 0
278*da0073e9SAndroid Build Coastguard Worker                and len(c._forward_hooks) == 0
279*da0073e9SAndroid Build Coastguard Worker                and len(c._forward_pre_hooks) == 0
280*da0073e9SAndroid Build Coastguard Worker            ), (
281*da0073e9SAndroid Build Coastguard Worker                "Modules must not have hooks registered at the time they are passed. However, registering hooks "
282*da0073e9SAndroid Build Coastguard Worker                + "on modules after passing them through make_graphed_callables is allowed."
283*da0073e9SAndroid Build Coastguard Worker            )
284*da0073e9SAndroid Build Coastguard Worker            assert all(b.requires_grad is False for b in c.buffers()), (
285*da0073e9SAndroid Build Coastguard Worker                "In any :class:`~torch.nn.Module` passed to "
286*da0073e9SAndroid Build Coastguard Worker                + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
287*da0073e9SAndroid Build Coastguard Worker                + "``requires_grad=False``."
288*da0073e9SAndroid Build Coastguard Worker            )
289*da0073e9SAndroid Build Coastguard Worker        flatten_arg = torch.utils._pytree.arg_tree_leaves(*args)
290*da0073e9SAndroid Build Coastguard Worker        flatten_sample_args.append(tuple(flatten_arg))
291*da0073e9SAndroid Build Coastguard Worker        assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
292*da0073e9SAndroid Build Coastguard Worker            "In the beta API, sample_args "
293*da0073e9SAndroid Build Coastguard Worker            + "for each callable must contain only Tensors. Other types are not allowed."
294*da0073e9SAndroid Build Coastguard Worker        )
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker    # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
297*da0073e9SAndroid Build Coastguard Worker    # passes to forward (ie, its sample_args) AND the module's parameter attributes.
298*da0073e9SAndroid Build Coastguard Worker    per_callable_len_user_args = [len(args) for args in flatten_sample_args]
299*da0073e9SAndroid Build Coastguard Worker    per_callable_module_params = [
300*da0073e9SAndroid Build Coastguard Worker        tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
301*da0073e9SAndroid Build Coastguard Worker        for c in callables
302*da0073e9SAndroid Build Coastguard Worker    ]
303*da0073e9SAndroid Build Coastguard Worker    per_callable_static_input_surfaces = [
304*da0073e9SAndroid Build Coastguard Worker        flatten_sample_args[i] + per_callable_module_params[i]
305*da0073e9SAndroid Build Coastguard Worker        for i in range(len(callables))
306*da0073e9SAndroid Build Coastguard Worker    ]
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker    fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
309*da0073e9SAndroid Build Coastguard Worker    bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker    mempool = graph_pool_handle() if pool is None else pool
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker    # Warmup
314*da0073e9SAndroid Build Coastguard Worker    # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
315*da0073e9SAndroid Build Coastguard Worker    # from ending up in any captures.
316*da0073e9SAndroid Build Coastguard Worker    torch.cuda.synchronize()
317*da0073e9SAndroid Build Coastguard Worker    with torch.cuda.stream(torch.cuda.Stream()):
318*da0073e9SAndroid Build Coastguard Worker        for func, args, static_input_surface in zip(
319*da0073e9SAndroid Build Coastguard Worker            callables, sample_args, per_callable_static_input_surfaces
320*da0073e9SAndroid Build Coastguard Worker        ):
321*da0073e9SAndroid Build Coastguard Worker            grad_inputs, outputs, outputs_grad = None, None, None
322*da0073e9SAndroid Build Coastguard Worker            for _ in range(num_warmup_iters):
323*da0073e9SAndroid Build Coastguard Worker                outputs = torch.utils._pytree.tree_leaves(func(*args))
324*da0073e9SAndroid Build Coastguard Worker                outputs_grad = tuple(o for o in outputs if o.requires_grad)
325*da0073e9SAndroid Build Coastguard Worker                if len(outputs_grad) > 0:
326*da0073e9SAndroid Build Coastguard Worker                    grad_inputs = torch.autograd.grad(
327*da0073e9SAndroid Build Coastguard Worker                        outputs=outputs_grad,
328*da0073e9SAndroid Build Coastguard Worker                        inputs=tuple(
329*da0073e9SAndroid Build Coastguard Worker                            i for i in static_input_surface if i.requires_grad
330*da0073e9SAndroid Build Coastguard Worker                        ),
331*da0073e9SAndroid Build Coastguard Worker                        grad_outputs=tuple(
332*da0073e9SAndroid Build Coastguard Worker                            torch.empty_like(o) for o in outputs if o.requires_grad
333*da0073e9SAndroid Build Coastguard Worker                        ),
334*da0073e9SAndroid Build Coastguard Worker                        only_inputs=True,
335*da0073e9SAndroid Build Coastguard Worker                        allow_unused=allow_unused_input,
336*da0073e9SAndroid Build Coastguard Worker                    )
337*da0073e9SAndroid Build Coastguard Worker            for v in [outputs, outputs_grad, grad_inputs]:
338*da0073e9SAndroid Build Coastguard Worker                del v
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker    torch.cuda.synchronize()
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker    # All captures here share a mempool. To avoid replays corrupting each other's memory,
343*da0073e9SAndroid Build Coastguard Worker    # the safest approach is to capture all passes in the same order they'll run:
344*da0073e9SAndroid Build Coastguard Worker    # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker    # Capture forward graphs
347*da0073e9SAndroid Build Coastguard Worker    per_callable_static_outputs = []
348*da0073e9SAndroid Build Coastguard Worker    per_callable_output_unflatten_spec = []
349*da0073e9SAndroid Build Coastguard Worker    for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
350*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.graph(fwd_graph, pool=mempool):
351*da0073e9SAndroid Build Coastguard Worker            outputs = func(*args)
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker        flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs)
354*da0073e9SAndroid Build Coastguard Worker        per_callable_static_outputs.append(tuple(flatten_outputs))
355*da0073e9SAndroid Build Coastguard Worker        per_callable_output_unflatten_spec.append(spec)
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker    # Capture backward graphs in reverse order
358*da0073e9SAndroid Build Coastguard Worker    per_callable_static_grad_outputs = []
359*da0073e9SAndroid Build Coastguard Worker    per_callable_static_grad_inputs = []
360*da0073e9SAndroid Build Coastguard Worker    for static_input_surface, static_outputs, bwd_graph, module_params in zip(
361*da0073e9SAndroid Build Coastguard Worker        reversed(per_callable_static_input_surfaces),
362*da0073e9SAndroid Build Coastguard Worker        reversed(per_callable_static_outputs),
363*da0073e9SAndroid Build Coastguard Worker        reversed(bwd_graphs),
364*da0073e9SAndroid Build Coastguard Worker        reversed(per_callable_module_params),
365*da0073e9SAndroid Build Coastguard Worker    ):
366*da0073e9SAndroid Build Coastguard Worker        # For now, assumes all static_outputs require grad
367*da0073e9SAndroid Build Coastguard Worker        # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
368*da0073e9SAndroid Build Coastguard Worker        static_grad_outputs = tuple(
369*da0073e9SAndroid Build Coastguard Worker            torch.empty_like(o) if o.requires_grad else None for o in static_outputs
370*da0073e9SAndroid Build Coastguard Worker        )
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker        outputs_grad = tuple(o for o in static_outputs if o.requires_grad)
373*da0073e9SAndroid Build Coastguard Worker        grad_inputs = None
374*da0073e9SAndroid Build Coastguard Worker        if len(outputs_grad) > 0:
375*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.graph(bwd_graph, pool=mempool):
376*da0073e9SAndroid Build Coastguard Worker                grad_inputs = torch.autograd.grad(
377*da0073e9SAndroid Build Coastguard Worker                    outputs=outputs_grad,
378*da0073e9SAndroid Build Coastguard Worker                    inputs=tuple(i for i in static_input_surface if i.requires_grad),
379*da0073e9SAndroid Build Coastguard Worker                    grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
380*da0073e9SAndroid Build Coastguard Worker                    only_inputs=True,
381*da0073e9SAndroid Build Coastguard Worker                    allow_unused=allow_unused_input,
382*da0073e9SAndroid Build Coastguard Worker                )
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker        # Constructs a tuple suitable for returning from Graphed.backward:
385*da0073e9SAndroid Build Coastguard Worker        # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
386*da0073e9SAndroid Build Coastguard Worker        # I couldn't think of a slick one-liner for this pattern.
387*da0073e9SAndroid Build Coastguard Worker        static_grad_inputs = []
388*da0073e9SAndroid Build Coastguard Worker        grad_idx = 0
389*da0073e9SAndroid Build Coastguard Worker        for arg in static_input_surface:
390*da0073e9SAndroid Build Coastguard Worker            if arg.requires_grad and grad_inputs is not None:
391*da0073e9SAndroid Build Coastguard Worker                static_grad_inputs.append(grad_inputs[grad_idx])
392*da0073e9SAndroid Build Coastguard Worker                grad_idx += 1
393*da0073e9SAndroid Build Coastguard Worker            else:
394*da0073e9SAndroid Build Coastguard Worker                static_grad_inputs.append(None)  # type: ignore[arg-type]
395*da0073e9SAndroid Build Coastguard Worker        static_grad_inputs = tuple(static_grad_inputs)  # type: ignore[assignment]
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker        per_callable_static_grad_outputs.append(static_grad_outputs)
398*da0073e9SAndroid Build Coastguard Worker        per_callable_static_grad_inputs.append(static_grad_inputs)
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker    # Reverses the most recent two lists
401*da0073e9SAndroid Build Coastguard Worker    per_callable_static_grad_outputs.reverse()
402*da0073e9SAndroid Build Coastguard Worker    per_callable_static_grad_inputs.reverse()
403*da0073e9SAndroid Build Coastguard Worker    # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker    def make_graphed_autograd_function(
406*da0073e9SAndroid Build Coastguard Worker        fwd_graph,
407*da0073e9SAndroid Build Coastguard Worker        bwd_graph,
408*da0073e9SAndroid Build Coastguard Worker        module_params,
409*da0073e9SAndroid Build Coastguard Worker        len_user_args,
410*da0073e9SAndroid Build Coastguard Worker        output_unflatten_spec,
411*da0073e9SAndroid Build Coastguard Worker        static_input_surface,
412*da0073e9SAndroid Build Coastguard Worker        static_outputs,
413*da0073e9SAndroid Build Coastguard Worker        static_grad_outputs,
414*da0073e9SAndroid Build Coastguard Worker        static_grad_inputs,
415*da0073e9SAndroid Build Coastguard Worker    ):
416*da0073e9SAndroid Build Coastguard Worker        class Graphed(torch.autograd.Function):
417*da0073e9SAndroid Build Coastguard Worker            @staticmethod
418*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, *inputs):
419*da0073e9SAndroid Build Coastguard Worker                # At this stage, only the user args may (potentially) be new tensors.
420*da0073e9SAndroid Build Coastguard Worker                for i in range(len_user_args):
421*da0073e9SAndroid Build Coastguard Worker                    if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
422*da0073e9SAndroid Build Coastguard Worker                        static_input_surface[i].copy_(inputs[i])
423*da0073e9SAndroid Build Coastguard Worker                fwd_graph.replay()
424*da0073e9SAndroid Build Coastguard Worker                assert isinstance(static_outputs, tuple)
425*da0073e9SAndroid Build Coastguard Worker                return tuple(o.detach() for o in static_outputs)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker            @staticmethod
428*da0073e9SAndroid Build Coastguard Worker            @torch.autograd.function.once_differentiable
429*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, *grads):
430*da0073e9SAndroid Build Coastguard Worker                assert len(grads) == len(static_grad_outputs)
431*da0073e9SAndroid Build Coastguard Worker                for g, grad in zip(static_grad_outputs, grads):
432*da0073e9SAndroid Build Coastguard Worker                    if g is not None:
433*da0073e9SAndroid Build Coastguard Worker                        # don't copy if autograd gods have been kind and the
434*da0073e9SAndroid Build Coastguard Worker                        # incoming grad is already in the right place
435*da0073e9SAndroid Build Coastguard Worker                        if g.data_ptr() != grad.data_ptr():
436*da0073e9SAndroid Build Coastguard Worker                            g.copy_(grad)
437*da0073e9SAndroid Build Coastguard Worker                bwd_graph.replay()
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker                # Input args that didn't require grad expect a None gradient.
440*da0073e9SAndroid Build Coastguard Worker                assert isinstance(static_grad_inputs, tuple)
441*da0073e9SAndroid Build Coastguard Worker                return tuple(
442*da0073e9SAndroid Build Coastguard Worker                    b.detach() if b is not None else b for b in static_grad_inputs
443*da0073e9SAndroid Build Coastguard Worker                )
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker        def functionalized(*user_args):
446*da0073e9SAndroid Build Coastguard Worker            # Runs the autograd function with inputs == all inputs to the graph that might require grad
447*da0073e9SAndroid Build Coastguard Worker            # (explicit user args + module parameters)
448*da0073e9SAndroid Build Coastguard Worker            # Assumes module params didn't change since capture.
449*da0073e9SAndroid Build Coastguard Worker            flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args)
450*da0073e9SAndroid Build Coastguard Worker            out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
451*da0073e9SAndroid Build Coastguard Worker            return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec)
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker        return functionalized
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker    # Put together the final graphed callables
456*da0073e9SAndroid Build Coastguard Worker    ret = []
457*da0073e9SAndroid Build Coastguard Worker    for i, func in enumerate(callables):
458*da0073e9SAndroid Build Coastguard Worker        graphed = make_graphed_autograd_function(
459*da0073e9SAndroid Build Coastguard Worker            fwd_graphs[i],
460*da0073e9SAndroid Build Coastguard Worker            bwd_graphs[i],
461*da0073e9SAndroid Build Coastguard Worker            per_callable_module_params[i],
462*da0073e9SAndroid Build Coastguard Worker            per_callable_len_user_args[i],
463*da0073e9SAndroid Build Coastguard Worker            per_callable_output_unflatten_spec[i],
464*da0073e9SAndroid Build Coastguard Worker            per_callable_static_input_surfaces[i],
465*da0073e9SAndroid Build Coastguard Worker            per_callable_static_outputs[i],
466*da0073e9SAndroid Build Coastguard Worker            per_callable_static_grad_outputs[i],
467*da0073e9SAndroid Build Coastguard Worker            per_callable_static_grad_inputs[i],
468*da0073e9SAndroid Build Coastguard Worker        )
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker        if isinstance(func, torch.nn.Module):
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker            def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
473*da0073e9SAndroid Build Coastguard Worker                def new_fwd(*user_args):
474*da0073e9SAndroid Build Coastguard Worker                    # If the module's training-or-eval state matches what we graphed,
475*da0073e9SAndroid Build Coastguard Worker                    # run the graph, otherwise run the original forward method
476*da0073e9SAndroid Build Coastguard Worker                    if func.training == graph_training_state:
477*da0073e9SAndroid Build Coastguard Worker                        return graphed(*user_args)
478*da0073e9SAndroid Build Coastguard Worker                    else:
479*da0073e9SAndroid Build Coastguard Worker                        return orig_fwd(*user_args)
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker                return new_fwd
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker            func.forward = make_graphed_forward(func, func.training, graphed, func.forward)  # type: ignore[assignment]
484*da0073e9SAndroid Build Coastguard Worker            ret.append(func)
485*da0073e9SAndroid Build Coastguard Worker        else:
486*da0073e9SAndroid Build Coastguard Worker            ret.append(graphed)
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker    if just_one_callable:
489*da0073e9SAndroid Build Coastguard Worker        return ret[0]
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    return tuple(ret)
492