1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <c10/core/Device.h> 5 #include <c10/cuda/CUDAGraphsC10Utils.h> 6 #include <c10/cuda/CUDAStream.h> 7 #include <c10/util/flat_hash_map.h> 8 9 namespace at { 10 11 struct Generator; 12 struct CUDAGeneratorImpl; 13 struct CUDAGeneratorState; 14 15 namespace cuda { 16 17 // Standalone way to get a unique mempool id usable as a pool=... argument 18 // to CUDAGraph::capture_begin 19 TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle(); 20 21 struct TORCH_CUDA_CPP_API CUDAGraph { 22 CUDAGraph(); 23 ~CUDAGraph(); 24 25 static void inc_pending_event_queries(); 26 static void dec_pending_event_queries(); 27 static int num_pending_event_queries(); 28 // See Note [Explicit Registration of Generators to the CUDA Graph] 29 void register_generator_state(c10::intrusive_ptr<at::CUDAGeneratorState> state); 30 void register_generator_state(const at::Generator& generator); 31 void capture_begin( 32 MempoolId_t pool = {0, 0}, 33 cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal); 34 void capture_end(); 35 void replay(); 36 void reset(); 37 MempoolId_t pool(); 38 void enable_debug_mode(); 39 void debug_dump(const std::string& debug_path); 40 41 protected: 42 cudaGraph_t graph_ = nullptr; 43 cudaGraphExec_t graph_exec_ = nullptr; 44 45 static std::atomic<int> pending_event_queries; 46 47 // internal states so reset() can do its best cleaning up 48 // Set to true in capture_end if cudaStreamEndCapture succeeded 49 // Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate 50 // to create graph_exec_, then graph_ is deleted 51 bool has_graph_ = false; 52 // Set to true in capture_end if cudaGraphInstantiate succeeded 53 bool has_graph_exec_ = false; 54 55 // the ID assigned by cuda during graph capture, 56 // used to identify when a stream is participating in capture 57 CaptureId_t capture_id_ = -1; 58 59 // uuid used to request a particular private mempool from CUDACachingAllocator. 60 // By default, this will be set to {id_, 0}. 61 // 62 // If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_ 63 // will be set to the other graph's mempool_id_, and therefore share a mempool with the 64 // other graph. 65 // 66 // If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(), 67 // it will share a mempool with any other captures that used "pool=handle". 68 // 69 // Sharing a mempool across graphs saves memory, and it's safe if you 70 // know you'll replay those graphs in the same order you captured them. 71 MempoolId_t mempool_id_; 72 73 // Stream on which capture began 74 at::cuda::CUDAStream capture_stream_; 75 76 // multiple generator states and their wholegraph_increments in this graph 77 // that are managed by the CUDA Graph 78 ska::flat_hash_map<c10::intrusive_ptr<at::CUDAGeneratorState>, uint64_t> 79 captured_generator_states_; 80 81 // Device where capture occurred. Right now, for simplicity, we require all ops 82 // in a capture to run on the same device, but this is a limitation of CUDAGraph, 83 // not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device 84 // captures if needed. 85 int capture_dev_; 86 }; 87 88 } // namespace cuda 89 } // namespace at 90