1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
17
18 #include <array>
19 #include <string>
20 #include <type_traits>
21
22 #include "absl/strings/string_view.h"
23
24 namespace xla {
25 namespace exhaustive_op_test {
26
27 // For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of
28 // precision to be guaranteed that we're printing the full number.
29 //
30 // (The general formula is, given a floating-point number with S significand
31 // bits, the number of decimal digits needed to print it to full precision is
32 //
33 // ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103).
34 //
35 // See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.)
36 namespace {
37 template <typename T>
38 struct ComponentStringifyFormat {
39 static const absl::string_view value;
40 };
41
42 template <>
43 constexpr absl::string_view ComponentStringifyFormat<double>::value =
44 "%0.17g (0x%16x)";
45
46 template <>
47 constexpr absl::string_view ComponentStringifyFormat<float>::value =
48 "%0.9g (0x%08x)";
49
50 template <>
51 constexpr absl::string_view ComponentStringifyFormat<Eigen::half>::value =
52 "%0.5g (0x%04x)";
53
54 template <>
55 constexpr absl::string_view ComponentStringifyFormat<bfloat16>::value =
56 "%0.4g (0x%04x)";
57
58 template <typename Type, typename FuncPtr>
CallErrorSpec(FuncPtr * func,const std::array<Type,1> & in)59 ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 1>& in) {
60 return func(in[0]);
61 }
62
63 template <typename Type, typename FuncPtr>
CallErrorSpec(FuncPtr * func,const std::array<Type,2> & in)64 ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 2>& in) {
65 return func(in[0], in[1]);
66 }
67
68 template <typename Type, typename FuncPtr>
CallOperation(FuncPtr * func,const std::array<Type,1> & in)69 Type CallOperation(FuncPtr* func, const std::array<Type, 1>& in) {
70 return func(in[0]);
71 }
72
73 template <typename Type, typename FuncPtr>
CallOperation(FuncPtr * func,const std::array<Type,2> & in)74 Type CallOperation(FuncPtr* func, const std::array<Type, 2>& in) {
75 return func(in[0], in[1]);
76 }
77
78 // The number of values that can be substituted for subnormal inputs.
79 constexpr int kNumSubnormalSubstitutionValues = 4;
80
81 // Encodings used to determine where subnormal test values are cached.
82 constexpr int kPositiveMin = 0;
83 constexpr int kNegativeMin = 1;
84 constexpr int kPositiveZero = 2;
85 constexpr int kNegativeZero = 3;
86 constexpr int kNonSubnormal = -1;
87 constexpr int kInvalidCacheIndex = -1;
88
89 template <typename T>
90 struct is_complex_t : absl::disjunction<std::is_same<T, complex64>,
91 std::is_same<T, complex128>> {};
92
93 // When we are testing a value such that all of its components are subnormal,
94 // we also need to test inputs made up of the Cartesian product of values
95 // replaced for each subnormal component. These additional test inputs are
96 // common enough where it will be efficient to just cache the results of these
97 // Cartesian products. In order to cache these values, we need a one to one
98 // mapping between these Cartesian products and cache locations.
99 //
100 // Our mapping works by assigning each component an integer in
101 // [0, kNumSubnormalSubstitutionValues) based on its test value. By lining
102 // these integers up with the n'th component corresponding to the n'th digit,
103 // then for each Cartesian product element we essentially create a unique base
104 // kNumSubnormalSubstitutionValues number. This number represents our cache
105 // index.
106 //
107 // In the event that there a component is not a subnormal, the value should
108 // not be cached, so we return a kNonSubnormal value.
109
110 template <
111 typename NativeRefT,
112 typename std::enable_if<!is_complex_t<NativeRefT>::value>::type* = nullptr>
GetCacheLocation(NativeRefT value)113 int GetCacheLocation(NativeRefT value) {
114 bool positive = !std::signbit(value);
115 if (std::abs(value) == std::numeric_limits<NativeRefT>::min()) {
116 return positive ? kPositiveMin : kNegativeMin;
117 } else if (value != 0) {
118 CHECK(std::fpclassify(value) != FP_SUBNORMAL);
119 return kNonSubnormal;
120 } else {
121 return positive ? kPositiveZero : kNegativeZero;
122 }
123 }
124
125 template <
126 typename NativeRefT,
127 typename std::enable_if<is_complex_t<NativeRefT>::value>::type* = nullptr>
GetCacheLocation(NativeRefT value)128 int GetCacheLocation(NativeRefT value) {
129 int real_loc =
130 GetCacheLocation<typename NativeRefT::value_type>(value.real());
131 int imag_loc =
132 GetCacheLocation<typename NativeRefT::value_type>(value.imag());
133 if (real_loc == kNonSubnormal || imag_loc == kNonSubnormal) {
134 return kNonSubnormal;
135 } else {
136 return real_loc * kNumSubnormalSubstitutionValues + imag_loc;
137 }
138 }
139
140 template <bool is_complex, typename NativeRefT, size_t N>
GetCacheLocation(const std::array<NativeRefT,N> & input)141 int GetCacheLocation(const std::array<NativeRefT, N>& input) {
142 int location = 0;
143 int cache_size_per_element = (is_complex ? kNumSubnormalSubstitutionValues *
144 kNumSubnormalSubstitutionValues
145 : kNumSubnormalSubstitutionValues);
146 for (int i = 0; i < N; ++i) {
147 int comp_loc = GetCacheLocation<NativeRefT>(input[i]);
148 if (i == kNonSubnormal) {
149 return kNonSubnormal;
150 }
151 location *= cache_size_per_element;
152 location += comp_loc;
153 }
154 return location;
155 }
156
157 // The inverse function of GetCacheLocation.
158
159 template <typename RetT,
160 typename std::enable_if<!is_complex_t<RetT>::value>::type* = nullptr>
FromCacheLocationComponent(int cache_loc)161 RetT FromCacheLocationComponent(int cache_loc) {
162 switch (cache_loc) {
163 case kPositiveMin:
164 return std::numeric_limits<RetT>::min();
165 case kNegativeMin:
166 return -std::numeric_limits<RetT>::min();
167 case kPositiveZero:
168 return static_cast<RetT>(0.0);
169 case kNegativeZero:
170 return static_cast<RetT>(-0.0);
171 default:
172 LOG(FATAL) << "Invalid cache_loc value of " << cache_loc;
173 }
174 }
175
176 template <typename RetT,
177 typename std::enable_if<is_complex_t<RetT>::value>::type* = nullptr>
FromCacheLocationComponent(int cache_loc)178 RetT FromCacheLocationComponent(int cache_loc) {
179 CHECK_LT(cache_loc,
180 kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues);
181 CHECK_GE(cache_loc, 0);
182
183 RetT value;
184 value.real(FromCacheLocationComponent<typename RetT::value_type>(
185 cache_loc / kNumSubnormalSubstitutionValues));
186 value.imag(FromCacheLocationComponent<typename RetT::value_type>(
187 cache_loc % kNumSubnormalSubstitutionValues));
188 return std::move(value);
189 }
190
191 template <bool is_complex, typename NativeRefT, size_t N>
FromCacheLocation(int cache_loc)192 std::array<NativeRefT, N> FromCacheLocation(int cache_loc) {
193 std::array<NativeRefT, N> input;
194 int cache_size_per_element = (is_complex ? kNumSubnormalSubstitutionValues *
195 kNumSubnormalSubstitutionValues
196 : kNumSubnormalSubstitutionValues);
197 for (int i = N - 1; i >= 0; --i) {
198 input[i] = FromCacheLocationComponent<NativeRefT>(cache_loc %
199 cache_size_per_element);
200 cache_loc /= cache_size_per_element;
201 }
202
203 return input;
204 }
205
206 // Returns a string that describes the test value for the actual value.
207 template <
208 typename NativeRefT,
209 typename std::enable_if<!is_complex_t<NativeRefT>::value>::type* = nullptr>
GetSubnormalDescription(NativeRefT test_val,NativeRefT actual_val)210 std::string GetSubnormalDescription(NativeRefT test_val,
211 NativeRefT actual_val) {
212 std::string sp_min_normal = "sign-preserving min-normal-float";
213 std::string sp_zero = "sign-preserving zero";
214 std::string nsp_zero = "non-sign-preserving zero";
215
216 switch (GetCacheLocation<NativeRefT>(test_val)) {
217 case kNegativeMin:
218 case kPositiveMin:
219 return sp_min_normal;
220 case kNegativeZero:
221 case kPositiveZero:
222 return (std::signbit(test_val) == std::signbit(actual_val)) ? sp_zero
223 : nsp_zero;
224 default:
225 return "";
226 }
227 }
228
229 template <
230 typename NativeRefT,
231 typename std::enable_if<is_complex_t<NativeRefT>::value>::type* = nullptr>
GetSubnormalDescription(NativeRefT test_val,NativeRefT actual_val)232 std::string GetSubnormalDescription(NativeRefT test_val,
233 NativeRefT actual_val) {
234 std::string real = GetSubnormalDescription<typename NativeRefT::value_type>(
235 test_val.real(), actual_val.real());
236 std::string imag = GetSubnormalDescription<typename NativeRefT::value_type>(
237 test_val.imag(), actual_val.imag());
238
239 if (real.empty()) {
240 if (imag.empty()) {
241 return "";
242 }
243 real = "real";
244 } else if (imag.empty()) {
245 imag = "imag";
246 }
247
248 return absl::StrCat("(", real, ", ", imag, ")");
249 }
250
251 template <bool is_complex, typename NativeRefT, size_t N>
GetSubnormalDescription(std::array<NativeRefT,N> test_vals,std::array<NativeRefT,N> actual_vals)252 std::string GetSubnormalDescription(std::array<NativeRefT, N> test_vals,
253 std::array<NativeRefT, N> actual_vals) {
254 if (N == 1) {
255 return GetSubnormalDescription<NativeRefT>(test_vals[0], actual_vals[0]);
256 }
257
258 std::array<std::string, N> str_vals;
259 for (int i = 0; i < N; ++i) {
260 str_vals[i] =
261 GetSubnormalDescription<NativeRefT>(test_vals[i], actual_vals[i]);
262 if (str_vals[i].empty()) {
263 str_vals[i] = "original";
264 }
265 }
266
267 return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")");
268 }
269
270 template <
271 typename NativeT, typename IntegralType,
272 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
StringifyNum(NativeT x)273 std::string StringifyNum(NativeT x) {
274 return absl::StrFormat(ComponentStringifyFormat<NativeT>::value,
275 static_cast<double>(x), BitCast<IntegralType>(x));
276 }
277
278 template <
279 typename NativeT, typename IntegralType,
280 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
StringifyNum(NativeT x)281 std::string StringifyNum(NativeT x) {
282 return absl::StrCat(
283 "(", StringifyNum<typename NativeT::value_type, IntegralType>(x.real()),
284 ", ", StringifyNum<typename NativeT::value_type, IntegralType>(x.imag()),
285 ")");
286 }
287
288 template <typename NativeT, typename IntegralType, size_t N>
StringifyNum(const std::array<NativeT,N> & inputs)289 std::string StringifyNum(const std::array<NativeT, N>& inputs) {
290 if (N == 1) {
291 return StringifyNum<NativeT, IntegralType>(inputs[0]);
292 }
293
294 std::array<std::string, N> str_vals;
295 for (int i = 0; i < N; ++i) {
296 str_vals[i] = StringifyNum<NativeT, IntegralType>(inputs[i]);
297 }
298
299 return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")");
300 }
301
302 template <typename ErrorGenerator>
PrintMismatch(int64_t * mismatches,const ErrorGenerator & err_generator)303 void PrintMismatch(int64_t* mismatches, const ErrorGenerator& err_generator) {
304 // We send a few mismatches to gunit so they show up nicely in test logs.
305 // Then we send more to LOG(ERROR). The remainder we squelch unless we're
306 // at vlog level 2.
307 constexpr int64_t kMaxMismatchesLoggedToGunit = 10;
308 constexpr int64_t kMaxMismatchesLoggedToErr = 1000;
309
310 (*mismatches)++;
311 if (*mismatches < kMaxMismatchesLoggedToGunit) {
312 FAIL() << err_generator();
313 } else if (*mismatches < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) {
314 LOG(ERROR) << err_generator();
315 } else if (*mismatches == kMaxMismatchesLoggedToErr) {
316 LOG(ERROR) << "Not printing any more mismatches; pass "
317 "--vmodule=exhaustive_op_test=2 to see "
318 "all of them.";
319 }
320 }
321 } // namespace
322
323 template <PrimitiveType T, size_t N>
ExpectNear(const InputLiterals & input_literals,const Literal & result_literal,EvaluateOp evaluate_op,ErrorSpecGen error_spec_gen)324 void ExhaustiveOpTestBase<T, N>::ExpectNear(const InputLiterals& input_literals,
325 const Literal& result_literal,
326 EvaluateOp evaluate_op,
327 ErrorSpecGen error_spec_gen) {
328 // Cache for when all components are subnormal testing values.
329 std::vector<NativeRefT> pure_subnormal_cache;
330 // Since we take the cross product of all possible test values, and each
331 // component has kNumSubnormalSubstitutionValues possible test values, then
332 // the total number of different cache locations are
333 // kNumSubnormalSubstitutionValues raised to the num_components.
334 // num_components = N for the reals, and 2*N for the complex.
335 int64_t max_cache_size =
336 pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1));
337 pure_subnormal_cache.reserve(max_cache_size);
338 for (int i = 0; i < max_cache_size; ++i) {
339 pure_subnormal_cache.push_back(CallOperation(
340 evaluate_op, FromCacheLocation<kIsComplex, NativeRefT, N>(i)));
341 }
342
343 NativeInputsList inputs_arr;
344 for (int i = 0; i < N; ++i) {
345 const Literal& literal = input_literals[i];
346 inputs_arr[i] = literal.data<NativeT>();
347 }
348
349 absl::Span<const NativeT> result_arr = result_literal.data<NativeT>();
350
351 int64_t mismatches = 0;
352
353 for (int64_t i = 0; i < result_arr.size(); ++i) {
354 NativeInputs inputs;
355 NativeRefInputs inputs_ref_ty;
356
357 for (int j = 0; j < N; ++j) {
358 inputs[j] = inputs_arr[j][i];
359 inputs_ref_ty[j] = static_cast<NativeRefT>(inputs[j]);
360 }
361
362 NativeT actual = result_arr[i];
363 NativeT expected =
364 static_cast<NativeT>(CallOperation(evaluate_op, inputs_ref_ty));
365 ErrorSpec error_spec = CallErrorSpec(error_spec_gen, inputs);
366
367 if (IsClose(static_cast<NativeRefT>(expected),
368 static_cast<NativeRefT>(actual), error_spec)) {
369 continue;
370 }
371
372 std::vector<NativeRefInputs> subnormal_test_inputs =
373 GetTestValuesWithSubnormalSubstitutions(inputs_ref_ty);
374
375 // Easy case: If `input` is not subnormal and !IsClose(expected, actual,
376 // error_spec), print an error.
377 if (subnormal_test_inputs.size() == 1) {
378 PrintMismatch(&mismatches, [&] {
379 return absl::StrFormat(
380 "Mismatch on %s. Expected %s, but got %s.",
381 StringifyNum<NativeT, ComponentIntegralNativeT, N>(inputs),
382 StringifyNum<NativeT, ComponentIntegralNativeT>(expected),
383 StringifyNum<NativeT, ComponentIntegralNativeT>(actual));
384 });
385 continue;
386 }
387
388 // Otherwise, we need to test the additional subnormal test values.
389 std::vector<NativeRefT> subnormal_test_results;
390 subnormal_test_results.reserve(subnormal_test_inputs.size());
391 bool passed_subnormal_test = false;
392
393 for (NativeRefInputs test_value : subnormal_test_inputs) {
394 NativeRefT result;
395 int cache_loc =
396 GetCacheLocation<kIsComplex, typename NativeRefInputs::value_type, N>(
397 test_value);
398 if (cache_loc == kInvalidCacheIndex) {
399 result = CallOperation(evaluate_op, test_value);
400 } else {
401 result = pure_subnormal_cache[cache_loc];
402 }
403
404 if (IsClose(result, static_cast<NativeRefT>(actual), error_spec)) {
405 passed_subnormal_test = true;
406 break;
407 }
408 subnormal_test_results.push_back(std::move(result));
409 }
410
411 if (passed_subnormal_test) {
412 continue;
413 }
414
415 std::string mismatch = absl::StrFormat(
416 "Mismatch on subnormal value %s. Expected one of:\n"
417 " %10s (evaluated at full-precision value)\n",
418 StringifyNum<NativeT, ComponentIntegralNativeT, N>(inputs),
419 StringifyNum<NativeT, ComponentIntegralNativeT>(expected));
420
421 CHECK_EQ(subnormal_test_inputs.size(), subnormal_test_results.size());
422 for (int i = 0; i < subnormal_test_inputs.size(); ++i) {
423 using IntegralNativeRefT =
424 typename ExhaustiveOpTestBase<RefT::value,
425 N>::ComponentIntegralNativeT;
426 absl::StrAppend(
427 &mismatch,
428 absl::StrFormat(" %10s (evaluated at %s)\n",
429 StringifyNum<NativeRefT, IntegralNativeRefT>(
430 subnormal_test_results[i]),
431 GetSubnormalDescription<kIsComplex, NativeRefT, N>(
432 subnormal_test_inputs[i], inputs_ref_ty)));
433 }
434 absl::StrAppend(
435 &mismatch,
436 absl::StrFormat(
437 "but got %s",
438 StringifyNum<NativeT, ComponentIntegralNativeT>(actual)));
439
440 PrintMismatch(&mismatches, [mismatch] { return mismatch; });
441 }
442 EXPECT_EQ(mismatches, 0);
443 }
444
445 template class ExhaustiveOpTestBase<C128, 1>;
446 template class ExhaustiveOpTestBase<C64, 1>;
447 template class ExhaustiveOpTestBase<F64, 1>;
448 template class ExhaustiveOpTestBase<F32, 1>;
449 template class ExhaustiveOpTestBase<F16, 1>;
450 template class ExhaustiveOpTestBase<BF16, 1>;
451
452 template class ExhaustiveOpTestBase<F64, 2>;
453 template class ExhaustiveOpTestBase<F32, 2>;
454 template class ExhaustiveOpTestBase<F16, 2>;
455 template class ExhaustiveOpTestBase<BF16, 2>;
456
457 } // namespace exhaustive_op_test
458 } // namespace xla
459