xref: /aosp_15_r20/external/pytorch/c10/core/GeneratorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <mutex>
5 
6 #include <c10/core/Device.h>
7 #include <c10/core/DispatchKeySet.h>
8 #include <c10/core/TensorImpl.h>
9 #include <c10/macros/Export.h>
10 #include <c10/util/intrusive_ptr.h>
11 #include <c10/util/python_stub.h>
12 
13 /**
14  * Note [Generator]
15  * ~~~~~~~~~~~~~~~~
16  * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm
17  * to generate a seemingly random sequence of numbers, that may be later be used
18  * in creating a random distribution. Such an engine almost always maintains a
19  * state and requires a seed to start off the creation of random numbers. Often
20  * times, users have found it beneficial to be able to explicitly create,
21  * retain, and destroy PRNG states and also be able to have control over the
22  * seed value.
23  *
24  * A Generator in ATen gives users the ability to read, write and modify a PRNG
25  * engine. For instance, it does so by letting users seed a PRNG engine, fork
26  * the state of the engine, etc.
27  *
28  * By default, there is one generator per device, and a device's generator is
29  * lazily created. A user can use the torch.Generator() api to create their own
30  * generator. Currently torch.Generator() can only create a CPUGeneratorImpl.
31  */
32 
33 /**
34  * Note [Acquire lock when using random generators]
35  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
36  * Generator and its derived classes are NOT thread-safe. Please note that most
37  * of the places where we have inserted locking for generators are historically
38  * based, and we haven't actually checked that everything is truly thread safe
39  * (and it probably isn't). Please use the public mutex_ when using any methods
40  * from these classes, except for the read-only methods. You can learn about the
41  * usage by looking into the unittests (aten/src/ATen/cpu_generator_test.cpp)
42  * and other places where we have used lock_guard.
43  *
44  * TODO: Look into changing the threading semantics of Generators in ATen (e.g.,
45  * making them non-thread safe and instead making the generator state
46  * splittable, to accommodate forks into other threads).
47  */
48 
49 namespace c10 {
50 
51 // The default seed is selected to be a large number
52 // with good distribution of 0s and 1s in bit representation
53 constexpr uint64_t default_rng_seed_val = 67280421310721;
54 
55 struct C10_API GeneratorImpl : public c10::intrusive_ptr_target {
56   // Constructors
57   GeneratorImpl(Device device_in, DispatchKeySet key_set);
58 
59   // Delete all copy and move assignment in favor of clone()
60   // method
61   GeneratorImpl(const GeneratorImpl& other) = delete;
62   GeneratorImpl(GeneratorImpl&& other) = delete;
63   GeneratorImpl& operator=(const GeneratorImpl& other) = delete;
64 
65   ~GeneratorImpl() override = default;
66   c10::intrusive_ptr<GeneratorImpl> clone() const;
67 
68   // Common methods for all generators
69   virtual void set_current_seed(uint64_t seed) = 0;
70   virtual void set_offset(uint64_t offset) = 0;
71   virtual uint64_t get_offset() const = 0;
72   virtual uint64_t current_seed() const = 0;
73   virtual uint64_t seed() = 0;
74   virtual void set_state(const c10::TensorImpl& new_state) = 0;
75   virtual c10::intrusive_ptr<c10::TensorImpl> get_state() const = 0;
76   virtual void graphsafe_set_state(
77       const c10::intrusive_ptr<c10::GeneratorImpl>& new_state);
78   virtual c10::intrusive_ptr<c10::GeneratorImpl> graphsafe_get_state() const;
79   Device device() const;
80 
81   // See Note [Acquire lock when using random generators]
82   std::mutex mutex_;
83 
key_setGeneratorImpl84   DispatchKeySet key_set() const {
85     return key_set_;
86   }
87 
set_pyobjGeneratorImpl88   inline void set_pyobj(PyObject* pyobj) noexcept {
89     pyobj_ = pyobj;
90   }
91 
pyobjGeneratorImpl92   inline PyObject* pyobj() const noexcept {
93     return pyobj_;
94   }
95 
96  protected:
97   Device device_;
98   DispatchKeySet key_set_;
99   PyObject* pyobj_ = nullptr;
100 
101   virtual GeneratorImpl* clone_impl() const = 0;
102 };
103 
104 namespace detail {
105 
106 C10_API uint64_t getNonDeterministicRandom(bool is_cuda = false);
107 
108 } // namespace detail
109 
110 } // namespace c10
111