xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Functions.h>
2 #include <ATen/Tensor.h>
3 #include <ATen/Utils.h>
4 #include <ATen/cuda/CUDAGeneratorImpl.h>
5 #include <ATen/cuda/CUDAGraph.h>
6 #include <ATen/cuda/CUDAGraphsUtils.cuh>
7 #include <c10/core/StreamGuard.h>
8 #include <c10/cuda/CUDAFunctions.h>
9 #include <c10/util/CallOnce.h>
10 #include <deque>
11 
12 namespace at {
13 namespace cuda::detail {
14 
15 namespace {
16 
17 // Ensures we only call cudaGetDeviceCount only once.
18 static c10::once_flag num_gpu_init_flag;
19 
20 // Total number of gpus in the system.
21 static int64_t num_gpus;
22 
23 // Ensures default_gens_cuda is initialized once.
24 static std::deque<c10::once_flag> cuda_gens_init_flag;
25 
26 // Default, global CUDA generators, one per GPU.
27 static std::vector<Generator> default_gens_cuda;
28 
29 /*
30  * Populates the global variables related to CUDA generators
31  * Warning: this function must only be called once!
32  */
initCUDAGenVector()33 static void initCUDAGenVector() {
34   num_gpus = c10::cuda::device_count();
35   cuda_gens_init_flag.resize(num_gpus);
36   default_gens_cuda.resize(num_gpus);
37 }
38 
39 } // anonymous namespace
40 
41 /**
42  * PyTorch maintains a collection of default generators that get
43  * initialized once. The purpose of these default generators is to
44  * maintain a global running state of the pseudo random number generation,
45  * when a user does not explicitly mention any generator.
46  * getDefaultCUDAGenerator gets the default generator for a particular
47  * cuda device.
48  */
getDefaultCUDAGenerator(DeviceIndex device_index)49 const Generator& getDefaultCUDAGenerator(DeviceIndex device_index) {
50   c10::call_once(num_gpu_init_flag, initCUDAGenVector);
51   DeviceIndex idx = device_index;
52   if (idx == -1) {
53     idx = c10::cuda::current_device();
54   } else {
55     TORCH_CHECK(idx >= 0 && idx < num_gpus);
56   }
57   c10::call_once(cuda_gens_init_flag[idx], [&] {
58     default_gens_cuda[idx] = make_generator<CUDAGeneratorImpl>(idx);
59     default_gens_cuda[idx].seed();
60   });
61   return default_gens_cuda[idx];
62 }
63 
64 /**
65  * Utility to create a CUDAGeneratorImpl. Returns a shared_ptr
66  */
createCUDAGenerator(DeviceIndex device_index)67 Generator createCUDAGenerator(DeviceIndex device_index) {
68   c10::call_once(num_gpu_init_flag, initCUDAGenVector);
69   DeviceIndex idx = device_index;
70   if (idx == -1) {
71     idx = c10::cuda::current_device();
72   }
73   TORCH_CHECK(idx >= 0 && idx < num_gpus, "The device_index is invalid.");
74   auto gen = make_generator<CUDAGeneratorImpl>(idx);
75   auto cuda_gen = check_generator<CUDAGeneratorImpl>(gen);
76   cuda_gen->set_current_seed(default_rng_seed_val);
77   cuda_gen->set_philox_offset_per_thread(0);
78   return gen;
79 }
80 
81 } // namespace cuda::detail
82 
83 /**
84  * Creates a clone of this CUDA Generator State.
85  */
clone()86 c10::intrusive_ptr<CUDAGeneratorState> CUDAGeneratorState::clone() {
87   return make_intrusive<CUDAGeneratorState>(
88       seed_, philox_offset_per_thread_, offset_intragraph_);
89 }
90 
91 /**
92  * Function to increase the internal offset based on the specified increment.
93  */
increase(uint64_t increment)94 void CUDAGeneratorState::increase(uint64_t increment) {
95   // Rounds increment up to the nearest multiple of 4 to meet alignment
96   // requirements.
97   // see Note [Why enforce RNG offset % 4 == 0?]
98   increment = ((increment + 3) / 4) * 4;
99   // Handling different behaviors based on whether capturing is active.
100   if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
101     // Ensures that the state is actually capturing.
102     TORCH_CHECK(
103         capturing_,
104         "Attempt to increase offset for a CUDA generator not in capture mode.");
105     // Ensures the offset is a multiple of 4
106     // see Note [Why enforce RNG offset % 4 == 0?]
107     TORCH_INTERNAL_ASSERT(
108         offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4.");
109     // Ensures the increment does not cause overflow.
110     TORCH_INTERNAL_ASSERT(
111         offset_intragraph_ <= std::numeric_limits<uint32_t>::max() - increment,
112         "Increment causes overflow in the offset value.");
113     offset_intragraph_ += increment;
114   } else {
115     // Checks that the increment is expected outside graph capturing.
116     TORCH_CHECK(
117         !capturing_,
118         "Offset increment outside graph capture encountered unexpectedly.");
119     // Ensures the offset is a multiple of 4
120     // see Note [Why enforce RNG offset % 4 == 0?]
121     TORCH_INTERNAL_ASSERT(
122         philox_offset_per_thread_ % 4 == 0,
123         "RNG offset must be a multiple of 4.");
124     philox_offset_per_thread_ += increment;
125   }
126 }
127 
128 /**
129  * Registers this state to a CUDA graph to manage within the graph.
130  */
register_graph(cuda::CUDAGraph * graph)131 void CUDAGeneratorState::register_graph(cuda::CUDAGraph* graph) {
132   // Ensures that the RNG state is not currently being captured.
133   at::cuda::assertNotCapturing(
134       "Cannot register the state during capturing stage.");
135 
136   // If this is the first graph to be registered, allocate memory for the seed
137   // and offset on the GPU.
138   if (registered_graphs_.empty()) {
139     auto options = at::TensorOptions().device(at::kCUDA).dtype(at::kLong);
140     seed_extragraph_ = at::empty({1}, options);
141     offset_extragraph_ = at::empty({1}, options);
142   }
143 
144   // Insert the graph into the set of registered graphs if it's not already
145   // registered.
146   if (registered_graphs_.find(graph) == registered_graphs_.end()) {
147     registered_graphs_.insert(graph);
148   }
149 }
150 
151 /**
152  * Unregisters a CUDA graph from the RNG state.
153  */
unregister_graph(cuda::CUDAGraph * graph)154 void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) {
155   // Verify the graph was previously registered.
156   TORCH_CHECK(
157       registered_graphs_.find(graph) != registered_graphs_.end(),
158       "The graph should be registered to the state");
159 
160   // Remove the graph from the set of registered graphs.
161   registered_graphs_.erase(graph);
162 
163   // If no more graphs are registered, deallocate the GPU memory for the seed
164   // and offset.
165   if (registered_graphs_.empty()) {
166     seed_extragraph_.reset();
167     offset_extragraph_.reset();
168   }
169 }
170 
171 /**
172  * Note [Explicit Registration of Generators to the CUDA Graph]
173  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
174  *
175  * Ideally, it would be more user-friendly if the state could be exchanged and generators
176  * could be registered with the CUDA graph implicitly. However, resetting GPU tensors during
177  * the capture stage causes these reset operations to be recorded within the CUDA graph.
178  * This behavior is undesirable because we do not want these tensors to be reset during
179  * the replay stage of the graph.
180  *
181  * As of now, there is no available method to perform a CUDA operation during the graph's
182  * recording phase without having that operation be included in the CUDA graph.
183  * This limitation necessitates explicit user action to register generators with the graph.
184  * By requiring users to manually register their generators, we can ensure that state resets
185  * (capture_prologue) only occur before the graph capture begins, thus avoiding unintended
186  * resets during the replay of the graph. See https://github.com/pytorch/pytorch/pull/114068.
187  */
188 
189 /**
190  * Performs the prologue steps for capturing a CUDA graph state.
191  * This method is intended to reset graph-related state variables before capturing begins.
192  */
capture_prologue()193 void CUDAGeneratorState::capture_prologue() {
194   capturing_ = true;
195   offset_intragraph_ = 0;
196   seed_extragraph_.fill_(int64_t(seed_));
197   offset_extragraph_.fill_(int64_t(0));
198 }
199 
200 /**
201  * Ends the capturing phase and resets related variables, returning the whole
202  * graph increment.
203  */
capture_epilogue()204 uint64_t CUDAGeneratorState::capture_epilogue() {
205   capturing_ = false;
206   return offset_intragraph_;
207 }
208 
209 /**
210  * Prepares the state for replay by setting initial state tensors and applying
211  * total increment.
212  */
replay_prologue(uint64_t wholegraph_increment)213 void CUDAGeneratorState::replay_prologue(uint64_t wholegraph_increment) {
214   // Ensures the generator is not in capturing mode.
215   at::cuda::assertNotCapturing(
216       "Cannot prepare for replay during capturing stage.");
217   seed_extragraph_.fill_(int64_t(seed_));
218   offset_extragraph_.fill_(int64_t(philox_offset_per_thread_));
219   // Applies the total increment achieved during previous captures to update the
220   // offset.
221   increase(wholegraph_increment);
222 }
223 
224 /**
225  * Note [Why enforce RNG offset % 4 == 0?]
226  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
227  * Curand philox does allow offsets that aren't a multiple of 4.
228  * But jit kernels don't use curand, they use a custom "Philox" class (see
229  * torch/csrc/jit/tensorexpr/cuda_random.h or
230  * torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu).
231  * The "Philox" constructor computes offset/4 (a uint64_t division) to locate its
232  * internal start in its virtual bitstream viewed as 128-bit chunks, then, when called
233  * in a thread, returns one 32-bit chunk at a time from that start in the bitstream.
234  * In other words, if the incoming offset is not a multiple of 4, each thread
235  * might repeat some previously-generated 32-bit values in the bitstream. See
236  * https://github.com/pytorch/pytorch/pull/50169.
237  */
238 
239 /**
240  * CUDAGeneratorImpl class implementation
241  */
CUDAGeneratorImpl(DeviceIndex device_index)242 CUDAGeneratorImpl::CUDAGeneratorImpl(DeviceIndex device_index)
243   : c10::GeneratorImpl{Device(DeviceType::CUDA, device_index),
244           DispatchKeySet(c10::DispatchKey::CUDA)} {
245   at::cuda::assertNotCapturing("Cannot construct a new CUDAGeneratorImpl");
246   state_ = make_intrusive<CUDAGeneratorState>();
247   no_reset_rnn_state_.clear();
248 }
249 
CUDAGeneratorImpl(DeviceIndex device_index,c10::intrusive_ptr<CUDAGeneratorState> state)250 CUDAGeneratorImpl::CUDAGeneratorImpl(
251     DeviceIndex device_index,
252     c10::intrusive_ptr<CUDAGeneratorState> state)
253     : c10::
254           GeneratorImpl{Device(DeviceType::CUDA, device_index), DispatchKeySet(c10::DispatchKey::CUDA)},
255       state_(std::move(state)) {
256   no_reset_rnn_state_.clear();
257 }
258 
259 /**
260  * Sets the seed to be used by curandStatePhilox4_32_10
261  * Resets the philox_offset_per_thread_ to 0
262  *
263  * See Note [Acquire lock when using random generators]
264  */
set_current_seed(uint64_t seed)265 void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
266   at::cuda::assertNotCapturing(
267       "Cannot call CUDAGeneratorImpl::set_current_seed");
268   state_->seed_ = seed;
269   state_->philox_offset_per_thread_ = 0;
270   no_reset_rnn_state_.clear();
271 }
272 
273 /**
274  * Sets the offset to be used by curandStatePhilox4_32_10
275  *
276  * See Note [Acquire lock when using random generators]
277  */
set_offset(uint64_t offset)278 void CUDAGeneratorImpl::set_offset(uint64_t offset) {
279   at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_offset");
280   // the set function checks if the offset is a multiple of 4.
281   set_philox_offset_per_thread(offset);
282   no_reset_rnn_state_.clear();
283 }
284 
285 /**
286  * Gets the current offset of CUDAGeneratorImpl.
287  */
get_offset() const288 uint64_t CUDAGeneratorImpl::get_offset() const {
289   // Debatable if get_offset() should be allowed in captured regions.
290   // Conservatively disallow it for now.
291   at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::get_offset");
292   return state_->philox_offset_per_thread_;
293 }
294 
295 /**
296  * Gets the current seed of CUDAGeneratorImpl.
297  */
current_seed() const298 uint64_t CUDAGeneratorImpl::current_seed() const {
299   // Debatable if current_seed() should be allowed in captured regions.
300   // Conservatively disallow it for now.
301   at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed");
302   return state_->seed_;
303 }
304 
305 /**
306  * Gets a nondeterministic random number from /dev/urandom or time,
307  * seeds the CPUGeneratorImpl with it and then returns that number.
308  *
309  * FIXME: You can move this function to Generator.cpp if the algorithm
310  * in getNonDeterministicRandom is unified for both CPU and CUDA
311  */
seed()312 uint64_t CUDAGeneratorImpl::seed() {
313   at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::seed");
314   auto random = c10::detail::getNonDeterministicRandom(true);
315   this->set_current_seed(random);
316   return random;
317 }
318 
319 /**
320  * Gets the current internal state of CUDAGeneratorImpl. The internal
321  * state is returned as a CPU byte tensor.
322  */
get_state() const323 c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
324   // The RNG state comprises the seed, and an offset used for Philox.
325   static const size_t seed_size = sizeof(uint64_t);
326   static const size_t offset_size = sizeof(int64_t);
327   static const size_t total_size = seed_size + offset_size;
328 
329   auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
330   auto rng_state = state_tensor.data_ptr<uint8_t>();
331   auto current_seed = this->current_seed();
332   auto offset = static_cast<int64_t>(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic<int64_t>
333   memcpy(rng_state, &current_seed, seed_size);
334   memcpy(rng_state + seed_size, &offset, offset_size);
335 
336   return state_tensor.getIntrusivePtr();
337 }
338 
339 /**
340  * Sets the internal state of CUDAGeneratorImpl. The new internal state
341  * must be a strided CPU byte tensor and have appropriate size. See
342  * comments of CUDAGeneratorImpl::state for information about the layout
343  * and size of the internal state.
344  */
set_state(const c10::TensorImpl & new_state)345 void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
346   at::cuda::assertNotCapturing(
347       "Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing.");
348   static const size_t seed_size = sizeof(uint64_t);
349   static const size_t offset_size = sizeof(int64_t);
350   static const size_t total_size = seed_size + offset_size;
351 
352   detail::check_rng_state(new_state);
353 
354   bool no_philox_seed = false;
355   auto new_state_size = new_state.numel();
356   if (new_state_size == total_size - offset_size) {
357     no_philox_seed = true;
358   } else {
359     TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
360   }
361 
362   uint64_t input_seed = 0;
363   auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
364   memcpy(&input_seed, new_rng_state, seed_size);
365   this->set_current_seed(input_seed);
366   int64_t philox_offset = 0;
367   if (!no_philox_seed) {
368     memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
369   }
370   this->set_philox_offset_per_thread(static_cast<uint64_t>(philox_offset));
371 }
372 
373 /**
374  * Sets the generator's current state to
375  * This function allows switching between different registered states of
376  * the generator.
377  */
graphsafe_set_state(const c10::intrusive_ptr<GeneratorImpl> & gen)378 void CUDAGeneratorImpl::graphsafe_set_state(
379     const c10::intrusive_ptr<GeneratorImpl>& gen) {
380   c10::intrusive_ptr<CUDAGeneratorImpl> cuda_gen =
381       dynamic_intrusive_pointer_cast<CUDAGeneratorImpl>(gen);
382   TORCH_CHECK(cuda_gen, "Expected a CUDA Generator");
383   state_ = cuda_gen->state_;
384 }
385 
386 /**
387  * Get the GeneratorImpl that point to current state_
388  */
graphsafe_get_state() const389 c10::intrusive_ptr<c10::GeneratorImpl> CUDAGeneratorImpl::graphsafe_get_state()
390     const {
391   auto gen = make_intrusive<CUDAGeneratorImpl>(device().index(), state_);
392   return gen;
393 }
394 
395 /**
396  * Sets the philox_offset_per_thread_ to be used by curandStatePhilox4_32_10
397  *
398  * See Note [Acquire lock when using random generators]
399  */
set_philox_offset_per_thread(uint64_t offset)400 void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
401   // see Note [Why enforce RNG offset % 4 == 0?]
402   TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
403   state_->philox_offset_per_thread_ = offset;
404 }
405 
406 /**
407  * Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl.
408  */
philox_offset_per_thread() const409 uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const {
410   return state_->philox_offset_per_thread_;
411 }
412 
413 /**
414  * Registers this state to a CUDA graph to manage within the graph.
415  */
register_graph(cuda::CUDAGraph * graph)416 void CUDAGeneratorImpl::register_graph(cuda::CUDAGraph* graph) {
417   graph->register_generator_state(state_);
418   state_->register_graph(graph);
419 }
420 
421 /**
422  * Unregisters a CUDA graph from the RNG state.
423  */
unregister_graph(cuda::CUDAGraph * graph)424 void CUDAGeneratorImpl::unregister_graph(cuda::CUDAGraph* graph) {
425   state_->unregister_graph(graph);
426 }
427 
428 /**
429  * Gets the seed and philox offset value to be used in
430  * curandStatePhilox4_32_10, in an opaque PhiloxCudaState that's safe
431  * and can be used non-divergently in callers whether CUDA graph
432  * capture is underway or not.  See
433  * Note [CUDA Graph-safe RNG states]
434  *
435  * Each kernel using philox has to sensibly increment offset
436  * for future users of philox. So it gets the "old" value for
437  * itself (before add), and tells subsequent users which offset
438  * they should use, since only the kernel knows how many randoms
439  * it intends to generate.
440  *
441  * Increment should be at least the number of curand() random numbers used in
442  * each thread. It is the user's responsibility to make sure the increment
443  * for philox is never smaller than the number of curand() calls. Increment
444  * value > the number of curand() calls won't harm but anything less would mean
445  * that you would be reusing random values from previous calls.
446  *
447  * See Note [Acquire lock when using random generators]
448  */
philox_cuda_state(uint64_t increment)449 PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) {
450   if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
451     uint32_t offset = state_->offset_intragraph_;
452     state_->increase(increment);
453     return PhiloxCudaState(
454         state_->seed_extragraph_.data_ptr<int64_t>(),
455         state_->offset_extragraph_.data_ptr<int64_t>(),
456         offset);
457   } else {
458     uint64_t offset = state_->philox_offset_per_thread_;
459     state_->increase(increment);
460     return PhiloxCudaState(state_->seed_, offset);
461   }
462 }
463 
464 /**
465  * Temporarily accommodates call sites that use philox_engine_inputs.
466  * Allows incremental refactor of call sites to use philox_cuda_state.
467  */
philox_engine_inputs(uint64_t increment)468 std::pair<uint64_t, uint64_t> CUDAGeneratorImpl::philox_engine_inputs(
469     uint64_t increment) {
470   at::cuda::assertNotCapturing(
471       "Refactor this op to use CUDAGeneratorImpl::philox_cuda_state. Cannot call CUDAGeneratorImpl::philox_engine_inputs");
472   uint64_t offset = state_->philox_offset_per_thread_;
473   state_->increase(increment);
474   return std::make_pair(state_->seed_, offset);
475 }
476 
477 /*
478  * Gets the DeviceType of CUDAGeneratorImpl.
479  * Used for type checking during run time.
480  */
device_type()481 DeviceType CUDAGeneratorImpl::device_type() {
482   return DeviceType::CUDA;
483 }
484 
485 /**
486  * Public clone method implementation
487  *
488  * See Note [Acquire lock when using random generators]
489  */
clone() const490 std::shared_ptr<CUDAGeneratorImpl> CUDAGeneratorImpl::clone() const {
491   return std::shared_ptr<CUDAGeneratorImpl>(this->clone_impl());
492 }
493 
494 /**
495  * Private clone method implementation
496  *
497  * See Note [Acquire lock when using random generators]
498  */
clone_impl() const499 CUDAGeneratorImpl* CUDAGeneratorImpl::clone_impl() const {
500   at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::clone_impl");
501   auto gen = new CUDAGeneratorImpl(this->device().index(), state_->clone());
502   return gen;
503 }
504 
505 } // namespace at
506