xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/rng_test.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <ATen/Generator.h>
3 #include <ATen/Tensor.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <torch/library.h>
6 #include <optional>
7 #include <torch/all.h>
8 #include <stdexcept>
9 
10 namespace {
11 
12 constexpr auto int64_min_val = std::numeric_limits<int64_t>::lowest();
13 constexpr auto int64_max_val = std::numeric_limits<int64_t>::max();
14 template <typename T,
15           typename std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_min_val()16 constexpr int64_t _min_val() {
17   return int64_min_val;
18 }
19 
20 template <typename T,
21           typename std::enable_if_t<std::is_integral_v<T>, int> = 0>
_min_val()22 constexpr int64_t _min_val() {
23   return static_cast<int64_t>(std::numeric_limits<T>::lowest());
24 }
25 
26 template <typename T,
27           typename std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_min_from()28 constexpr int64_t _min_from() {
29   return -(static_cast<int64_t>(1) << std::numeric_limits<T>::digits);
30 }
31 
32 template <typename T,
33           typename std::enable_if_t<std::is_integral_v<T>, int> = 0>
_min_from()34 constexpr int64_t _min_from() {
35   return _min_val<T>();
36 }
37 
38 template <typename T,
39           typename std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_max_val()40 constexpr int64_t _max_val() {
41   return int64_max_val;
42 }
43 
44 template <typename T,
45           typename std::enable_if_t<std::is_integral_v<T>, int> = 0>
_max_val()46 constexpr int64_t _max_val() {
47   return static_cast<int64_t>(std::numeric_limits<T>::max());
48 }
49 
50 template <typename T,
51           typename std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_max_to()52 constexpr int64_t _max_to() {
53   return static_cast<int64_t>(1) << std::numeric_limits<T>::digits;
54 }
55 
56 template <typename T,
57           typename std::enable_if_t<std::is_integral_v<T>, int> = 0>
_max_to()58 constexpr int64_t _max_to() {
59   return _max_val<T>();
60 }
61 
62 template<typename RNG, c10::ScalarType S, typename T>
test_random_from_to(const at::Device & device)63 void test_random_from_to(const at::Device& device) {
64 
65   constexpr int64_t max_val = _max_val<T>();
66   constexpr int64_t max_to = _max_to<T>();
67 
68   constexpr auto uint64_max_val = std::numeric_limits<uint64_t>::max();
69 
70   std::vector<int64_t> froms;
71   std::vector<::std::optional<int64_t>> tos;
72   if constexpr (::std::is_same_v<T, bool>) {
73     froms = {
74       0L
75     };
76     tos = {
77       1L,
78       static_cast<::std::optional<int64_t>>(::std::nullopt)
79     };
80   } else if constexpr (::std::is_signed_v<T>) {
81     constexpr int64_t min_from = _min_from<T>();
82     froms = {
83       min_from,
84       -42L,
85       0L,
86       42L
87     };
88     tos = {
89       ::std::optional<int64_t>(-42L),
90       ::std::optional<int64_t>(0L),
91       ::std::optional<int64_t>(42L),
92       ::std::optional<int64_t>(max_to),
93       static_cast<::std::optional<int64_t>>(::std::nullopt)
94     };
95   } else {
96     froms = {
97       0L,
98       42L
99     };
100     tos = {
101       ::std::optional<int64_t>(42L),
102       ::std::optional<int64_t>(max_to),
103       static_cast<::std::optional<int64_t>>(::std::nullopt)
104     };
105   }
106 
107   const std::vector<uint64_t> vals = {
108     0L,
109     42L,
110     static_cast<uint64_t>(max_val),
111     static_cast<uint64_t>(max_val) + 1,
112     uint64_max_val
113   };
114 
115   bool full_64_bit_range_case_covered = false;
116   bool from_to_case_covered = false;
117   bool from_case_covered = false;
118   for (const int64_t from : froms) {
119     for (const ::std::optional<int64_t> & to : tos) {
120       if (!to.has_value() || from < *to) {
121         for (const uint64_t val : vals) {
122           auto gen = at::make_generator<RNG>(val);
123 
124           auto actual = torch::empty({3, 3}, torch::TensorOptions().dtype(S).device(device));
125           actual.random_(from, to, gen);
126 
127           T exp;
128           uint64_t range;
129           if (!to.has_value() && from == int64_min_val) {
130             exp = static_cast<int64_t>(val);
131             full_64_bit_range_case_covered = true;
132           } else {
133             if (to.has_value()) {
134               range = static_cast<uint64_t>(*to) - static_cast<uint64_t>(from);
135               from_to_case_covered = true;
136             } else {
137               range = static_cast<uint64_t>(max_to) - static_cast<uint64_t>(from) + 1;
138               from_case_covered = true;
139             }
140             if (range < (1ULL << 32)) {
141               exp = static_cast<T>(static_cast<int64_t>((static_cast<uint32_t>(val) % range + from)));
142             } else {
143               exp = static_cast<T>(static_cast<int64_t>((val % range + from)));
144             }
145           }
146           ASSERT_TRUE(from <= exp);
147           if (to.has_value()) {
148             ASSERT_TRUE(static_cast<int64_t>(exp) < *to);
149           }
150           const auto expected = torch::full_like(actual, exp);
151           if constexpr (::std::is_same_v<T, bool>) {
152             ASSERT_TRUE(torch::allclose(actual.toType(torch::kInt), expected.toType(torch::kInt)));
153           } else {
154             ASSERT_TRUE(torch::allclose(actual, expected));
155           }
156         }
157       }
158     }
159   }
160   if constexpr (::std::is_same_v<T, int64_t>) {
161     ASSERT_TRUE(full_64_bit_range_case_covered);
162   } else {
163     (void)full_64_bit_range_case_covered;
164   }
165   ASSERT_TRUE(from_to_case_covered);
166   ASSERT_TRUE(from_case_covered);
167 }
168 
169 template<typename RNG, c10::ScalarType S, typename T>
test_random(const at::Device & device)170 void test_random(const at::Device& device) {
171   const auto max_val = _max_val<T>();
172   const auto uint64_max_val = std::numeric_limits<uint64_t>::max();
173 
174   const std::vector<uint64_t> vals = {
175     0L,
176     42L,
177     static_cast<uint64_t>(max_val),
178     static_cast<uint64_t>(max_val) + 1,
179     uint64_max_val
180   };
181 
182   for (const uint64_t val : vals) {
183     auto gen = at::make_generator<RNG>(val);
184 
185     auto actual = torch::empty({3, 3}, torch::TensorOptions().dtype(S).device(device));
186     actual.random_(gen);
187 
188     uint64_t range;
189     if constexpr (::std::is_floating_point_v<T>) {
190       range = static_cast<uint64_t>((1ULL << ::std::numeric_limits<T>::digits) + 1);
191     } else if constexpr (::std::is_same_v<T, bool>) {
192       range = 2;
193     } else {
194       range = static_cast<uint64_t>(::std::numeric_limits<T>::max()) + 1;
195     }
196     T exp;
197     if constexpr (::std::is_same_v<T, double> || ::std::is_same_v<T, int64_t>) {
198       exp = val % range;
199     } else {
200       exp = static_cast<uint32_t>(val) % range;
201     }
202 
203     ASSERT_TRUE(0 <= static_cast<int64_t>(exp));
204     ASSERT_TRUE(static_cast<uint64_t>(exp) < range);
205 
206     const auto expected = torch::full_like(actual, exp);
207     if constexpr (::std::is_same_v<T, bool>) {
208       ASSERT_TRUE(torch::allclose(actual.toType(torch::kInt), expected.toType(torch::kInt)));
209     } else {
210       ASSERT_TRUE(torch::allclose(actual, expected));
211     }
212   }
213 }
214 
215 }
216