xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
17 
18 #include <array>
19 #include <string>
20 #include <type_traits>
21 
22 #include "absl/strings/string_view.h"
23 
24 namespace xla {
25 namespace exhaustive_op_test {
26 
27 // For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of
28 // precision to be guaranteed that we're printing the full number.
29 //
30 // (The general formula is, given a floating-point number with S significand
31 // bits, the number of decimal digits needed to print it to full precision is
32 //
33 //   ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103).
34 //
35 // See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.)
36 namespace {
37 template <typename T>
38 struct ComponentStringifyFormat {
39   static const absl::string_view value;
40 };
41 
42 template <>
43 constexpr absl::string_view ComponentStringifyFormat<double>::value =
44     "%0.17g (0x%16x)";
45 
46 template <>
47 constexpr absl::string_view ComponentStringifyFormat<float>::value =
48     "%0.9g (0x%08x)";
49 
50 template <>
51 constexpr absl::string_view ComponentStringifyFormat<Eigen::half>::value =
52     "%0.5g (0x%04x)";
53 
54 template <>
55 constexpr absl::string_view ComponentStringifyFormat<bfloat16>::value =
56     "%0.4g (0x%04x)";
57 
58 template <typename Type, typename FuncPtr>
CallErrorSpec(FuncPtr * func,const std::array<Type,1> & in)59 ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 1>& in) {
60   return func(in[0]);
61 }
62 
63 template <typename Type, typename FuncPtr>
CallErrorSpec(FuncPtr * func,const std::array<Type,2> & in)64 ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 2>& in) {
65   return func(in[0], in[1]);
66 }
67 
68 template <typename Type, typename FuncPtr>
CallOperation(FuncPtr * func,const std::array<Type,1> & in)69 Type CallOperation(FuncPtr* func, const std::array<Type, 1>& in) {
70   return func(in[0]);
71 }
72 
73 template <typename Type, typename FuncPtr>
CallOperation(FuncPtr * func,const std::array<Type,2> & in)74 Type CallOperation(FuncPtr* func, const std::array<Type, 2>& in) {
75   return func(in[0], in[1]);
76 }
77 
78 // The number of values that can be substituted for subnormal inputs.
79 constexpr int kNumSubnormalSubstitutionValues = 4;
80 
81 // Encodings used to determine where subnormal test values are cached.
82 constexpr int kPositiveMin = 0;
83 constexpr int kNegativeMin = 1;
84 constexpr int kPositiveZero = 2;
85 constexpr int kNegativeZero = 3;
86 constexpr int kNonSubnormal = -1;
87 constexpr int kInvalidCacheIndex = -1;
88 
89 template <typename T>
90 struct is_complex_t : absl::disjunction<std::is_same<T, complex64>,
91                                         std::is_same<T, complex128>> {};
92 
93 // When we are testing a value such that all of its components are subnormal,
94 // we also need to test inputs made up of the Cartesian product of values
95 // replaced for each subnormal component. These additional test inputs are
96 // common enough where it will be efficient to just cache the results of these
97 // Cartesian products. In order to cache these values, we need a one to one
98 // mapping between these Cartesian products and cache locations.
99 //
100 // Our mapping works by assigning each component an integer in
101 // [0, kNumSubnormalSubstitutionValues) based on its test value. By lining
102 // these integers up with the n'th component corresponding to the n'th digit,
103 // then for each Cartesian product element we essentially create a unique base
104 // kNumSubnormalSubstitutionValues number. This number represents our cache
105 // index.
106 //
107 // In the event that there a component is not a subnormal, the value should
108 // not be cached, so we return a kNonSubnormal value.
109 
110 template <
111     typename NativeRefT,
112     typename std::enable_if<!is_complex_t<NativeRefT>::value>::type* = nullptr>
GetCacheLocation(NativeRefT value)113 int GetCacheLocation(NativeRefT value) {
114   bool positive = !std::signbit(value);
115   if (std::abs(value) == std::numeric_limits<NativeRefT>::min()) {
116     return positive ? kPositiveMin : kNegativeMin;
117   } else if (value != 0) {
118     CHECK(std::fpclassify(value) != FP_SUBNORMAL);
119     return kNonSubnormal;
120   } else {
121     return positive ? kPositiveZero : kNegativeZero;
122   }
123 }
124 
125 template <
126     typename NativeRefT,
127     typename std::enable_if<is_complex_t<NativeRefT>::value>::type* = nullptr>
GetCacheLocation(NativeRefT value)128 int GetCacheLocation(NativeRefT value) {
129   int real_loc =
130       GetCacheLocation<typename NativeRefT::value_type>(value.real());
131   int imag_loc =
132       GetCacheLocation<typename NativeRefT::value_type>(value.imag());
133   if (real_loc == kNonSubnormal || imag_loc == kNonSubnormal) {
134     return kNonSubnormal;
135   } else {
136     return real_loc * kNumSubnormalSubstitutionValues + imag_loc;
137   }
138 }
139 
140 template <bool is_complex, typename NativeRefT, size_t N>
GetCacheLocation(const std::array<NativeRefT,N> & input)141 int GetCacheLocation(const std::array<NativeRefT, N>& input) {
142   int location = 0;
143   int cache_size_per_element = (is_complex ? kNumSubnormalSubstitutionValues *
144                                                  kNumSubnormalSubstitutionValues
145                                            : kNumSubnormalSubstitutionValues);
146   for (int i = 0; i < N; ++i) {
147     int comp_loc = GetCacheLocation<NativeRefT>(input[i]);
148     if (i == kNonSubnormal) {
149       return kNonSubnormal;
150     }
151     location *= cache_size_per_element;
152     location += comp_loc;
153   }
154   return location;
155 }
156 
157 // The inverse function of GetCacheLocation.
158 
159 template <typename RetT,
160           typename std::enable_if<!is_complex_t<RetT>::value>::type* = nullptr>
FromCacheLocationComponent(int cache_loc)161 RetT FromCacheLocationComponent(int cache_loc) {
162   switch (cache_loc) {
163     case kPositiveMin:
164       return std::numeric_limits<RetT>::min();
165     case kNegativeMin:
166       return -std::numeric_limits<RetT>::min();
167     case kPositiveZero:
168       return static_cast<RetT>(0.0);
169     case kNegativeZero:
170       return static_cast<RetT>(-0.0);
171     default:
172       LOG(FATAL) << "Invalid cache_loc value of " << cache_loc;
173   }
174 }
175 
176 template <typename RetT,
177           typename std::enable_if<is_complex_t<RetT>::value>::type* = nullptr>
FromCacheLocationComponent(int cache_loc)178 RetT FromCacheLocationComponent(int cache_loc) {
179   CHECK_LT(cache_loc,
180            kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues);
181   CHECK_GE(cache_loc, 0);
182 
183   RetT value;
184   value.real(FromCacheLocationComponent<typename RetT::value_type>(
185       cache_loc / kNumSubnormalSubstitutionValues));
186   value.imag(FromCacheLocationComponent<typename RetT::value_type>(
187       cache_loc % kNumSubnormalSubstitutionValues));
188   return std::move(value);
189 }
190 
191 template <bool is_complex, typename NativeRefT, size_t N>
FromCacheLocation(int cache_loc)192 std::array<NativeRefT, N> FromCacheLocation(int cache_loc) {
193   std::array<NativeRefT, N> input;
194   int cache_size_per_element = (is_complex ? kNumSubnormalSubstitutionValues *
195                                                  kNumSubnormalSubstitutionValues
196                                            : kNumSubnormalSubstitutionValues);
197   for (int i = N - 1; i >= 0; --i) {
198     input[i] = FromCacheLocationComponent<NativeRefT>(cache_loc %
199                                                       cache_size_per_element);
200     cache_loc /= cache_size_per_element;
201   }
202 
203   return input;
204 }
205 
206 // Returns a string that describes the test value for the actual value.
207 template <
208     typename NativeRefT,
209     typename std::enable_if<!is_complex_t<NativeRefT>::value>::type* = nullptr>
GetSubnormalDescription(NativeRefT test_val,NativeRefT actual_val)210 std::string GetSubnormalDescription(NativeRefT test_val,
211                                     NativeRefT actual_val) {
212   std::string sp_min_normal = "sign-preserving min-normal-float";
213   std::string sp_zero = "sign-preserving zero";
214   std::string nsp_zero = "non-sign-preserving zero";
215 
216   switch (GetCacheLocation<NativeRefT>(test_val)) {
217     case kNegativeMin:
218     case kPositiveMin:
219       return sp_min_normal;
220     case kNegativeZero:
221     case kPositiveZero:
222       return (std::signbit(test_val) == std::signbit(actual_val)) ? sp_zero
223                                                                   : nsp_zero;
224     default:
225       return "";
226   }
227 }
228 
229 template <
230     typename NativeRefT,
231     typename std::enable_if<is_complex_t<NativeRefT>::value>::type* = nullptr>
GetSubnormalDescription(NativeRefT test_val,NativeRefT actual_val)232 std::string GetSubnormalDescription(NativeRefT test_val,
233                                     NativeRefT actual_val) {
234   std::string real = GetSubnormalDescription<typename NativeRefT::value_type>(
235       test_val.real(), actual_val.real());
236   std::string imag = GetSubnormalDescription<typename NativeRefT::value_type>(
237       test_val.imag(), actual_val.imag());
238 
239   if (real.empty()) {
240     if (imag.empty()) {
241       return "";
242     }
243     real = "real";
244   } else if (imag.empty()) {
245     imag = "imag";
246   }
247 
248   return absl::StrCat("(", real, ", ", imag, ")");
249 }
250 
251 template <bool is_complex, typename NativeRefT, size_t N>
GetSubnormalDescription(std::array<NativeRefT,N> test_vals,std::array<NativeRefT,N> actual_vals)252 std::string GetSubnormalDescription(std::array<NativeRefT, N> test_vals,
253                                     std::array<NativeRefT, N> actual_vals) {
254   if (N == 1) {
255     return GetSubnormalDescription<NativeRefT>(test_vals[0], actual_vals[0]);
256   }
257 
258   std::array<std::string, N> str_vals;
259   for (int i = 0; i < N; ++i) {
260     str_vals[i] =
261         GetSubnormalDescription<NativeRefT>(test_vals[i], actual_vals[i]);
262     if (str_vals[i].empty()) {
263       str_vals[i] = "original";
264     }
265   }
266 
267   return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")");
268 }
269 
270 template <
271     typename NativeT, typename IntegralType,
272     typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
StringifyNum(NativeT x)273 std::string StringifyNum(NativeT x) {
274   return absl::StrFormat(ComponentStringifyFormat<NativeT>::value,
275                          static_cast<double>(x), BitCast<IntegralType>(x));
276 }
277 
278 template <
279     typename NativeT, typename IntegralType,
280     typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
StringifyNum(NativeT x)281 std::string StringifyNum(NativeT x) {
282   return absl::StrCat(
283       "(", StringifyNum<typename NativeT::value_type, IntegralType>(x.real()),
284       ", ", StringifyNum<typename NativeT::value_type, IntegralType>(x.imag()),
285       ")");
286 }
287 
288 template <typename NativeT, typename IntegralType, size_t N>
StringifyNum(const std::array<NativeT,N> & inputs)289 std::string StringifyNum(const std::array<NativeT, N>& inputs) {
290   if (N == 1) {
291     return StringifyNum<NativeT, IntegralType>(inputs[0]);
292   }
293 
294   std::array<std::string, N> str_vals;
295   for (int i = 0; i < N; ++i) {
296     str_vals[i] = StringifyNum<NativeT, IntegralType>(inputs[i]);
297   }
298 
299   return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")");
300 }
301 
302 template <typename ErrorGenerator>
PrintMismatch(int64_t * mismatches,const ErrorGenerator & err_generator)303 void PrintMismatch(int64_t* mismatches, const ErrorGenerator& err_generator) {
304   // We send a few mismatches to gunit so they show up nicely in test logs.
305   // Then we send more to LOG(ERROR).  The remainder we squelch unless we're
306   // at vlog level 2.
307   constexpr int64_t kMaxMismatchesLoggedToGunit = 10;
308   constexpr int64_t kMaxMismatchesLoggedToErr = 1000;
309 
310   (*mismatches)++;
311   if (*mismatches < kMaxMismatchesLoggedToGunit) {
312     FAIL() << err_generator();
313   } else if (*mismatches < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) {
314     LOG(ERROR) << err_generator();
315   } else if (*mismatches == kMaxMismatchesLoggedToErr) {
316     LOG(ERROR) << "Not printing any more mismatches; pass "
317                   "--vmodule=exhaustive_op_test=2 to see "
318                   "all of them.";
319   }
320 }
321 }  // namespace
322 
323 template <PrimitiveType T, size_t N>
ExpectNear(const InputLiterals & input_literals,const Literal & result_literal,EvaluateOp evaluate_op,ErrorSpecGen error_spec_gen)324 void ExhaustiveOpTestBase<T, N>::ExpectNear(const InputLiterals& input_literals,
325                                             const Literal& result_literal,
326                                             EvaluateOp evaluate_op,
327                                             ErrorSpecGen error_spec_gen) {
328   // Cache for when all components are subnormal testing values.
329   std::vector<NativeRefT> pure_subnormal_cache;
330   // Since we take the cross product of all possible test values, and each
331   // component has kNumSubnormalSubstitutionValues possible test values, then
332   // the total number of different cache locations are
333   // kNumSubnormalSubstitutionValues raised to the num_components.
334   // num_components = N for the reals, and 2*N for the complex.
335   int64_t max_cache_size =
336       pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1));
337   pure_subnormal_cache.reserve(max_cache_size);
338   for (int i = 0; i < max_cache_size; ++i) {
339     pure_subnormal_cache.push_back(CallOperation(
340         evaluate_op, FromCacheLocation<kIsComplex, NativeRefT, N>(i)));
341   }
342 
343   NativeInputsList inputs_arr;
344   for (int i = 0; i < N; ++i) {
345     const Literal& literal = input_literals[i];
346     inputs_arr[i] = literal.data<NativeT>();
347   }
348 
349   absl::Span<const NativeT> result_arr = result_literal.data<NativeT>();
350 
351   int64_t mismatches = 0;
352 
353   for (int64_t i = 0; i < result_arr.size(); ++i) {
354     NativeInputs inputs;
355     NativeRefInputs inputs_ref_ty;
356 
357     for (int j = 0; j < N; ++j) {
358       inputs[j] = inputs_arr[j][i];
359       inputs_ref_ty[j] = static_cast<NativeRefT>(inputs[j]);
360     }
361 
362     NativeT actual = result_arr[i];
363     NativeT expected =
364         static_cast<NativeT>(CallOperation(evaluate_op, inputs_ref_ty));
365     ErrorSpec error_spec = CallErrorSpec(error_spec_gen, inputs);
366 
367     if (IsClose(static_cast<NativeRefT>(expected),
368                 static_cast<NativeRefT>(actual), error_spec)) {
369       continue;
370     }
371 
372     std::vector<NativeRefInputs> subnormal_test_inputs =
373         GetTestValuesWithSubnormalSubstitutions(inputs_ref_ty);
374 
375     // Easy case: If `input` is not subnormal and !IsClose(expected, actual,
376     // error_spec), print an error.
377     if (subnormal_test_inputs.size() == 1) {
378       PrintMismatch(&mismatches, [&] {
379         return absl::StrFormat(
380             "Mismatch on %s. Expected %s, but got %s.",
381             StringifyNum<NativeT, ComponentIntegralNativeT, N>(inputs),
382             StringifyNum<NativeT, ComponentIntegralNativeT>(expected),
383             StringifyNum<NativeT, ComponentIntegralNativeT>(actual));
384       });
385       continue;
386     }
387 
388     // Otherwise, we need to test the additional subnormal test values.
389     std::vector<NativeRefT> subnormal_test_results;
390     subnormal_test_results.reserve(subnormal_test_inputs.size());
391     bool passed_subnormal_test = false;
392 
393     for (NativeRefInputs test_value : subnormal_test_inputs) {
394       NativeRefT result;
395       int cache_loc =
396           GetCacheLocation<kIsComplex, typename NativeRefInputs::value_type, N>(
397               test_value);
398       if (cache_loc == kInvalidCacheIndex) {
399         result = CallOperation(evaluate_op, test_value);
400       } else {
401         result = pure_subnormal_cache[cache_loc];
402       }
403 
404       if (IsClose(result, static_cast<NativeRefT>(actual), error_spec)) {
405         passed_subnormal_test = true;
406         break;
407       }
408       subnormal_test_results.push_back(std::move(result));
409     }
410 
411     if (passed_subnormal_test) {
412       continue;
413     }
414 
415     std::string mismatch = absl::StrFormat(
416         "Mismatch on subnormal value %s.  Expected one of:\n"
417         "  %10s (evaluated at full-precision value)\n",
418         StringifyNum<NativeT, ComponentIntegralNativeT, N>(inputs),
419         StringifyNum<NativeT, ComponentIntegralNativeT>(expected));
420 
421     CHECK_EQ(subnormal_test_inputs.size(), subnormal_test_results.size());
422     for (int i = 0; i < subnormal_test_inputs.size(); ++i) {
423       using IntegralNativeRefT =
424           typename ExhaustiveOpTestBase<RefT::value,
425                                         N>::ComponentIntegralNativeT;
426       absl::StrAppend(
427           &mismatch,
428           absl::StrFormat("  %10s (evaluated at %s)\n",
429                           StringifyNum<NativeRefT, IntegralNativeRefT>(
430                               subnormal_test_results[i]),
431                           GetSubnormalDescription<kIsComplex, NativeRefT, N>(
432                               subnormal_test_inputs[i], inputs_ref_ty)));
433     }
434     absl::StrAppend(
435         &mismatch,
436         absl::StrFormat(
437             "but got %s",
438             StringifyNum<NativeT, ComponentIntegralNativeT>(actual)));
439 
440     PrintMismatch(&mismatches, [mismatch] { return mismatch; });
441   }
442   EXPECT_EQ(mismatches, 0);
443 }
444 
445 template class ExhaustiveOpTestBase<C128, 1>;
446 template class ExhaustiveOpTestBase<C64, 1>;
447 template class ExhaustiveOpTestBase<F64, 1>;
448 template class ExhaustiveOpTestBase<F32, 1>;
449 template class ExhaustiveOpTestBase<F16, 1>;
450 template class ExhaustiveOpTestBase<BF16, 1>;
451 
452 template class ExhaustiveOpTestBase<F64, 2>;
453 template class ExhaustiveOpTestBase<F32, 2>;
454 template class ExhaustiveOpTestBase<F16, 2>;
455 template class ExhaustiveOpTestBase<BF16, 2>;
456 
457 }  // namespace exhaustive_op_test
458 }  // namespace xla
459