1 #include <gtest/gtest.h>
2 #include <ATen/Generator.h>
3 #include <ATen/Tensor.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <torch/library.h>
6 #include <optional>
7 #include <torch/all.h>
8 #include <stdexcept>
9
10 namespace {
11
12 constexpr auto int64_min_val = std::numeric_limits<int64_t>::lowest();
13 constexpr auto int64_max_val = std::numeric_limits<int64_t>::max();
14 template <typename T,
15 typename std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_min_val()16 constexpr int64_t _min_val() {
17 return int64_min_val;
18 }
19
20 template <typename T,
21 typename std::enable_if_t<std::is_integral_v<T>, int> = 0>
_min_val()22 constexpr int64_t _min_val() {
23 return static_cast<int64_t>(std::numeric_limits<T>::lowest());
24 }
25
26 template <typename T,
27 typename std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_min_from()28 constexpr int64_t _min_from() {
29 return -(static_cast<int64_t>(1) << std::numeric_limits<T>::digits);
30 }
31
32 template <typename T,
33 typename std::enable_if_t<std::is_integral_v<T>, int> = 0>
_min_from()34 constexpr int64_t _min_from() {
35 return _min_val<T>();
36 }
37
38 template <typename T,
39 typename std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_max_val()40 constexpr int64_t _max_val() {
41 return int64_max_val;
42 }
43
44 template <typename T,
45 typename std::enable_if_t<std::is_integral_v<T>, int> = 0>
_max_val()46 constexpr int64_t _max_val() {
47 return static_cast<int64_t>(std::numeric_limits<T>::max());
48 }
49
50 template <typename T,
51 typename std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_max_to()52 constexpr int64_t _max_to() {
53 return static_cast<int64_t>(1) << std::numeric_limits<T>::digits;
54 }
55
56 template <typename T,
57 typename std::enable_if_t<std::is_integral_v<T>, int> = 0>
_max_to()58 constexpr int64_t _max_to() {
59 return _max_val<T>();
60 }
61
62 template<typename RNG, c10::ScalarType S, typename T>
test_random_from_to(const at::Device & device)63 void test_random_from_to(const at::Device& device) {
64
65 constexpr int64_t max_val = _max_val<T>();
66 constexpr int64_t max_to = _max_to<T>();
67
68 constexpr auto uint64_max_val = std::numeric_limits<uint64_t>::max();
69
70 std::vector<int64_t> froms;
71 std::vector<::std::optional<int64_t>> tos;
72 if constexpr (::std::is_same_v<T, bool>) {
73 froms = {
74 0L
75 };
76 tos = {
77 1L,
78 static_cast<::std::optional<int64_t>>(::std::nullopt)
79 };
80 } else if constexpr (::std::is_signed_v<T>) {
81 constexpr int64_t min_from = _min_from<T>();
82 froms = {
83 min_from,
84 -42L,
85 0L,
86 42L
87 };
88 tos = {
89 ::std::optional<int64_t>(-42L),
90 ::std::optional<int64_t>(0L),
91 ::std::optional<int64_t>(42L),
92 ::std::optional<int64_t>(max_to),
93 static_cast<::std::optional<int64_t>>(::std::nullopt)
94 };
95 } else {
96 froms = {
97 0L,
98 42L
99 };
100 tos = {
101 ::std::optional<int64_t>(42L),
102 ::std::optional<int64_t>(max_to),
103 static_cast<::std::optional<int64_t>>(::std::nullopt)
104 };
105 }
106
107 const std::vector<uint64_t> vals = {
108 0L,
109 42L,
110 static_cast<uint64_t>(max_val),
111 static_cast<uint64_t>(max_val) + 1,
112 uint64_max_val
113 };
114
115 bool full_64_bit_range_case_covered = false;
116 bool from_to_case_covered = false;
117 bool from_case_covered = false;
118 for (const int64_t from : froms) {
119 for (const ::std::optional<int64_t> & to : tos) {
120 if (!to.has_value() || from < *to) {
121 for (const uint64_t val : vals) {
122 auto gen = at::make_generator<RNG>(val);
123
124 auto actual = torch::empty({3, 3}, torch::TensorOptions().dtype(S).device(device));
125 actual.random_(from, to, gen);
126
127 T exp;
128 uint64_t range;
129 if (!to.has_value() && from == int64_min_val) {
130 exp = static_cast<int64_t>(val);
131 full_64_bit_range_case_covered = true;
132 } else {
133 if (to.has_value()) {
134 range = static_cast<uint64_t>(*to) - static_cast<uint64_t>(from);
135 from_to_case_covered = true;
136 } else {
137 range = static_cast<uint64_t>(max_to) - static_cast<uint64_t>(from) + 1;
138 from_case_covered = true;
139 }
140 if (range < (1ULL << 32)) {
141 exp = static_cast<T>(static_cast<int64_t>((static_cast<uint32_t>(val) % range + from)));
142 } else {
143 exp = static_cast<T>(static_cast<int64_t>((val % range + from)));
144 }
145 }
146 ASSERT_TRUE(from <= exp);
147 if (to.has_value()) {
148 ASSERT_TRUE(static_cast<int64_t>(exp) < *to);
149 }
150 const auto expected = torch::full_like(actual, exp);
151 if constexpr (::std::is_same_v<T, bool>) {
152 ASSERT_TRUE(torch::allclose(actual.toType(torch::kInt), expected.toType(torch::kInt)));
153 } else {
154 ASSERT_TRUE(torch::allclose(actual, expected));
155 }
156 }
157 }
158 }
159 }
160 if constexpr (::std::is_same_v<T, int64_t>) {
161 ASSERT_TRUE(full_64_bit_range_case_covered);
162 } else {
163 (void)full_64_bit_range_case_covered;
164 }
165 ASSERT_TRUE(from_to_case_covered);
166 ASSERT_TRUE(from_case_covered);
167 }
168
169 template<typename RNG, c10::ScalarType S, typename T>
test_random(const at::Device & device)170 void test_random(const at::Device& device) {
171 const auto max_val = _max_val<T>();
172 const auto uint64_max_val = std::numeric_limits<uint64_t>::max();
173
174 const std::vector<uint64_t> vals = {
175 0L,
176 42L,
177 static_cast<uint64_t>(max_val),
178 static_cast<uint64_t>(max_val) + 1,
179 uint64_max_val
180 };
181
182 for (const uint64_t val : vals) {
183 auto gen = at::make_generator<RNG>(val);
184
185 auto actual = torch::empty({3, 3}, torch::TensorOptions().dtype(S).device(device));
186 actual.random_(gen);
187
188 uint64_t range;
189 if constexpr (::std::is_floating_point_v<T>) {
190 range = static_cast<uint64_t>((1ULL << ::std::numeric_limits<T>::digits) + 1);
191 } else if constexpr (::std::is_same_v<T, bool>) {
192 range = 2;
193 } else {
194 range = static_cast<uint64_t>(::std::numeric_limits<T>::max()) + 1;
195 }
196 T exp;
197 if constexpr (::std::is_same_v<T, double> || ::std::is_same_v<T, int64_t>) {
198 exp = val % range;
199 } else {
200 exp = static_cast<uint32_t>(val) % range;
201 }
202
203 ASSERT_TRUE(0 <= static_cast<int64_t>(exp));
204 ASSERT_TRUE(static_cast<uint64_t>(exp) < range);
205
206 const auto expected = torch::full_like(actual, exp);
207 if constexpr (::std::is_same_v<T, bool>) {
208 ASSERT_TRUE(torch::allclose(actual.toType(torch::kInt), expected.toType(torch::kInt)));
209 } else {
210 ASSERT_TRUE(torch::allclose(actual, expected));
211 }
212 }
213 }
214
215 }
216