xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/cuda_random.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 namespace torch {
4 namespace jit {
5 namespace tensorexpr {
6 
7 constexpr auto philox_random_string = R"(
8 
9 class Philox {
10 public:
11   __device__ inline Philox(unsigned long long seed,
12                            unsigned long long subsequence,
13                            unsigned long long offset) {
14     key.x = (unsigned int)seed;
15     key.y = (unsigned int)(seed >> 32);
16     counter = make_uint4(0, 0, 0, 0);
17     counter.z = (unsigned int)(subsequence);
18     counter.w = (unsigned int)(subsequence >> 32);
19     STATE = 0;
20     incr_n(offset / 4);
21   }
22 
23   __device__ inline unsigned long operator()() {
24     if(STATE == 0) {
25       uint4 counter_ = counter;
26       uint2 key_ = key;
27       for(int i = 0; i < 9; i++) {
28         counter_ = single_round(counter_, key_);
29         key_.x += (kPhilox10A); key_.y += (kPhilox10B);
30       }
31       output = single_round(counter_, key_);
32       incr();
33     }
34     unsigned long ret;
35     switch(STATE) {
36       case 0: ret = output.x; break;
37       case 1: ret = output.y; break;
38       case 2: ret = output.z; break;
39       case 3: ret = output.w; break;
40     }
41     STATE = (STATE + 1) % 4;
42     return ret;
43   }
44 
45 private:
46   uint4 counter;
47   uint4 output;
48   uint2 key;
49   unsigned int STATE;
50   __device__ inline void incr_n(unsigned long long n) {
51     unsigned int nlo = (unsigned int)(n);
52     unsigned int nhi = (unsigned int)(n >> 32);
53     counter.x += nlo;
54     if (counter.x < nlo)
55       nhi++;
56     counter.y += nhi;
57     if (nhi <= counter.y)
58       return;
59     if (++counter.z)
60       return;
61     ++counter.w;
62   }
63   __device__ inline void incr() {
64     if (++counter.x)
65       return;
66     if (++counter.y)
67       return;
68     if (++counter.z)
69       return;
70     ++counter.w;
71   }
72   __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
73                                     unsigned int *result_high) {
74     *result_high = __umulhi(a, b);
75     return a*b;
76   }
77 
78   __device__ inline uint4 single_round(uint4 ctr, uint2 key) {
79     unsigned int hi0;
80     unsigned int hi1;
81     unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
82     unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
83 
84     uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
85     return ret;
86   }
87 
88   static const unsigned long kPhilox10A = 0x9E3779B9;
89   static const unsigned long kPhilox10B = 0xBB67AE85;
90   static const unsigned long kPhiloxSA = 0xD2511F53;
91   static const unsigned long kPhiloxSB = 0xCD9E8D57;
92 };
93 
94 // Inverse of 2^32.
95 #define M_RAN_INVM32 2.3283064e-10f
96 __device__  __inline__ float Uint32ToFloat(unsigned int x) {
97   return x * M_RAN_INVM32;
98 }
99 
100 )";
101 
102 } // namespace tensorexpr
103 } // namespace jit
104 } // namespace torch
105