xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDAGeneratorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Context.h>
4 #include <ATen/core/Generator.h>
5 #include <ATen/core/TensorBase.h>
6 #include <ATen/cuda/PhiloxCudaState.h>
7 #include <atomic>
8 #include <limits>
9 #include <memory>
10 #include <unordered_set>
11 namespace at {
12 
13 namespace cuda {
14 struct CUDAGraph;
15 }
16 
17 /**
18  * Note [CUDA Graph-safe RNG states]
19  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
20  *
21  * Strategy:
22  * ~~~~~~~~~
23  * (It helps to look at
24  * cuda/detail/PhiloxCudaStateRaw.cuh and
25  * cuda/detail/UnpackRaw.cuh
26  * while you read this.)
27  *
28  * A CUDA graph containing multiple RNG ops behaves like a
29  * single giant kernel from the perspective of ops external
30  * to the graph.  During graph capture, logic in CUDAGeneratorImpl
31  * records the total of all offset increments that occur in the
32  * graphed region, and records the final total as the offset for
33  * the entire graph.
34  *
35  * When the graph reruns, the logic that reruns it
36  * increments this device's CUDA generator's offset
37  * by that total.
38  *
39  * Meanwhile, within the graph, at capture time, instead of
40  * populating PhiloxCudaStates with the uint64_t offset pulled
41  * directly from the global state, PhiloxCudaState uses a pointer
42  * to a one-element stream-local int64_t device tensor
43  * holding an initial offset value, and a uint64_t holding an
44  * intra-graph offset. (The intra-graph offset starts from zero
45  * when capture begins.)  In each consumer kernel,
46  * at::cuda::philox::unpack computes the offset to use for this kernel
47  * as intra-graph offset + *initial offset.
48  *
49  * When the graph reruns, the logic that reruns it first
50  * fill_s the initial offset tensor with this device's
51  * CUDA generator's current offset.
52  *
53  * The control flow above ensures graphed execution is bitwise
54  * identical to eager execution as long as RNG ops are enqueued
55  * from a single thread, even if RNG ops and graphs containing
56  * RNG ops are enqueued and run simultaneously on multiple streams.
57  *
58  * Usage:
59  * ~~~~~~
60  * PhiloxCudaState in this file, and unpack() in
61  * cuda/CUDAGraphsUtils.cuh allow non-divergent use of
62  * CUDAGeneratorImpl whether graph capture is underway or not.
63  *
64  * Each PhiloxCudaState instance should be used for one and only one
65  * consumer kernel.
66  *
67  * Example (see e.g. native/cuda/Dropout.cu):
68  *
69  * #include <ATen/cuda/CUDAGeneratorImpl.h>
70  * #include <ATen/cuda/CUDAGraphsUtils.cuh>
71  *
72  * __global__ void kernel(..., PhiloxCudaState philox_args) {
73  *   auto seeds = at::cuda::philox::unpack(philox_args);
74  *   IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
75  *   curandStatePhilox4_32_10_t state;
76  *   curand_init(std::get<0>(seeds), // seed
77  *               idx,                // per-thread subsequence
78  *               std::get<1>(seeds), // offset in subsequence
79  *               &state);
80  *   ...
81  * }
82  *
83  * host_caller(...) {
84  *   PhiloxCudaState rng_engine_inputs;
85  *   {
86  *     // See Note [Acquire lock when using random generators]
87  *     std::lock_guard<std::mutex> lock(gen->mutex_);
88  *
89  *     // gen could be HostState or DevState here! No divergent code needed!
90  *     rng_engine_inputs = gen->philox_cuda_state(offset_increment);
91  *   }
92  *   kernel<<<...>>>(..., rng_engine_inputs);
93  * }
94  *
95  */
96 
97 struct CUDAGeneratorState : public c10::intrusive_ptr_target {
98   uint64_t seed_;
99   uint64_t philox_offset_per_thread_;
100   uint32_t offset_intragraph_;
101   bool capturing_{};
102   std::unordered_set<cuda::CUDAGraph*> registered_graphs_;
103   at::TensorBase seed_extragraph_{};
104   at::TensorBase offset_extragraph_{};
105 
106   CUDAGeneratorState(
107       uint64_t seed = default_rng_seed_val,
108       uint64_t philox_offset_per_thread = 0,
109       uint32_t offset_intragraph = 0)
seed_CUDAGeneratorState110       : seed_(seed),
111         philox_offset_per_thread_(philox_offset_per_thread),
112         offset_intragraph_(offset_intragraph) {}
113 
114   void increase(uint64_t increment);
115 
116   void register_graph(cuda::CUDAGraph* graph);
117   void unregister_graph(cuda::CUDAGraph* graph);
118 
119   void capture_prologue();
120   // capture_epilogue returns the wholegraph_increment
121   uint64_t capture_epilogue();
122   void replay_prologue(uint64_t wholegraph_increment);
123   c10::intrusive_ptr<CUDAGeneratorState> clone();
124 };
125 
126 struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
127   // Constructors
128   CUDAGeneratorImpl(DeviceIndex device_index = -1);
129   CUDAGeneratorImpl(
130       DeviceIndex device_index,
131       c10::intrusive_ptr<CUDAGeneratorState> state_);
132   ~CUDAGeneratorImpl() override = default;
133 
134   // CUDAGeneratorImpl methods
135   std::shared_ptr<CUDAGeneratorImpl> clone() const;
136   void set_current_seed(uint64_t seed) override;
137   void set_offset(uint64_t offset) override;
138   uint64_t get_offset() const override;
139   uint64_t current_seed() const override;
140   uint64_t seed() override;
141   void set_state(const c10::TensorImpl& new_state) override;
142   c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
143   void graphsafe_set_state(
144       const c10::intrusive_ptr<GeneratorImpl>& state) override;
145   c10::intrusive_ptr<c10::GeneratorImpl> graphsafe_get_state() const override;
146 
147   void set_philox_offset_per_thread(uint64_t offset);
148   uint64_t philox_offset_per_thread() const;
149 
150   void register_graph(cuda::CUDAGraph* graph);
151   void unregister_graph(cuda::CUDAGraph* graph);
152 
153   // Generates a PhiloxCudaState with a specified increment, and increment
154   // current state
155   PhiloxCudaState philox_cuda_state(uint64_t increment);
156 
reset_rnn_stateCUDAGeneratorImpl157   bool reset_rnn_state() {
158     return !no_reset_rnn_state_.test_and_set();
159   }
160 
161   // Temporarily accommodates call sites that use philox_engine_inputs.
162   // Allows incremental refactor of call sites to use philox_cuda_state.
163   std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);
164 
165   static c10::DeviceType device_type();
166 
167  private:
168   CUDAGeneratorImpl* clone_impl() const override;
169 
170   c10::intrusive_ptr<CUDAGeneratorState> state_;
171   std::atomic_flag no_reset_rnn_state_;
172 };
173 
174 namespace cuda::detail {
175 
176 TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator(
177     DeviceIndex device_index = -1);
178 TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1);
179 
180 } // namespace cuda::detail
181 } // namespace at
182