xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDAGraph.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <c10/core/Device.h>
5 #include <c10/cuda/CUDAGraphsC10Utils.h>
6 #include <c10/cuda/CUDAStream.h>
7 #include <c10/util/flat_hash_map.h>
8 
9 namespace at {
10 
11 struct Generator;
12 struct CUDAGeneratorImpl;
13 struct CUDAGeneratorState;
14 
15 namespace cuda {
16 
17 // Standalone way to get a unique mempool id usable as a pool=... argument
18 // to CUDAGraph::capture_begin
19 TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
20 
21 struct TORCH_CUDA_CPP_API CUDAGraph {
22   CUDAGraph();
23   ~CUDAGraph();
24 
25   static void inc_pending_event_queries();
26   static void dec_pending_event_queries();
27   static int num_pending_event_queries();
28   // See Note [Explicit Registration of Generators to the CUDA Graph]
29   void register_generator_state(c10::intrusive_ptr<at::CUDAGeneratorState> state);
30   void register_generator_state(const at::Generator& generator);
31   void capture_begin(
32       MempoolId_t pool = {0, 0},
33       cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
34   void capture_end();
35   void replay();
36   void reset();
37   MempoolId_t pool();
38   void enable_debug_mode();
39   void debug_dump(const std::string& debug_path);
40 
41  protected:
42   cudaGraph_t graph_ = nullptr;
43   cudaGraphExec_t graph_exec_ = nullptr;
44 
45   static std::atomic<int> pending_event_queries;
46 
47   // internal states so reset() can do its best cleaning up
48   // Set to true in capture_end if cudaStreamEndCapture succeeded
49   // Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate
50   // to create graph_exec_, then graph_ is deleted
51   bool has_graph_ = false;
52   // Set to true in capture_end if cudaGraphInstantiate succeeded
53   bool has_graph_exec_ = false;
54 
55   // the ID assigned by cuda during graph capture,
56   // used to identify when a stream is participating in capture
57   CaptureId_t capture_id_ = -1;
58 
59   // uuid used to request a particular private mempool from CUDACachingAllocator.
60   // By default, this will be set to {id_, 0}.
61   //
62   // If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_
63   // will be set to the other graph's mempool_id_, and therefore share a mempool with the
64   // other graph.
65   //
66   // If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(),
67   // it will share a mempool with any other captures that used "pool=handle".
68   //
69   // Sharing a mempool across graphs saves memory, and it's safe if you
70   // know you'll replay those graphs in the same order you captured them.
71   MempoolId_t mempool_id_;
72 
73   // Stream on which capture began
74   at::cuda::CUDAStream capture_stream_;
75 
76   // multiple generator states and their wholegraph_increments in this graph
77   // that are managed by the CUDA Graph
78   ska::flat_hash_map<c10::intrusive_ptr<at::CUDAGeneratorState>, uint64_t>
79       captured_generator_states_;
80 
81   // Device where capture occurred. Right now, for simplicity, we require all ops
82   // in a capture to run on the same device, but this is a limitation of CUDAGraph,
83   // not CUDA itself.  We can straightforwardly modify CUDAGraph to support multi-device
84   // captures if needed.
85   int capture_dev_;
86 };
87 
88 } // namespace cuda
89 } // namespace at
90