1 #pragma once 2 3 #include <c10/util/irange.h> 4 5 // define constants like M_PI and C keywords for MSVC 6 #ifdef _MSC_VER 7 #ifndef _USE_MATH_DEFINES 8 #define _USE_MATH_DEFINES 9 #endif 10 #include <math.h> 11 #endif 12 13 #include <array> 14 #include <cmath> 15 #include <cstdint> 16 17 namespace at { 18 19 constexpr int MERSENNE_STATE_N = 624; 20 constexpr int MERSENNE_STATE_M = 397; 21 constexpr uint32_t MATRIX_A = 0x9908b0df; 22 constexpr uint32_t UMASK = 0x80000000; 23 constexpr uint32_t LMASK = 0x7fffffff; 24 25 /** 26 * Note [Mt19937 Engine implementation] 27 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 28 * Originally implemented in: 29 * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/CODES/MTARCOK/mt19937ar-cok.c 30 * and modified with C++ constructs. Moreover the state array of the engine 31 * has been modified to hold 32 bit uints instead of 64 bits. 32 * 33 * Note that we reimplemented mt19937 instead of using std::mt19937 because, 34 * at::mt19937 turns out to be faster in the pytorch codebase. PyTorch builds with -O2 35 * by default and following are the benchmark numbers (benchmark code can be found at 36 * https://github.com/syed-ahmed/benchmark-rngs): 37 * 38 * with -O2 39 * Time to get 100000000 philox randoms with at::uniform_real_distribution = 0.462759s 40 * Time to get 100000000 at::mt19937 randoms with at::uniform_real_distribution = 0.39628s 41 * Time to get 100000000 std::mt19937 randoms with std::uniform_real_distribution = 0.352087s 42 * Time to get 100000000 std::mt19937 randoms with at::uniform_real_distribution = 0.419454s 43 * 44 * std::mt19937 is faster when used in conjunction with std::uniform_real_distribution, 45 * however we can't use std::uniform_real_distribution because of this bug: 46 * http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524. Plus, even if we used 47 * std::uniform_real_distribution and filtered out the 1's, it is a different algorithm 48 * than what's in pytorch currently and that messes up the tests in tests_distributions.py. 49 * The other option, using std::mt19937 with at::uniform_real_distribution is a tad bit slower 50 * than at::mt19937 with at::uniform_real_distribution and hence, we went with the latter. 51 * 52 * Copyright notice: 53 * A C-program for MT19937, with initialization improved 2002/2/10. 54 * Coded by Takuji Nishimura and Makoto Matsumoto. 55 * This is a faster version by taking Shawn Cokus's optimization, 56 * Matthe Bellew's simplification, Isaku Wada's real version. 57 * 58 * Before using, initialize the state by using init_genrand(seed) 59 * or init_by_array(init_key, key_length). 60 * 61 * Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, 62 * All rights reserved. 63 * 64 * Redistribution and use in source and binary forms, with or without 65 * modification, are permitted provided that the following conditions 66 * are met: 67 * 68 * 1. Redistributions of source code must retain the above copyright 69 * notice, this list of conditions and the following disclaimer. 70 * 71 * 2. Redistributions in binary form must reproduce the above copyright 72 * notice, this list of conditions and the following disclaimer in the 73 * documentation and/or other materials provided with the distribution. 74 * 75 * 3. The names of its contributors may not be used to endorse or promote 76 * products derived from this software without specific prior written 77 * permission. 78 * 79 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 80 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 81 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 82 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 83 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 84 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 85 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 86 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 87 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 88 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 89 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 90 * 91 * 92 * Any feedback is very welcome. 93 * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html 94 * email: m-mat @ math.sci.hiroshima-u.ac.jp (remove space) 95 */ 96 97 /** 98 * mt19937_data_pod is used to get POD data in and out 99 * of mt19937_engine. Used in torch.get_rng_state and 100 * torch.set_rng_state functions. 101 */ 102 struct mt19937_data_pod { 103 uint64_t seed_; 104 int left_; 105 bool seeded_; 106 uint32_t next_; 107 std::array<uint32_t, MERSENNE_STATE_N> state_; 108 }; 109 110 class mt19937_engine { 111 public: 112 113 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) 114 inline explicit mt19937_engine(uint64_t seed = 5489) { 115 init_with_uint32(seed); 116 } 117 data()118 inline mt19937_data_pod data() const { 119 return data_; 120 } 121 set_data(const mt19937_data_pod & data)122 inline void set_data(const mt19937_data_pod& data) { 123 data_ = data; 124 } 125 seed()126 inline uint64_t seed() const { 127 return data_.seed_; 128 } 129 is_valid()130 inline bool is_valid() { 131 if ((data_.seeded_ == true) 132 && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N) 133 && (data_.next_ <= MERSENNE_STATE_N)) { 134 return true; 135 } 136 return false; 137 } 138 operator()139 inline uint32_t operator()() { 140 if (--(data_.left_) == 0) { 141 next_state(); 142 } 143 uint32_t y = *(data_.state_.data() + data_.next_++); 144 y ^= (y >> 11); 145 y ^= (y << 7) & 0x9d2c5680; 146 y ^= (y << 15) & 0xefc60000; 147 y ^= (y >> 18); 148 149 return y; 150 } 151 152 private: 153 mt19937_data_pod data_; 154 init_with_uint32(uint64_t seed)155 inline void init_with_uint32(uint64_t seed) { 156 data_.seed_ = seed; 157 data_.seeded_ = true; 158 data_.state_[0] = seed & 0xffffffff; 159 for (const auto j : c10::irange(1, MERSENNE_STATE_N)) { 160 data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j); 161 } 162 data_.left_ = 1; 163 data_.next_ = 0; 164 } 165 mix_bits(uint32_t u,uint32_t v)166 inline uint32_t mix_bits(uint32_t u, uint32_t v) { 167 return (u & UMASK) | (v & LMASK); 168 } 169 twist(uint32_t u,uint32_t v)170 inline uint32_t twist(uint32_t u, uint32_t v) { 171 return (mix_bits(u,v) >> 1) ^ (v & 1 ? MATRIX_A : 0); 172 } 173 next_state()174 inline void next_state() { 175 uint32_t* p = data_.state_.data(); 176 data_.left_ = MERSENNE_STATE_N; 177 data_.next_ = 0; 178 179 for(int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) { 180 *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]); 181 } 182 183 for(int j = MERSENNE_STATE_M; --j; p++) { 184 *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]); 185 } 186 187 *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]); 188 } 189 190 }; 191 192 typedef mt19937_engine mt19937; 193 194 } // namespace at 195