xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/MT19937RNGEngine.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 
5 // define constants like M_PI and C keywords for MSVC
6 #ifdef _MSC_VER
7 #ifndef _USE_MATH_DEFINES
8 #define _USE_MATH_DEFINES
9 #endif
10 #include <math.h>
11 #endif
12 
13 #include <array>
14 #include <cmath>
15 #include <cstdint>
16 
17 namespace at {
18 
19 constexpr int MERSENNE_STATE_N = 624;
20 constexpr int MERSENNE_STATE_M = 397;
21 constexpr uint32_t MATRIX_A = 0x9908b0df;
22 constexpr uint32_t UMASK = 0x80000000;
23 constexpr uint32_t LMASK = 0x7fffffff;
24 
25 /**
26  * Note [Mt19937 Engine implementation]
27  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28  * Originally implemented in:
29  * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/CODES/MTARCOK/mt19937ar-cok.c
30  * and modified with C++ constructs. Moreover the state array of the engine
31  * has been modified to hold 32 bit uints instead of 64 bits.
32  *
33  * Note that we reimplemented mt19937 instead of using std::mt19937 because,
34  * at::mt19937 turns out to be faster in the pytorch codebase. PyTorch builds with -O2
35  * by default and following are the benchmark numbers (benchmark code can be found at
36  * https://github.com/syed-ahmed/benchmark-rngs):
37  *
38  * with -O2
39  * Time to get 100000000 philox randoms with at::uniform_real_distribution = 0.462759s
40  * Time to get 100000000 at::mt19937 randoms with at::uniform_real_distribution = 0.39628s
41  * Time to get 100000000 std::mt19937 randoms with std::uniform_real_distribution = 0.352087s
42  * Time to get 100000000 std::mt19937 randoms with at::uniform_real_distribution = 0.419454s
43  *
44  * std::mt19937 is faster when used in conjunction with std::uniform_real_distribution,
45  * however we can't use std::uniform_real_distribution because of this bug:
46  * http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524. Plus, even if we used
47  * std::uniform_real_distribution and filtered out the 1's, it is a different algorithm
48  * than what's in pytorch currently and that messes up the tests in tests_distributions.py.
49  * The other option, using std::mt19937 with at::uniform_real_distribution is a tad bit slower
50  * than at::mt19937 with at::uniform_real_distribution and hence, we went with the latter.
51  *
52  * Copyright notice:
53  * A C-program for MT19937, with initialization improved 2002/2/10.
54  * Coded by Takuji Nishimura and Makoto Matsumoto.
55  * This is a faster version by taking Shawn Cokus's optimization,
56  * Matthe Bellew's simplification, Isaku Wada's real version.
57  *
58  * Before using, initialize the state by using init_genrand(seed)
59  * or init_by_array(init_key, key_length).
60  *
61  * Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura,
62  * All rights reserved.
63  *
64  * Redistribution and use in source and binary forms, with or without
65  * modification, are permitted provided that the following conditions
66  * are met:
67  *
68  *   1. Redistributions of source code must retain the above copyright
69  *   notice, this list of conditions and the following disclaimer.
70  *
71  *   2. Redistributions in binary form must reproduce the above copyright
72  *   notice, this list of conditions and the following disclaimer in the
73  *   documentation and/or other materials provided with the distribution.
74  *
75  *   3. The names of its contributors may not be used to endorse or promote
76  *   products derived from this software without specific prior written
77  *   permission.
78  *
79  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
80  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
81  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
82  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
83  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
84  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
85  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
86  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
87  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
88  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
89  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
90  *
91  *
92  * Any feedback is very welcome.
93  * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html
94  * email: m-mat @ math.sci.hiroshima-u.ac.jp (remove space)
95  */
96 
97 /**
98  * mt19937_data_pod is used to get POD data in and out
99  * of mt19937_engine. Used in torch.get_rng_state and
100  * torch.set_rng_state functions.
101  */
102 struct mt19937_data_pod {
103   uint64_t seed_;
104   int left_;
105   bool seeded_;
106   uint32_t next_;
107   std::array<uint32_t, MERSENNE_STATE_N> state_;
108 };
109 
110 class mt19937_engine {
111 public:
112 
113   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
114   inline explicit mt19937_engine(uint64_t seed = 5489) {
115     init_with_uint32(seed);
116   }
117 
data()118   inline mt19937_data_pod data() const {
119     return data_;
120   }
121 
set_data(const mt19937_data_pod & data)122   inline void set_data(const mt19937_data_pod& data) {
123     data_ = data;
124   }
125 
seed()126   inline uint64_t seed() const {
127     return data_.seed_;
128   }
129 
is_valid()130   inline bool is_valid() {
131     if ((data_.seeded_ == true)
132       && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N)
133       && (data_.next_ <= MERSENNE_STATE_N)) {
134       return true;
135     }
136     return false;
137   }
138 
operator()139   inline uint32_t operator()() {
140     if (--(data_.left_) == 0) {
141         next_state();
142     }
143     uint32_t y = *(data_.state_.data() + data_.next_++);
144     y ^= (y >> 11);
145     y ^= (y << 7) & 0x9d2c5680;
146     y ^= (y << 15) & 0xefc60000;
147     y ^= (y >> 18);
148 
149     return y;
150   }
151 
152 private:
153   mt19937_data_pod data_;
154 
init_with_uint32(uint64_t seed)155   inline void init_with_uint32(uint64_t seed) {
156     data_.seed_ = seed;
157     data_.seeded_ = true;
158     data_.state_[0] = seed & 0xffffffff;
159     for (const auto j : c10::irange(1, MERSENNE_STATE_N)) {
160       data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j);
161     }
162     data_.left_ = 1;
163     data_.next_ = 0;
164   }
165 
mix_bits(uint32_t u,uint32_t v)166   inline uint32_t mix_bits(uint32_t u, uint32_t v) {
167     return (u & UMASK) | (v & LMASK);
168   }
169 
twist(uint32_t u,uint32_t v)170   inline uint32_t twist(uint32_t u, uint32_t v) {
171     return (mix_bits(u,v) >> 1) ^ (v & 1 ? MATRIX_A : 0);
172   }
173 
next_state()174   inline void next_state() {
175     uint32_t* p = data_.state_.data();
176     data_.left_ = MERSENNE_STATE_N;
177     data_.next_ = 0;
178 
179     for(int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) {
180       *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]);
181     }
182 
183     for(int j = MERSENNE_STATE_M; --j; p++) {
184       *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]);
185     }
186 
187     *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]);
188   }
189 
190 };
191 
192 typedef mt19937_engine mt19937;
193 
194 } // namespace at
195