xref: /aosp_15_r20/external/pytorch/docs/source/notes/cuda.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1.. meta::
2   :description: A guide to torch.cuda, a PyTorch module to run CUDA operations
3   :keywords: memory management, PYTORCH_CUDA_ALLOC_CONF, optimize PyTorch, CUDA
4
5.. _cuda-semantics:
6
7CUDA semantics
8==============
9
10
11:mod:`torch.cuda` is used to set up and run CUDA operations. It keeps track of
12the currently selected GPU, and all CUDA tensors you allocate will by default be
13created on that device. The selected device can be changed with a
14:any:`torch.cuda.device` context manager.
15
16However, once a tensor is allocated, you can do operations on it irrespective
17of the selected device, and the results will be always placed on the same
18device as the tensor.
19
20Cross-GPU operations are not allowed by default, with the exception of
21:meth:`~torch.Tensor.copy_` and other methods with copy-like functionality
22such as :meth:`~torch.Tensor.to` and :meth:`~torch.Tensor.cuda`.
23Unless you enable peer-to-peer memory access, any attempts to launch ops on
24tensors spread across different devices will raise an error.
25
26Below you can find a small example showcasing this::
27
28    cuda = torch.device('cuda')     # Default CUDA device
29    cuda0 = torch.device('cuda:0')
30    cuda2 = torch.device('cuda:2')  # GPU 2 (these are 0-indexed)
31
32    x = torch.tensor([1., 2.], device=cuda0)
33    # x.device is device(type='cuda', index=0)
34    y = torch.tensor([1., 2.]).cuda()
35    # y.device is device(type='cuda', index=0)
36
37    with torch.cuda.device(1):
38        # allocates a tensor on GPU 1
39        a = torch.tensor([1., 2.], device=cuda)
40
41        # transfers a tensor from CPU to GPU 1
42        b = torch.tensor([1., 2.]).cuda()
43        # a.device and b.device are device(type='cuda', index=1)
44
45        # You can also use ``Tensor.to`` to transfer a tensor:
46        b2 = torch.tensor([1., 2.]).to(device=cuda)
47        # b.device and b2.device are device(type='cuda', index=1)
48
49        c = a + b
50        # c.device is device(type='cuda', index=1)
51
52        z = x + y
53        # z.device is device(type='cuda', index=0)
54
55        # even within a context, you can specify the device
56        # (or give a GPU index to the .cuda call)
57        d = torch.randn(2, device=cuda2)
58        e = torch.randn(2).to(cuda2)
59        f = torch.randn(2).cuda(cuda2)
60        # d.device, e.device, and f.device are all device(type='cuda', index=2)
61
62.. _tf32_on_ampere:
63
64TensorFloat-32 (TF32) on Ampere (and later) devices
65---------------------------------------------------
66
67Starting in PyTorch 1.7, there is a new flag called `allow_tf32`. This flag
68defaults to True in PyTorch 1.7 to PyTorch 1.11, and False in PyTorch 1.12 and later.
69This flag controls whether PyTorch is allowed to use the TensorFloat32 (TF32) tensor cores,
70available on NVIDIA GPUs since Ampere, internally to compute matmul (matrix multiplies
71and batched matrix multiplies) and convolutions.
72
73TF32 tensor cores are designed to achieve better performance on matmul and convolutions on
74`torch.float32` tensors by rounding input data to have 10 bits of mantissa, and accumulating
75results with FP32 precision, maintaining FP32 dynamic range.
76
77matmuls and convolutions are controlled separately, and their corresponding flags can be accessed at:
78
79.. code:: python
80
81  # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
82  # in PyTorch 1.12 and later.
83  torch.backends.cuda.matmul.allow_tf32 = True
84
85  # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
86  torch.backends.cudnn.allow_tf32 = True
87
88The precision of matmuls can also be set more broadly (limited not just to CUDA) via :meth:`~torch.set_float_32_matmul_precision`.
89Note that besides matmuls and convolutions themselves, functions and nn modules that internally uses
90matmuls or convolutions are also affected. These include `nn.Linear`, `nn.Conv*`, cdist, tensordot,
91affine grid and grid sample, adaptive log softmax, GRU and LSTM.
92
93To get an idea of the precision and speed, see the example code and benchmark data (on A100) below:
94
95.. code:: python
96
97  a_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
98  b_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
99  ab_full = a_full @ b_full
100  mean = ab_full.abs().mean()  # 80.7277
101
102  a = a_full.float()
103  b = b_full.float()
104
105  # Do matmul at TF32 mode.
106  torch.backends.cuda.matmul.allow_tf32 = True
107  ab_tf32 = a @ b  # takes 0.016s on GA100
108  error = (ab_tf32 - ab_full).abs().max()  # 0.1747
109  relative_error = error / mean  # 0.0022
110
111  # Do matmul with TF32 disabled.
112  torch.backends.cuda.matmul.allow_tf32 = False
113  ab_fp32 = a @ b  # takes 0.11s on GA100
114  error = (ab_fp32 - ab_full).abs().max()  # 0.0031
115  relative_error = error / mean  # 0.000039
116
117From the above example, we can see that with TF32 enabled, the speed is ~7x faster on A100, and that
118relative error compared to double precision is approximately 2 orders of magnitude larger. Note that
119the exact ratio of TF32 to single precision speed depends on the hardware generation, as properties
120such as the ratio of memory bandwidth to compute as well as the ratio of TF32 to FP32 matmul throughput
121may vary from generation to generation or model to model.
122If full FP32 precision is needed, users can disable TF32 by:
123
124.. code:: python
125
126  torch.backends.cuda.matmul.allow_tf32 = False
127  torch.backends.cudnn.allow_tf32 = False
128
129To toggle the TF32 flags off in C++, you can do
130
131.. code:: C++
132
133  at::globalContext().setAllowTF32CuBLAS(false);
134  at::globalContext().setAllowTF32CuDNN(false);
135
136For more information about TF32, see:
137
138- `TensorFloat-32`_
139- `CUDA 11`_
140- `Ampere architecture`_
141
142.. _TensorFloat-32: https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/
143.. _CUDA 11: https://devblogs.nvidia.com/cuda-11-features-revealed/
144.. _Ampere architecture: https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/
145
146.. _fp16reducedprecision:
147
148Reduced Precision Reduction in FP16 GEMMs
149-----------------------------------------
150
151fp16 GEMMs are potentially done with some intermediate reduced precision reductions (e.g., in fp16 rather than fp32). These selective reductions in precision can allow for higher performance on certain workloads (particularly those with a large `k` dimension) and GPU architectures at the cost of numerical precision and potential for overflow.
152
153Some example benchmark data on V100:
154
155.. code::
156
157  [--------------------------- bench_gemm_transformer --------------------------]
158        [  m ,  k  ,  n  ]    |  allow_fp16_reduc=True  |  allow_fp16_reduc=False
159  1 threads: --------------------------------------------------------------------
160        [4096, 4048, 4096]    |           1634.6        |           1639.8
161        [4096, 4056, 4096]    |           1670.8        |           1661.9
162        [4096, 4080, 4096]    |           1664.2        |           1658.3
163        [4096, 4096, 4096]    |           1639.4        |           1651.0
164        [4096, 4104, 4096]    |           1677.4        |           1674.9
165        [4096, 4128, 4096]    |           1655.7        |           1646.0
166        [4096, 4144, 4096]    |           1796.8        |           2519.6
167        [4096, 5096, 4096]    |           2094.6        |           3190.0
168        [4096, 5104, 4096]    |           2144.0        |           2663.5
169        [4096, 5112, 4096]    |           2149.1        |           2766.9
170        [4096, 5120, 4096]    |           2142.8        |           2631.0
171        [4096, 9728, 4096]    |           3875.1        |           5779.8
172        [4096, 16384, 4096]   |           6182.9        |           9656.5
173  (times in microseconds).
174
175If full precision reductions are needed, users can disable reduced precision reductions in fp16 GEMMs with:
176
177.. code:: python
178
179  torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
180
181To toggle the reduced precision reduction flags in C++, one can do
182
183.. code:: C++
184
185  at::globalContext().setAllowFP16ReductionCuBLAS(false);
186
187.. _bf16reducedprecision:
188
189Reduced Precision Reduction in BF16 GEMMs
190-----------------------------------------
191
192A similar flag (as above) exists for BFloat16 GEMMs.
193Note that this switch is set to `True` by default for BF16, if you observe
194numerical instability in your workload, you may wish to set it to `False`.
195
196If reduced precision reductions are not desired, users can disable reduced
197precision reductions in bf16 GEMMs with:
198
199.. code:: python
200
201  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
202
203To toggle the reduced precision reduction flags in C++, one can do
204
205.. code:: C++
206
207  at::globalContext().setAllowBF16ReductionCuBLAS(true);
208
209Asynchronous execution
210----------------------
211
212By default, GPU operations are asynchronous.  When you call a function that
213uses the GPU, the operations are *enqueued* to the particular device, but not
214necessarily executed until later.  This allows us to execute more computations
215in parallel, including operations on CPU or other GPUs.
216
217In general, the effect of asynchronous computation is invisible to the caller,
218because (1) each device executes operations in the order they are queued, and
219(2) PyTorch automatically performs necessary synchronization when copying data
220between CPU and GPU or between two GPUs.  Hence, computation will proceed as if
221every operation was executed synchronously.
222
223You can force synchronous computation by setting environment variable
224``CUDA_LAUNCH_BLOCKING=1``.  This can be handy when an error occurs on the GPU.
225(With asynchronous execution, such an error isn't reported until after the
226operation is actually executed, so the stack trace does not show where it was
227requested.)
228
229A consequence of the asynchronous computation is that time measurements without
230synchronizations are not accurate. To get precise measurements, one should either
231call :func:`torch.cuda.synchronize()` before measuring, or use :class:`torch.cuda.Event`
232to record times as following::
233
234    start_event = torch.cuda.Event(enable_timing=True)
235    end_event = torch.cuda.Event(enable_timing=True)
236    start_event.record()
237
238    # Run some things here
239
240    end_event.record()
241    torch.cuda.synchronize()  # Wait for the events to be recorded!
242    elapsed_time_ms = start_event.elapsed_time(end_event)
243
244As an exception, several functions such as :meth:`~torch.Tensor.to` and
245:meth:`~torch.Tensor.copy_` admit an explicit :attr:`non_blocking` argument,
246which lets the caller bypass synchronization when it is unnecessary.
247Another exception is CUDA streams, explained below.
248
249CUDA streams
250^^^^^^^^^^^^
251
252A `CUDA stream`_ is a linear sequence of execution that belongs to a specific
253device.  You normally do not need to create one explicitly: by default, each
254device uses its own "default" stream.
255
256Operations inside each stream are serialized in the order they are created,
257but operations from different streams can execute concurrently in any
258relative order, unless explicit synchronization functions (such as
259:meth:`~torch.cuda.synchronize` or :meth:`~torch.cuda.Stream.wait_stream`) are
260used.  For example, the following code is incorrect::
261
262    cuda = torch.device('cuda')
263    s = torch.cuda.Stream()  # Create a new stream.
264    A = torch.empty((100, 100), device=cuda).normal_(0.0, 1.0)
265    with torch.cuda.stream(s):
266        # sum() may start execution before normal_() finishes!
267        B = torch.sum(A)
268
269When the "current stream" is the default stream, PyTorch automatically performs
270necessary synchronization when data is moved around, as explained above.
271However, when using non-default streams, it is the user's responsibility to
272ensure proper synchronization.  The fixed version of this example is::
273
274    cuda = torch.device('cuda')
275    s = torch.cuda.Stream()  # Create a new stream.
276    A = torch.empty((100, 100), device=cuda).normal_(0.0, 1.0)
277    s.wait_stream(torch.cuda.default_stream(cuda))  # NEW!
278    with torch.cuda.stream(s):
279        B = torch.sum(A)
280    A.record_stream(s)  # NEW!
281
282There are two new additions.  The :meth:`torch.cuda.Stream.wait_stream` call
283ensures that the ``normal_()`` execution has finished before we start running
284``sum(A)`` on a side stream.  The :meth:`torch.Tensor.record_stream` (see for
285more details) ensures that we do not deallocate A before ``sum(A)`` has
286completed.  You can also manually wait on the stream at some later point in
287time with ``torch.cuda.default_stream(cuda).wait_stream(s)`` (note that it
288is pointless to wait immediately, since that will prevent the stream execution
289from running in parallel with other work on the default stream.)  See the
290documentation for :meth:`torch.Tensor.record_stream` on more details on when
291to use one or another.
292
293Note that this synchronization is necessary even when there is no
294read dependency, e.g., as seen in this example::
295
296    cuda = torch.device('cuda')
297    s = torch.cuda.Stream()  # Create a new stream.
298    A = torch.empty((100, 100), device=cuda)
299    s.wait_stream(torch.cuda.default_stream(cuda))  # STILL REQUIRED!
300    with torch.cuda.stream(s):
301        A.normal_(0.0, 1.0)
302        A.record_stream(s)
303
304Despite the computation on ``s`` not reading the contents of ``A`` and no
305other uses of ``A``, it is still necessary to synchronize, because ``A``
306may correspond to memory reallocated by the CUDA caching allocator, with
307pending operations from the old (deallocated) memory.
308
309.. _bwd-cuda-stream-semantics:
310
311Stream semantics of backward passes
312^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
313
314Each backward CUDA op runs on the same stream that was used for its corresponding forward op.
315If your forward pass runs independent ops in parallel on different streams,
316this helps the backward pass exploit that same parallelism.
317
318The stream semantics of a backward call with respect to surrounding ops are the same
319as for any other call. The backward pass inserts internal syncs to ensure this even when
320backward ops run on multiple streams as described in the previous paragraph.
321More concretely, when calling
322:func:`autograd.backward<torch.autograd.backward>`,
323:func:`autograd.grad<torch.autograd.grad>`, or
324:meth:`tensor.backward<torch.Tensor.backward>`,
325and optionally supplying CUDA tensor(s) as the  initial gradient(s) (e.g.,
326:func:`autograd.backward(..., grad_tensors=initial_grads)<torch.autograd.backward>`,
327:func:`autograd.grad(..., grad_outputs=initial_grads)<torch.autograd.grad>`, or
328:meth:`tensor.backward(..., gradient=initial_grad)<torch.Tensor.backward>`),
329the acts of
330
3311. optionally populating initial gradient(s),
3322. invoking the backward pass, and
3333. using the gradients
334
335have the same stream-semantics relationship as any group of ops::
336
337    s = torch.cuda.Stream()
338
339    # Safe, grads are used in the same stream context as backward()
340    with torch.cuda.stream(s):
341        loss.backward()
342        use grads
343
344    # Unsafe
345    with torch.cuda.stream(s):
346        loss.backward()
347    use grads
348
349    # Safe, with synchronization
350    with torch.cuda.stream(s):
351        loss.backward()
352    torch.cuda.current_stream().wait_stream(s)
353    use grads
354
355    # Safe, populating initial grad and invoking backward are in the same stream context
356    with torch.cuda.stream(s):
357        loss.backward(gradient=torch.ones_like(loss))
358
359    # Unsafe, populating initial_grad and invoking backward are in different stream contexts,
360    # without synchronization
361    initial_grad = torch.ones_like(loss)
362    with torch.cuda.stream(s):
363        loss.backward(gradient=initial_grad)
364
365    # Safe, with synchronization
366    initial_grad = torch.ones_like(loss)
367    s.wait_stream(torch.cuda.current_stream())
368    with torch.cuda.stream(s):
369        initial_grad.record_stream(s)
370        loss.backward(gradient=initial_grad)
371
372BC note: Using grads on the default stream
373~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
374
375In prior versions of PyTorch (1.9 and earlier), the autograd engine always synced
376the default stream with all backward ops, so the following pattern::
377
378    with torch.cuda.stream(s):
379        loss.backward()
380    use grads
381
382was safe as long as ``use grads`` happened on the default stream.
383In present PyTorch, that pattern is no longer safe. If ``backward()``
384and ``use grads`` are in different stream contexts, you must sync the streams::
385
386    with torch.cuda.stream(s):
387        loss.backward()
388    torch.cuda.current_stream().wait_stream(s)
389    use grads
390
391even if ``use grads`` is on the default stream.
392
393.. _CUDA stream: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#streams
394
395.. _cuda-memory-management:
396
397Memory management
398-----------------
399
400PyTorch uses a caching memory allocator to speed up memory allocations. This
401allows fast memory deallocation without device synchronizations. However, the
402unused memory managed by the allocator will still show as if used in
403``nvidia-smi``. You can use :meth:`~torch.cuda.memory_allocated` and
404:meth:`~torch.cuda.max_memory_allocated` to monitor memory occupied by
405tensors, and use :meth:`~torch.cuda.memory_reserved` and
406:meth:`~torch.cuda.max_memory_reserved` to monitor the total amount of memory
407managed by the caching allocator. Calling :meth:`~torch.cuda.empty_cache`
408releases all **unused** cached memory from PyTorch so that those can be used
409by other GPU applications. However, the occupied GPU memory by tensors will not
410be freed so it can not increase the amount of GPU memory available for PyTorch.
411
412To better understand how CUDA memory is being used over time,
413:ref:`torch_cuda_memory` describes tools for capturing and visualizing traces of memory use.
414
415For more advanced users, we offer more comprehensive memory benchmarking via
416:meth:`~torch.cuda.memory_stats`. We also offer the capability to capture a
417complete snapshot of the memory allocator state via
418:meth:`~torch.cuda.memory_snapshot`, which can help you understand the
419underlying allocation patterns produced by your code.
420
421.. _cuda-memory-envvars:
422
423Optimizing memory usage  with ``PYTORCH_CUDA_ALLOC_CONF``
424^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
425
426Use of a caching allocator can interfere with memory checking tools such as
427``cuda-memcheck``.  To debug memory errors using ``cuda-memcheck``, set
428``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching.
429
430The behavior of the caching allocator can be controlled via the environment variable
431``PYTORCH_CUDA_ALLOC_CONF``.
432The format is ``PYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2>:<value2>...``
433Available options:
434
435* ``backend`` allows selecting the underlying allocator implementation.
436  Currently, valid options are ``native``, which uses PyTorch's native
437  implementation, and ``cudaMallocAsync``, which uses
438  `CUDA's built-in asynchronous allocator`_.
439  ``cudaMallocAsync`` requires CUDA 11.4 or newer. The default is ``native``.
440  ``backend`` applies to all devices used by the process, and can't be
441  specified on a per-device basis.
442* ``max_split_size_mb`` prevents the native allocator
443  from splitting blocks larger than this size (in MB). This can reduce
444  fragmentation and may allow some borderline workloads to complete without
445  running out of memory. Performance cost can range from 'zero' to 'substantial'
446  depending on allocation patterns.  Default value is unlimited, i.e. all blocks
447  can be split. The
448  :meth:`~torch.cuda.memory_stats` and
449  :meth:`~torch.cuda.memory_summary` methods are useful for tuning.  This
450  option should be used as a last resort for a workload that is aborting
451  due to 'out of memory' and showing a large amount of inactive split blocks.
452  ``max_split_size_mb`` is only meaningful with ``backend:native``.
453  With ``backend:cudaMallocAsync``, ``max_split_size_mb`` is ignored.
454* ``roundup_power2_divisions`` helps with rounding the requested allocation
455  size to nearest power-2 division and making better use of the blocks. In
456  the native CUDACachingAllocator, the sizes are rounded up in multiple
457  of blocks size of 512, so this works fine for smaller sizes. However, this
458  can be inefficient for large near-by allocations as each will go to different
459  size of blocks and re-use of those blocks are minimized. This might create
460  lots of unused blocks and will waste GPU memory capacity. This option enables
461  the rounding of allocation size to nearest power-2 division. For example, if
462  we need to round-up size of 1200 and if number of divisions is 4,
463  the size 1200 lies between 1024 and 2048 and if we do 4 divisions between
464  them, the values are 1024, 1280, 1536, and 1792. So, allocation size of 1200
465  will be rounded to 1280 as the nearest ceiling of power-2 division.
466  Specify a single value to apply for all allocation sizes or specify an
467  array of key value pairs to set power-2 division individually for each
468  power of two interval. For example to set 1 division for all allocations
469  under 256MB, 2 division for allocations between 256MB and 512MB, 4 divisions
470  for allocations between 512MB and 1GB and 8 divisions for any larger allocations,
471  set the knob value to: [256:1,512:2,1024:4,>:8].
472  ``roundup_power2_divisions`` is only meaningful with ``backend:native``.
473  With ``backend:cudaMallocAsync``, ``roundup_power2_divisions`` is ignored.
474* ``garbage_collection_threshold`` helps actively reclaiming unused GPU memory to
475  avoid triggering expensive sync-and-reclaim-all operation (release_cached_blocks),
476  which can be unfavorable to latency-critical GPU applications (e.g., servers).
477  Upon setting this threshold (e.g., 0.8), the allocator will start reclaiming
478  GPU memory blocks if the GPU memory capacity usage exceeds the threshold (i.e.,
479  80% of the total memory allocated to the GPU application). The algorithm prefers
480  to free old & unused blocks first to avoid freeing blocks that are actively being
481  reused. The threshold value should be between greater than 0.0 and less than 1.0.
482  ``garbage_collection_threshold`` is only meaningful with ``backend:native``.
483  With ``backend:cudaMallocAsync``, ``garbage_collection_threshold`` is ignored.
484* ``expandable_segments`` (experimental, default: `False`) If set to `True`, this setting instructs
485  the allocator to create CUDA allocations that can later be expanded to better handle cases
486  where a job changing allocation sizes frequently, such as having a changing batch size.
487  Normally for large (>2MB) allocations, the allocator calls cudaMalloc to get allocations
488  that are the same size as what the user requests. In the future, parts of these
489  allocations can be reused for other requests if they are free. This works well
490  when the program makes many requests of exactly the same size or of sizes that
491  even multiples of that size. Many deep learning models follow this behavior.
492  However, one common exception is when the batch size changes slightly from one
493  iteration to the next, e.g. in batched inference. When the program runs
494  initially with batch size `N`, it will make allocations appropriate for that size.
495  If in the future, it runs at size `N - 1`, the existing allocations will still be
496  big enough. However, if it runs at size `N + 1`, then it will have to make new
497  allocations that are slightly larger. Not all the tensors are the same size.
498  Some might be `(N + 1)*A` and others `(N + 1)*A*B` where `A` and `B` are some non-batch
499  dimensions in the model. Because the allocator reuses existing allocations when
500  they are big enough, some number of `(N + 1)*A` allocations will actually fit in
501  the already existing `N*B*A` segments, though not perfectly. As the model runs it
502  will partially fill up all of these segments leaving unusable free slices of
503  memory at the end of these segments. The allocator at some point will need to
504  `cudaMalloc` a new `(N + 1)*A*B` segment. If there is not enough memory, there is
505  now no way to recover the slices of memory that are free at the end of existing
506  segments. With models 50+ layers deep, this pattern might repeat 50+ times
507  creating many slivers.
508
509  `expandable_segments` allows the allocator to create a segment initially and then
510  expand its size later when more memory is needed. Instead of making one segment
511  per allocation, it tries to make one segment (per stream) that grows as
512  necessary. Now when the `N + 1` case runs, the allocations will tile nicely into
513  the one large segment until it fills up. Then more memory is requested and
514  appended to the end of the segment. This process does not create as many slivers
515  of unusable memory, so it is more likely to succeed at finding this memory.
516
517  `pinned_use_cuda_host_register` option is a boolean flag that determines whether to
518  use the CUDA API's cudaHostRegister function for allocating pinned memory instead
519  of the default cudaHostAlloc. When set to True, the memory is allocated using regular
520  malloc and then pages are mapped to the memory before calling cudaHostRegister.
521  This pre-mapping of pages helps reduce the lock time during the execution
522  of cudaHostRegister.
523
524  `pinned_num_register_threads` option is only valid when pinned_use_cuda_host_register
525  is set to True. By default, one thread is used to map the pages. This option allows
526  using more threads to parallelize the page mapping operations to reduce the overall
527  allocation time of pinned memory. A good value for this option is 8 based on
528  benchmarking results.
529
530.. note::
531
532    Some stats reported by the
533    :ref:`CUDA memory management API<cuda-memory-management-api>`
534    are specific to ``backend:native``, and are not meaningful with
535    ``backend:cudaMallocAsync``.
536    See each function's docstring for details.
537
538.. _CUDA's built-in asynchronous allocator:
539    https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
540
541.. _cuda-memory-custom-allocator:
542
543Using custom memory allocators for CUDA
544---------------------------------------
545
546It is possible to define allocators as simple functions in C/C++ and compile
547them as a shared library, the code below shows a basic allocator that just
548traces all the memory operations.
549
550.. code:: C++
551
552   #include <sys/types.h>
553   #include <cuda_runtime_api.h>
554   #include <iostream>
555   // Compile with g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC
556   extern "C" {
557   void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
558      void *ptr;
559      cudaMalloc(&ptr, size);
560      std::cout<<"alloc "<<ptr<<size<<std::endl;
561      return ptr;
562   }
563
564   void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) {
565      std::cout<<"free "<<ptr<< " "<<stream<<std::endl;
566      cudaFree(ptr);
567   }
568   }
569
570
571This can be used in python through the :class:`torch.cuda.memory.CUDAPluggableAllocator`.
572The user is responsible for supplying the path to the `.so` file and the name
573of the alloc/free functions that match the signatures specified above.
574
575.. code:: python
576
577   import torch
578
579   # Load the allocator
580   new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
581       'alloc.so', 'my_malloc', 'my_free')
582   # Swap the current allocator
583   torch.cuda.memory.change_current_allocator(new_alloc)
584   # This will allocate memory in the device using the new allocator
585   b = torch.zeros(10, device='cuda')
586
587
588.. code:: python
589
590   import torch
591
592   # Do an initial memory allocator
593   b = torch.zeros(10, device='cuda')
594   # Load the allocator
595   new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
596       'alloc.so', 'my_malloc', 'my_free')
597   # This will error since the current allocator was already instantiated
598   torch.cuda.memory.change_current_allocator(new_alloc)
599
600.. cublas-workspaces:
601
602cuBLAS workspaces
603-----------------
604
605For each combination of cuBLAS handle and CUDA stream, a cuBLAS workspace will be allocated
606if that handle and stream combination executes a cuBLAS kernel that requires a workspace.
607In order to avoid repeatedly allocating workspaces, these workspaces are not deallocated unless
608``torch._C._cuda_clearCublasWorkspaces()`` is called. The workspace size per allocation can be
609specified via the environment variable ``CUBLAS_WORKSPACE_CONFIG`` with the format ``:[SIZE]:[COUNT]``.
610As an example, the default workspace size per allocation is ``CUBLAS_WORKSPACE_CONFIG=:4096:2:16:8``
611which specifies a total size of ``2 * 4096 + 8 * 16 KiB``. To force cuBLAS to avoid using workspaces,
612set ``CUBLAS_WORKSPACE_CONFIG=:0:0``.
613
614.. _cufft-plan-cache:
615
616cuFFT plan cache
617----------------
618
619For each CUDA device, an LRU cache of cuFFT plans is used to speed up repeatedly
620running FFT methods (e.g., :func:`torch.fft.fft`) on CUDA tensors of same geometry
621with same configuration. Because some cuFFT plans may allocate GPU memory,
622these caches have a maximum capacity.
623
624You may control and query the properties of the cache of current device with
625the following APIs:
626
627* ``torch.backends.cuda.cufft_plan_cache.max_size`` gives the capacity of the
628  cache (default is 4096 on CUDA 10 and newer, and 1023 on older CUDA versions).
629  Setting this value directly modifies the capacity.
630
631* ``torch.backends.cuda.cufft_plan_cache.size`` gives the number of plans
632  currently residing in the cache.
633
634* ``torch.backends.cuda.cufft_plan_cache.clear()`` clears the cache.
635
636To control and query plan caches of a non-default device, you can index the
637``torch.backends.cuda.cufft_plan_cache`` object with either a :class:`torch.device`
638object or a device index, and access one of the above attributes. E.g., to set
639the capacity of the cache for device ``1``, one can write
640``torch.backends.cuda.cufft_plan_cache[1].max_size = 10``.
641
642.. _cuda-just-in-time-compilation:
643
644Just-in-Time Compilation
645------------------------
646
647PyTorch just-in-time compiles some operations, like torch.special.zeta, when
648performed on CUDA tensors. This compilation can be time consuming
649(up to a few seconds depending on your hardware and software)
650and may occur multiple times for a single operator since many PyTorch operators actually
651select from a variety of kernels, each of which must be compiled once, depending on their input.
652This compilation occurs once per process, or just once if a kernel cache is used.
653
654By default, PyTorch creates a kernel cache in $XDG_CACHE_HOME/torch/kernels if
655XDG_CACHE_HOME is defined and $HOME/.cache/torch/kernels if it's not (except on Windows,
656where the kernel cache is not yet supported). The caching behavior can be directly
657controlled with two environment variables. If USE_PYTORCH_KERNEL_CACHE is set to 0 then no
658cache will be used, and if PYTORCH_KERNEL_CACHE_PATH is set then that path will be used
659as a kernel cache instead of the default location.
660
661Best practices
662--------------
663
664Device-agnostic code
665^^^^^^^^^^^^^^^^^^^^
666
667Due to the structure of PyTorch, you may need to explicitly write
668device-agnostic (CPU or GPU) code; an example may be creating a new tensor as
669the initial hidden state of a recurrent neural network.
670
671The first step is to determine whether the GPU should be used or not. A common
672pattern is to use Python's ``argparse`` module to read in user arguments, and
673have a flag that can be used to disable CUDA, in combination with
674:meth:`~torch.cuda.is_available`. In the following, ``args.device`` results in a
675:class:`torch.device` object that can be used to move tensors to CPU or CUDA.
676
677::
678
679    import argparse
680    import torch
681
682    parser = argparse.ArgumentParser(description='PyTorch Example')
683    parser.add_argument('--disable-cuda', action='store_true',
684                        help='Disable CUDA')
685    args = parser.parse_args()
686    args.device = None
687    if not args.disable_cuda and torch.cuda.is_available():
688        args.device = torch.device('cuda')
689    else:
690        args.device = torch.device('cpu')
691
692.. note::
693
694    When assessing the availability of CUDA in a given environment (:meth:`~torch.cuda.is_available`), PyTorch's default
695    behavior is to call the CUDA Runtime API method `cudaGetDeviceCount`_. Because this call in turn initializes the
696    CUDA Driver API (via `cuInit`_) if it is not already initialized, subsequent forks of a process that has run
697    :meth:`~torch.cuda.is_available` will fail with a CUDA initialization error.
698
699    One can set ``PYTORCH_NVML_BASED_CUDA_CHECK=1`` in your environment before importing PyTorch modules that execute
700    :meth:`~torch.cuda.is_available` (or before executing it directly) in order to direct
701    :meth:`~torch.cuda.is_available` to attempt an NVML-based assessment (`nvmlDeviceGetCount_v2`_). If the
702    NVML-based assessment is successful (i.e. NVML discovery/initialization does not fail),
703    :meth:`~torch.cuda.is_available` calls will not poison subsequent forks.
704
705    If NVML discovery/initialization fails, :meth:`~torch.cuda.is_available` will fallback to the standard CUDA Runtime
706    API assessment and the aforementioned fork constraint will apply.
707
708    Note that the above NVML-based CUDA availability assessment provides a weaker guarantee than the default CUDA
709    Runtime API approach (which requires CUDA initialization to succeed). In some circumstances, the NVML-based check
710    may succeed while later CUDA initialization fails.
711
712Now that we have ``args.device``, we can use it to create a Tensor on the
713desired device.
714
715::
716
717    x = torch.empty((8, 42), device=args.device)
718    net = Network().to(device=args.device)
719
720This can be used in a number of cases to produce device agnostic code. Below
721is an example when using a dataloader:
722
723::
724
725    cuda0 = torch.device('cuda:0')  # CUDA GPU 0
726    for i, x in enumerate(train_loader):
727        x = x.to(cuda0)
728
729When working with multiple GPUs on a system, you can use the
730``CUDA_VISIBLE_DEVICES`` environment flag to manage which GPUs are available to
731PyTorch. As mentioned above, to manually control which GPU a tensor is created
732on, the best practice is to use a :any:`torch.cuda.device` context manager.
733
734::
735
736    print("Outside device is 0")  # On device 0 (default in most scenarios)
737    with torch.cuda.device(1):
738        print("Inside device is 1")  # On device 1
739    print("Outside device is still 0")  # On device 0
740
741If you have a tensor and would like to create a new tensor of the same type on
742the same device, then you can use a ``torch.Tensor.new_*`` method
743(see :class:`torch.Tensor`).
744Whilst the previously mentioned ``torch.*`` factory functions
745(:ref:`tensor-creation-ops`) depend on the current GPU context and
746the attributes arguments you pass in, ``torch.Tensor.new_*`` methods preserve
747the device and other attributes of the tensor.
748
749This is the recommended practice when creating modules in which new
750tensors need to be created internally during the forward pass.
751
752::
753
754    cuda = torch.device('cuda')
755    x_cpu = torch.empty(2)
756    x_gpu = torch.empty(2, device=cuda)
757    x_cpu_long = torch.empty(2, dtype=torch.int64)
758
759    y_cpu = x_cpu.new_full([3, 2], fill_value=0.3)
760    print(y_cpu)
761
762        tensor([[ 0.3000,  0.3000],
763                [ 0.3000,  0.3000],
764                [ 0.3000,  0.3000]])
765
766    y_gpu = x_gpu.new_full([3, 2], fill_value=-5)
767    print(y_gpu)
768
769        tensor([[-5.0000, -5.0000],
770                [-5.0000, -5.0000],
771                [-5.0000, -5.0000]], device='cuda:0')
772
773    y_cpu_long = x_cpu_long.new_tensor([[1, 2, 3]])
774    print(y_cpu_long)
775
776        tensor([[ 1,  2,  3]])
777
778
779If you want to create a tensor of the same type and size of another tensor, and
780fill it with either ones or zeros, :meth:`~torch.ones_like` or
781:meth:`~torch.zeros_like` are provided as convenient helper functions (which
782also preserve :class:`torch.device` and :class:`torch.dtype` of a Tensor).
783
784::
785
786    x_cpu = torch.empty(2, 3)
787    x_gpu = torch.empty(2, 3)
788
789    y_cpu = torch.ones_like(x_cpu)
790    y_gpu = torch.zeros_like(x_gpu)
791
792
793.. _cuda-memory-pinning:
794
795Use pinned memory buffers
796^^^^^^^^^^^^^^^^^^^^^^^^^
797
798.. warning::
799
800    This is an advanced tip. If you overuse pinned memory, it can cause serious
801    problems when running low on RAM, and you should be aware that pinning is
802    often an expensive operation.
803
804Host to GPU copies are much faster when they originate from pinned (page-locked)
805memory. CPU tensors and storages expose a :meth:`~torch.Tensor.pin_memory`
806method, that returns a copy of the object, with data put in a pinned region.
807
808Also, once you pin a tensor or storage, you can use asynchronous GPU copies.
809Just pass an additional ``non_blocking=True`` argument to a
810:meth:`~torch.Tensor.to` or a :meth:`~torch.Tensor.cuda` call. This can be used
811to overlap data transfers with computation.
812
813You can make the :class:`~torch.utils.data.DataLoader` return batches placed in
814pinned memory by passing ``pin_memory=True`` to its constructor.
815
816.. _cuda-nn-ddp-instead:
817
818Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel
819^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
820
821Most use cases involving batched inputs and multiple GPUs should default to
822using :class:`~torch.nn.parallel.DistributedDataParallel` to utilize more
823than one GPU.
824
825There are significant caveats to using CUDA models with
826:mod:`~torch.multiprocessing`; unless care is taken to meet the data handling
827requirements exactly, it is likely that your program will have incorrect or
828undefined behavior.
829
830It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`,
831instead of :class:`~torch.nn.DataParallel` to do multi-GPU training, even if
832there is only a single node.
833
834The difference between :class:`~torch.nn.parallel.DistributedDataParallel` and
835:class:`~torch.nn.DataParallel` is: :class:`~torch.nn.parallel.DistributedDataParallel`
836uses multiprocessing where a process is created for each GPU, while
837:class:`~torch.nn.DataParallel` uses multithreading. By using multiprocessing,
838each GPU has its dedicated process, this avoids the performance overhead caused
839by GIL of Python interpreter.
840
841If you use :class:`~torch.nn.parallel.DistributedDataParallel`, you could use
842`torch.distributed.launch` utility to launch your program, see :ref:`distributed-launch`.
843
844.. _cudaGetDeviceCount:
845    https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g18808e54893cfcaafefeab31a73cc55f
846
847.. _cuInit:
848    https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__INITIALIZE.html#group__CUDA__INITIALIZE_1g0a2f1517e1bd8502c7194c3a8c134bc3
849
850.. _nvmlDeviceGetCount_v2:
851    https://docs.nvidia.com/deploy/nvml-api/group__nvmlDeviceQueries.html#group__nvmlDeviceQueries_1ga93623b195bff04bbe3490ca33c8a42d
852
853.. _cuda-graph-semantics:
854
855CUDA Graphs
856-----------
857
858A CUDA graph is a record of the work (mostly kernels and their arguments) that a
859CUDA stream and its dependent streams perform.
860For general principles and details on the underlying CUDA API, see
861`Getting Started with CUDA Graphs`_ and the
862`Graphs section`_ of the CUDA C Programming Guide.
863
864PyTorch supports the construction of CUDA graphs using `stream capture`_, which puts a
865CUDA stream in *capture mode*. CUDA work issued to a capturing stream doesn't actually
866run on the GPU. Instead, the work is recorded in a graph.
867
868After capture, the graph can be *launched* to run the GPU work as many times as needed.
869Each replay runs the same kernels with the same arguments. For pointer arguments this
870means the same memory addresses are used.
871By filling input memory with new data (e.g., from a new batch) before each replay,
872you can rerun the same work on new data.
873
874Why CUDA Graphs?
875^^^^^^^^^^^^^^^^
876
877Replaying a graph sacrifices the dynamic flexibility of typical eager execution in exchange for
878**greatly reduced CPU overhead**. A graph's arguments and kernels are fixed, so a graph replay
879skips all layers of argument setup and kernel dispatch, including Python, C++, and CUDA driver
880overheads. Under the hood, a replay submits the entire graph's work to the GPU with
881a single call to `cudaGraphLaunch`_.  Kernels in a replay also execute slightly faster
882on the GPU, but eliding CPU overhead is the main benefit.
883
884You should try CUDA graphs if all or part of your network is graph-safe (usually this means
885static shapes and static control flow, but see the other :ref:`constraints<capture-constraints>`)
886and you suspect its runtime is at least somewhat CPU-limited.
887
888.. _Getting Started with CUDA Graphs:
889    https://developer.nvidia.com/blog/cuda-graphs/
890.. _Graphs section:
891    https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-graphs
892.. _stream capture:
893    https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture
894.. _cudaGraphLaunch:
895    https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
896
897PyTorch API
898^^^^^^^^^^^
899
900.. warning::
901    This API is in beta and may change in future releases.
902
903PyTorch exposes graphs via a raw :class:`torch.cuda.CUDAGraph` class
904and two convenience wrappers,
905:class:`torch.cuda.graph` and
906:class:`torch.cuda.make_graphed_callables`.
907
908:class:`torch.cuda.graph` is a simple, versatile context manager that
909captures CUDA work in its context.
910Before capture, warm up the workload to be captured by running
911a few eager iterations. Warmup must occur on a side stream.
912Because the graph reads from and writes to the same memory addresses in every
913replay, you must maintain long-lived references to tensors that hold
914input and output data during capture.
915To run the graph on new input data, copy new data to the capture's input tensor(s),
916replay the graph, then read the new output from the capture's output tensor(s).
917Example::
918
919    g = torch.cuda.CUDAGraph()
920
921    # Placeholder input used for capture
922    static_input = torch.empty((5,), device="cuda")
923
924    # Warmup before capture
925    s = torch.cuda.Stream()
926    s.wait_stream(torch.cuda.current_stream())
927    with torch.cuda.stream(s):
928        for _ in range(3):
929            static_output = static_input * 2
930    torch.cuda.current_stream().wait_stream(s)
931
932    # Captures the graph
933    # To allow capture, automatically sets a side stream as the current stream in the context
934    with torch.cuda.graph(g):
935        static_output = static_input * 2
936
937    # Fills the graph's input memory with new data to compute on
938    static_input.copy_(torch.full((5,), 3, device="cuda"))
939    g.replay()
940    # static_output holds the results
941    print(static_output)  # full of 3 * 2 = 6
942
943    # Fills the graph's input memory with more data to compute on
944    static_input.copy_(torch.full((5,), 4, device="cuda"))
945    g.replay()
946    print(static_output)  # full of 4 * 2 = 8
947
948See
949:ref:`Whole-network capture<whole-network-capture>`,
950:ref:`Usage with torch.cuda.amp<graphs-with-amp>`, and
951:ref:`Usage with multiple streams<multistream-capture>`
952for realistic and advanced patterns.
953
954:class:`~torch.cuda.make_graphed_callables` is more sophisticated.
955:class:`~torch.cuda.make_graphed_callables` accepts Python functions and
956:class:`torch.nn.Module`\s. For each passed function or Module,
957it creates separate graphs of the forward-pass and backward-pass work. See
958:ref:`Partial-network capture<partial-network-capture>`.
959
960.. _capture-constraints:
961
962Constraints
963~~~~~~~~~~~
964
965A set of ops is *capturable* if it doesn't violate any of the following constraints.
966
967Constraints apply to all work in a
968:class:`torch.cuda.graph` context and all work in the forward and backward passes
969of any callable you pass to :func:`torch.cuda.make_graphed_callables`.
970
971Violating any of these will likely cause a runtime error:
972
973* Capture must occur on a non-default stream. (This is only a concern if you use the raw
974  :meth:`CUDAGraph.capture_begin<torch.cuda.CUDAGraph.capture_begin>` and
975  :meth:`CUDAGraph.capture_end<torch.cuda.CUDAGraph.capture_end>` calls.
976  :class:`~torch.cuda.graph` and
977  :func:`~torch.cuda.make_graphed_callables` set a side stream for you.)
978* Ops that synchronize the CPU with the GPU (e.g., ``.item()`` calls) are prohibited.
979* CUDA RNG operations are permitted, and when using multiple :class:`torch.Generator` instances within a graph,
980  they must be registered using :meth:`CUDAGraph.register_generator_state<torch.cuda.CUDAGraph.register_generator_state>` before graph capture.
981  Avoid using :meth:`Generator.get_state<torch.get_state>` and :meth:`Generator.set_state<torch.set_state>` during capture;
982  instead, utilize :meth:`Generator.graphsafe_set_state<torch.Generator.graphsafe_set_state>` and :meth:`Generator.graphsafe_get_state<torch.Generator.graphsafe_get_state>`
983  for managing generator states safely within the graph context. This ensures proper RNG operation and generator management within CUDA graphs.
984
985
986Violating any of these will likely cause silent numerical errors or undefined behavior:
987
988* Within a process, only one capture may be underway at a time.
989* No non-captured CUDA work may run in this process (on any thread) while capture is underway.
990* CPU work is not captured. If the captured ops include CPU work, that work will be elided during replay.
991* Every replay reads from and writes to the same (virtual) memory addresses.
992* Dynamic control flow (based on CPU or GPU data) is prohibited.
993* Dynamic shapes are prohibited. The graph assumes every tensor in the captured op sequence
994  has the same size and layout in every replay.
995* Using multiple streams in a capture is allowed, but there are :ref:`restrictions<multistream-capture>`.
996
997Non-constraints
998~~~~~~~~~~~~~~~
999
1000* Once captured, the graph may be replayed on any stream.
1001
1002.. _whole-network-capture:
1003
1004Whole-network capture
1005^^^^^^^^^^^^^^^^^^^^^^
1006
1007If your entire network is capturable, you can capture and replay an entire iteration::
1008
1009    N, D_in, H, D_out = 640, 4096, 2048, 1024
1010    model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
1011                                torch.nn.Dropout(p=0.2),
1012                                torch.nn.Linear(H, D_out),
1013                                torch.nn.Dropout(p=0.1)).cuda()
1014    loss_fn = torch.nn.MSELoss()
1015    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
1016
1017    # Placeholders used for capture
1018    static_input = torch.randn(N, D_in, device='cuda')
1019    static_target = torch.randn(N, D_out, device='cuda')
1020
1021    # warmup
1022    # Uses static_input and static_target here for convenience,
1023    # but in a real setting, because the warmup includes optimizer.step()
1024    # you must use a few batches of real data.
1025    s = torch.cuda.Stream()
1026    s.wait_stream(torch.cuda.current_stream())
1027    with torch.cuda.stream(s):
1028        for i in range(3):
1029            optimizer.zero_grad(set_to_none=True)
1030            y_pred = model(static_input)
1031            loss = loss_fn(y_pred, static_target)
1032            loss.backward()
1033            optimizer.step()
1034    torch.cuda.current_stream().wait_stream(s)
1035
1036    # capture
1037    g = torch.cuda.CUDAGraph()
1038    # Sets grads to None before capture, so backward() will create
1039    # .grad attributes with allocations from the graph's private pool
1040    optimizer.zero_grad(set_to_none=True)
1041    with torch.cuda.graph(g):
1042        static_y_pred = model(static_input)
1043        static_loss = loss_fn(static_y_pred, static_target)
1044        static_loss.backward()
1045        optimizer.step()
1046
1047    real_inputs = [torch.rand_like(static_input) for _ in range(10)]
1048    real_targets = [torch.rand_like(static_target) for _ in range(10)]
1049
1050    for data, target in zip(real_inputs, real_targets):
1051        # Fills the graph's input memory with new data to compute on
1052        static_input.copy_(data)
1053        static_target.copy_(target)
1054        # replay() includes forward, backward, and step.
1055        # You don't even need to call optimizer.zero_grad() between iterations
1056        # because the captured backward refills static .grad tensors in place.
1057        g.replay()
1058        # Params have been updated. static_y_pred, static_loss, and .grad
1059        # attributes hold values from computing on this iteration's data.
1060
1061.. _partial-network-capture:
1062
1063Partial-network capture
1064^^^^^^^^^^^^^^^^^^^^^^^^^
1065
1066If some of your network is unsafe to capture (e.g., due to dynamic control flow,
1067dynamic shapes, CPU syncs, or essential CPU-side logic), you can run the unsafe
1068part(s) eagerly and use :func:`torch.cuda.make_graphed_callables` to graph only
1069the capture-safe part(s).
1070
1071By default, callables returned by :func:`~torch.cuda.make_graphed_callables`
1072are autograd-aware, and can be used in the training loop as direct replacements
1073for the functions or :class:`nn.Module<torch.nn.Module>`\ s you passed.
1074
1075:func:`~torch.cuda.make_graphed_callables` internally creates
1076:class:`~torch.cuda.CUDAGraph` objects, runs warmup iterations, and maintains
1077static inputs and outputs as needed.  Therefore (unlike with
1078:class:`torch.cuda.graph`) you don't need to handle those manually.
1079
1080In the following example, data-dependent dynamic control flow means the
1081network isn't capturable end-to-end, but
1082:func:`~torch.cuda.make_graphed_callables`
1083lets us capture and run graph-safe sections as graphs regardless::
1084
1085    N, D_in, H, D_out = 640, 4096, 2048, 1024
1086
1087    module1 = torch.nn.Linear(D_in, H).cuda()
1088    module2 = torch.nn.Linear(H, D_out).cuda()
1089    module3 = torch.nn.Linear(H, D_out).cuda()
1090
1091    loss_fn = torch.nn.MSELoss()
1092    optimizer = torch.optim.SGD(chain(module1.parameters(),
1093                                      module2.parameters(),
1094                                      module3.parameters()),
1095                                lr=0.1)
1096
1097    # Sample inputs used for capture
1098    # requires_grad state of sample inputs must match
1099    # requires_grad state of real inputs each callable will see.
1100    x = torch.randn(N, D_in, device='cuda')
1101    h = torch.randn(N, H, device='cuda', requires_grad=True)
1102
1103    module1 = torch.cuda.make_graphed_callables(module1, (x,))
1104    module2 = torch.cuda.make_graphed_callables(module2, (h,))
1105    module3 = torch.cuda.make_graphed_callables(module3, (h,))
1106
1107    real_inputs = [torch.rand_like(x) for _ in range(10)]
1108    real_targets = [torch.randn(N, D_out, device="cuda") for _ in range(10)]
1109
1110    for data, target in zip(real_inputs, real_targets):
1111        optimizer.zero_grad(set_to_none=True)
1112
1113        tmp = module1(data)  # forward ops run as a graph
1114
1115        if tmp.sum().item() > 0:
1116            tmp = module2(tmp)  # forward ops run as a graph
1117        else:
1118            tmp = module3(tmp)  # forward ops run as a graph
1119
1120        loss = loss_fn(tmp, target)
1121        # module2's or module3's (whichever was chosen) backward ops,
1122        # as well as module1's backward ops, run as graphs
1123        loss.backward()
1124        optimizer.step()
1125
1126.. _graphs-with-amp:
1127
1128Usage with torch.cuda.amp
1129^^^^^^^^^^^^^^^^^^^^^^^^^
1130
1131For typical optimizers, :meth:`GradScaler.step<torch.cuda.amp.GradScaler.step>` syncs
1132the CPU with the GPU, which is prohibited during capture. To avoid errors, either use
1133:ref:`partial-network capture<partial-network-capture>`, or (if forward, loss,
1134and backward are capture-safe) capture forward, loss, and backward but not the
1135optimizer step::
1136
1137    # warmup
1138    # In a real setting, use a few batches of real data.
1139    s = torch.cuda.Stream()
1140    s.wait_stream(torch.cuda.current_stream())
1141    with torch.cuda.stream(s):
1142        for i in range(3):
1143            optimizer.zero_grad(set_to_none=True)
1144            with torch.cuda.amp.autocast():
1145                y_pred = model(static_input)
1146                loss = loss_fn(y_pred, static_target)
1147            scaler.scale(loss).backward()
1148            scaler.step(optimizer)
1149            scaler.update()
1150    torch.cuda.current_stream().wait_stream(s)
1151
1152    # capture
1153    g = torch.cuda.CUDAGraph()
1154    optimizer.zero_grad(set_to_none=True)
1155    with torch.cuda.graph(g):
1156        with torch.cuda.amp.autocast():
1157            static_y_pred = model(static_input)
1158            static_loss = loss_fn(static_y_pred, static_target)
1159        scaler.scale(static_loss).backward()
1160        # don't capture scaler.step(optimizer) or scaler.update()
1161
1162    real_inputs = [torch.rand_like(static_input) for _ in range(10)]
1163    real_targets = [torch.rand_like(static_target) for _ in range(10)]
1164
1165    for data, target in zip(real_inputs, real_targets):
1166        static_input.copy_(data)
1167        static_target.copy_(target)
1168        g.replay()
1169        # Runs scaler.step and scaler.update eagerly
1170        scaler.step(optimizer)
1171        scaler.update()
1172
1173.. _multistream-capture:
1174
1175Usage with multiple streams
1176^^^^^^^^^^^^^^^^^^^^^^^^^^^
1177
1178Capture mode automatically propagates to any streams that sync with a capturing stream.
1179Within capture, you may expose parallelism by issuing calls to different streams,
1180but the overall stream dependency DAG must branch out from the
1181initial capturing stream after capture begins and rejoin the initial stream
1182before capture ends::
1183
1184    with torch.cuda.graph(g):
1185        # at context manager entrance, torch.cuda.current_stream()
1186        # is the initial capturing stream
1187
1188        # INCORRECT (does not branch out from or rejoin initial stream)
1189        with torch.cuda.stream(s):
1190            cuda_work()
1191
1192        # CORRECT:
1193        # branches out from initial stream
1194        s.wait_stream(torch.cuda.current_stream())
1195        with torch.cuda.stream(s):
1196            cuda_work()
1197        # rejoins initial stream before capture ends
1198        torch.cuda.current_stream().wait_stream(s)
1199
1200.. note::
1201
1202    To avoid confusion for power users looking at replays in nsight systems or nvprof:
1203    Unlike eager execution, the graph interprets a nontrivial stream DAG in capture
1204    as a hint, not a command. During replay, the graph may reorganize independent ops
1205    onto different streams or enqueue them in a different order (while respecting your
1206    original DAG's overall dependencies).
1207
1208Usage with DistributedDataParallel
1209^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1210
1211NCCL < 2.9.6
1212~~~~~~~~~~~~
1213
1214NCCL versions earlier than 2.9.6 don't allow collectives to be captured.
1215You must use :ref:`partial-network capture<partial-network-capture>`,
1216which defers allreduces to happen outside graphed sections of backward.
1217
1218Call :func:`~torch.cuda.make_graphed_callables` on graphable network sections
1219*before* wrapping the network with DDP.
1220
1221NCCL >= 2.9.6
1222~~~~~~~~~~~~~
1223
1224NCCL versions 2.9.6 or later allow collectives in the graph.
1225Approaches that capture an :ref:`entire backward pass<whole-network-capture>`
1226are a viable option, but need three setup steps.
1227
12281. Disable DDP's internal async error handling::
1229
1230    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
1231    torch.distributed.init_process_group(...)
1232
12332. Before full-backward capture, DDP must be constructed in a side-stream context::
1234
1235    with torch.cuda.stream(s):
1236        model = DistributedDataParallel(model)
1237
12383. Your warmup must run at least 11 DDP-enabled eager iterations before capture.
1239
1240.. _graph-memory-management:
1241
1242Graph memory management
1243^^^^^^^^^^^^^^^^^^^^^^^
1244
1245A captured graph acts on the same virtual addresses every time it replays.
1246If PyTorch frees the memory, a later replay can hit an illegal memory access.
1247If PyTorch reassigns the memory to new tensors, the replay can corrupt the values
1248seen by those tensors.  Therefore, the virtual addresses used by the graph must be
1249reserved for the graph across replays. The PyTorch caching allocator achieves this
1250by detecting when capture is underway and satisfying the capture's allocations
1251from a graph-private memory pool. The private pool stays alive until its
1252:class:`~torch.cuda.CUDAGraph` object and all tensors created during capture
1253go out of scope.
1254
1255Private pools are maintained automatically. By default, the allocator creates a
1256separate private pool for each capture. If you capture multiple graphs,
1257this conservative approach ensures graph replays never corrupt each other's values,
1258but sometimes needlessly wastes memory.
1259
1260Sharing memory across captures
1261~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1262
1263To economize the memory stashed in private pools, :class:`torch.cuda.graph`
1264and :func:`torch.cuda.make_graphed_callables` optionally allow different
1265captures to share the same private pool.
1266It's safe for a set of graphs to share a private pool if you know they'll always
1267be replayed in the same order they were captured,
1268and never be replayed concurrently.
1269
1270:class:`torch.cuda.graph`'s ``pool`` argument is a hint to use a particular private pool,
1271and can be used to share memory across graphs as shown::
1272
1273    g1 = torch.cuda.CUDAGraph()
1274    g2 = torch.cuda.CUDAGraph()
1275
1276    # (create static inputs for g1 and g2, run warmups of their workloads...)
1277
1278    # Captures g1
1279    with torch.cuda.graph(g1):
1280        static_out_1 = g1_workload(static_in_1)
1281
1282    # Captures g2, hinting that g2 may share a memory pool with g1
1283    with torch.cuda.graph(g2, pool=g1.pool()):
1284        static_out_2 = g2_workload(static_in_2)
1285
1286    static_in_1.copy_(real_data_1)
1287    static_in_2.copy_(real_data_2)
1288    g1.replay()
1289    g2.replay()
1290
1291With :func:`torch.cuda.make_graphed_callables`, if you want to graph several
1292callables and you know they'll always run in the same order (and never concurrently)
1293pass them as a tuple in the same order they'll run in the live workload, and
1294:func:`~torch.cuda.make_graphed_callables` will capture their graphs using a shared
1295private pool.
1296
1297If, in the live workload, your callables will run in an order that occasionally changes,
1298or if they'll run concurrently, passing them as a tuple to a single invocation of
1299:func:`~torch.cuda.make_graphed_callables` is not allowed. Instead, you must call
1300:func:`~torch.cuda.make_graphed_callables` separately for each one.
1301