1 #pragma once
2
3 #include <ATen/CPUApplyUtils.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Dispatch_v2.h>
6 #include <ATen/ExpandBase.h>
7 #include <ATen/core/DistributionsHelper.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/cpu/Loops.h>
10 #include <mutex>
11
12 #ifdef CPU_CAPABILITY_AVX2
13 #include <ATen/native/cpu/avx_mathfun.h>
14 #include <c10/util/irange.h>
15 #endif
16
17
18
19
20 namespace at::native::templates::cpu {
21 namespace {
22
23 // ==================================================== Random ========================================================
24
25 template<typename RNG>
random_from_to_kernel(TensorIteratorBase & iter,uint64_t range,int64_t base,RNG generator)26 void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
27 AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] {
28 std::lock_guard<std::mutex> lock(generator->mutex_);
29 cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
30 uniform_int_from_to_distribution<scalar_t> random(range, base);
31 return random(generator);
32 });
33 }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
34 }
35
36 // This is the special kernel to handle single specific case:
37 // from(inclusive) = std::numeric_limits<int64_t>::lowest()
38 // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
39 template<typename RNG>
random_full_64_bits_range_kernel(TensorIteratorBase & iter,RNG generator)40 void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) {
41 AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] {
42 if constexpr (std::is_same_v<scalar_t, int64_t> ||
43 std::is_same_v<scalar_t, double> ||
44 std::is_same_v<scalar_t, float> ||
45 std::is_same_v<scalar_t, at::BFloat16>) {
46 std::lock_guard<std::mutex> lock(generator->mutex_);
47 cpu_serial_kernel(iter, [generator]() -> scalar_t {
48 uniform_int_full_range_distribution<scalar_t> random;
49 return random(generator);
50 });
51 } else {
52 TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
53 }
54 });
55 }
56
57 template<typename RNG>
58 struct RandomFromToKernel {
operatorRandomFromToKernel59 void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen) {
60 random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
61 }
operatorRandomFromToKernel62 void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
63 random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
64 }
65 };
66
67 template<typename RNG>
random_kernel(TensorIteratorBase & iter,RNG generator)68 void random_kernel(TensorIteratorBase& iter, RNG generator) {
69 std::lock_guard<std::mutex> lock(generator->mutex_);
70 AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
71 cpu_serial_kernel(iter, [generator]() -> scalar_t {
72 uniform_int_distribution<scalar_t> random;
73 return random(generator);
74 });
75 });
76 }
77
78 template<typename RNG>
79 struct RandomKernel {
operatorRandomKernel80 void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
81 random_kernel(iter, check_generator<RNG>(gen));
82 }
83 };
84
85 // ==================================================== Normal ========================================================
86
87 #ifdef CPU_CAPABILITY_AVX2
normal_fill_16_AVX2(float * data,const __m256 * two_pi,const __m256 * one,const __m256 * minus_two,const __m256 * mean,const __m256 * std_v)88 static void normal_fill_16_AVX2(float *data,
89 const __m256* two_pi,
90 const __m256* one,
91 const __m256* minus_two,
92 const __m256* mean,
93 const __m256* std_v) {
94 const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data));
95 const __m256 u2 = _mm256_loadu_ps(data + 8);
96 // sincos256_ps and log256_ps are from avx_mathfun.h
97 const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1)));
98 const __m256 theta = _mm256_mul_ps(*two_pi, u2);
99 __m256 sintheta, costheta;
100 sincos256_ps(theta, &sintheta, &costheta);
101 const __m256 n1 = _mm256_mul_ps(radius, costheta);
102 const __m256 n2 = _mm256_mul_ps(radius, sintheta);
103 _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean));
104 _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean));
105 }
106
107 template<typename RNG>
normal_fill_AVX2(const TensorBase & self,const float mean,const float std,RNG generator)108 void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) {
109 float *data = self.data_ptr<float>();
110 auto size = self.numel();
111 std::lock_guard<std::mutex> lock(generator->mutex_);
112 for (const auto i : c10::irange(size)) {
113 at::uniform_real_distribution<float> uniform(0, 1);
114 data[i] = uniform(generator);
115 }
116 const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi<double>);
117 const __m256 one = _mm256_set1_ps(1.0f);
118 const __m256 minus_two = _mm256_set1_ps(-2.0f);
119 const __m256 mean_v = _mm256_set1_ps(mean);
120 const __m256 std_v = _mm256_set1_ps(std);
121
122 for (int64_t i = 0; i < size - 15; i += 16) {
123 normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v);
124 }
125
126 if (size % 16 != 0) {
127 // Recompute the last 16 values.
128 data = data + size - 16;
129 for (const auto i : c10::irange(16)) {
130 at::uniform_real_distribution<float> uniform(0, 1);
131 data[i] = uniform(generator);
132 }
133 normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v);
134 }
135 }
136 #endif
137
138 template <typename scalar_t>
normal_fill_16(scalar_t * data,const scalar_t mean,const scalar_t std)139 static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) {
140 for (const auto j : c10::irange(8)) {
141 const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log.
142 const scalar_t u2 = data[j + 8];
143 const scalar_t radius = std::sqrt(-2 * std::log(u1));
144 const scalar_t theta = 2.0f * c10::pi<double> * u2;
145 data[j] = radius * std::cos(theta) * std + mean;
146 data[j + 8] = radius * std::sin(theta) * std + mean;
147 }
148 }
149
150 #if defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
normal_fill_16_VSX(float * data,const Vectorized<float> & two_pi,const Vectorized<float> & one,const Vectorized<float> & minus_two,const Vectorized<float> & mean,const Vectorized<float> & std)151 static void normal_fill_16_VSX(float *data,const Vectorized<float> &two_pi,const Vectorized<float> &one,const Vectorized<float> &minus_two,const Vectorized<float> &mean,const Vectorized<float> &std) {
152 using Vec = Vectorized<float>;
153 Vec u1=one-Vec::loadu(data);
154 Vec u2=Vec::loadu(data+8);
155 Vec radius=(minus_two * u1.log());
156 radius=radius.sqrt();
157 Vec theta=two_pi * u2;
158 Vec output_vec=radius * theta.cos() * std + mean;
159 Vec output_vec2=radius * theta.sin() * std + mean;
160 output_vec.store(data);
161 output_vec2.store(data+8);
162 }
163
164 template <typename scalar_t, typename RNG>
normal_fill_VSX(const TensorBase & self,const scalar_t mean,const scalar_t std,RNG generator)165 void normal_fill_VSX(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
166 float *data = self.data_ptr<float>();
167 auto size = self.numel();
168 std::lock_guard<std::mutex> lock(generator->mutex_);
169 for (const auto i : c10::irange(size)) {
170 at::uniform_real_distribution<scalar_t> uniform(0, 1);
171 data[i] = uniform(generator);
172 }
173
174 using Vec = Vectorized<float>;
175 const Vec two_pi = Vec(2.0f * c10::pi<double>);
176 const Vec one = Vec(1.0f);
177 const Vec minus_two = Vec(-2.0f);
178 const Vec var_vec = Vec(std);
179 const Vec mean_vec = Vec(mean);
180
181 for (int64_t i = 0; i < size - 15; i += 16) {
182 if(Vec::size()==8) {
183 normal_fill_16_VSX(data + i, two_pi, one, minus_two, mean_vec, var_vec);
184 }
185 else{
186 normal_fill_16<scalar_t>(data + i, mean, std);
187 }
188 }
189 if (size % 16 != 0) {
190 // Recompute the last 16 values.
191 data = data + size - 16;
192 for (const auto i : c10::irange(16)) {
193 at::uniform_real_distribution<scalar_t> uniform(0, 1);
194 data[i] = uniform(generator);
195 }
196 if(Vec::size()==8){
197 normal_fill_16_VSX(data, two_pi, one, minus_two, mean_vec, var_vec);
198 }
199 else{
200 normal_fill_16<scalar_t>(data, mean, std);
201 }
202 }
203 }
204 #endif //VSX
205
206 template <typename scalar_t, typename RNG>
normal_fill(const TensorBase & self,const scalar_t mean,const scalar_t std,RNG generator)207 void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
208 scalar_t *data = self.data_ptr<scalar_t>();
209 auto size = self.numel();
210 std::lock_guard<std::mutex> lock(generator->mutex_);
211 for (const auto i : c10::irange(size)) {
212 at::uniform_real_distribution<scalar_t> uniform(0, 1);
213 data[i] = uniform(generator);
214 }
215
216 for (int64_t i = 0; i < size - 15; i += 16) {
217 normal_fill_16<scalar_t>(data + i, mean, std);
218 }
219 if (size % 16 != 0) {
220 // Recompute the last 16 values.
221 data = data + size - 16;
222 for (const auto i : c10::irange(16)) {
223 at::uniform_real_distribution<scalar_t> uniform(0, 1);
224 data[i] = uniform(generator);
225 }
226 normal_fill_16<scalar_t>(data, mean, std);
227 }
228 }
229
230 template<typename RNG>
normal_kernel(const TensorBase & self,double mean,double std,RNG generator)231 void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) {
232 auto size = self.numel();
233 if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) {
234 #ifdef CPU_CAPABILITY_AVX2
235 normal_fill_AVX2(self, static_cast<float>(mean), static_cast<float>(std), generator);
236 #elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
237 normal_fill_VSX(self, static_cast<float>(mean), static_cast<float>(std), generator);
238 #else
239 normal_fill(self, static_cast<float>(mean), static_cast<float>(std), generator);
240 #endif
241 } else {
242 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] {
243 if (size >= 16 && self.is_contiguous()) {
244 normal_fill<scalar_t>(self, static_cast<scalar_t>(mean), static_cast<scalar_t>(std), generator);
245 } else {
246 auto iter = TensorIterator::borrowing_nullary_op(self);
247 std::lock_guard<std::mutex> lock(generator->mutex_);
248 cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t {
249 at::normal_distribution<double> normal(mean, std);
250 return static_cast<scalar_t>(normal(generator));
251 });
252 }
253 });
254 }
255 }
256
257 template<typename RNG>
258 struct NormalKernel {
operatorNormalKernel259 void operator()(Tensor& self, double mean, double std, std::optional<Generator> gen) {
260 normal_kernel(self, mean, std, check_generator<RNG>(gen));
261 }
262 };
263
264 // ==================================================== Uniform =======================================================
265
266 template<typename RNG>
uniform_kernel(TensorIteratorBase & iter,double from_,double to_,RNG generator)267 void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) {
268 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() {
269 std::lock_guard<std::mutex> lock(generator->mutex_);
270 auto from = static_cast<scalar_t>(from_);
271 auto to = static_cast<scalar_t>(to_);
272 at::uniform_real_distribution<scalar_t> uniform(from, to);
273 cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t {
274 return static_cast<scalar_t>(uniform(generator));
275 });
276 });
277 }
278
279 template<typename RNG>
280 struct UniformKernel {
operatorUniformKernel281 void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
282 uniform_kernel(iter, from, to, check_generator<RNG>(gen));
283 }
284 };
285
286 // ==================================================== Cauchy ========================================================
287
288 template<typename RNG>
cauchy_kernel(TensorIteratorBase & iter,double median,double sigma,RNG generator)289 void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) {
290 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() {
291 std::lock_guard<std::mutex> lock(generator->mutex_);
292 at::cauchy_distribution<double> cauchy(median, sigma);
293 cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t {
294 return static_cast<scalar_t>(cauchy(generator));
295 });
296 });
297 }
298
299 template<typename RNG>
300 struct CauchyKernel {
operatorCauchyKernel301 void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
302 cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
303 }
304 };
305
306 // ================================================== LogNormal =======================================================
307
308 template<typename RNG>
log_normal_kernel(TensorIteratorBase & iter,double mean,double std,RNG generator)309 void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) {
310 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() {
311 std::lock_guard<std::mutex> lock(generator->mutex_);
312 at::lognormal_distribution<double> logNormal(mean, std);
313 cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t {
314 return static_cast<scalar_t>(logNormal(generator));
315 });
316 });
317 }
318
319 template<typename RNG>
320 struct LogNormalKernel {
operatorLogNormalKernel321 void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
322 log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
323 }
324 };
325
326 // =================================================== Geometric ======================================================
327
328 template<typename RNG>
geometric_kernel(TensorIteratorBase & iter,double p,RNG generator)329 void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) {
330 AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() {
331 std::lock_guard<std::mutex> lock(generator->mutex_);
332 at::geometric_distribution<double> geometric(p);
333 cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t {
334 return static_cast<scalar_t>(geometric(generator));
335 });
336 });
337 }
338
339 template<typename RNG>
340 struct GeometricKernel {
operatorGeometricKernel341 void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
342 geometric_kernel(iter, p, check_generator<RNG>(gen));
343 }
344 };
345
346 // ================================================== Exponential =====================================================
347
348 template<typename RNG>
exponential_kernel(TensorIteratorBase & iter,double lambda,RNG generator)349 void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) {
350 TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
351 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() {
352 std::lock_guard<std::mutex> lock(generator->mutex_);
353 at::exponential_distribution<double> exponential(lambda);
354 cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t {
355 return static_cast<scalar_t>(exponential(generator));
356 });
357 });
358 }
359
360 template<typename RNG>
361 struct ExponentialKernel {
operatorExponentialKernel362 void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
363 exponential_kernel(iter, lambda, check_generator<RNG>(gen));
364 }
365 };
366
367 // ================================================== Bernoulli =======================================================
368
369 template<typename RNG>
bernoulli_kernel(const TensorBase & self,const TensorBase & p_,RNG generator)370 void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
371 AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
372 self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
373 // See Note [Acquire lock when using random generators]
374 std::lock_guard<std::mutex> lock(generator->mutex_);
375 using self_t = scalar_t;
376 auto p_cpu = p_.to(kCPU);
377 auto p = expand_inplace(self, p_cpu);
378 auto iter = TensorIteratorConfig()
379 .add_output(self)
380 .add_const_input(*p)
381 .check_all_same_dtype(false)
382 .build();
383 if (p->scalar_type() == kDouble) {
384 cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
385 at::bernoulli_distribution<double> bernoulli(p_val);
386 return static_cast<self_t>(bernoulli(generator));
387 });
388 } else {
389 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
390 p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
391 using p_t = scalar_t;
392 cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
393 at::bernoulli_distribution<float> bernoulli(p_val);
394 return static_cast<self_t>(bernoulli(generator));
395 });
396 });
397 }
398 });
399 }
400
401 template<typename RNG>
bernoulli_kernel(const TensorBase & self,double p,RNG generator)402 void bernoulli_kernel(const TensorBase &self, double p, RNG generator) {
403 AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
404 self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
405 // See Note [Acquire lock when using random generators]
406 std::lock_guard<std::mutex> lock(generator->mutex_);
407 auto iter = TensorIterator::borrowing_nullary_op(self);
408 cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
409 at::bernoulli_distribution<double> bernoulli(p);
410 return static_cast<scalar_t>(bernoulli(generator));
411 });
412 });
413 }
414
415 template<typename RNG>
416 struct BernoulliKernel {
operatorBernoulliKernel417 void operator()(const TensorBase &self, double p, std::optional<Generator> gen) {
418 bernoulli_kernel(self, p, check_generator<RNG>(gen));
419 }
operatorBernoulliKernel420 void operator()(const TensorBase &self, const TensorBase &p_, std::optional<Generator> gen) {
421 bernoulli_kernel(self, p_, check_generator<RNG>(gen));
422 }
423 };
424
425 }}
426