1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport gc 3*da0073e9SAndroid Build Coastguard Workerimport typing 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom .._utils import _dummy_type 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerif not hasattr(torch._C, "_CudaStreamBase"): 11*da0073e9SAndroid Build Coastguard Worker # Define dummy base classes 12*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph") 13*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle") 14*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type( 15*da0073e9SAndroid Build Coastguard Worker "_cuda_isCurrentStreamCapturing" 16*da0073e9SAndroid Build Coastguard Worker ) 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerfrom torch._C import ( # noqa: F401 19*da0073e9SAndroid Build Coastguard Worker _cuda_isCurrentStreamCapturing, 20*da0073e9SAndroid Build Coastguard Worker _CUDAGraph, 21*da0073e9SAndroid Build Coastguard Worker _graph_pool_handle, 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerdef is_current_stream_capturing(): 26*da0073e9SAndroid Build Coastguard Worker r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise. 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker If a CUDA context does not exist on the current device, returns False without initializing the context. 29*da0073e9SAndroid Build Coastguard Worker """ 30*da0073e9SAndroid Build Coastguard Worker return _cuda_isCurrentStreamCapturing() 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker# Python shim helps Sphinx process docstrings more reliably. 34*da0073e9SAndroid Build Coastguard Workerdef graph_pool_handle(): 35*da0073e9SAndroid Build Coastguard Worker r"""Return an opaque token representing the id of a graph memory pool. 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker See :ref:`Graph memory management<graph-memory-management>`. 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker .. warning:: 40*da0073e9SAndroid Build Coastguard Worker This API is in beta and may change in future releases. 41*da0073e9SAndroid Build Coastguard Worker """ 42*da0073e9SAndroid Build Coastguard Worker return _graph_pool_handle() 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker# Python shim helps Sphinx process docstrings more reliably. 46*da0073e9SAndroid Build Coastguard Workerclass CUDAGraph(torch._C._CUDAGraph): 47*da0073e9SAndroid Build Coastguard Worker r"""Wrapper around a CUDA graph. 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker .. warning:: 50*da0073e9SAndroid Build Coastguard Worker This API is in beta and may change in future releases. 51*da0073e9SAndroid Build Coastguard Worker """ 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker def __new__(cls): 54*da0073e9SAndroid Build Coastguard Worker return super().__new__(cls) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker def capture_begin(self, pool=None, capture_error_mode="global"): 57*da0073e9SAndroid Build Coastguard Worker r"""Begin capturing CUDA work on the current stream. 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker Typically, you shouldn't call ``capture_begin`` yourself. 60*da0073e9SAndroid Build Coastguard Worker Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`, 61*da0073e9SAndroid Build Coastguard Worker which call ``capture_begin`` internally. 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker Arguments: 64*da0073e9SAndroid Build Coastguard Worker pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or 65*da0073e9SAndroid Build Coastguard Worker :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory 66*da0073e9SAndroid Build Coastguard Worker with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`. 67*da0073e9SAndroid Build Coastguard Worker capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. 68*da0073e9SAndroid Build Coastguard Worker Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, 69*da0073e9SAndroid Build Coastguard Worker may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for 70*da0073e9SAndroid Build Coastguard Worker actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting 71*da0073e9SAndroid Build Coastguard Worker unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_ 72*da0073e9SAndroid Build Coastguard Worker """ # noqa: B950 73*da0073e9SAndroid Build Coastguard Worker super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def capture_end(self): 76*da0073e9SAndroid Build Coastguard Worker r"""End CUDA graph capture on the current stream. 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker After ``capture_end``, ``replay`` may be called on this instance. 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker Typically, you shouldn't call ``capture_end`` yourself. 81*da0073e9SAndroid Build Coastguard Worker Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`, 82*da0073e9SAndroid Build Coastguard Worker which call ``capture_end`` internally. 83*da0073e9SAndroid Build Coastguard Worker """ 84*da0073e9SAndroid Build Coastguard Worker super().capture_end() 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def replay(self): 87*da0073e9SAndroid Build Coastguard Worker r"""Replay the CUDA work captured by this graph.""" 88*da0073e9SAndroid Build Coastguard Worker super().replay() 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker def reset(self): 91*da0073e9SAndroid Build Coastguard Worker r"""Delete the graph currently held by this instance.""" 92*da0073e9SAndroid Build Coastguard Worker super().reset() 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker def pool(self): 95*da0073e9SAndroid Build Coastguard Worker r"""Return an opaque token representing the id of this graph's memory pool. 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker This id can optionally be passed to another graph's ``capture_begin``, 98*da0073e9SAndroid Build Coastguard Worker which hints the other graph may share the same memory pool. 99*da0073e9SAndroid Build Coastguard Worker """ 100*da0073e9SAndroid Build Coastguard Worker return super().pool() 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def enable_debug_mode(self): 103*da0073e9SAndroid Build Coastguard Worker r"""Enable debugging mode for CUDAGraph.debug_dump.""" 104*da0073e9SAndroid Build Coastguard Worker return super().enable_debug_mode() 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def debug_dump(self, debug_path): 107*da0073e9SAndroid Build Coastguard Worker r""" 108*da0073e9SAndroid Build Coastguard Worker Arguments: 109*da0073e9SAndroid Build Coastguard Worker debug_path (required): Path to dump the graph to. 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker Calls a debugging function to dump the graph if the debugging is 112*da0073e9SAndroid Build Coastguard Worker enabled via CUDAGraph.enable_debug_mode() 113*da0073e9SAndroid Build Coastguard Worker """ 114*da0073e9SAndroid Build Coastguard Worker return super().debug_dump(debug_path) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Workerclass graph: 118*da0073e9SAndroid Build Coastguard Worker r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay. 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction, 121*da0073e9SAndroid Build Coastguard Worker detailed use, and constraints. 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker Arguments: 124*da0073e9SAndroid Build Coastguard Worker cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture. 125*da0073e9SAndroid Build Coastguard Worker pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or 126*da0073e9SAndroid Build Coastguard Worker :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture 127*da0073e9SAndroid Build Coastguard Worker may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`. 128*da0073e9SAndroid Build Coastguard Worker stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context. 129*da0073e9SAndroid Build Coastguard Worker If not supplied, ``graph`` sets its own internal side stream as the current stream in the context. 130*da0073e9SAndroid Build Coastguard Worker capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. 131*da0073e9SAndroid Build Coastguard Worker Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, 132*da0073e9SAndroid Build Coastguard Worker may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for 133*da0073e9SAndroid Build Coastguard Worker actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting 134*da0073e9SAndroid Build Coastguard Worker unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_ 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker .. note:: 137*da0073e9SAndroid Build Coastguard Worker For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture 138*da0073e9SAndroid Build Coastguard Worker used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture. 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker .. warning:: 141*da0073e9SAndroid Build Coastguard Worker This API is in beta and may change in future releases. 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker .. _cudaStreamCaptureMode: 144*da0073e9SAndroid Build Coastguard Worker https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 145*da0073e9SAndroid Build Coastguard Worker """ # noqa: B950 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker default_capture_stream: typing.Optional["torch.cuda.Stream"] = None 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker def __init__( 150*da0073e9SAndroid Build Coastguard Worker self, 151*da0073e9SAndroid Build Coastguard Worker cuda_graph, 152*da0073e9SAndroid Build Coastguard Worker pool=None, 153*da0073e9SAndroid Build Coastguard Worker stream=None, 154*da0073e9SAndroid Build Coastguard Worker capture_error_mode: str = "global", 155*da0073e9SAndroid Build Coastguard Worker ): 156*da0073e9SAndroid Build Coastguard Worker # Lazy-init of default_capture_stream helps avoid circular-import errors. 157*da0073e9SAndroid Build Coastguard Worker # Not thread safe, but graphs already have the general (explicitly documented) 158*da0073e9SAndroid Build Coastguard Worker # restriction that only one capture may be underway at a time in the process. 159*da0073e9SAndroid Build Coastguard Worker if self.__class__.default_capture_stream is None: 160*da0073e9SAndroid Build Coastguard Worker self.__class__.default_capture_stream = torch.cuda.Stream() 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker self.pool = () if pool is None else (pool,) 163*da0073e9SAndroid Build Coastguard Worker self.capture_stream = ( 164*da0073e9SAndroid Build Coastguard Worker stream if stream is not None else self.__class__.default_capture_stream 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker assert self.capture_stream is not None 167*da0073e9SAndroid Build Coastguard Worker self.stream_ctx = torch.cuda.stream(self.capture_stream) 168*da0073e9SAndroid Build Coastguard Worker self.cuda_graph = cuda_graph 169*da0073e9SAndroid Build Coastguard Worker self.capture_error_mode = capture_error_mode 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 172*da0073e9SAndroid Build Coastguard Worker # Free as much memory as we can for the graph 173*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 174*da0073e9SAndroid Build Coastguard Worker gc.collect() 175*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker # Stackoverflow seems comfortable with this pattern 178*da0073e9SAndroid Build Coastguard Worker # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487 179*da0073e9SAndroid Build Coastguard Worker self.stream_ctx.__enter__() 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker self.cuda_graph.capture_begin( 182*da0073e9SAndroid Build Coastguard Worker *self.pool, capture_error_mode=self.capture_error_mode 183*da0073e9SAndroid Build Coastguard Worker ) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_value, traceback): 186*da0073e9SAndroid Build Coastguard Worker self.cuda_graph.capture_end() 187*da0073e9SAndroid Build Coastguard Worker self.stream_ctx.__exit__(exc_type, exc_value, traceback) 188*da0073e9SAndroid Build Coastguard Worker # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Workerdef make_graphed_callables( 192*da0073e9SAndroid Build Coastguard Worker callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None 193*da0073e9SAndroid Build Coastguard Worker): 194*da0073e9SAndroid Build Coastguard Worker r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions. 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker Each graphed callable's forward pass runs its source callable's 197*da0073e9SAndroid Build Coastguard Worker forward CUDA work as a CUDA graph inside a single autograd node. 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker The graphed callable's forward pass also appends 200*da0073e9SAndroid Build Coastguard Worker a backward node to the autograd graph. During backward, this node runs the 201*da0073e9SAndroid Build Coastguard Worker callable's backward work as a CUDA graph. 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker Therefore, each graphed callable should be a drop-in replacement for its source callable 204*da0073e9SAndroid Build Coastguard Worker in an autograd-enabled training loop. 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints. 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker If you pass a tuple of several callables, their captures will use the same memory pool. 209*da0073e9SAndroid Build Coastguard Worker See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate. 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker Arguments: 212*da0073e9SAndroid Build Coastguard Worker callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph. 213*da0073e9SAndroid Build Coastguard Worker See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables 214*da0073e9SAndroid Build Coastguard Worker is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order 215*da0073e9SAndroid Build Coastguard Worker they'll run in the live workload. 216*da0073e9SAndroid Build Coastguard Worker sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable. 217*da0073e9SAndroid Build Coastguard Worker If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors. 218*da0073e9SAndroid Build Coastguard Worker If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors. 219*da0073e9SAndroid Build Coastguard Worker num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs 220*da0073e9SAndroid Build Coastguard Worker 11 iterations for warm up. Default: ``3``. 221*da0073e9SAndroid Build Coastguard Worker allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs 222*da0073e9SAndroid Build Coastguard Worker (and therefore their grad is always zero) is an error. Defaults to False. 223*da0073e9SAndroid Build Coastguard Worker pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or 224*da0073e9SAndroid Build Coastguard Worker :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory 225*da0073e9SAndroid Build Coastguard Worker with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`. 226*da0073e9SAndroid Build Coastguard Worker .. note:: 227*da0073e9SAndroid Build Coastguard Worker The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state 228*da0073e9SAndroid Build Coastguard Worker that's expected for the corresponding real input in the training loop. 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker .. warning:: 231*da0073e9SAndroid Build Coastguard Worker This API is in beta and may change in future releases. 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker .. warning:: 234*da0073e9SAndroid Build Coastguard Worker ``sample_args`` for each callable must contain only Tensors. Other types are not allowed. 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker .. warning:: 237*da0073e9SAndroid Build Coastguard Worker Returned callables do not support higher order differentiation (e.g., double backward). 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker .. warning:: 240*da0073e9SAndroid Build Coastguard Worker In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters 241*da0073e9SAndroid Build Coastguard Worker may be trainable. Buffers must have ``requires_grad=False``. 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker .. warning:: 244*da0073e9SAndroid Build Coastguard Worker After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`, 245*da0073e9SAndroid Build Coastguard Worker you may not add or remove any of that Module's parameters or buffers. 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker .. warning:: 248*da0073e9SAndroid Build Coastguard Worker :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks 249*da0073e9SAndroid Build Coastguard Worker registered on them at the time they are passed. However, registering hooks on modules *after* passing them 250*da0073e9SAndroid Build Coastguard Worker through :func:`~torch.cuda.make_graphed_callables` is allowed. 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker .. warning:: 253*da0073e9SAndroid Build Coastguard Worker When running a graphed callable, you must pass its arguments in the same order and format 254*da0073e9SAndroid Build Coastguard Worker they appeared in that callable's ``sample_args``. 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker .. warning:: 257*da0073e9SAndroid Build Coastguard Worker The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled 258*da0073e9SAndroid Build Coastguard Worker caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`. 259*da0073e9SAndroid Build Coastguard Worker """ 260*da0073e9SAndroid Build Coastguard Worker if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): 261*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 262*da0073e9SAndroid Build Coastguard Worker "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`." 263*da0073e9SAndroid Build Coastguard Worker ) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker just_one_callable = False 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker if not isinstance(callables, tuple): 268*da0073e9SAndroid Build Coastguard Worker just_one_callable = True 269*da0073e9SAndroid Build Coastguard Worker callables = (callables,) 270*da0073e9SAndroid Build Coastguard Worker sample_args = (sample_args,) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker flatten_sample_args = [] 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker for c, args in zip(callables, sample_args): 275*da0073e9SAndroid Build Coastguard Worker if isinstance(c, torch.nn.Module): 276*da0073e9SAndroid Build Coastguard Worker assert ( 277*da0073e9SAndroid Build Coastguard Worker len(c._backward_hooks) == 0 278*da0073e9SAndroid Build Coastguard Worker and len(c._forward_hooks) == 0 279*da0073e9SAndroid Build Coastguard Worker and len(c._forward_pre_hooks) == 0 280*da0073e9SAndroid Build Coastguard Worker ), ( 281*da0073e9SAndroid Build Coastguard Worker "Modules must not have hooks registered at the time they are passed. However, registering hooks " 282*da0073e9SAndroid Build Coastguard Worker + "on modules after passing them through make_graphed_callables is allowed." 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker assert all(b.requires_grad is False for b in c.buffers()), ( 285*da0073e9SAndroid Build Coastguard Worker "In any :class:`~torch.nn.Module` passed to " 286*da0073e9SAndroid Build Coastguard Worker + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " 287*da0073e9SAndroid Build Coastguard Worker + "``requires_grad=False``." 288*da0073e9SAndroid Build Coastguard Worker ) 289*da0073e9SAndroid Build Coastguard Worker flatten_arg = torch.utils._pytree.arg_tree_leaves(*args) 290*da0073e9SAndroid Build Coastguard Worker flatten_sample_args.append(tuple(flatten_arg)) 291*da0073e9SAndroid Build Coastguard Worker assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( 292*da0073e9SAndroid Build Coastguard Worker "In the beta API, sample_args " 293*da0073e9SAndroid Build Coastguard Worker + "for each callable must contain only Tensors. Other types are not allowed." 294*da0073e9SAndroid Build Coastguard Worker ) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly 297*da0073e9SAndroid Build Coastguard Worker # passes to forward (ie, its sample_args) AND the module's parameter attributes. 298*da0073e9SAndroid Build Coastguard Worker per_callable_len_user_args = [len(args) for args in flatten_sample_args] 299*da0073e9SAndroid Build Coastguard Worker per_callable_module_params = [ 300*da0073e9SAndroid Build Coastguard Worker tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () 301*da0073e9SAndroid Build Coastguard Worker for c in callables 302*da0073e9SAndroid Build Coastguard Worker ] 303*da0073e9SAndroid Build Coastguard Worker per_callable_static_input_surfaces = [ 304*da0073e9SAndroid Build Coastguard Worker flatten_sample_args[i] + per_callable_module_params[i] 305*da0073e9SAndroid Build Coastguard Worker for i in range(len(callables)) 306*da0073e9SAndroid Build Coastguard Worker ] 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] 309*da0073e9SAndroid Build Coastguard Worker bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker mempool = graph_pool_handle() if pool is None else pool 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker # Warmup 314*da0073e9SAndroid Build Coastguard Worker # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work 315*da0073e9SAndroid Build Coastguard Worker # from ending up in any captures. 316*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 317*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(torch.cuda.Stream()): 318*da0073e9SAndroid Build Coastguard Worker for func, args, static_input_surface in zip( 319*da0073e9SAndroid Build Coastguard Worker callables, sample_args, per_callable_static_input_surfaces 320*da0073e9SAndroid Build Coastguard Worker ): 321*da0073e9SAndroid Build Coastguard Worker grad_inputs, outputs, outputs_grad = None, None, None 322*da0073e9SAndroid Build Coastguard Worker for _ in range(num_warmup_iters): 323*da0073e9SAndroid Build Coastguard Worker outputs = torch.utils._pytree.tree_leaves(func(*args)) 324*da0073e9SAndroid Build Coastguard Worker outputs_grad = tuple(o for o in outputs if o.requires_grad) 325*da0073e9SAndroid Build Coastguard Worker if len(outputs_grad) > 0: 326*da0073e9SAndroid Build Coastguard Worker grad_inputs = torch.autograd.grad( 327*da0073e9SAndroid Build Coastguard Worker outputs=outputs_grad, 328*da0073e9SAndroid Build Coastguard Worker inputs=tuple( 329*da0073e9SAndroid Build Coastguard Worker i for i in static_input_surface if i.requires_grad 330*da0073e9SAndroid Build Coastguard Worker ), 331*da0073e9SAndroid Build Coastguard Worker grad_outputs=tuple( 332*da0073e9SAndroid Build Coastguard Worker torch.empty_like(o) for o in outputs if o.requires_grad 333*da0073e9SAndroid Build Coastguard Worker ), 334*da0073e9SAndroid Build Coastguard Worker only_inputs=True, 335*da0073e9SAndroid Build Coastguard Worker allow_unused=allow_unused_input, 336*da0073e9SAndroid Build Coastguard Worker ) 337*da0073e9SAndroid Build Coastguard Worker for v in [outputs, outputs_grad, grad_inputs]: 338*da0073e9SAndroid Build Coastguard Worker del v 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker # All captures here share a mempool. To avoid replays corrupting each other's memory, 343*da0073e9SAndroid Build Coastguard Worker # the safest approach is to capture all passes in the same order they'll run: 344*da0073e9SAndroid Build Coastguard Worker # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker # Capture forward graphs 347*da0073e9SAndroid Build Coastguard Worker per_callable_static_outputs = [] 348*da0073e9SAndroid Build Coastguard Worker per_callable_output_unflatten_spec = [] 349*da0073e9SAndroid Build Coastguard Worker for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): 350*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(fwd_graph, pool=mempool): 351*da0073e9SAndroid Build Coastguard Worker outputs = func(*args) 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs) 354*da0073e9SAndroid Build Coastguard Worker per_callable_static_outputs.append(tuple(flatten_outputs)) 355*da0073e9SAndroid Build Coastguard Worker per_callable_output_unflatten_spec.append(spec) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker # Capture backward graphs in reverse order 358*da0073e9SAndroid Build Coastguard Worker per_callable_static_grad_outputs = [] 359*da0073e9SAndroid Build Coastguard Worker per_callable_static_grad_inputs = [] 360*da0073e9SAndroid Build Coastguard Worker for static_input_surface, static_outputs, bwd_graph, module_params in zip( 361*da0073e9SAndroid Build Coastguard Worker reversed(per_callable_static_input_surfaces), 362*da0073e9SAndroid Build Coastguard Worker reversed(per_callable_static_outputs), 363*da0073e9SAndroid Build Coastguard Worker reversed(bwd_graphs), 364*da0073e9SAndroid Build Coastguard Worker reversed(per_callable_module_params), 365*da0073e9SAndroid Build Coastguard Worker ): 366*da0073e9SAndroid Build Coastguard Worker # For now, assumes all static_outputs require grad 367*da0073e9SAndroid Build Coastguard Worker # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad." 368*da0073e9SAndroid Build Coastguard Worker static_grad_outputs = tuple( 369*da0073e9SAndroid Build Coastguard Worker torch.empty_like(o) if o.requires_grad else None for o in static_outputs 370*da0073e9SAndroid Build Coastguard Worker ) 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker outputs_grad = tuple(o for o in static_outputs if o.requires_grad) 373*da0073e9SAndroid Build Coastguard Worker grad_inputs = None 374*da0073e9SAndroid Build Coastguard Worker if len(outputs_grad) > 0: 375*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(bwd_graph, pool=mempool): 376*da0073e9SAndroid Build Coastguard Worker grad_inputs = torch.autograd.grad( 377*da0073e9SAndroid Build Coastguard Worker outputs=outputs_grad, 378*da0073e9SAndroid Build Coastguard Worker inputs=tuple(i for i in static_input_surface if i.requires_grad), 379*da0073e9SAndroid Build Coastguard Worker grad_outputs=tuple(o for o in static_grad_outputs if o is not None), 380*da0073e9SAndroid Build Coastguard Worker only_inputs=True, 381*da0073e9SAndroid Build Coastguard Worker allow_unused=allow_unused_input, 382*da0073e9SAndroid Build Coastguard Worker ) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker # Constructs a tuple suitable for returning from Graphed.backward: 385*da0073e9SAndroid Build Coastguard Worker # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad. 386*da0073e9SAndroid Build Coastguard Worker # I couldn't think of a slick one-liner for this pattern. 387*da0073e9SAndroid Build Coastguard Worker static_grad_inputs = [] 388*da0073e9SAndroid Build Coastguard Worker grad_idx = 0 389*da0073e9SAndroid Build Coastguard Worker for arg in static_input_surface: 390*da0073e9SAndroid Build Coastguard Worker if arg.requires_grad and grad_inputs is not None: 391*da0073e9SAndroid Build Coastguard Worker static_grad_inputs.append(grad_inputs[grad_idx]) 392*da0073e9SAndroid Build Coastguard Worker grad_idx += 1 393*da0073e9SAndroid Build Coastguard Worker else: 394*da0073e9SAndroid Build Coastguard Worker static_grad_inputs.append(None) # type: ignore[arg-type] 395*da0073e9SAndroid Build Coastguard Worker static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker per_callable_static_grad_outputs.append(static_grad_outputs) 398*da0073e9SAndroid Build Coastguard Worker per_callable_static_grad_inputs.append(static_grad_inputs) 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker # Reverses the most recent two lists 401*da0073e9SAndroid Build Coastguard Worker per_callable_static_grad_outputs.reverse() 402*da0073e9SAndroid Build Coastguard Worker per_callable_static_grad_inputs.reverse() 403*da0073e9SAndroid Build Coastguard Worker # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker def make_graphed_autograd_function( 406*da0073e9SAndroid Build Coastguard Worker fwd_graph, 407*da0073e9SAndroid Build Coastguard Worker bwd_graph, 408*da0073e9SAndroid Build Coastguard Worker module_params, 409*da0073e9SAndroid Build Coastguard Worker len_user_args, 410*da0073e9SAndroid Build Coastguard Worker output_unflatten_spec, 411*da0073e9SAndroid Build Coastguard Worker static_input_surface, 412*da0073e9SAndroid Build Coastguard Worker static_outputs, 413*da0073e9SAndroid Build Coastguard Worker static_grad_outputs, 414*da0073e9SAndroid Build Coastguard Worker static_grad_inputs, 415*da0073e9SAndroid Build Coastguard Worker ): 416*da0073e9SAndroid Build Coastguard Worker class Graphed(torch.autograd.Function): 417*da0073e9SAndroid Build Coastguard Worker @staticmethod 418*da0073e9SAndroid Build Coastguard Worker def forward(ctx, *inputs): 419*da0073e9SAndroid Build Coastguard Worker # At this stage, only the user args may (potentially) be new tensors. 420*da0073e9SAndroid Build Coastguard Worker for i in range(len_user_args): 421*da0073e9SAndroid Build Coastguard Worker if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): 422*da0073e9SAndroid Build Coastguard Worker static_input_surface[i].copy_(inputs[i]) 423*da0073e9SAndroid Build Coastguard Worker fwd_graph.replay() 424*da0073e9SAndroid Build Coastguard Worker assert isinstance(static_outputs, tuple) 425*da0073e9SAndroid Build Coastguard Worker return tuple(o.detach() for o in static_outputs) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker @staticmethod 428*da0073e9SAndroid Build Coastguard Worker @torch.autograd.function.once_differentiable 429*da0073e9SAndroid Build Coastguard Worker def backward(ctx, *grads): 430*da0073e9SAndroid Build Coastguard Worker assert len(grads) == len(static_grad_outputs) 431*da0073e9SAndroid Build Coastguard Worker for g, grad in zip(static_grad_outputs, grads): 432*da0073e9SAndroid Build Coastguard Worker if g is not None: 433*da0073e9SAndroid Build Coastguard Worker # don't copy if autograd gods have been kind and the 434*da0073e9SAndroid Build Coastguard Worker # incoming grad is already in the right place 435*da0073e9SAndroid Build Coastguard Worker if g.data_ptr() != grad.data_ptr(): 436*da0073e9SAndroid Build Coastguard Worker g.copy_(grad) 437*da0073e9SAndroid Build Coastguard Worker bwd_graph.replay() 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker # Input args that didn't require grad expect a None gradient. 440*da0073e9SAndroid Build Coastguard Worker assert isinstance(static_grad_inputs, tuple) 441*da0073e9SAndroid Build Coastguard Worker return tuple( 442*da0073e9SAndroid Build Coastguard Worker b.detach() if b is not None else b for b in static_grad_inputs 443*da0073e9SAndroid Build Coastguard Worker ) 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker def functionalized(*user_args): 446*da0073e9SAndroid Build Coastguard Worker # Runs the autograd function with inputs == all inputs to the graph that might require grad 447*da0073e9SAndroid Build Coastguard Worker # (explicit user args + module parameters) 448*da0073e9SAndroid Build Coastguard Worker # Assumes module params didn't change since capture. 449*da0073e9SAndroid Build Coastguard Worker flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args) 450*da0073e9SAndroid Build Coastguard Worker out = Graphed.apply(*(tuple(flatten_user_args) + module_params)) 451*da0073e9SAndroid Build Coastguard Worker return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec) 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker return functionalized 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker # Put together the final graphed callables 456*da0073e9SAndroid Build Coastguard Worker ret = [] 457*da0073e9SAndroid Build Coastguard Worker for i, func in enumerate(callables): 458*da0073e9SAndroid Build Coastguard Worker graphed = make_graphed_autograd_function( 459*da0073e9SAndroid Build Coastguard Worker fwd_graphs[i], 460*da0073e9SAndroid Build Coastguard Worker bwd_graphs[i], 461*da0073e9SAndroid Build Coastguard Worker per_callable_module_params[i], 462*da0073e9SAndroid Build Coastguard Worker per_callable_len_user_args[i], 463*da0073e9SAndroid Build Coastguard Worker per_callable_output_unflatten_spec[i], 464*da0073e9SAndroid Build Coastguard Worker per_callable_static_input_surfaces[i], 465*da0073e9SAndroid Build Coastguard Worker per_callable_static_outputs[i], 466*da0073e9SAndroid Build Coastguard Worker per_callable_static_grad_outputs[i], 467*da0073e9SAndroid Build Coastguard Worker per_callable_static_grad_inputs[i], 468*da0073e9SAndroid Build Coastguard Worker ) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker if isinstance(func, torch.nn.Module): 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): 473*da0073e9SAndroid Build Coastguard Worker def new_fwd(*user_args): 474*da0073e9SAndroid Build Coastguard Worker # If the module's training-or-eval state matches what we graphed, 475*da0073e9SAndroid Build Coastguard Worker # run the graph, otherwise run the original forward method 476*da0073e9SAndroid Build Coastguard Worker if func.training == graph_training_state: 477*da0073e9SAndroid Build Coastguard Worker return graphed(*user_args) 478*da0073e9SAndroid Build Coastguard Worker else: 479*da0073e9SAndroid Build Coastguard Worker return orig_fwd(*user_args) 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker return new_fwd 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment] 484*da0073e9SAndroid Build Coastguard Worker ret.append(func) 485*da0073e9SAndroid Build Coastguard Worker else: 486*da0073e9SAndroid Build Coastguard Worker ret.append(graphed) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker if just_one_callable: 489*da0073e9SAndroid Build Coastguard Worker return ret[0] 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker return tuple(ret) 492