1 #include <ATen/Functions.h>
2 #include <ATen/Tensor.h>
3 #include <ATen/Utils.h>
4 #include <ATen/cuda/CUDAGeneratorImpl.h>
5 #include <ATen/cuda/CUDAGraph.h>
6 #include <ATen/cuda/CUDAGraphsUtils.cuh>
7 #include <c10/core/StreamGuard.h>
8 #include <c10/cuda/CUDAFunctions.h>
9 #include <c10/util/CallOnce.h>
10 #include <deque>
11
12 namespace at {
13 namespace cuda::detail {
14
15 namespace {
16
17 // Ensures we only call cudaGetDeviceCount only once.
18 static c10::once_flag num_gpu_init_flag;
19
20 // Total number of gpus in the system.
21 static int64_t num_gpus;
22
23 // Ensures default_gens_cuda is initialized once.
24 static std::deque<c10::once_flag> cuda_gens_init_flag;
25
26 // Default, global CUDA generators, one per GPU.
27 static std::vector<Generator> default_gens_cuda;
28
29 /*
30 * Populates the global variables related to CUDA generators
31 * Warning: this function must only be called once!
32 */
initCUDAGenVector()33 static void initCUDAGenVector() {
34 num_gpus = c10::cuda::device_count();
35 cuda_gens_init_flag.resize(num_gpus);
36 default_gens_cuda.resize(num_gpus);
37 }
38
39 } // anonymous namespace
40
41 /**
42 * PyTorch maintains a collection of default generators that get
43 * initialized once. The purpose of these default generators is to
44 * maintain a global running state of the pseudo random number generation,
45 * when a user does not explicitly mention any generator.
46 * getDefaultCUDAGenerator gets the default generator for a particular
47 * cuda device.
48 */
getDefaultCUDAGenerator(DeviceIndex device_index)49 const Generator& getDefaultCUDAGenerator(DeviceIndex device_index) {
50 c10::call_once(num_gpu_init_flag, initCUDAGenVector);
51 DeviceIndex idx = device_index;
52 if (idx == -1) {
53 idx = c10::cuda::current_device();
54 } else {
55 TORCH_CHECK(idx >= 0 && idx < num_gpus);
56 }
57 c10::call_once(cuda_gens_init_flag[idx], [&] {
58 default_gens_cuda[idx] = make_generator<CUDAGeneratorImpl>(idx);
59 default_gens_cuda[idx].seed();
60 });
61 return default_gens_cuda[idx];
62 }
63
64 /**
65 * Utility to create a CUDAGeneratorImpl. Returns a shared_ptr
66 */
createCUDAGenerator(DeviceIndex device_index)67 Generator createCUDAGenerator(DeviceIndex device_index) {
68 c10::call_once(num_gpu_init_flag, initCUDAGenVector);
69 DeviceIndex idx = device_index;
70 if (idx == -1) {
71 idx = c10::cuda::current_device();
72 }
73 TORCH_CHECK(idx >= 0 && idx < num_gpus, "The device_index is invalid.");
74 auto gen = make_generator<CUDAGeneratorImpl>(idx);
75 auto cuda_gen = check_generator<CUDAGeneratorImpl>(gen);
76 cuda_gen->set_current_seed(default_rng_seed_val);
77 cuda_gen->set_philox_offset_per_thread(0);
78 return gen;
79 }
80
81 } // namespace cuda::detail
82
83 /**
84 * Creates a clone of this CUDA Generator State.
85 */
clone()86 c10::intrusive_ptr<CUDAGeneratorState> CUDAGeneratorState::clone() {
87 return make_intrusive<CUDAGeneratorState>(
88 seed_, philox_offset_per_thread_, offset_intragraph_);
89 }
90
91 /**
92 * Function to increase the internal offset based on the specified increment.
93 */
increase(uint64_t increment)94 void CUDAGeneratorState::increase(uint64_t increment) {
95 // Rounds increment up to the nearest multiple of 4 to meet alignment
96 // requirements.
97 // see Note [Why enforce RNG offset % 4 == 0?]
98 increment = ((increment + 3) / 4) * 4;
99 // Handling different behaviors based on whether capturing is active.
100 if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
101 // Ensures that the state is actually capturing.
102 TORCH_CHECK(
103 capturing_,
104 "Attempt to increase offset for a CUDA generator not in capture mode.");
105 // Ensures the offset is a multiple of 4
106 // see Note [Why enforce RNG offset % 4 == 0?]
107 TORCH_INTERNAL_ASSERT(
108 offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4.");
109 // Ensures the increment does not cause overflow.
110 TORCH_INTERNAL_ASSERT(
111 offset_intragraph_ <= std::numeric_limits<uint32_t>::max() - increment,
112 "Increment causes overflow in the offset value.");
113 offset_intragraph_ += increment;
114 } else {
115 // Checks that the increment is expected outside graph capturing.
116 TORCH_CHECK(
117 !capturing_,
118 "Offset increment outside graph capture encountered unexpectedly.");
119 // Ensures the offset is a multiple of 4
120 // see Note [Why enforce RNG offset % 4 == 0?]
121 TORCH_INTERNAL_ASSERT(
122 philox_offset_per_thread_ % 4 == 0,
123 "RNG offset must be a multiple of 4.");
124 philox_offset_per_thread_ += increment;
125 }
126 }
127
128 /**
129 * Registers this state to a CUDA graph to manage within the graph.
130 */
register_graph(cuda::CUDAGraph * graph)131 void CUDAGeneratorState::register_graph(cuda::CUDAGraph* graph) {
132 // Ensures that the RNG state is not currently being captured.
133 at::cuda::assertNotCapturing(
134 "Cannot register the state during capturing stage.");
135
136 // If this is the first graph to be registered, allocate memory for the seed
137 // and offset on the GPU.
138 if (registered_graphs_.empty()) {
139 auto options = at::TensorOptions().device(at::kCUDA).dtype(at::kLong);
140 seed_extragraph_ = at::empty({1}, options);
141 offset_extragraph_ = at::empty({1}, options);
142 }
143
144 // Insert the graph into the set of registered graphs if it's not already
145 // registered.
146 if (registered_graphs_.find(graph) == registered_graphs_.end()) {
147 registered_graphs_.insert(graph);
148 }
149 }
150
151 /**
152 * Unregisters a CUDA graph from the RNG state.
153 */
unregister_graph(cuda::CUDAGraph * graph)154 void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) {
155 // Verify the graph was previously registered.
156 TORCH_CHECK(
157 registered_graphs_.find(graph) != registered_graphs_.end(),
158 "The graph should be registered to the state");
159
160 // Remove the graph from the set of registered graphs.
161 registered_graphs_.erase(graph);
162
163 // If no more graphs are registered, deallocate the GPU memory for the seed
164 // and offset.
165 if (registered_graphs_.empty()) {
166 seed_extragraph_.reset();
167 offset_extragraph_.reset();
168 }
169 }
170
171 /**
172 * Note [Explicit Registration of Generators to the CUDA Graph]
173 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
174 *
175 * Ideally, it would be more user-friendly if the state could be exchanged and generators
176 * could be registered with the CUDA graph implicitly. However, resetting GPU tensors during
177 * the capture stage causes these reset operations to be recorded within the CUDA graph.
178 * This behavior is undesirable because we do not want these tensors to be reset during
179 * the replay stage of the graph.
180 *
181 * As of now, there is no available method to perform a CUDA operation during the graph's
182 * recording phase without having that operation be included in the CUDA graph.
183 * This limitation necessitates explicit user action to register generators with the graph.
184 * By requiring users to manually register their generators, we can ensure that state resets
185 * (capture_prologue) only occur before the graph capture begins, thus avoiding unintended
186 * resets during the replay of the graph. See https://github.com/pytorch/pytorch/pull/114068.
187 */
188
189 /**
190 * Performs the prologue steps for capturing a CUDA graph state.
191 * This method is intended to reset graph-related state variables before capturing begins.
192 */
capture_prologue()193 void CUDAGeneratorState::capture_prologue() {
194 capturing_ = true;
195 offset_intragraph_ = 0;
196 seed_extragraph_.fill_(int64_t(seed_));
197 offset_extragraph_.fill_(int64_t(0));
198 }
199
200 /**
201 * Ends the capturing phase and resets related variables, returning the whole
202 * graph increment.
203 */
capture_epilogue()204 uint64_t CUDAGeneratorState::capture_epilogue() {
205 capturing_ = false;
206 return offset_intragraph_;
207 }
208
209 /**
210 * Prepares the state for replay by setting initial state tensors and applying
211 * total increment.
212 */
replay_prologue(uint64_t wholegraph_increment)213 void CUDAGeneratorState::replay_prologue(uint64_t wholegraph_increment) {
214 // Ensures the generator is not in capturing mode.
215 at::cuda::assertNotCapturing(
216 "Cannot prepare for replay during capturing stage.");
217 seed_extragraph_.fill_(int64_t(seed_));
218 offset_extragraph_.fill_(int64_t(philox_offset_per_thread_));
219 // Applies the total increment achieved during previous captures to update the
220 // offset.
221 increase(wholegraph_increment);
222 }
223
224 /**
225 * Note [Why enforce RNG offset % 4 == 0?]
226 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
227 * Curand philox does allow offsets that aren't a multiple of 4.
228 * But jit kernels don't use curand, they use a custom "Philox" class (see
229 * torch/csrc/jit/tensorexpr/cuda_random.h or
230 * torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu).
231 * The "Philox" constructor computes offset/4 (a uint64_t division) to locate its
232 * internal start in its virtual bitstream viewed as 128-bit chunks, then, when called
233 * in a thread, returns one 32-bit chunk at a time from that start in the bitstream.
234 * In other words, if the incoming offset is not a multiple of 4, each thread
235 * might repeat some previously-generated 32-bit values in the bitstream. See
236 * https://github.com/pytorch/pytorch/pull/50169.
237 */
238
239 /**
240 * CUDAGeneratorImpl class implementation
241 */
CUDAGeneratorImpl(DeviceIndex device_index)242 CUDAGeneratorImpl::CUDAGeneratorImpl(DeviceIndex device_index)
243 : c10::GeneratorImpl{Device(DeviceType::CUDA, device_index),
244 DispatchKeySet(c10::DispatchKey::CUDA)} {
245 at::cuda::assertNotCapturing("Cannot construct a new CUDAGeneratorImpl");
246 state_ = make_intrusive<CUDAGeneratorState>();
247 no_reset_rnn_state_.clear();
248 }
249
CUDAGeneratorImpl(DeviceIndex device_index,c10::intrusive_ptr<CUDAGeneratorState> state)250 CUDAGeneratorImpl::CUDAGeneratorImpl(
251 DeviceIndex device_index,
252 c10::intrusive_ptr<CUDAGeneratorState> state)
253 : c10::
254 GeneratorImpl{Device(DeviceType::CUDA, device_index), DispatchKeySet(c10::DispatchKey::CUDA)},
255 state_(std::move(state)) {
256 no_reset_rnn_state_.clear();
257 }
258
259 /**
260 * Sets the seed to be used by curandStatePhilox4_32_10
261 * Resets the philox_offset_per_thread_ to 0
262 *
263 * See Note [Acquire lock when using random generators]
264 */
set_current_seed(uint64_t seed)265 void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
266 at::cuda::assertNotCapturing(
267 "Cannot call CUDAGeneratorImpl::set_current_seed");
268 state_->seed_ = seed;
269 state_->philox_offset_per_thread_ = 0;
270 no_reset_rnn_state_.clear();
271 }
272
273 /**
274 * Sets the offset to be used by curandStatePhilox4_32_10
275 *
276 * See Note [Acquire lock when using random generators]
277 */
set_offset(uint64_t offset)278 void CUDAGeneratorImpl::set_offset(uint64_t offset) {
279 at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_offset");
280 // the set function checks if the offset is a multiple of 4.
281 set_philox_offset_per_thread(offset);
282 no_reset_rnn_state_.clear();
283 }
284
285 /**
286 * Gets the current offset of CUDAGeneratorImpl.
287 */
get_offset() const288 uint64_t CUDAGeneratorImpl::get_offset() const {
289 // Debatable if get_offset() should be allowed in captured regions.
290 // Conservatively disallow it for now.
291 at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::get_offset");
292 return state_->philox_offset_per_thread_;
293 }
294
295 /**
296 * Gets the current seed of CUDAGeneratorImpl.
297 */
current_seed() const298 uint64_t CUDAGeneratorImpl::current_seed() const {
299 // Debatable if current_seed() should be allowed in captured regions.
300 // Conservatively disallow it for now.
301 at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed");
302 return state_->seed_;
303 }
304
305 /**
306 * Gets a nondeterministic random number from /dev/urandom or time,
307 * seeds the CPUGeneratorImpl with it and then returns that number.
308 *
309 * FIXME: You can move this function to Generator.cpp if the algorithm
310 * in getNonDeterministicRandom is unified for both CPU and CUDA
311 */
seed()312 uint64_t CUDAGeneratorImpl::seed() {
313 at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::seed");
314 auto random = c10::detail::getNonDeterministicRandom(true);
315 this->set_current_seed(random);
316 return random;
317 }
318
319 /**
320 * Gets the current internal state of CUDAGeneratorImpl. The internal
321 * state is returned as a CPU byte tensor.
322 */
get_state() const323 c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
324 // The RNG state comprises the seed, and an offset used for Philox.
325 static const size_t seed_size = sizeof(uint64_t);
326 static const size_t offset_size = sizeof(int64_t);
327 static const size_t total_size = seed_size + offset_size;
328
329 auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
330 auto rng_state = state_tensor.data_ptr<uint8_t>();
331 auto current_seed = this->current_seed();
332 auto offset = static_cast<int64_t>(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic<int64_t>
333 memcpy(rng_state, ¤t_seed, seed_size);
334 memcpy(rng_state + seed_size, &offset, offset_size);
335
336 return state_tensor.getIntrusivePtr();
337 }
338
339 /**
340 * Sets the internal state of CUDAGeneratorImpl. The new internal state
341 * must be a strided CPU byte tensor and have appropriate size. See
342 * comments of CUDAGeneratorImpl::state for information about the layout
343 * and size of the internal state.
344 */
set_state(const c10::TensorImpl & new_state)345 void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
346 at::cuda::assertNotCapturing(
347 "Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing.");
348 static const size_t seed_size = sizeof(uint64_t);
349 static const size_t offset_size = sizeof(int64_t);
350 static const size_t total_size = seed_size + offset_size;
351
352 detail::check_rng_state(new_state);
353
354 bool no_philox_seed = false;
355 auto new_state_size = new_state.numel();
356 if (new_state_size == total_size - offset_size) {
357 no_philox_seed = true;
358 } else {
359 TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
360 }
361
362 uint64_t input_seed = 0;
363 auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
364 memcpy(&input_seed, new_rng_state, seed_size);
365 this->set_current_seed(input_seed);
366 int64_t philox_offset = 0;
367 if (!no_philox_seed) {
368 memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
369 }
370 this->set_philox_offset_per_thread(static_cast<uint64_t>(philox_offset));
371 }
372
373 /**
374 * Sets the generator's current state to
375 * This function allows switching between different registered states of
376 * the generator.
377 */
graphsafe_set_state(const c10::intrusive_ptr<GeneratorImpl> & gen)378 void CUDAGeneratorImpl::graphsafe_set_state(
379 const c10::intrusive_ptr<GeneratorImpl>& gen) {
380 c10::intrusive_ptr<CUDAGeneratorImpl> cuda_gen =
381 dynamic_intrusive_pointer_cast<CUDAGeneratorImpl>(gen);
382 TORCH_CHECK(cuda_gen, "Expected a CUDA Generator");
383 state_ = cuda_gen->state_;
384 }
385
386 /**
387 * Get the GeneratorImpl that point to current state_
388 */
graphsafe_get_state() const389 c10::intrusive_ptr<c10::GeneratorImpl> CUDAGeneratorImpl::graphsafe_get_state()
390 const {
391 auto gen = make_intrusive<CUDAGeneratorImpl>(device().index(), state_);
392 return gen;
393 }
394
395 /**
396 * Sets the philox_offset_per_thread_ to be used by curandStatePhilox4_32_10
397 *
398 * See Note [Acquire lock when using random generators]
399 */
set_philox_offset_per_thread(uint64_t offset)400 void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
401 // see Note [Why enforce RNG offset % 4 == 0?]
402 TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
403 state_->philox_offset_per_thread_ = offset;
404 }
405
406 /**
407 * Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl.
408 */
philox_offset_per_thread() const409 uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const {
410 return state_->philox_offset_per_thread_;
411 }
412
413 /**
414 * Registers this state to a CUDA graph to manage within the graph.
415 */
register_graph(cuda::CUDAGraph * graph)416 void CUDAGeneratorImpl::register_graph(cuda::CUDAGraph* graph) {
417 graph->register_generator_state(state_);
418 state_->register_graph(graph);
419 }
420
421 /**
422 * Unregisters a CUDA graph from the RNG state.
423 */
unregister_graph(cuda::CUDAGraph * graph)424 void CUDAGeneratorImpl::unregister_graph(cuda::CUDAGraph* graph) {
425 state_->unregister_graph(graph);
426 }
427
428 /**
429 * Gets the seed and philox offset value to be used in
430 * curandStatePhilox4_32_10, in an opaque PhiloxCudaState that's safe
431 * and can be used non-divergently in callers whether CUDA graph
432 * capture is underway or not. See
433 * Note [CUDA Graph-safe RNG states]
434 *
435 * Each kernel using philox has to sensibly increment offset
436 * for future users of philox. So it gets the "old" value for
437 * itself (before add), and tells subsequent users which offset
438 * they should use, since only the kernel knows how many randoms
439 * it intends to generate.
440 *
441 * Increment should be at least the number of curand() random numbers used in
442 * each thread. It is the user's responsibility to make sure the increment
443 * for philox is never smaller than the number of curand() calls. Increment
444 * value > the number of curand() calls won't harm but anything less would mean
445 * that you would be reusing random values from previous calls.
446 *
447 * See Note [Acquire lock when using random generators]
448 */
philox_cuda_state(uint64_t increment)449 PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) {
450 if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
451 uint32_t offset = state_->offset_intragraph_;
452 state_->increase(increment);
453 return PhiloxCudaState(
454 state_->seed_extragraph_.data_ptr<int64_t>(),
455 state_->offset_extragraph_.data_ptr<int64_t>(),
456 offset);
457 } else {
458 uint64_t offset = state_->philox_offset_per_thread_;
459 state_->increase(increment);
460 return PhiloxCudaState(state_->seed_, offset);
461 }
462 }
463
464 /**
465 * Temporarily accommodates call sites that use philox_engine_inputs.
466 * Allows incremental refactor of call sites to use philox_cuda_state.
467 */
philox_engine_inputs(uint64_t increment)468 std::pair<uint64_t, uint64_t> CUDAGeneratorImpl::philox_engine_inputs(
469 uint64_t increment) {
470 at::cuda::assertNotCapturing(
471 "Refactor this op to use CUDAGeneratorImpl::philox_cuda_state. Cannot call CUDAGeneratorImpl::philox_engine_inputs");
472 uint64_t offset = state_->philox_offset_per_thread_;
473 state_->increase(increment);
474 return std::make_pair(state_->seed_, offset);
475 }
476
477 /*
478 * Gets the DeviceType of CUDAGeneratorImpl.
479 * Used for type checking during run time.
480 */
device_type()481 DeviceType CUDAGeneratorImpl::device_type() {
482 return DeviceType::CUDA;
483 }
484
485 /**
486 * Public clone method implementation
487 *
488 * See Note [Acquire lock when using random generators]
489 */
clone() const490 std::shared_ptr<CUDAGeneratorImpl> CUDAGeneratorImpl::clone() const {
491 return std::shared_ptr<CUDAGeneratorImpl>(this->clone_impl());
492 }
493
494 /**
495 * Private clone method implementation
496 *
497 * See Note [Acquire lock when using random generators]
498 */
clone_impl() const499 CUDAGeneratorImpl* CUDAGeneratorImpl::clone_impl() const {
500 at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::clone_impl");
501 auto gen = new CUDAGeneratorImpl(this->device().index(), state_->clone());
502 return gen;
503 }
504
505 } // namespace at
506