xref: /aosp_15_r20/external/tensorflow/tensorflow/core/lib/random/random_distributions.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
17 #define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
18 
19 #include <algorithm>
20 #include <cmath>
21 #include <type_traits>
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/lib/random/philox_random.h"
25 #include "tensorflow/core/lib/random/random_distributions_utils.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace tensorflow {
29 namespace random {
30 
31 // Helper function to convert a 16-bit integer to a half between [0..1).
32 PHILOX_DEVICE_INLINE Eigen::half Uint16ToHalf(uint16 x);
33 // Helper function to convert a 16-bit integer to a bfloat16 between [0..1).
34 PHILOX_DEVICE_INLINE bfloat16 Uint16ToGfloat16(uint16 x);
35 
36 // Computes a + b. Requires that the result is representable in the destination
37 // type and that b is not maximal (i.e. b + 1 is not 0). Notably, the addend b
38 // need *not* be representable in that type. (The condition on b excludes the
39 // extremal case INT_MIN + UINT_MAX = INT_MAX, which this function cannot
40 // compute.)
41 template <typename Int>
SignedAdd(Int a,typename std::make_unsigned<Int>::type b)42 PHILOX_DEVICE_INLINE Int SignedAdd(Int a,
43                                    typename std::make_unsigned<Int>::type b) {
44   // Implementation note: both b_div_2 and b - b_div_2 are positive and
45   // representable as Int.
46   auto b_div_2 = b >> 1;
47   return a + static_cast<Int>(b_div_2) + static_cast<Int>(b - b_div_2);
48 }
49 
50 // A class that generates uniform distribution random numbers from the
51 // underlying random integer generator.
52 // Arguments:
53 //   Generator: a generator type that returns a number of uint32 upon each
54 //              invocation. It needs to define kResultElementCount for the
55 //              sample count for each invocation, and ResultType for the
56 //              actual returned sample type.
57 //   RealType: the data type of the real numbers that will be returned by the
58 //             distribution. This could be either float or double for now.
59 // This class is meant to be implemented through specialization. The default
60 // is not defined by design.
61 template <class Generator, typename RealType>
62 class UniformDistribution;
63 
64 template <class Generator>
65 class UniformDistribution<Generator, Eigen::half> {
66  public:
67   // The number of elements that will be returned.
68   static constexpr int kResultElementCount = Generator::kResultElementCount;
69   // Cost of generation of a single element (in cycles).
70   static constexpr int kElementCost = 3;
71   // Indicate that this distribution may take variable number of samples
72   // during the runtime.
73   static constexpr bool kVariableSamplesPerOutput = false;
74   typedef Array<Eigen::half, kResultElementCount> ResultType;
75   typedef Eigen::half ResultElementType;
76 
77   PHILOX_DEVICE_INLINE
operator()78   ResultType operator()(Generator* gen) {
79     typename Generator::ResultType sample = (*gen)();
80     ResultType result;
81     for (int i = 0; i < kResultElementCount; ++i) {
82       result[i] = Uint16ToHalf(sample[i]);  // Truncate the upper 16 bits.
83     }
84     return result;
85   }
86 };
87 
88 template <class Generator>
89 class UniformDistribution<Generator, bfloat16> {
90  public:
91   // The number of elements that will be returned.
92   static constexpr int kResultElementCount = Generator::kResultElementCount;
93   // Cost of generation of a single element (in cycles).
94   static constexpr int kElementCost = 3;
95   // Indicate that this distribution may take variable number of samples
96   // during the runtime.
97   static constexpr bool kVariableSamplesPerOutput = false;
98   typedef Array<bfloat16, kResultElementCount> ResultType;
99   typedef bfloat16 ResultElementType;
100 
101   PHILOX_DEVICE_INLINE
operator()102   ResultType operator()(Generator* gen) {
103     typename Generator::ResultType sample = (*gen)();
104     ResultType result;
105     for (int i = 0; i < kResultElementCount; ++i) {
106       result[i] = Uint16ToGfloat16(sample[i]);
107     }
108     return result;
109   }
110 };
111 
112 template <class Generator>
113 class UniformDistribution<Generator, float> {
114  public:
115   // The number of elements that will be returned.
116   static constexpr int kResultElementCount = Generator::kResultElementCount;
117   // Cost of generation of a single element (in cycles).
118   static constexpr int kElementCost = 3;
119   // Indicate that this distribution may take variable number of samples
120   // during the runtime.
121   static constexpr bool kVariableSamplesPerOutput = false;
122   typedef Array<float, kResultElementCount> ResultType;
123   typedef float ResultElementType;
124 
125   PHILOX_DEVICE_INLINE
operator()126   ResultType operator()(Generator* gen) {
127     typename Generator::ResultType sample = (*gen)();
128     ResultType result;
129     for (int i = 0; i < kResultElementCount; ++i) {
130       result[i] = Uint32ToFloat(sample[i]);
131     }
132     return result;
133   }
134 };
135 
136 template <class Generator>
137 class UniformDistribution<Generator, double> {
138  public:
139   // The number of elements that will be returned.
140   static constexpr int kResultElementCount = Generator::kResultElementCount / 2;
141   // Cost of generation of a single element (in cycles).
142   static constexpr int kElementCost = 3;
143   // Indicate that this distribution may take variable number of samples
144   // during the runtime.
145   static constexpr bool kVariableSamplesPerOutput = false;
146   typedef Array<double, kResultElementCount> ResultType;
147   typedef double ResultElementType;
148 
149   PHILOX_DEVICE_INLINE
operator()150   ResultType operator()(Generator* gen) {
151     typename Generator::ResultType sample = (*gen)();
152     ResultType result;
153     for (int i = 0; i < kResultElementCount; ++i) {
154       result[i] = Uint64ToDouble(sample[2 * i], sample[2 * i + 1]);
155     }
156     return result;
157   }
158 };
159 
160 template <class Generator>
161 class UniformDistribution<Generator, int32> {
162  public:
163   // The number of elements that will be returned.
164   static constexpr int kResultElementCount = Generator::kResultElementCount;
165   // Cost of generation of a single element (in cycles).
166   static constexpr int kElementCost = 3;
167   // Indicate that this distribution may take variable number of samples
168   // during the runtime.
169   static constexpr bool kVariableSamplesPerOutput = false;
170   typedef Array<int32, kResultElementCount> ResultType;
171   typedef int32 ResultElementType;
172 
173   // Must have lo < hi
UniformDistribution(int32_t lo,int32_t hi)174   UniformDistribution(int32_t lo, int32_t hi)
175       : lo_(lo), range_(static_cast<uint32>(hi) - static_cast<uint32>(lo)) {}
176 
177   PHILOX_DEVICE_INLINE
operator()178   ResultType operator()(Generator* gen) {
179     typename Generator::ResultType sample = (*gen)();
180     ResultType result;
181     for (int i = 0; i < kResultElementCount; ++i) {
182       result[i] = SignedAdd(lo_, sample[i] % range_);
183     }
184     return result;
185   }
186 
187  private:
188   // Note that lo_ is intentionally signed while range_ is intentionally
189   // unsigned.  This is because hi - lo can overflow signed integers if
190   // lo < 0 < hi, but always fits in unsigned.
191   int32 lo_;
192   uint32 range_;
193 };
194 
195 template <class Generator>
196 class UniformDistribution<Generator, int64_t> {
197  public:
198   // The number of elements that will be returned.
199   static constexpr int kResultElementCount = Generator::kResultElementCount / 2;
200   // Cost of generation of a single element (in cycles).
201   static constexpr int kElementCost = 3;
202   // Indicate that this distribution may take variable number of samples
203   // during the runtime.
204   static constexpr bool kVariableSamplesPerOutput = false;
205   typedef Array<int64_t, kResultElementCount> ResultType;
206   typedef int64_t ResultElementType;
207 
208   // Must have lo < hi
UniformDistribution(int64_t lo,int64_t hi)209   UniformDistribution(int64_t lo, int64_t hi)
210       : lo_(lo), range_(static_cast<uint64>(hi) - static_cast<uint64>(lo)) {}
211 
212   PHILOX_DEVICE_INLINE
operator()213   ResultType operator()(Generator* gen) {
214     typename Generator::ResultType sample = (*gen)();
215     ResultType result;
216     for (int i = 0; i < kResultElementCount; ++i) {
217       auto bits = sample[2 * i] | static_cast<uint64>(sample[2 * i + 1]) << 32;
218       result[i] = SignedAdd(lo_, bits % range_);
219     }
220     return result;
221   }
222 
223  private:
224   // Note that lo_ is intentionally signed while range_ is intentionally
225   // unsigned.  This is because hi - lo can overflow signed integers if
226   // lo < 0 < hi, but always fits in unsigned.
227   int64_t lo_;
228   uint64 range_;
229 };
230 
231 // Similar to `UniformDistribution`, except that instead of generating numbers
232 // in the range [low, high), it generates numbers covering the whole range of
233 // the integer type.
234 template <typename Generator, typename IntType>
235 class UniformFullIntDistribution;
236 
237 template <typename Generator, typename IntType>
238 class UniformFullIntDistribution32 {
239  public:
240   // The number of elements that will be returned.
241   static constexpr int kResultElementCount = Generator::kResultElementCount;
242   // Cost of generation of a single element (in cycles).
243   static constexpr int kElementCost = 3;
244   // Indicate that this distribution may take variable number of samples
245   // during the runtime.
246   static constexpr bool kVariableSamplesPerOutput = false;
247   typedef Array<IntType, kResultElementCount> ResultType;
248   typedef IntType ResultElementType;
249 
250   PHILOX_DEVICE_INLINE
operator()251   ResultType operator()(Generator* gen) {
252     typename Generator::ResultType sample = (*gen)();
253     ResultType result;
254     for (int i = 0; i < kResultElementCount; ++i) {
255       result[i] = sample[i];
256     }
257     return result;
258   }
259 };
260 
261 template <typename Generator, typename IntType>
262 class UniformFullIntDistribution64 {
263  public:
264   // The number of elements that will be returned.
265   static constexpr int kResultElementCount = Generator::kResultElementCount / 2;
266   // Cost of generation of a single element (in cycles).
267   static constexpr int kElementCost = 3;
268   // Indicate that this distribution may take variable number of samples
269   // during the runtime.
270   static constexpr bool kVariableSamplesPerOutput = false;
271   typedef Array<IntType, kResultElementCount> ResultType;
272   typedef IntType ResultElementType;
273 
274   PHILOX_DEVICE_INLINE
operator()275   ResultType operator()(Generator* gen) {
276     typename Generator::ResultType sample = (*gen)();
277     ResultType result;
278     for (int i = 0; i < kResultElementCount; ++i) {
279       result[i] = sample[2 * i] | static_cast<uint64>(sample[2 * i + 1]) << 32;
280     }
281     return result;
282   }
283 };
284 
285 template <typename Generator>
286 class UniformFullIntDistribution<Generator, int32>
287     : public UniformFullIntDistribution32<Generator, int32> {};
288 template <typename Generator>
289 class UniformFullIntDistribution<Generator, uint32>
290     : public UniformFullIntDistribution32<Generator, uint32> {};
291 template <typename Generator>
292 class UniformFullIntDistribution<Generator, int64_t>
293     : public UniformFullIntDistribution64<Generator, int64_t> {};
294 template <typename Generator>
295 class UniformFullIntDistribution<Generator, uint64>
296     : public UniformFullIntDistribution64<Generator, uint64> {};
297 
298 // A class that adapts the underlying native multiple samples to return a single
299 // sample at a time.
300 template <class Generator>
301 class SingleSampleAdapter {
302  public:
303   // The number of elements that will be returned.
304   static constexpr int kResultElementCount = 1;
305   // The number of elements that will be returned by the underlying generator.
306   static constexpr int kNativeElementCount = Generator::kResultElementCount;
307   typedef typename Generator::ResultElementType ResultType;
308   typedef typename Generator::ResultElementType ResultElementType;
309 
310   PHILOX_DEVICE_INLINE
SingleSampleAdapter(Generator * gen)311   explicit SingleSampleAdapter(Generator* gen)
312       : generator_(gen), used_result_index_(Generator::kResultElementCount) {}
313 
314   PHILOX_DEVICE_INLINE
operator()315   ResultType operator()() {
316     if (used_result_index_ == Generator::kResultElementCount) {
317       unused_results_ = (*generator_)();
318       used_result_index_ = 0;
319     }
320 
321     return unused_results_[used_result_index_++];
322   }
323 
324   PHILOX_DEVICE_INLINE
Skip(uint64 num_skips)325   void Skip(uint64 num_skips) {
326     if (!num_skips) {
327       return;
328     }
329     int num_unused_results = kNativeElementCount - used_result_index_;
330     if (num_skips <= num_unused_results) {
331       used_result_index_ += num_skips;
332       return;
333     }
334     num_skips -= num_unused_results;
335     used_result_index_ = kNativeElementCount;
336     SkipFromGenerator(num_skips / kNativeElementCount);
337     num_skips = num_skips % kNativeElementCount;
338     if (num_skips) {
339       unused_results_ = (*generator_)();
340       used_result_index_ = num_skips;
341     }
342   }
343 
344  private:
345   // This implementation iteratively skips over `num_skips` samples
346   // from `generator_`. There is an O(1) implementation for PhiloxRandom
347   // in random_distributions.cc.
348   PHILOX_DEVICE_INLINE
SkipFromGenerator(uint64 num_skips)349   void SkipFromGenerator(uint64 num_skips) {
350     while (num_skips--) {
351       (*generator_)();
352     }
353   }
354 
355   Generator* generator_;
356   typename Generator::ResultType unused_results_;
357   int used_result_index_;
358 };
359 
360 // A class that generates unit normal distribution random numbers from the
361 // underlying random integer generator.
362 // Arguments:
363 //   Generator: a generator type that returns a number of uint32 upon each
364 //              each invocation. It needs to define kResultElementCount for the
365 //              sample count for each invocation, and ResultType for actual
366 //              returned sample type.
367 //   RealType: the data type of the real numbers that will be returned by the
368 //             distribution. This could be either float or double for now.
369 // This class is meant to be implemented through specialization. The default
370 // is not defined by design.
371 template <class Generator, typename RealType>
372 class NormalDistribution;
373 
374 PHILOX_DEVICE_INLINE
375 void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0,
376                      double* d1);
377 
378 // Exactly like the float version, except that we convert to half afterwards;
379 // since we don't have half-precision sin/cos even on GPUs, there's nothing to
380 // gain from working in half internally.
381 template <class Generator>
382 class NormalDistribution<Generator, Eigen::half> {
383  public:
384   // The number of elements that will be returned.
385   static constexpr int kResultElementCount = Generator::kResultElementCount;
386   // Cost of generation of a single element (in cycles).
387   static constexpr int kElementCost = 70;
388   // Indicate that this distribution may take variable number of samples
389   // during the runtime.
390   static constexpr bool kVariableSamplesPerOutput = false;
391   typedef Array<Eigen::half, kResultElementCount> ResultType;
392   typedef Eigen::half ResultElementType;
393 
394   PHILOX_DEVICE_INLINE
operator()395   ResultType operator()(Generator* gen) {
396     typename Generator::ResultType sample = (*gen)();
397     ResultType result;
398     for (int i = 0; i < kResultElementCount; i += 2) {
399       float f[2];
400       BoxMullerFloat(sample[i], sample[i + 1], &f[0], &f[1]);
401       result[i] = Eigen::half(f[0]);
402       result[i + 1] = Eigen::half(f[1]);
403     }
404     return result;
405   }
406 };
407 
408 template <class Generator>
409 class NormalDistribution<Generator, bfloat16> {
410  public:
411   // The number of elements that will be returned.
412   static constexpr int kResultElementCount = Generator::kResultElementCount;
413   // Cost of generation of a single element (in cycles).
414   static constexpr int kElementCost = 70;
415   // Indicate that this distribution may take variable number of samples
416   // during the runtime.
417   static constexpr bool kVariableSamplesPerOutput = false;
418   typedef Array<bfloat16, kResultElementCount> ResultType;
419   typedef bfloat16 ResultElementType;
420 
421   PHILOX_DEVICE_INLINE
operator()422   ResultType operator()(Generator* gen) {
423     typename Generator::ResultType sample = (*gen)();
424     ResultType result;
425     static_assert(kResultElementCount % 2 == 0,
426                   "kResultElementCount should be an even number");
427     for (int i = 0; i < kResultElementCount; i += 2) {
428       float f[2];
429       // Box-Muller transform requires processing 2 elements at a time.
430       BoxMullerFloat(sample[i], sample[i + 1], &f[0], &f[1]);
431       result[i] = bfloat16(f[0]);
432       result[i + 1] = bfloat16(f[1]);
433     }
434     return result;
435   }
436 };
437 
438 template <class Generator>
439 class NormalDistribution<Generator, float> {
440  public:
441   // The number of elements that will be returned.
442   static constexpr int kResultElementCount = Generator::kResultElementCount;
443   // Cost of generation of a single element (in cycles).
444   static constexpr int kElementCost = 70;
445   // Indicate that this distribution may take variable number of samples
446   // during the runtime.
447   static constexpr bool kVariableSamplesPerOutput = false;
448   typedef Array<float, kResultElementCount> ResultType;
449   typedef float ResultElementType;
450 
451   PHILOX_DEVICE_INLINE
operator()452   ResultType operator()(Generator* gen) {
453     typename Generator::ResultType sample = (*gen)();
454     ResultType result;
455     for (int i = 0; i < kResultElementCount; i += 2) {
456       BoxMullerFloat(sample[i], sample[i + 1], &result[i], &result[i + 1]);
457     }
458     return result;
459   }
460 };
461 
462 template <class Generator>
463 class NormalDistribution<Generator, double> {
464  public:
465   // The number of elements that will be returned.
466   static constexpr int kResultElementCount = Generator::kResultElementCount / 2;
467   // Cost of generation of a single element (in cycles).
468   static constexpr int kElementCost = 70;
469   // Indicate that this distribution may take variable number of samples
470   // during the runtime.
471   static constexpr bool kVariableSamplesPerOutput = false;
472   typedef Array<double, kResultElementCount> ResultType;
473   typedef double ResultElementType;
474 
475   PHILOX_DEVICE_INLINE
operator()476   ResultType operator()(Generator* gen) {
477     typename Generator::ResultType sample = (*gen)();
478     ResultType result;
479     for (int i = 0; i < kResultElementCount; i += 2) {
480       const int i2 = 2 * i;
481       BoxMullerDouble(sample[i2], sample[i2 + 1], sample[i2 + 2],
482                       sample[i2 + 3], &result[i], &result[i + 1]);
483     }
484     return result;
485   }
486 };
487 
488 // A class that returns standard normal distribution between
489 // [-kTruncateValue, kTruncateValue].
490 // Arguments:
491 //   Generator: a generator type that returns a number of uint32 upon each
492 //              each invocation. It needs to define kResultElementCount for the
493 //              sample count for each invocation, and ResultType for actual
494 //              returned sample type.
495 //   RealType: the data type of the real numbers that will be returned by the
496 //             distribution. This could be either float or double for now.
497 // This class is meant to be implemented through specialization. The default
498 // is not defined by design.
499 template <class SingleSampleGenerator, typename RealType>
500 class TruncatedNormalDistribution;
501 
502 // Exactly like the float version, except that we convert to half afterwards;
503 // since we don't have half-precision sin/cos even on GPUs, there's nothing to
504 // gain from working in half internally.
505 template <class SingleSampleGenerator>
506 class TruncatedNormalDistribution<SingleSampleGenerator, Eigen::half> {
507  public:
508   // The number of elements that will be returned.
509   static constexpr int kResultElementCount =
510       SingleSampleGenerator::kNativeElementCount;
511   // Cost of generation of a single element (in cycles).
512   static constexpr int kElementCost = 90;
513   // Indicate that this distribution may take variable number of samples
514   // during the runtime.
515   static constexpr bool kVariableSamplesPerOutput = true;
516   // The threshold where the normal distribution is truncated.
517   const float kTruncateValue = 2.0f;
518 
519   typedef Array<Eigen::half, kResultElementCount> ResultType;
520   typedef Eigen::half ResultElementType;
521 
522   PHILOX_DEVICE_INLINE
operator()523   ResultType operator()(SingleSampleGenerator* gen) {
524     ResultType results;
525     int index = 0;
526     while (true) {
527       // Repeatedly take samples from the normal distribution, until we have
528       // the desired number of elements that fall within the pre-defined cutoff
529       // threshold.
530       const uint32 x0 = (*gen)();
531       const uint32 x1 = (*gen)();
532       float f[2];
533       BoxMullerFloat(x0, x1, &f[0], &f[1]);
534 
535       if (Eigen::numext::abs(f[0]) < kTruncateValue) {
536         results[index++] = Eigen::half(f[0]);
537         if (index >= kResultElementCount) {
538           return results;
539         }
540       }
541       if (Eigen::numext::abs(f[1]) < kTruncateValue) {
542         results[index++] = Eigen::half(f[1]);
543         if (index >= kResultElementCount) {
544           return results;
545         }
546       }
547     }
548   }
549 };
550 
551 template <class SingleSampleGenerator>
552 class TruncatedNormalDistribution<SingleSampleGenerator, bfloat16> {
553  public:
554   // The number of elements that will be returned.
555   static constexpr int kResultElementCount =
556       SingleSampleGenerator::kNativeElementCount;
557   // Cost of generation of a single element (in cycles).
558   static constexpr int kElementCost = 90;
559   // Indicate that this distribution may take variable number of samples
560   // during the runtime.
561   static constexpr bool kVariableSamplesPerOutput = true;
562   // The threshold where the normal distribution is truncated.
563   const float kTruncateValue = 2.0f;
564 
565   typedef Array<bfloat16, kResultElementCount> ResultType;
566   typedef bfloat16 ResultElementType;
567 
568   PHILOX_DEVICE_INLINE
operator()569   ResultType operator()(SingleSampleGenerator* gen) {
570     ResultType results;
571     int index = 0;
572     while (true) {
573       // Repeatedly take samples from the normal distribution, until we have
574       // the desired number of elements that fall within the pre-defined cutoff
575       // threshold.
576       const uint32 x0 = (*gen)();
577       const uint32 x1 = (*gen)();
578       float f[2];
579       BoxMullerFloat(x0, x1, &f[0], &f[1]);
580 
581       if (Eigen::numext::abs(f[0]) < kTruncateValue) {
582         results[index++] = bfloat16(f[0]);
583         if (index >= kResultElementCount) {
584           return results;
585         }
586       }
587       if (Eigen::numext::abs(f[1]) < kTruncateValue) {
588         results[index++] = bfloat16(f[1]);
589         if (index >= kResultElementCount) {
590           return results;
591         }
592       }
593     }
594   }
595 };
596 
597 // Partial specialization for float.
598 template <class SingleSampleGenerator>
599 class TruncatedNormalDistribution<SingleSampleGenerator, float> {
600  public:
601   // The number of elements that will be returned.
602   static constexpr int kResultElementCount =
603       SingleSampleGenerator::kNativeElementCount;
604   // Cost of generation of a single element (in cycles).
605   static constexpr int kElementCost = 90;
606   // Indicate that this distribution may take variable number of samples
607   // during the runtime.
608   static constexpr bool kVariableSamplesPerOutput = true;
609   // The threshold where the normal distribution is truncated.
610   const float kTruncateValue = 2.0f;
611 
612   typedef Array<float, kResultElementCount> ResultType;
613   typedef float ResultElementType;
614 
615   PHILOX_DEVICE_INLINE
operator()616   ResultType operator()(SingleSampleGenerator* gen) {
617     ResultType results;
618     int index = 0;
619     while (true) {
620       // Repeatedly take samples from the normal distribution, until we have
621       // the desired number of elements that fall within the pre-defined cutoff
622       // threshold.
623       const uint32 x0 = (*gen)();
624       const uint32 x1 = (*gen)();
625       float f[2];
626       BoxMullerFloat(x0, x1, &f[0], &f[1]);
627 
628       if (Eigen::numext::abs(f[0]) < kTruncateValue) {
629         results[index++] = f[0];
630         if (index >= kResultElementCount) {
631           return results;
632         }
633       }
634       if (Eigen::numext::abs(f[1]) < kTruncateValue) {
635         results[index++] = f[1];
636         if (index >= kResultElementCount) {
637           return results;
638         }
639       }
640     }
641   }
642 };
643 
644 // Partial specialization for double.
645 template <class SingleSampleGenerator>
646 class TruncatedNormalDistribution<SingleSampleGenerator, double> {
647  public:
648   // The number of elements that will be returned.
649   static constexpr int kResultElementCount =
650       (SingleSampleGenerator::kNativeElementCount > 1)
651           ? SingleSampleGenerator::kNativeElementCount / 2
652           : 1;
653   // Cost of generation of a single element (in cycles).
654   static constexpr int kElementCost = 90;
655   // Indicate that this distribution may take variable number of samples
656   // during the runtime.
657   static constexpr bool kVariableSamplesPerOutput = true;
658   typedef Array<double, kResultElementCount> ResultType;
659   typedef double ResultElementType;
660   const double kTruncateValue = 2.0;
661 
662   PHILOX_DEVICE_INLINE
operator()663   ResultType operator()(SingleSampleGenerator* gen) {
664     ResultType results;
665     int index = 0;
666     while (true) {
667       const uint32 x0 = (*gen)();
668       const uint32 x1 = (*gen)();
669       const uint32 x2 = (*gen)();
670       const uint32 x3 = (*gen)();
671       double d[2];
672       BoxMullerDouble(x0, x1, x2, x3, &d[0], &d[1]);
673 
674       if (Eigen::numext::abs(d[0]) < kTruncateValue) {
675         results[index++] = d[0];
676         if (index >= kResultElementCount) {
677           return results;
678         }
679       }
680       if (Eigen::numext::abs(d[1]) < kTruncateValue) {
681         results[index++] = d[1];
682         if (index >= kResultElementCount) {
683           return results;
684         }
685       }
686     }
687   }
688 };
689 
690 // Helper function to convert four 32-bit uniform integers to two doubles
691 // under the unit normal distribution.
692 PHILOX_DEVICE_INLINE
BoxMullerDouble(uint32 x0,uint32 x1,uint32 x2,uint32 x3,double * d0,double * d1)693 void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0,
694                      double* d1) {
695   // This function implements the Box-Muller transform:
696   // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
697   // Do not send a really small number to log().
698   // We cannot mark "epsilon" as "static const" because NVCC would complain
699   const double epsilon = 1.0e-7;
700   double u1 = Uint64ToDouble(x0, x1);
701   if (u1 < epsilon) {
702     u1 = epsilon;
703   }
704   const double v1 = 2 * M_PI * Uint64ToDouble(x2, x3);
705   const double u2 = Eigen::numext::sqrt(-2.0 * Eigen::numext::log(u1));
706 #if !defined(__linux__)
707   *d0 = Eigen::numext::sin(v1);
708   *d1 = Eigen::numext::cos(v1);
709 #else
710   sincos(v1, d0, d1);
711 #endif
712   *d0 *= u2;
713   *d1 *= u2;
714 }
715 
716 // Helper function to convert an 16-bit integer to a half between [0..1).
Uint16ToHalf(uint16 x)717 PHILOX_DEVICE_INLINE Eigen::half Uint16ToHalf(uint16 x) {
718   // IEEE754 halfs are formatted as follows (MSB first):
719   //    sign(1) exponent(5) mantissa(10)
720   // Conceptually construct the following:
721   //    sign == 0
722   //    exponent == 15  -- an excess 15 representation of a zero exponent
723   //    mantissa == 10 random bits
724   const uint16 man = x & 0x3ffu;  // 10 bit mantissa
725   const uint16 exp = static_cast<uint16>(15);
726   const uint16 val = (exp << 10) | man;
727 
728   Eigen::half result = Eigen::numext::bit_cast<Eigen::half>(val);
729   return result - Eigen::half(1.0);
730 }
731 
732 // Helper function to convert an 16-bit integer to a bfloat16 between [0..1).
733 // This can create a uniform distribution of values between [0..1).
Uint16ToGfloat16(uint16 x)734 PHILOX_DEVICE_INLINE bfloat16 Uint16ToGfloat16(uint16 x) {
735   // bfloat are formatted as follows (MSB first):
736   //    sign(1) exponent(8) mantissa(7)
737   // Conceptually construct the following:
738   //    sign == 0
739   //    exponent == 127  -- an excess 127 representation of a zero exponent
740   //    mantissa == 7 random bits
741   const uint16 man = x & 0x7fu;  // 7 bit mantissa
742   const uint16 exp = static_cast<uint16>(127);
743   const uint16 val = (exp << 7) | man;
744 
745   bfloat16 result;
746   memcpy(&result, &val, sizeof(val));
747   // The mantissa has an implicit leading 1, so the above code creates a value
748   // in [1, 2). The minus will not cause a rounding that makes the result 1.
749   // Instead it will just be close to 1.
750   return result - bfloat16(1.0);
751 }
752 
753 }  // namespace random
754 }  // namespace tensorflow
755 
756 #endif  // TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
757