xref: /aosp_15_r20/external/pytorch/aten/src/ATen/CPUGeneratorImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/CPUGeneratorImpl.h>
2 #include <ATen/Utils.h>
3 #include <ATen/core/MT19937RNGEngine.h>
4 #include <c10/util/MathConstants.h>
5 #include <algorithm>
6 
7 namespace at {
8 
9 namespace detail {
10 
11 /**
12  * CPUGeneratorImplStateLegacy is a POD class needed for memcpys
13  * in torch.get_rng_state() and torch.set_rng_state().
14  * It is a legacy class and even though it is replaced with
15  * at::CPUGeneratorImpl, we need this class and some of its fields
16  * to support backward compatibility on loading checkpoints.
17  */
18 struct CPUGeneratorImplStateLegacy {
19   /* The initial seed. */
20   uint64_t the_initial_seed;
21   int left;  /* = 1; */
22   int seeded; /* = 0; */
23   uint64_t next;
24   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
25   uint64_t state[at::MERSENNE_STATE_N]; /* the array for the state vector  */
26 
27   /********************************/
28 
29   /* For normal distribution */
30   double normal_x;
31   double normal_y;
32   double normal_rho;
33   int normal_is_valid; /* = 0; */
34 };
35 
36 /**
37  * CPUGeneratorImplState is a POD class containing
38  * new data introduced in at::CPUGeneratorImpl and the legacy state. It is used
39  * as a helper for torch.get_rng_state() and torch.set_rng_state()
40  * functions.
41  */
42 struct CPUGeneratorImplState {
43   CPUGeneratorImplStateLegacy legacy_pod;
44   float next_float_normal_sample;
45   bool is_next_float_normal_sample_valid;
46 };
47 
48 /**
49  * PyTorch maintains a collection of default generators that get
50  * initialized once. The purpose of these default generators is to
51  * maintain a global running state of the pseudo random number generation,
52  * when a user does not explicitly mention any generator.
53  * getDefaultCPUGenerator gets the default generator for a particular
54  * device.
55  */
getDefaultCPUGenerator()56 const Generator& getDefaultCPUGenerator() {
57   static auto default_gen_cpu = createCPUGenerator(c10::detail::getNonDeterministicRandom());
58   return default_gen_cpu;
59 }
60 
61 /**
62  * Utility to create a CPUGeneratorImpl. Returns a shared_ptr
63  */
createCPUGenerator(uint64_t seed_val)64 Generator createCPUGenerator(uint64_t seed_val) {
65   return make_generator<CPUGeneratorImpl>(seed_val);
66 }
67 
68 /**
69  * Helper function to concatenate two 32 bit unsigned int
70  * and return them as a 64 bit unsigned int
71  */
make64BitsFrom32Bits(uint32_t hi,uint32_t lo)72 inline uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) {
73   return (static_cast<uint64_t>(hi) << 32) | lo;
74 }
75 
76 } // namespace detail
77 
78 /**
79  * CPUGeneratorImpl class implementation
80  */
CPUGeneratorImpl(uint64_t seed_in)81 CPUGeneratorImpl::CPUGeneratorImpl(uint64_t seed_in)
82   : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(c10::DispatchKey::CPU)},
83     engine_{seed_in},
84     next_float_normal_sample_{std::optional<float>()},
85     next_double_normal_sample_{std::optional<double>()} { }
86 
87 /**
88  * Manually seeds the engine with the seed input
89  * See Note [Acquire lock when using random generators]
90  */
set_current_seed(uint64_t seed)91 void CPUGeneratorImpl::set_current_seed(uint64_t seed) {
92   next_float_normal_sample_.reset();
93   next_double_normal_sample_.reset();
94   engine_ = mt19937(seed);
95 }
96 
97 /**
98  * Sets the offset of RNG state.
99  * See Note [Acquire lock when using random generators]
100  */
set_offset(uint64_t offset)101 void CPUGeneratorImpl::set_offset(uint64_t offset [[maybe_unused]]) {
102   TORCH_CHECK(false, "CPU Generator does not use offset");
103 }
104 
105 /**
106  * Gets the current offset of CPUGeneratorImpl.
107  */
get_offset() const108 uint64_t CPUGeneratorImpl::get_offset() const {
109   TORCH_CHECK(false, "CPU Generator does not use offset");
110 }
111 
112 /**
113  * Gets the current seed of CPUGeneratorImpl.
114  */
current_seed() const115 uint64_t CPUGeneratorImpl::current_seed() const {
116   return engine_.seed();
117 }
118 
119 /**
120  * Gets a nondeterministic random number from /dev/urandom or time,
121  * seeds the CPUGeneratorImpl with it and then returns that number.
122  *
123  * FIXME: You can move this function to Generator.cpp if the algorithm
124  * in getNonDeterministicRandom is unified for both CPU and CUDA
125  */
seed()126 uint64_t CPUGeneratorImpl::seed() {
127   auto random = c10::detail::getNonDeterministicRandom();
128   this->set_current_seed(random);
129   return random;
130 }
131 
132 /**
133  * Sets the internal state of CPUGeneratorImpl. The new internal state
134  * must be a strided CPU byte tensor and of the same size as either
135  * CPUGeneratorImplStateLegacy (for legacy CPU generator state) or
136  * CPUGeneratorImplState (for new state).
137  *
138  * FIXME: Remove support of the legacy state in the future?
139  */
set_state(const c10::TensorImpl & new_state)140 void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
141   using detail::CPUGeneratorImplState;
142   using detail::CPUGeneratorImplStateLegacy;
143 
144   static_assert(std::is_standard_layout_v<CPUGeneratorImplStateLegacy>, "CPUGeneratorImplStateLegacy is not a PODType");
145   static_assert(std::is_standard_layout_v<CPUGeneratorImplState>, "CPUGeneratorImplState is not a PODType");
146 
147   static const size_t size_legacy = sizeof(CPUGeneratorImplStateLegacy);
148   static const size_t size_current = sizeof(CPUGeneratorImplState);
149   static_assert(size_legacy != size_current, "CPUGeneratorImplStateLegacy and CPUGeneratorImplState can't be of the same size");
150 
151   detail::check_rng_state(new_state);
152 
153   at::mt19937 engine;
154   auto float_normal_sample = std::optional<float>();
155   auto double_normal_sample = std::optional<double>();
156 
157   // Construct the state of at::CPUGeneratorImpl based on input byte tensor size.
158   CPUGeneratorImplStateLegacy* legacy_pod{nullptr};
159   auto new_state_size = new_state.numel();
160   if (new_state_size == size_legacy) {
161     legacy_pod = (CPUGeneratorImplStateLegacy*)new_state.data();
162     // Note that in CPUGeneratorImplStateLegacy, we didn't have float version
163     // of normal sample and hence we leave the std::optional<float> as is
164 
165     // Update next_double_normal_sample.
166     // Note that CPUGeneratorImplStateLegacy stores two uniform values (normal_x, normal_y)
167     // and a rho value (normal_rho). These three values were redundant and in the new
168     // DistributionsHelper.h, we store the actual extra normal sample, rather than three
169     // intermediate values.
170     if (legacy_pod->normal_is_valid) {
171       auto r = legacy_pod->normal_rho;
172       auto theta = 2.0 * c10::pi<double> * legacy_pod->normal_x;
173       // we return the sin version of the normal sample when in caching mode
174       double_normal_sample = std::optional<double>(r * ::sin(theta));
175     }
176   } else if (new_state_size == size_current) {
177     auto rng_state = (CPUGeneratorImplState*)new_state.data();
178     legacy_pod = &rng_state->legacy_pod;
179     // update next_float_normal_sample
180     if (rng_state->is_next_float_normal_sample_valid) {
181       float_normal_sample = std::optional<float>(rng_state->next_float_normal_sample);
182     }
183 
184     // Update next_double_normal_sample.
185     // Note that in getRNGState, we now return the actual normal sample in normal_y
186     // and if it's valid in normal_is_valid. The redundant normal_x and normal_rho
187     // are squashed to 0.0.
188     if (legacy_pod->normal_is_valid) {
189       double_normal_sample = std::optional<double>(legacy_pod->normal_y);
190     }
191   } else {
192     AT_ERROR("Expected either a CPUGeneratorImplStateLegacy of size ", size_legacy,
193              " or a CPUGeneratorImplState of size ", size_current,
194              " but found the input RNG state size to be ", new_state_size);
195   }
196 
197   // construct engine_
198   // Note that CPUGeneratorImplStateLegacy stored a state array of 64 bit uints, whereas in our
199   // redefined mt19937, we have changed to a state array of 32 bit uints. Hence, we are
200   // doing a std::copy.
201   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
202   at::mt19937_data_pod rng_data;
203   std::copy(std::begin(legacy_pod->state), std::end(legacy_pod->state), rng_data.state_.begin());
204   rng_data.seed_ = legacy_pod->the_initial_seed;
205   rng_data.left_ = legacy_pod->left;
206   rng_data.seeded_ = legacy_pod->seeded;
207   rng_data.next_ = static_cast<uint32_t>(legacy_pod->next);
208   engine.set_data(rng_data);
209   TORCH_CHECK(engine.is_valid(), "Invalid mt19937 state");
210   this->engine_ = engine;
211   this->next_float_normal_sample_ = float_normal_sample;
212   this->next_double_normal_sample_ = double_normal_sample;
213 }
214 
215 /**
216  * Gets the current internal state of CPUGeneratorImpl. The internal
217  * state is returned as a CPU byte tensor.
218  */
get_state() const219 c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
220   using detail::CPUGeneratorImplState;
221 
222   static const size_t size = sizeof(CPUGeneratorImplState);
223   static_assert(std::is_standard_layout_v<CPUGeneratorImplState>, "CPUGeneratorImplState is not a PODType");
224 
225   auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
226   auto rng_state = state_tensor.data_ptr();
227 
228   // accumulate generator data to be copied into byte tensor
229   auto accum_state = std::make_unique<CPUGeneratorImplState>();
230   auto rng_data = this->engine_.data();
231   accum_state->legacy_pod.the_initial_seed = rng_data.seed_;
232   accum_state->legacy_pod.left = rng_data.left_;
233   accum_state->legacy_pod.seeded = rng_data.seeded_;
234   accum_state->legacy_pod.next = rng_data.next_;
235   std::copy(rng_data.state_.begin(), rng_data.state_.end(), std::begin(accum_state->legacy_pod.state));
236   accum_state->legacy_pod.normal_x = 0.0; // we don't use it anymore and this is just a dummy
237   accum_state->legacy_pod.normal_rho = 0.0; // we don't use it anymore and this is just a dummy
238   accum_state->legacy_pod.normal_is_valid = false;
239   accum_state->legacy_pod.normal_y = 0.0;
240   accum_state->next_float_normal_sample = 0.0f;
241   accum_state->is_next_float_normal_sample_valid = false;
242   if (this->next_double_normal_sample_) {
243     accum_state->legacy_pod.normal_is_valid = true;
244     accum_state->legacy_pod.normal_y = *(this->next_double_normal_sample_);
245   }
246   if (this->next_float_normal_sample_) {
247     accum_state->is_next_float_normal_sample_valid = true;
248     accum_state->next_float_normal_sample = *(this->next_float_normal_sample_);
249   }
250 
251   memcpy(rng_state, accum_state.get(), size);
252   return state_tensor.getIntrusivePtr();
253 }
254 
255 /**
256  * Gets the DeviceType of CPUGeneratorImpl.
257  * Used for type checking during run time.
258  */
device_type()259 DeviceType CPUGeneratorImpl::device_type() {
260   return DeviceType::CPU;
261 }
262 
263 /**
264  * Gets a random 32 bit unsigned integer from the engine
265  *
266  * See Note [Acquire lock when using random generators]
267  */
random()268 uint32_t CPUGeneratorImpl::random() {
269   return engine_();
270 }
271 
272 /**
273  * Gets a random 64 bit unsigned integer from the engine
274  *
275  * See Note [Acquire lock when using random generators]
276  */
random64()277 uint64_t CPUGeneratorImpl::random64() {
278   uint32_t random1 = engine_();
279   uint32_t random2 = engine_();
280   return detail::make64BitsFrom32Bits(random1, random2);
281 }
282 
283 /**
284  * Get the cached normal random in float
285  */
next_float_normal_sample()286 std::optional<float> CPUGeneratorImpl::next_float_normal_sample() {
287   return next_float_normal_sample_;
288 }
289 
290 /**
291  * Get the cached normal random in double
292  */
next_double_normal_sample()293 std::optional<double> CPUGeneratorImpl::next_double_normal_sample() {
294   return next_double_normal_sample_;
295 }
296 
297 /**
298  * Cache normal random in float
299  *
300  * See Note [Acquire lock when using random generators]
301  */
set_next_float_normal_sample(std::optional<float> randn)302 void CPUGeneratorImpl::set_next_float_normal_sample(std::optional<float> randn) {
303   next_float_normal_sample_ = randn;
304 }
305 
306 /**
307  * Cache normal random in double
308  *
309  * See Note [Acquire lock when using random generators]
310  */
set_next_double_normal_sample(std::optional<double> randn)311 void CPUGeneratorImpl::set_next_double_normal_sample(std::optional<double> randn) {
312   next_double_normal_sample_ = randn;
313 }
314 
315 /**
316  * Get the engine of the CPUGeneratorImpl
317  */
engine()318 at::mt19937 CPUGeneratorImpl::engine() {
319   return engine_;
320 }
321 
322 /**
323  * Set the engine of the CPUGeneratorImpl
324  *
325  * See Note [Acquire lock when using random generators]
326  */
set_engine(at::mt19937 engine)327 void CPUGeneratorImpl::set_engine(at::mt19937 engine) {
328   engine_ = engine;
329 }
330 
331 /**
332  * Public clone method implementation
333  *
334  * See Note [Acquire lock when using random generators]
335  */
clone() const336 std::shared_ptr<CPUGeneratorImpl> CPUGeneratorImpl::clone() const {
337   return std::shared_ptr<CPUGeneratorImpl>(this->clone_impl());
338 }
339 
340 /**
341  * Private clone method implementation
342  *
343  * See Note [Acquire lock when using random generators]
344  */
clone_impl() const345 CPUGeneratorImpl* CPUGeneratorImpl::clone_impl() const {
346   auto gen = new CPUGeneratorImpl();
347   gen->set_engine(engine_);
348   gen->set_next_float_normal_sample(next_float_normal_sample_);
349   gen->set_next_double_normal_sample(next_double_normal_sample_);
350   return gen;
351 }
352 
353 } // namespace at
354