1 #pragma once 2 3 #include <ATen/Context.h> 4 #include <ATen/core/Generator.h> 5 #include <ATen/core/TensorBase.h> 6 #include <ATen/cuda/PhiloxCudaState.h> 7 #include <atomic> 8 #include <limits> 9 #include <memory> 10 #include <unordered_set> 11 namespace at { 12 13 namespace cuda { 14 struct CUDAGraph; 15 } 16 17 /** 18 * Note [CUDA Graph-safe RNG states] 19 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 20 * 21 * Strategy: 22 * ~~~~~~~~~ 23 * (It helps to look at 24 * cuda/detail/PhiloxCudaStateRaw.cuh and 25 * cuda/detail/UnpackRaw.cuh 26 * while you read this.) 27 * 28 * A CUDA graph containing multiple RNG ops behaves like a 29 * single giant kernel from the perspective of ops external 30 * to the graph. During graph capture, logic in CUDAGeneratorImpl 31 * records the total of all offset increments that occur in the 32 * graphed region, and records the final total as the offset for 33 * the entire graph. 34 * 35 * When the graph reruns, the logic that reruns it 36 * increments this device's CUDA generator's offset 37 * by that total. 38 * 39 * Meanwhile, within the graph, at capture time, instead of 40 * populating PhiloxCudaStates with the uint64_t offset pulled 41 * directly from the global state, PhiloxCudaState uses a pointer 42 * to a one-element stream-local int64_t device tensor 43 * holding an initial offset value, and a uint64_t holding an 44 * intra-graph offset. (The intra-graph offset starts from zero 45 * when capture begins.) In each consumer kernel, 46 * at::cuda::philox::unpack computes the offset to use for this kernel 47 * as intra-graph offset + *initial offset. 48 * 49 * When the graph reruns, the logic that reruns it first 50 * fill_s the initial offset tensor with this device's 51 * CUDA generator's current offset. 52 * 53 * The control flow above ensures graphed execution is bitwise 54 * identical to eager execution as long as RNG ops are enqueued 55 * from a single thread, even if RNG ops and graphs containing 56 * RNG ops are enqueued and run simultaneously on multiple streams. 57 * 58 * Usage: 59 * ~~~~~~ 60 * PhiloxCudaState in this file, and unpack() in 61 * cuda/CUDAGraphsUtils.cuh allow non-divergent use of 62 * CUDAGeneratorImpl whether graph capture is underway or not. 63 * 64 * Each PhiloxCudaState instance should be used for one and only one 65 * consumer kernel. 66 * 67 * Example (see e.g. native/cuda/Dropout.cu): 68 * 69 * #include <ATen/cuda/CUDAGeneratorImpl.h> 70 * #include <ATen/cuda/CUDAGraphsUtils.cuh> 71 * 72 * __global__ void kernel(..., PhiloxCudaState philox_args) { 73 * auto seeds = at::cuda::philox::unpack(philox_args); 74 * IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; 75 * curandStatePhilox4_32_10_t state; 76 * curand_init(std::get<0>(seeds), // seed 77 * idx, // per-thread subsequence 78 * std::get<1>(seeds), // offset in subsequence 79 * &state); 80 * ... 81 * } 82 * 83 * host_caller(...) { 84 * PhiloxCudaState rng_engine_inputs; 85 * { 86 * // See Note [Acquire lock when using random generators] 87 * std::lock_guard<std::mutex> lock(gen->mutex_); 88 * 89 * // gen could be HostState or DevState here! No divergent code needed! 90 * rng_engine_inputs = gen->philox_cuda_state(offset_increment); 91 * } 92 * kernel<<<...>>>(..., rng_engine_inputs); 93 * } 94 * 95 */ 96 97 struct CUDAGeneratorState : public c10::intrusive_ptr_target { 98 uint64_t seed_; 99 uint64_t philox_offset_per_thread_; 100 uint32_t offset_intragraph_; 101 bool capturing_{}; 102 std::unordered_set<cuda::CUDAGraph*> registered_graphs_; 103 at::TensorBase seed_extragraph_{}; 104 at::TensorBase offset_extragraph_{}; 105 106 CUDAGeneratorState( 107 uint64_t seed = default_rng_seed_val, 108 uint64_t philox_offset_per_thread = 0, 109 uint32_t offset_intragraph = 0) seed_CUDAGeneratorState110 : seed_(seed), 111 philox_offset_per_thread_(philox_offset_per_thread), 112 offset_intragraph_(offset_intragraph) {} 113 114 void increase(uint64_t increment); 115 116 void register_graph(cuda::CUDAGraph* graph); 117 void unregister_graph(cuda::CUDAGraph* graph); 118 119 void capture_prologue(); 120 // capture_epilogue returns the wholegraph_increment 121 uint64_t capture_epilogue(); 122 void replay_prologue(uint64_t wholegraph_increment); 123 c10::intrusive_ptr<CUDAGeneratorState> clone(); 124 }; 125 126 struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl { 127 // Constructors 128 CUDAGeneratorImpl(DeviceIndex device_index = -1); 129 CUDAGeneratorImpl( 130 DeviceIndex device_index, 131 c10::intrusive_ptr<CUDAGeneratorState> state_); 132 ~CUDAGeneratorImpl() override = default; 133 134 // CUDAGeneratorImpl methods 135 std::shared_ptr<CUDAGeneratorImpl> clone() const; 136 void set_current_seed(uint64_t seed) override; 137 void set_offset(uint64_t offset) override; 138 uint64_t get_offset() const override; 139 uint64_t current_seed() const override; 140 uint64_t seed() override; 141 void set_state(const c10::TensorImpl& new_state) override; 142 c10::intrusive_ptr<c10::TensorImpl> get_state() const override; 143 void graphsafe_set_state( 144 const c10::intrusive_ptr<GeneratorImpl>& state) override; 145 c10::intrusive_ptr<c10::GeneratorImpl> graphsafe_get_state() const override; 146 147 void set_philox_offset_per_thread(uint64_t offset); 148 uint64_t philox_offset_per_thread() const; 149 150 void register_graph(cuda::CUDAGraph* graph); 151 void unregister_graph(cuda::CUDAGraph* graph); 152 153 // Generates a PhiloxCudaState with a specified increment, and increment 154 // current state 155 PhiloxCudaState philox_cuda_state(uint64_t increment); 156 reset_rnn_stateCUDAGeneratorImpl157 bool reset_rnn_state() { 158 return !no_reset_rnn_state_.test_and_set(); 159 } 160 161 // Temporarily accommodates call sites that use philox_engine_inputs. 162 // Allows incremental refactor of call sites to use philox_cuda_state. 163 std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment); 164 165 static c10::DeviceType device_type(); 166 167 private: 168 CUDAGeneratorImpl* clone_impl() const override; 169 170 c10::intrusive_ptr<CUDAGeneratorState> state_; 171 std::atomic_flag no_reset_rnn_state_; 172 }; 173 174 namespace cuda::detail { 175 176 TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator( 177 DeviceIndex device_index = -1); 178 TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1); 179 180 } // namespace cuda::detail 181 } // namespace at 182