1 #pragma once 2 3 namespace torch { 4 namespace jit { 5 namespace tensorexpr { 6 7 constexpr auto philox_random_string = R"( 8 9 class Philox { 10 public: 11 __device__ inline Philox(unsigned long long seed, 12 unsigned long long subsequence, 13 unsigned long long offset) { 14 key.x = (unsigned int)seed; 15 key.y = (unsigned int)(seed >> 32); 16 counter = make_uint4(0, 0, 0, 0); 17 counter.z = (unsigned int)(subsequence); 18 counter.w = (unsigned int)(subsequence >> 32); 19 STATE = 0; 20 incr_n(offset / 4); 21 } 22 23 __device__ inline unsigned long operator()() { 24 if(STATE == 0) { 25 uint4 counter_ = counter; 26 uint2 key_ = key; 27 for(int i = 0; i < 9; i++) { 28 counter_ = single_round(counter_, key_); 29 key_.x += (kPhilox10A); key_.y += (kPhilox10B); 30 } 31 output = single_round(counter_, key_); 32 incr(); 33 } 34 unsigned long ret; 35 switch(STATE) { 36 case 0: ret = output.x; break; 37 case 1: ret = output.y; break; 38 case 2: ret = output.z; break; 39 case 3: ret = output.w; break; 40 } 41 STATE = (STATE + 1) % 4; 42 return ret; 43 } 44 45 private: 46 uint4 counter; 47 uint4 output; 48 uint2 key; 49 unsigned int STATE; 50 __device__ inline void incr_n(unsigned long long n) { 51 unsigned int nlo = (unsigned int)(n); 52 unsigned int nhi = (unsigned int)(n >> 32); 53 counter.x += nlo; 54 if (counter.x < nlo) 55 nhi++; 56 counter.y += nhi; 57 if (nhi <= counter.y) 58 return; 59 if (++counter.z) 60 return; 61 ++counter.w; 62 } 63 __device__ inline void incr() { 64 if (++counter.x) 65 return; 66 if (++counter.y) 67 return; 68 if (++counter.z) 69 return; 70 ++counter.w; 71 } 72 __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, 73 unsigned int *result_high) { 74 *result_high = __umulhi(a, b); 75 return a*b; 76 } 77 78 __device__ inline uint4 single_round(uint4 ctr, uint2 key) { 79 unsigned int hi0; 80 unsigned int hi1; 81 unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); 82 unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); 83 84 uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; 85 return ret; 86 } 87 88 static const unsigned long kPhilox10A = 0x9E3779B9; 89 static const unsigned long kPhilox10B = 0xBB67AE85; 90 static const unsigned long kPhiloxSA = 0xD2511F53; 91 static const unsigned long kPhiloxSB = 0xCD9E8D57; 92 }; 93 94 // Inverse of 2^32. 95 #define M_RAN_INVM32 2.3283064e-10f 96 __device__ __inline__ float Uint32ToFloat(unsigned int x) { 97 return x * M_RAN_INVM32; 98 } 99 100 )"; 101 102 } // namespace tensorexpr 103 } // namespace jit 104 } // namespace torch 105