xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDAGraph.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cuda/CUDAGeneratorImpl.h>
2 #include <ATen/cuda/CUDAGraph.h>
3 #include <ATen/cuda/Exceptions.h>
4 #include <ATen/Functions.h>
5 #include <c10/cuda/CUDACachingAllocator.h>
6 #include <c10/cuda/CUDAFunctions.h>
7 
8 #include <chrono>
9 #include <cstddef>
10 #include <cstdint>
11 #include <thread>
12 #include <vector>
13 
14 namespace at::cuda {
15 
16 static bool _cuda_graphs_debug = false;
17 constexpr int kSynchronizeBusyWaitMillis = 10;
18 
graph_pool_handle()19 MempoolId_t graph_pool_handle() {
20   // Sets just the second value, to distinguish it from MempoolId_ts created from
21   // cudaStreamGetCaptureInfo id_s in capture_begin.
22   auto new_pool = c10::cuda::MemPool();
23   return new_pool.id();
24 }
25 
26 /**
27  * Note [CUDA Graph Wrapper Class]
28  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
29  * Q: Why do we need graph capture and launch bindings in Pytorch?
30  *    Why can't they live in a user extension, for example?
31  *
32  * A1: Convenience.
33  * A2: To ensure valid numerics on replay, some native CUDA ops (like RNG ops with
34  *     CPU statefulness) need cooperation from the capture and replay bindings
35  *     (see Note [CUDA Graph-safe RNG states] in CUDAGeneratorImpl.h).
36  *
37  *     We can't expect users to know about this cooperation.  If users write capture
38  *     bindings naively in an extension, they likely won't interact with the native
39  *     ops properly.  Their graphs would yield invalid numerics on replay.
40  */
41 
42 /**
43  * Note [Interaction with CUDA graph capture] in CUDACachingAllocator.cpp
44  * describes memory management for captures.
45  */
46 
47 std::atomic<int> CUDAGraph::pending_event_queries = 0;
48 
49 // Track any outstanding event queries that could happen e.g., in a NCCL watchdog so that they
50 // can be resolved before the capture begins. Note that event queries are not allowed during a
51 // graph capture in the default capture mode.
inc_pending_event_queries()52 void CUDAGraph::inc_pending_event_queries() {
53   pending_event_queries++;
54 }
55 
dec_pending_event_queries()56 void CUDAGraph::dec_pending_event_queries() {
57   TORCH_INTERNAL_ASSERT(pending_event_queries > 0,
58     "Attempted to decrement the number of outstanding events to be queried, but it was <= 0.");
59   pending_event_queries--;
60 }
61 
num_pending_event_queries()62 int CUDAGraph::num_pending_event_queries() {
63   return pending_event_queries;
64 }
65 
CUDAGraph()66 CUDAGraph::CUDAGraph()
67   // CUDAStreams may not be default-constructed.
68   : capture_stream_(at::cuda::getCurrentCUDAStream()) {
69 }
70 
register_generator_state(c10::intrusive_ptr<at::CUDAGeneratorState> state)71 void CUDAGraph::register_generator_state(
72     c10::intrusive_ptr<at::CUDAGeneratorState> state) {
73   captured_generator_states_[std::move(state)] = 0;
74 }
75 
register_generator_state(const at::Generator & generator)76 void CUDAGraph::register_generator_state(const at::Generator& generator) {
77   c10::intrusive_ptr<CUDAGeneratorImpl> cuda_gen =
78       dynamic_intrusive_pointer_cast<CUDAGeneratorImpl>(
79           generator.getIntrusivePtr());
80   cuda_gen->register_graph(this);
81 }
82 
capture_begin(MempoolId_t pool,cudaStreamCaptureMode capture_mode)83 void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) {
84   TORCH_CHECK(!has_graph_exec_,
85               "This CUDAGraph instance already owns a captured graph. "
86               "To capture a new graph, create a new instance.");
87 
88   // default generator is always registered
89   auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
90       std::nullopt, cuda::detail::getDefaultCUDAGenerator());
91   gen->register_graph(this);
92 
93   for (auto& [generator_state, wholegraph_increments] :
94        captured_generator_states_) {
95     generator_state->capture_prologue();
96   }
97 
98   auto stream = at::cuda::getCurrentCUDAStream();
99 
100   TORCH_CHECK(stream != at::cuda::getDefaultCUDAStream(),
101               "CUDA graphs must be captured on a non-default stream. "
102               "(However, after capture, it's ok to replay them on the "
103               "default stream.)");
104 
105   capture_stream_ = stream;
106   capture_dev_ = c10::cuda::current_device();
107 
108   if (pool.first != 0 || pool.second != 0) {
109     // Either value being nonzero means the user supplied a pool to share.
110     // But only one should be nonzero.
111     // If pool was created by another graph's capture_begin, first should be nonzero.
112     // If pool was created by graph_pool_handle, second should be nonzero.
113     TORCH_INTERNAL_ASSERT(!(pool.first && pool.second));
114     mempool_id_ = pool;
115   } else {
116     // User did not ask us to share a mempool. Create graph pool handle using is_user_created=false.
117     // Sets just the first value, to distinguish it from MempoolId_ts created by graph_pool_handle().
118     auto mempool = c10::cuda::MemPool({}, false);
119     mempool_id_ = mempool.id();
120     TORCH_INTERNAL_ASSERT(mempool_id_.first > 0);
121   }
122 
123   // Addendum: beginAllocateStreamToPool is now called before cudaStreamBeginCapture to prevent an
124   // autograd thread's free() call triggering an invalid cudaEventRecord in the caching allocator
125   // due to the capture status being updated _after_ a capture had already started.
126   c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, [this](cudaStream_t stream) {
127       cudaStreamCaptureStatus status;
128       CaptureId_t stream_capture_id;
129       AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
130       return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_;
131   });
132 
133   // At this point, any NCCL watchdogs should be aware that we are in capture mode
134   // and therefore should not enqueue any additional work that could be event-queried.
135   // We still must wait on any existing work that has not been cleaned up.
136   while (num_pending_event_queries()) {
137     TORCH_WARN_ONCE("Waiting for pending NCCL work to finish before starting graph capture.");
138     std::this_thread::sleep_for(
139       std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
140   }
141 
142   // cudaStreamCaptureModeGlobal is the most conservative option to
143   // prevent potentially unsafe CUDA API calls during capture.  See
144   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
145   AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, capture_mode));
146 
147   cudaStreamCaptureStatus status;
148   AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &capture_id_));
149   TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
150 
151 }
152 
capture_end()153 void CUDAGraph::capture_end() {
154   auto stream = at::cuda::getCurrentCUDAStream();
155 
156   TORCH_CHECK(stream == capture_stream_,
157               "Capture must end on the same stream it began on.");
158 
159   AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_));
160 
161   c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
162 
163   TORCH_CHECK(graph_ != NULL, "Invalid capture.");
164   has_graph_ = true;
165 
166   // In typical graph usage some tensors (e.g. the tensors used for graph IO) are not freed
167   // between replays.
168   // If Pytorch compiles and runs with a CUDA 11.4+ toolkit, there's a chance the allocator backend
169   // is cudaMallocAsync.
170   // cudaMallocAsync is generally graph-safe, but if some tensors are not freed between replays,
171   // the graph's internal bookkeeping requires that we instantiate with
172   // cudaGraphInstantiateFlagAutoFreeOnLaunch. See
173   // cudaGraphLaunch
174   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
175   // cudaGraphInstantiateWithFlags
176   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
177 #if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040)
178   int version;
179   AT_CUDA_CHECK(cudaDriverGetVersion(&version));
180   if (version < 11040) {
181 #endif
182     // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
183     // who prefer not to report error message through these arguments moving forward
184     // (they prefer return value, or errors on api calls internal to the capture)
185 #if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
186     AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, 0));
187 #else
188     AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
189 #endif
190 #if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040)
191   } else {
192     AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
193                                                 graph_,
194                                                 cudaGraphInstantiateFlagAutoFreeOnLaunch));
195   }
196 #endif
197 
198   has_graph_exec_ = true;
199 
200   for (auto& [generator_state, wholegraph_increments] :
201        captured_generator_states_) {
202     wholegraph_increments = generator_state->capture_epilogue();
203   }
204 
205   size_t numCUDAGraphNodes = 0;
206   AT_CUDA_CHECK(cudaGraphGetNodes(graph_, NULL, &numCUDAGraphNodes));
207   if (numCUDAGraphNodes == 0) {
208       TORCH_WARN("The CUDA Graph is empty. This usually means that the graph was ",
209                  "attempted to be captured on wrong device or stream.");
210   }
211 
212   // check if debug path is set
213   if (!_cuda_graphs_debug) {
214     // Now that we've instantiated graph_ into graph_exec_,
215     // we don't need graph_ anymore.
216     AT_CUDA_CHECK(cudaGraphDestroy(graph_));
217     has_graph_ = false;
218   } else {
219     TORCH_WARN("DEBUG: TORCH_CUDAGRAPHS_DEBUG_PATH detected. graph_ will not be freed until debug_dump is called.");
220   }
221 }
222 
replay()223 void CUDAGraph::replay() {
224   TORCH_CHECK(has_graph_exec_,
225               "Called CUDAGraph::replay without a preceding successful capture.");
226 
227   c10::OptionalDeviceGuard device_guard{capture_stream_.device()};
228 
229   for (auto& [generator_state, wholegraph_increments] :
230        captured_generator_states_) {
231     generator_state->replay_prologue(wholegraph_increments);
232   }
233   // graph_exec_ may be replayed in any stream.
234   AT_CUDA_CHECK(cudaGraphLaunch(graph_exec_, at::cuda::getCurrentCUDAStream()));
235 
236   int version;
237   AT_CUDA_CHECK(cudaDriverGetVersion(&version));
238   if (version < 11040) {
239     // Workaround for bug in libcuda.so that causes replayed graphs with
240     // certain topologies to be corrupted (kernels elided, internal syncs
241     // ignored) when replayed back to back without a sync in between.
242     // The bug is fixed in CUDA 11.4+.
243     AT_CUDA_CHECK(cudaDeviceSynchronize());
244   }
245 }
246 
enable_debug_mode()247 void CUDAGraph::enable_debug_mode() {
248   _cuda_graphs_debug = true;
249 }
250 
debug_dump(const std::string & debug_path)251 void CUDAGraph::debug_dump(const std::string& debug_path) {
252 #if (defined(CUDA_VERSION) && CUDA_VERSION >= 11030)|| defined(USE_ROCM)
253   if (_cuda_graphs_debug) {
254     TORCH_WARN("DEBUG: calling debug_dump()");
255     if (has_graph_) {
256       TORCH_WARN("DEBUG: calling cudaGraphDebugDotPrint() with ", debug_path);
257       C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), cudaGraphDebugDotFlagsVerbose)); // most verbose output
258       AT_CUDA_CHECK(cudaGraphDestroy(graph_));
259       has_graph_ = false;
260     }
261   } else {
262     TORCH_WARN("CUDA Graphs debug not enabled, set with torch._C._cuda_enable_graphs_debug_mode");
263   }
264 #else
265   TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.3 or ROCM >= 5.6");
266 #endif
267 }
268 
reset()269 void CUDAGraph::reset() {
270   // I'd prefer these checks throw exceptions, not print warnings,
271   // but the destructor calls reset(), and at least one CI build
272   // refuses to compile with a throwing destructor.
273   //
274   // Instead of calling reset() in the destructor to clean up, I could
275   // call reset() in the __del__ method of a thin Python wrapper,
276   // in which case reset would be allowed to throw exceptions.
277   // But Stackoverflow does not like user-defined __del__.
278   // __del__ prevents Graph instances from EVER being garbage collected
279   // if they participate in a reference cycle.
280   // And exceptions thrown in __del__ only print a warning anyway.
281   //
282   // Calling reset() in the C++ destructor, with warnings instead of exceptions
283   // if calls fail, is the compromise we chose.
284   //
285   // If capture_begin, the capture, or capture_end failed at some point, this CUDAGraph, the generator,
286   // and the allocator could end up in all kinds of weird states depending where failure occurred.
287   // If the user catches the failure exception in a script, or is running in REPL or (god forbid)
288   // a Jupyter notebook, I don't see an easy way for reset() to gracefully fix all such possible error states.
289   if (has_graph_ || has_graph_exec_) {
290     // notifyCaptureDestroy may throw. How should we handle this?
291     c10::cuda::CUDACachingAllocator::releasePool(capture_dev_, mempool_id_);
292   }
293   if (has_graph_) {
294     C10_CUDA_CHECK_WARN(cudaGraphDestroy(graph_));
295     has_graph_ = false;
296   }
297   if (has_graph_exec_) {
298     C10_CUDA_CHECK_WARN(cudaGraphExecDestroy(graph_exec_));
299     has_graph_exec_ = false;
300   }
301 }
302 
303 // Returns an id another graph's capture_begin can use to share the same memory pool as this graph.
pool()304 MempoolId_t CUDAGraph::pool() {
305 TORCH_CHECK(has_graph_exec_,
306               "Called CUDAGraph::pool() without a preceding successful capture.");
307   return mempool_id_;
308 }
309 
~CUDAGraph()310 CUDAGraph::~CUDAGraph() {
311   for (auto& [generator_state, wholegraph_increments] :
312        captured_generator_states_) {
313     generator_state->unregister_graph(this);
314   }
315   reset();
316 }
317 
318 } // namespace at::cuda
319