xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/DistributionTemplates.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/CPUApplyUtils.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Dispatch_v2.h>
6 #include <ATen/ExpandBase.h>
7 #include <ATen/core/DistributionsHelper.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/cpu/Loops.h>
10 #include <mutex>
11 
12 #ifdef CPU_CAPABILITY_AVX2
13 #include <ATen/native/cpu/avx_mathfun.h>
14 #include <c10/util/irange.h>
15 #endif
16 
17 
18 
19 
20 namespace at::native::templates::cpu {
21 namespace {
22 
23 // ==================================================== Random ========================================================
24 
25 template<typename RNG>
random_from_to_kernel(TensorIteratorBase & iter,uint64_t range,int64_t base,RNG generator)26 void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
27   AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] {
28     std::lock_guard<std::mutex> lock(generator->mutex_);
29     cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
30       uniform_int_from_to_distribution<scalar_t> random(range, base);
31       return random(generator);
32     });
33   }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
34 }
35 
36 // This is the special kernel to handle single specific case:
37 // from(inclusive) = std::numeric_limits<int64_t>::lowest()
38 // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
39 template<typename RNG>
random_full_64_bits_range_kernel(TensorIteratorBase & iter,RNG generator)40 void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) {
41   AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] {
42     if constexpr (std::is_same_v<scalar_t, int64_t> ||
43         std::is_same_v<scalar_t, double> ||
44         std::is_same_v<scalar_t, float> ||
45         std::is_same_v<scalar_t, at::BFloat16>) {
46       std::lock_guard<std::mutex> lock(generator->mutex_);
47       cpu_serial_kernel(iter, [generator]() -> scalar_t {
48         uniform_int_full_range_distribution<scalar_t> random;
49         return random(generator);
50       });
51     } else {
52       TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
53     }
54   });
55 }
56 
57 template<typename RNG>
58 struct RandomFromToKernel {
operatorRandomFromToKernel59   void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen) {
60     random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
61   }
operatorRandomFromToKernel62   void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
63     random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
64   }
65 };
66 
67 template<typename RNG>
random_kernel(TensorIteratorBase & iter,RNG generator)68 void random_kernel(TensorIteratorBase& iter, RNG generator) {
69   std::lock_guard<std::mutex> lock(generator->mutex_);
70   AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
71     cpu_serial_kernel(iter, [generator]() -> scalar_t {
72       uniform_int_distribution<scalar_t> random;
73       return random(generator);
74     });
75   });
76 }
77 
78 template<typename RNG>
79 struct RandomKernel {
operatorRandomKernel80   void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
81     random_kernel(iter, check_generator<RNG>(gen));
82   }
83 };
84 
85 // ==================================================== Normal ========================================================
86 
87 #ifdef CPU_CAPABILITY_AVX2
normal_fill_16_AVX2(float * data,const __m256 * two_pi,const __m256 * one,const __m256 * minus_two,const __m256 * mean,const __m256 * std_v)88 static void normal_fill_16_AVX2(float *data,
89                          const __m256* two_pi,
90                          const __m256* one,
91                          const __m256* minus_two,
92                          const __m256* mean,
93                          const __m256* std_v) {
94   const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data));
95   const __m256 u2 = _mm256_loadu_ps(data + 8);
96   // sincos256_ps and log256_ps are from avx_mathfun.h
97   const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1)));
98   const __m256 theta = _mm256_mul_ps(*two_pi, u2);
99   __m256 sintheta, costheta;
100   sincos256_ps(theta, &sintheta, &costheta);
101   const __m256 n1 = _mm256_mul_ps(radius, costheta);
102   const __m256 n2 = _mm256_mul_ps(radius, sintheta);
103   _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean));
104   _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean));
105 }
106 
107 template<typename RNG>
normal_fill_AVX2(const TensorBase & self,const float mean,const float std,RNG generator)108 void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) {
109   float *data = self.data_ptr<float>();
110   auto size = self.numel();
111   std::lock_guard<std::mutex> lock(generator->mutex_);
112   for (const auto i : c10::irange(size)) {
113     at::uniform_real_distribution<float> uniform(0, 1);
114     data[i] = uniform(generator);
115   }
116   const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi<double>);
117   const __m256 one = _mm256_set1_ps(1.0f);
118   const __m256 minus_two = _mm256_set1_ps(-2.0f);
119   const __m256 mean_v = _mm256_set1_ps(mean);
120   const __m256 std_v = _mm256_set1_ps(std);
121 
122   for (int64_t i = 0; i < size - 15; i += 16) {
123     normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v);
124   }
125 
126   if (size % 16 != 0) {
127     // Recompute the last 16 values.
128     data = data + size - 16;
129     for (const auto i : c10::irange(16)) {
130       at::uniform_real_distribution<float> uniform(0, 1);
131       data[i] = uniform(generator);
132     }
133     normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v);
134   }
135 }
136 #endif
137 
138 template <typename scalar_t>
normal_fill_16(scalar_t * data,const scalar_t mean,const scalar_t std)139 static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) {
140   for (const auto j : c10::irange(8)) {
141     const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log.
142     const scalar_t u2 = data[j + 8];
143     const scalar_t radius = std::sqrt(-2 * std::log(u1));
144     const scalar_t theta = 2.0f * c10::pi<double> * u2;
145     data[j] = radius * std::cos(theta) * std + mean;
146     data[j + 8] = radius * std::sin(theta) * std + mean;
147   }
148 }
149 
150 #if defined(__VSX__)  || defined(CPU_CAPABILITY_VSX)
normal_fill_16_VSX(float * data,const Vectorized<float> & two_pi,const Vectorized<float> & one,const Vectorized<float> & minus_two,const Vectorized<float> & mean,const Vectorized<float> & std)151 static void normal_fill_16_VSX(float *data,const Vectorized<float> &two_pi,const Vectorized<float> &one,const Vectorized<float> &minus_two,const Vectorized<float> &mean,const Vectorized<float> &std) {
152   using Vec = Vectorized<float>;
153   Vec u1=one-Vec::loadu(data);
154   Vec u2=Vec::loadu(data+8);
155   Vec radius=(minus_two * u1.log());
156   radius=radius.sqrt();
157   Vec theta=two_pi * u2;
158   Vec output_vec=radius * theta.cos() * std + mean;
159   Vec output_vec2=radius * theta.sin() * std + mean;
160   output_vec.store(data);
161   output_vec2.store(data+8);
162 }
163 
164 template <typename scalar_t, typename RNG>
normal_fill_VSX(const TensorBase & self,const scalar_t mean,const scalar_t std,RNG generator)165 void normal_fill_VSX(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
166   float *data = self.data_ptr<float>();
167   auto size = self.numel();
168   std::lock_guard<std::mutex> lock(generator->mutex_);
169   for (const auto i : c10::irange(size)) {
170     at::uniform_real_distribution<scalar_t> uniform(0, 1);
171     data[i] = uniform(generator);
172   }
173 
174   using Vec = Vectorized<float>;
175   const Vec two_pi = Vec(2.0f * c10::pi<double>);
176   const Vec one = Vec(1.0f);
177   const Vec minus_two = Vec(-2.0f);
178   const Vec var_vec  = Vec(std);
179   const Vec mean_vec = Vec(mean);
180 
181   for (int64_t i = 0; i < size - 15; i += 16) {
182     if(Vec::size()==8) {
183       normal_fill_16_VSX(data + i, two_pi, one, minus_two, mean_vec, var_vec);
184     }
185     else{
186       normal_fill_16<scalar_t>(data + i, mean, std);
187     }
188   }
189   if (size % 16 != 0) {
190     // Recompute the last 16 values.
191     data = data + size - 16;
192     for (const auto i : c10::irange(16)) {
193       at::uniform_real_distribution<scalar_t> uniform(0, 1);
194       data[i] = uniform(generator);
195     }
196     if(Vec::size()==8){
197       normal_fill_16_VSX(data, two_pi, one, minus_two, mean_vec, var_vec);
198     }
199     else{
200       normal_fill_16<scalar_t>(data, mean, std);
201     }
202   }
203 }
204 #endif //VSX
205 
206 template <typename scalar_t, typename RNG>
normal_fill(const TensorBase & self,const scalar_t mean,const scalar_t std,RNG generator)207 void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
208   scalar_t *data = self.data_ptr<scalar_t>();
209   auto size = self.numel();
210   std::lock_guard<std::mutex> lock(generator->mutex_);
211   for (const auto i : c10::irange(size)) {
212     at::uniform_real_distribution<scalar_t> uniform(0, 1);
213     data[i] = uniform(generator);
214   }
215 
216   for (int64_t i = 0; i < size - 15; i += 16) {
217     normal_fill_16<scalar_t>(data + i, mean, std);
218   }
219   if (size % 16 != 0) {
220     // Recompute the last 16 values.
221     data = data + size - 16;
222     for (const auto i : c10::irange(16)) {
223       at::uniform_real_distribution<scalar_t> uniform(0, 1);
224       data[i] = uniform(generator);
225     }
226     normal_fill_16<scalar_t>(data, mean, std);
227   }
228 }
229 
230 template<typename RNG>
normal_kernel(const TensorBase & self,double mean,double std,RNG generator)231 void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) {
232   auto size = self.numel();
233   if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) {
234 #ifdef CPU_CAPABILITY_AVX2
235     normal_fill_AVX2(self, static_cast<float>(mean), static_cast<float>(std), generator);
236 #elif defined(__VSX__)  || defined(CPU_CAPABILITY_VSX)
237     normal_fill_VSX(self, static_cast<float>(mean), static_cast<float>(std), generator);
238 #else
239     normal_fill(self, static_cast<float>(mean), static_cast<float>(std), generator);
240 #endif
241   } else {
242     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] {
243       if (size >= 16 && self.is_contiguous()) {
244         normal_fill<scalar_t>(self, static_cast<scalar_t>(mean), static_cast<scalar_t>(std), generator);
245       } else {
246         auto iter = TensorIterator::borrowing_nullary_op(self);
247         std::lock_guard<std::mutex> lock(generator->mutex_);
248         cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t {
249           at::normal_distribution<double> normal(mean, std);
250           return static_cast<scalar_t>(normal(generator));
251         });
252       }
253     });
254   }
255 }
256 
257 template<typename RNG>
258 struct NormalKernel {
operatorNormalKernel259   void operator()(Tensor& self, double mean, double std, std::optional<Generator> gen) {
260     normal_kernel(self, mean, std, check_generator<RNG>(gen));
261   }
262 };
263 
264 // ==================================================== Uniform =======================================================
265 
266 template<typename RNG>
uniform_kernel(TensorIteratorBase & iter,double from_,double to_,RNG generator)267 void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) {
268   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() {
269     std::lock_guard<std::mutex> lock(generator->mutex_);
270     auto from = static_cast<scalar_t>(from_);
271     auto to = static_cast<scalar_t>(to_);
272     at::uniform_real_distribution<scalar_t> uniform(from, to);
273     cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t {
274       return static_cast<scalar_t>(uniform(generator));
275     });
276   });
277 }
278 
279 template<typename RNG>
280 struct UniformKernel {
operatorUniformKernel281   void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
282     uniform_kernel(iter, from, to, check_generator<RNG>(gen));
283   }
284 };
285 
286 // ==================================================== Cauchy ========================================================
287 
288 template<typename RNG>
cauchy_kernel(TensorIteratorBase & iter,double median,double sigma,RNG generator)289 void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) {
290   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() {
291     std::lock_guard<std::mutex> lock(generator->mutex_);
292     at::cauchy_distribution<double> cauchy(median, sigma);
293     cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t {
294       return static_cast<scalar_t>(cauchy(generator));
295     });
296   });
297 }
298 
299 template<typename RNG>
300 struct CauchyKernel {
operatorCauchyKernel301   void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
302     cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
303   }
304 };
305 
306 // ================================================== LogNormal =======================================================
307 
308 template<typename RNG>
log_normal_kernel(TensorIteratorBase & iter,double mean,double std,RNG generator)309 void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) {
310   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() {
311     std::lock_guard<std::mutex> lock(generator->mutex_);
312     at::lognormal_distribution<double> logNormal(mean, std);
313     cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t {
314       return static_cast<scalar_t>(logNormal(generator));
315     });
316   });
317 }
318 
319 template<typename RNG>
320 struct LogNormalKernel {
operatorLogNormalKernel321   void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
322     log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
323   }
324 };
325 
326 // =================================================== Geometric ======================================================
327 
328 template<typename RNG>
geometric_kernel(TensorIteratorBase & iter,double p,RNG generator)329 void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) {
330   AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() {
331     std::lock_guard<std::mutex> lock(generator->mutex_);
332     at::geometric_distribution<double> geometric(p);
333     cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t {
334       return static_cast<scalar_t>(geometric(generator));
335     });
336   });
337 }
338 
339 template<typename RNG>
340 struct GeometricKernel {
operatorGeometricKernel341   void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
342     geometric_kernel(iter, p, check_generator<RNG>(gen));
343   }
344 };
345 
346 // ================================================== Exponential =====================================================
347 
348 template<typename RNG>
exponential_kernel(TensorIteratorBase & iter,double lambda,RNG generator)349 void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) {
350   TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
351   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() {
352     std::lock_guard<std::mutex> lock(generator->mutex_);
353     at::exponential_distribution<double> exponential(lambda);
354     cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t {
355       return static_cast<scalar_t>(exponential(generator));
356     });
357   });
358 }
359 
360 template<typename RNG>
361 struct ExponentialKernel {
operatorExponentialKernel362   void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
363     exponential_kernel(iter, lambda, check_generator<RNG>(gen));
364   }
365 };
366 
367 // ================================================== Bernoulli =======================================================
368 
369 template<typename RNG>
bernoulli_kernel(const TensorBase & self,const TensorBase & p_,RNG generator)370 void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
371   AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
372   self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
373     // See Note [Acquire lock when using random generators]
374     std::lock_guard<std::mutex> lock(generator->mutex_);
375     using self_t = scalar_t;
376     auto p_cpu = p_.to(kCPU);
377     auto p = expand_inplace(self, p_cpu);
378     auto iter = TensorIteratorConfig()
379         .add_output(self)
380         .add_const_input(*p)
381         .check_all_same_dtype(false)
382         .build();
383     if (p->scalar_type() == kDouble) {
384       cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
385         at::bernoulli_distribution<double> bernoulli(p_val);
386         return static_cast<self_t>(bernoulli(generator));
387       });
388     } else {
389       AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
390       p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
391         using p_t = scalar_t;
392         cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
393           at::bernoulli_distribution<float> bernoulli(p_val);
394           return static_cast<self_t>(bernoulli(generator));
395         });
396       });
397     }
398   });
399 }
400 
401 template<typename RNG>
bernoulli_kernel(const TensorBase & self,double p,RNG generator)402 void bernoulli_kernel(const TensorBase &self, double p, RNG generator) {
403   AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
404   self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
405     // See Note [Acquire lock when using random generators]
406     std::lock_guard<std::mutex> lock(generator->mutex_);
407     auto iter = TensorIterator::borrowing_nullary_op(self);
408     cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
409       at::bernoulli_distribution<double> bernoulli(p);
410       return static_cast<scalar_t>(bernoulli(generator));
411     });
412   });
413 }
414 
415 template<typename RNG>
416 struct BernoulliKernel {
operatorBernoulliKernel417   void operator()(const TensorBase &self, double p, std::optional<Generator> gen) {
418     bernoulli_kernel(self, p, check_generator<RNG>(gen));
419   }
operatorBernoulliKernel420   void operator()(const TensorBase &self, const TensorBase &p_, std::optional<Generator> gen) {
421     bernoulli_kernel(self, p_, check_generator<RNG>(gen));
422   }
423 };
424 
425 }}
426