xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
18 
19 #include <algorithm>
20 #include <array>
21 #include <cmath>
22 #include <cstdint>
23 #include <functional>
24 #include <iterator>
25 #include <string>
26 #include <utility>
27 
28 #include "tensorflow/compiler/xla/bit_cast.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/lib/math.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
33 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/tests/test_macros.h"
36 
37 namespace xla {
38 namespace exhaustive_op_test {
39 
40 struct ErrorSpec {
41   float abs_err;
42   float rel_err;
43 
44   // If true, will consider -0 not near to +0 and vice versa.  Note that
45   // +epsilon may still be considered close to -0, depending on the error
46   // spec; this only covers the case when both `expected` and `actual` are
47   // equal to 0.
48   bool strict_signed_zeros = false;
49 
ErrorSpecErrorSpec50   ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {}
51 };
52 
53 // Representations of the reference function passed in by the user.
54 template <typename NativeRefT, size_t K>
55 struct EvaluateOpWrapper {};
56 template <typename NativeRefT>
57 struct EvaluateOpWrapper<NativeRefT, 1> {
58   using type = NativeRefT (*)(NativeRefT);
59 };
60 template <typename NativeRefT>
61 struct EvaluateOpWrapper<NativeRefT, 2> {
62   using type = NativeRefT (*)(NativeRefT, NativeRefT);
63 };
64 
65 // Representations of the reference function passed in by the user.
66 template <typename XlaInputs, size_t K>
67 struct EnqueueOpWrapper {};
68 template <typename XlaInputs>
69 struct EnqueueOpWrapper<XlaInputs, 1> {
70   using type = std::function<XlaOp(XlaOp)>;
71   static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
72     return ty(inputs[0]);
73   }
74 };
75 template <typename XlaInputs>
76 struct EnqueueOpWrapper<XlaInputs, 2> {
77   using type = std::function<XlaOp(XlaOp, XlaOp)>;
78   static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
79     return ty(inputs[0], inputs[1]);
80   }
81 };
82 
83 // Representations of the ErrorSpecGen function passed in by the user.
84 template <PrimitiveType T, size_t K>
85 struct ErrorSpecGenWrapper {};
86 template <PrimitiveType T>
87 struct ErrorSpecGenWrapper<T, 1> {
88   using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
89   using type = ErrorSpec (*)(NativeT);
90 };
91 template <PrimitiveType T>
92 struct ErrorSpecGenWrapper<T, 2> {
93   using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
94   using type = ErrorSpec (*)(NativeT, NativeT);
95 };
96 
97 template <PrimitiveType T, size_t N>
98 typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator();
99 
100 // T: The primitive type being tested.
101 // N: The number of operands that the function being tested takes.
102 template <PrimitiveType T, size_t N>
103 class ExhaustiveOpTestBase : public ClientLibraryTestBase {
104  public:
105   // Definitions depending on the primitive type T.
106 
107   static constexpr bool kIsComplex = (T == C128 || T == C64);
108 
109   // The primitive type used to compute the reference output.
110   struct RefT {
111     static constexpr PrimitiveType value = (T == F16 || T == BF16) ? F32 : T;
112   };
113 
114   // The primitive type of the component of T. If T is not complex, then
115   // ComponentT = T.
116   struct ComponentT {
117     static constexpr PrimitiveType value = !kIsComplex ? T
118                                            : T == C128 ? F64
119                                            : T == C64  ? F32
120                                                        : PRIMITIVE_TYPE_INVALID;
121   };
122 
123   // Same as ComponentT, but for the RefT primitive type.
124   struct ComponentRefT {
125     static constexpr PrimitiveType value = !kIsComplex           ? RefT::value
126                                            : RefT::value == C128 ? F64
127                                            : RefT::value == C64
128                                                ? F32
129                                                : PRIMITIVE_TYPE_INVALID;
130   };
131 
132   // The primitive type of an unsigned integer that can be bitcasted to and from
133   // ComponentT.
134   struct ComponentIntegralT {
135     static constexpr PrimitiveType value = (T == C128 || T == F64)  ? U64
136                                            : (T == C64 || T == F32) ? U32
137                                            : (T == F16 || T == BF16)
138                                                ? U16
139                                                : PRIMITIVE_TYPE_INVALID;
140   };
141 
142   // Native types that correspond to the primitive types above.
143   using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
144   using NativeRefT =
145       typename primitive_util::PrimitiveTypeToNative<RefT::value>::type;
146   using ComponentNativeT =
147       typename primitive_util::PrimitiveTypeToNative<ComponentT::value>::type;
148   using ComponentNativeRefT = typename primitive_util::PrimitiveTypeToNative<
149       ComponentRefT::value>::type;
150   using ComponentIntegralNativeT =
151       typename primitive_util::PrimitiveTypeToNative<
152           ComponentIntegralT::value>::type;
153 
154   using InputLiterals = std::array<Literal, N>;
155 
156  private:
157   // N spans corresponding to the list of literal data values.
158   using NativeInputsList = std::array<absl::Span<const NativeT>, N>;
159 
160   // N data items representing a single input to an XLA function.
161   using NativeInputs = std::array<NativeT, N>;
162 
163   // N data items representing a single input to an interpreter backend
164   // function.
165   using NativeRefInputs = std::array<NativeRefT, N>;
166 
167   // N data items representing a single input to an XLA function.
168   using XlaInputs = std::array<XlaOp, N>;
169 
170  public:
171   using ErrorSpecGen = typename ErrorSpecGenWrapper<T, N>::type;
172   using EvaluateOp = typename EvaluateOpWrapper<NativeRefT, N>::type;
173   using EnqueueOp = typename EnqueueOpWrapper<XlaInputs, N>::type;
174 
175   explicit ExhaustiveOpTestBase()
176       : ty_(T), platform_(client_->platform()->Name()) {
177     SetFastMathDisabled(true);
178 
179     // Run all HLO passes.  In particular, constant folding is disabled by
180     // default for tests, but we need to run it in order to tickle some bugs.
181     mutable_debug_options()->clear_xla_disable_hlo_passes();
182   }
183 
184   void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op) {
185     Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator<T, N>());
186   }
187 
188   // A helper for implementing the Run method for exhaustive op tests. It
189   // constructs the HLO module, compiles and runs the module and checks the
190   // result.
191   //
192   // We use a function pointer for evaluate_op for performance because it is
193   // called each time an output element is compared inside a loop in routine
194   // ExpectNear.
195   void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op,
196            ErrorSpecGen error_spec_gen) {
197     InputLiterals input_literals = CreateInputLiterals();
198     FillInput(&input_literals);
199 
200     XlaBuilder builder(TestName());
201     XlaInputs xla_inputs;
202     for (int i = 0; i < N; ++i) {
203       xla_inputs[i] =
204           Parameter(&builder, i, input_literals[i].shape(), "input");
205     }
206     EnqueueOpWrapper<XlaInputs, N>::BuildFromInputs(xla_inputs, enqueue_op);
207 
208     TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build());
209     TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
210                             RunComputationHelper(comp, input_literals));
211     ExpectNear(input_literals, result_literal, evaluate_op, error_spec_gen);
212   }
213 
214   StatusOr<Literal> RunComputationHelper(const XlaComputation& comp,
215                                          const Literal& literal) {
216     return RunComputation(comp, {&literal});
217   }
218 
219   StatusOr<Literal> RunComputationHelper(
220       const XlaComputation& comp, const std::array<Literal, N>& literals) {
221     std::array<const Literal*, N> lit_ptrs;
222     for (int i = 0; i < N; ++i) {
223       lit_ptrs[i] = &literals[i];
224     }
225     return RunComputation(comp, lit_ptrs);
226   }
227 
228   // We essentially reimplement LiteralTestUtil::Near here because
229   //  a) this streamlined implementation is much faster, and
230   //  b) we can print out better error messages (namely, we can print out
231   //     which floating-point value input failed, while LiteralTestUtil::Near
232   //     can only print out the input index that failed).
233   //  c) we need special handling of certain inputs.  For example, we say that
234   //     a denormal input has multiple correct outputs (namely, f(x) and f(0))
235   //     and just needs to be close to one of them.
236   void ExpectNear(const InputLiterals& input_literals,
237                   const Literal& result_literal, EvaluateOp evaluate_op,
238                   ErrorSpecGen error_spec_gen);
239 
240   // Builds and runs the computation using the LocalClient API, rather than the
241   // plain Client API, which is used by ClientLibraryTestBase.  This is because
242   // the plain Client API results does more memcpys to/from Literals, and that's
243   // slow given that we're touching a lot of data here.
244   StatusOr<Literal> RunComputation(
245       const XlaComputation& computation,
246       absl::Span<const Literal* const> input_literals) {
247     // Copy debug options from ClientLibraryTestBase.  In particular, we're
248     // interested in disabling constant folding.
249     ExecutableBuildOptions build_opts;
250     *build_opts.mutable_debug_options() = *mutable_debug_options();
251 
252     std::vector<ScopedShapedBuffer> input_buffers;
253     absl::c_transform(input_literals, std::back_inserter(input_buffers),
254                       [&](const Literal* input_literal) {
255                         return client_
256                             ->LiteralToShapedBuffer(*input_literal,
257                                                     /*device_ordinal=*/0)
258                             .value();
259                       });
260     std::vector<const Shape*> input_shapes;
261     absl::c_transform(input_buffers, std::back_inserter(input_shapes),
262                       [&](const ScopedShapedBuffer& buffer) {
263                         return &buffer.on_device_shape();
264                       });
265 
266     TF_ASSIGN_OR_RETURN(
267         auto executables,
268         client_->Compile(computation, input_shapes, build_opts));
269 
270     std::vector<const ShapedBuffer*> input_buffer_pointers;
271     absl::c_transform(
272         input_buffers, std::back_inserter(input_buffer_pointers),
273         [&](const ScopedShapedBuffer& buffer) { return &buffer; });
274 
275     ExecutableRunOptions run_opts;
276     run_opts.set_allocator(client_->backend().memory_allocator());
277     run_opts.set_intra_op_thread_pool(
278         client_->backend().eigen_intra_op_thread_pool_device());
279     TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
280                         executables[0]->Run(input_buffer_pointers, run_opts));
281 
282     TF_ASSIGN_OR_RETURN(Literal result_literal,
283                         client_->ShapedBufferToLiteral(result));
284     return std::move(result_literal);
285   }
286 
287   const std::string& Platform() { return platform_; }
288 
289   // Returns the number of elements in each input literal.
290   virtual int64_t GetInputSize() = 0;
291 
292   // Fills the literals with values to test for.
293   virtual void FillInput(InputLiterals* literals) = 0;
294 
295   // Replace infinites with max value to help compute errors.
296   static ComponentNativeRefT ReplaceInfWithMax(ComponentNativeRefT value) {
297     if (std::isinf(value)) {
298       return std::copysign(std::numeric_limits<ComponentNativeRefT>::max(),
299                            value);
300     }
301     return value;
302   }
303 
304   // Returns true if both components are 0, but their sign bits differ.
305   static bool CheckSignedZeroError(ComponentNativeRefT expected,
306                                    ComponentNativeRefT actual) {
307     return expected == 0 && actual == 0 &&
308            std::signbit(expected) != std::signbit(actual);
309   }
310 
311   // Sets the components to 0 if both are NaNs.
312   static void RemoveCorrespondingNaNs(ComponentNativeRefT* expected,
313                                       ComponentNativeRefT* actual) {
314     if (std::isnan(*expected) && std::isnan(*actual)) {
315       *expected = 0;
316       *actual = 0;
317     }
318   }
319 
320   // The Implementation of the functions above, except for complex inputs.
321 
322   static std::complex<ComponentNativeRefT> ReplaceInfWithMax(
323       std::complex<ComponentNativeRefT> value) {
324     value.real(ReplaceInfWithMax(value.real()));
325     value.imag(ReplaceInfWithMax(value.imag()));
326     return value;
327   }
328 
329   static bool CheckSignedZeroError(std::complex<ComponentNativeRefT> expected,
330                                    std::complex<ComponentNativeRefT> actual) {
331     return CheckSignedZeroError(expected.real(), actual.real()) ||
332            CheckSignedZeroError(expected.imag(), actual.imag());
333   }
334 
335   static void RemoveCorrespondingNaNs(
336       std::complex<ComponentNativeRefT>* expected,
337       std::complex<ComponentNativeRefT>* actual) {
338     ComponentNativeRefT expected_real = expected->real();
339     ComponentNativeRefT expected_imag = expected->imag();
340     ComponentNativeRefT actual_real = actual->real();
341     ComponentNativeRefT actual_imag = actual->imag();
342     RemoveCorrespondingNaNs(&expected_real, &actual_real);
343     RemoveCorrespondingNaNs(&expected_imag, &actual_imag);
344     expected->real(expected_real);
345     expected->imag(expected_imag);
346     actual->real(actual_real);
347     actual->imag(actual_imag);
348   }
349 
350   // Returns a list of inputs that should be tested for closeness given some
351   // original input values.
352   //
353   // For denormal component inputs, we accept answers that are close to any of:
354   //
355   //   - evaluate_op(input)
356   //   - evaluate_op(+/-0), where the sign of 0 equal to the sign of
357   //     `input`,
358   //   - evaluate_op(+/-min_normal_float), where the sign of
359   //     min_normal_float matches `input`.
360   //   - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of
361   //     0 is the opposite of `input`.
362   //
363   // (In particular, the XLA:CPU implementation of log flushes positive
364   // denormals to min-normal-float.  This seems kind of reasonable if our
365   // goal is to avoid infinities because they cause nans?)
366   std::vector<ComponentNativeRefT> GetTestValuesWithSubnormalSubstitutions(
367       ComponentNativeRefT value) {
368     std::vector<ComponentNativeRefT> test_values;
369     if (std::fpclassify(value) == FP_SUBNORMAL) {
370       test_values.reserve(relaxed_denormal_signs_ ? 3 : 2);
371       test_values.push_back(std::copysign(0, value));
372       test_values.push_back(std::copysign(
373           std::numeric_limits<ComponentNativeRefT>::min(), value));
374       if (relaxed_denormal_signs_) {
375         test_values.push_back(std::copysign(0, -value));
376       }
377     } else {
378       test_values.push_back(value);
379     }
380     return test_values;
381   }
382 
383   // Similar to complex numbers, we only need to test the components that are
384   // subnormal. We can find the subnormal testing values for each component,
385   // then take the Cartesian product of each set of component values.
386   std::vector<std::complex<ComponentNativeRefT>>
387   GetTestValuesWithSubnormalSubstitutions(
388       std::complex<ComponentNativeRefT> value) {
389     using complex = std::complex<ComponentNativeRefT>;
390 
391     auto real_values = GetTestValuesWithSubnormalSubstitutions(value.real());
392     auto imag_values = GetTestValuesWithSubnormalSubstitutions(value.imag());
393 
394     std::vector<complex> test_values;
395     test_values.reserve(real_values.size() * imag_values.size());
396     for (auto real : real_values) {
397       for (auto imag : imag_values) {
398         test_values.push_back(complex(real, imag));
399       }
400     }
401 
402     return test_values;
403   }
404 
405   // The test values for an XLA function with N operands are the Cartesian
406   // product of the test values for each of the N operands.
407   std::vector<std::array<NativeRefT, N>>
408   GetTestValuesWithSubnormalSubstitutions(
409       const std::array<NativeRefT, N>& value) {
410     std::vector<std::array<NativeRefT, N>> test_values;
411 
412     std::array<std::vector<NativeRefT>, N> component_test_values;
413     int total = 1;
414     for (int i = 0; i < N; ++i) {
415       component_test_values[i] =
416           GetTestValuesWithSubnormalSubstitutions(value[i]);
417       if (!component_test_values.empty()) {
418         total *= component_test_values[i].size();
419       }
420     }
421 
422     // If total == 1, then value has no subnormal components, so we can just
423     // return a vector with value in it.
424     if (total == 1) {
425       test_values.push_back(value);
426       return test_values;
427     }
428 
429     test_values.reserve(total);
430 
431     // Perform a Cartesian product of the vectors in component_test_values.
432     // We can calculate this by uniquely mapping each integer from 0 to
433     // (total - 1) to a list of component indices. The function that maps an
434     // integer z to the index of component j is:
435     //    component_index(j) =  (i / NumValues(0, j-1)) % NumValues(j, j)
436     // and NumIndices(x, y) is the number of values in the Cartesian product of
437     // component_test_values[x], component_test_values[x+1], ...
438     // component_test_values[y].
439     for (int i = 0; i < total; ++i) {
440       int accumulated_num_values = 1;
441       std::array<NativeRefT, N> test_value;
442       for (int j = 0; j < N; ++j) {
443         int num_indices = component_test_values[j].size();
444         int component_index = (i / accumulated_num_values) % num_indices;
445         test_value[j] = component_test_values[j][component_index];
446         accumulated_num_values *= num_indices;
447       }
448       test_values.push_back(std::move(test_value));
449     }
450     return test_values;
451   }
452 
453   InputLiterals CreateInputLiterals() {
454     InputLiterals literals;
455     for (int i = 0; i < N; ++i) {
456       literals[i] = LiteralUtil::CreateFromDimensions(T, {GetInputSize()});
457     }
458     return std::move(literals);
459   }
460 
461   // Determines if two output values are sufficiently close to each other based
462   // on an error spec.
463   bool IsClose(NativeRefT expected, NativeRefT actual, ErrorSpec spec) {
464     // When two corresponding values are a NaN, they can be considered to have
465     // the same value, so the values are just set to 0.
466     RemoveCorrespondingNaNs(&expected, &actual);
467 
468     if (spec.strict_signed_zeros) {
469       if (CheckSignedZeroError(expected, actual)) {
470         return false;
471       }
472     }
473 
474     // Replace Inf with Max when calculating absolute or relative errors. This
475     // allows the test to pass when another value are close to Inf and the
476     // specified absolute or relative errors are not zero.
477     double abs_err =
478         std::abs(ReplaceInfWithMax(expected) - ReplaceInfWithMax(actual));
479     double rel_err = abs_err / std::abs(ReplaceInfWithMax(expected));
480 
481     return abs_err <= spec.abs_err || rel_err <= spec.rel_err;
482   }
483 
484   // Converts part or all bits in an uint64_t to the value of the floating point
485   // data type being tested.
486   //
487   // When trying to exhaustive test for an operation of data type T, we always
488   // use an integral I with the same number of bits at T to exhaustive the input
489   // bit patterns for T. This bit pattern is zero extended and stored as
490   // uint64_t. This function is used to convert such a bit pattern stored as
491   // uint64_t to the input value for T.
492   static ComponentNativeT ConvertValue(uint64_t bits) {
493     using I = ComponentIntegralNativeT;
494     I used_bits = static_cast<I>(bits);
495     return BitCast<ComponentNativeT>(used_bits);
496   }
497 
498   ComponentNativeT ConvertAndReplaceKnownIncorrectValueWith(
499       uint64_t bits, int replacement_value = 0) {
500     if (known_incorrect_fn_ && known_incorrect_fn_(bits)) {
501       return static_cast<ComponentNativeT>(replacement_value);
502     }
503     return ConvertValue(bits);
504   }
505 
506  protected:
507   // The primitive type being tested.
508   const PrimitiveType ty_;
509 
510   // The platform under test.
511   const std::string platform_;
512 
513   // Testing will ignore inputs for which known_incorrect_fn_ returns true. The
514   // argument to the function is the raw bits for the data being test, zero
515   // extended to 64 bits if the data type is less than 64 bits.
516   std::function<bool(int64_t)> known_incorrect_fn_;
517 
518   // If true, allows denormals to be flushed to non-sign-preserving 0.
519   //
520   // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of
521   // a negative number) or -inf (flush the denormal to sign-perserving zero,
522   // then sqrt(-0)).  But with this as true, we'll also accept 0 (sqrt(0)).
523   //
524   // XLA:GPU preserves denormal signs, but other backends don't.
525   bool relaxed_denormal_signs_ = platform_ != "CUDA";
526 };
527 
528 // Represents a set of 64 bit chunks by representing the starting bit chunk,
529 // the last bit chunk, and the spacing between two adjacent bit chunks, without
530 // actually storing all the bit chunks being generated. The bit chunk iterator
531 // is provided to retrieve all the bit chunks.
532 //
533 // This data structure is used to generate the bit representation to test
534 // operations that requires more than 64 bit input data. In this case,
535 // truly exhaustive testing is not possible and we want to test a value every
536 // n values, where n == spacing_.
537 //
538 // Currently, the iterator of BitChunks adds the `spacing_` to a bit chunk to
539 // compute the next bit chunk. We can change this to use values generated
540 // by a random number generator that can achieve the average spacing
541 // statistically, if we will find this is necessary.
542 class BitChunks {
543  public:
544   class iterator {
545    public:
546     using iterator_category = std::input_iterator_tag;
547     using value_type = uint64_t;
548     using difference_type = uint64_t;
549     using pointer = const uint64_t*;
550     using reference = uint64_t;
551 
552     iterator() = default;
553 
554     explicit iterator(const BitChunks* bit_chunks)
555         : bit_chunks_(bit_chunks), next_bit_chunk_(bit_chunks->start_) {}
556 
557     iterator& operator++() {
558       Next();
559       return *this;
560     }
561 
562     iterator operator++(int) {
563       iterator retval = *this;
564       Next();
565       return retval;
566     }
567 
568     bool operator==(iterator other) const {
569       return bit_chunks_ == other.bit_chunks_ &&
570              next_bit_chunk_ == other.next_bit_chunk_;
571     }
572 
573     bool operator!=(iterator other) const { return !(*this == other); }
574 
575     iterator MoveToEnd() {
576       MoveNextBitChunkToOnePassEnd();
577       return *this;
578     }
579 
580     reference operator*() const {
581       CHECK(*this != this->bit_chunks_->end());
582       return next_bit_chunk_;
583     }
584 
585     const BitChunks* GetBitChunks() const { return bit_chunks_; }
586 
587     void Reset() { next_bit_chunk_ = bit_chunks_->start_; }
588 
589     void Next() {
590       CHECK(*this != this->bit_chunks_->end());
591       if (next_bit_chunk_ == bit_chunks_->end_) {
592         MoveNextBitChunkToOnePassEnd();
593       } else {
594         next_bit_chunk_ += bit_chunks_->spacing_;
595         if (next_bit_chunk_ > bit_chunks_->end_) {
596           next_bit_chunk_ = bit_chunks_->end_;
597         }
598       }
599     }
600 
601     std::string ToString() const {
602       return absl::StrFormat("0x%08x", next_bit_chunk_);
603     }
604 
605    private:
606     // Move next_bit_chunk_ to 1 pass the bit_chunks_->end, to mark that the
607     // iterator has reached the end. When spacing_ is not one, or if we will
608     // change to use a random value instead of spacing_ in function Next(),
609     // normalizing the representation of the iterator ending this way can
610     // can simplify the checking for iterator ending.
611     void MoveNextBitChunkToOnePassEnd() {
612       next_bit_chunk_ = bit_chunks_->end_ + 1;
613     }
614 
615     const BitChunks* bit_chunks_;
616     uint64_t next_bit_chunk_;
617   };
618 
619   iterator begin() const { return iterator(this); }
620   iterator end() const {
621     iterator end(this);
622     return end.MoveToEnd();
623   }
624 
625   explicit BitChunks(uint64_t start = 0, uint64_t end = 0, uint64_t spacing = 1)
626       : start_(start), end_(end), spacing_(spacing) {
627     CHECK_GE(end_, start_);
628     CHECK_NE(spacing, 0) << ToString();
629   }
630 
631   int64_t GetTotalBitChunks() const {
632     if (start_ == end_) {
633       return 1;
634     }
635 
636     return 1 + (end_ - start_ + spacing_ - 1) / spacing_;
637   }
638 
639   std::string ToString() const {
640     return absl::StrFormat("(0x%08x, 0x%08x, 0x%08x)", start_, end_, spacing_);
641   }
642 
643   uint64_t start_;
644   uint64_t end_;
645   uint64_t spacing_;
646 };
647 
648 inline std::string StringifyNum(BitChunks c) { return c.ToString(); }
649 
650 inline std::string StringifyNum(BitChunks::iterator c) { return c.ToString(); }
651 
652 template <typename T>
653 void AppendStringifyNum(std::string* s, T x) {
654   absl::StrAppend(s, StringifyNum(x));
655 }
656 
657 // Represents a set of floating point values through the possible values for
658 // the three components: mantissa, exponent, and sign. Also implements an
659 // iterator for retrieving all the represented floating point values.
660 class FpValues {
661  public:
662   static constexpr int kTotalBitChunks = 3;
663 
664   class iterator {
665    public:
666     using iterator_category = std::input_iterator_tag;
667     using value_type = uint64_t;
668     using difference_type = uint64_t;
669     using pointer = const uint64_t*;
670     using reference = uint64_t;
671 
672     explicit iterator(const FpValues* fp_values) : fp_values_(fp_values) {
673       for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
674         iters_[i] = BitChunks::iterator(&fp_values->GetBitChunks(i));
675       }
676     }
677 
678     iterator& operator++() {
679       Next();
680       return *this;
681     }
682 
683     iterator operator++(int) {
684       iterator retval = *this;
685       Next();
686       return retval;
687     }
688 
689     bool operator==(iterator other) const {
690       for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
691         if (iters_[i] != other.GetBitChunksIter(i)) {
692           return false;
693         }
694       }
695       return true;
696     }
697 
698     bool operator!=(iterator other) const { return !(*this == other); }
699 
700     iterator MoveToEnd() {
701       for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
702         iters_[i].MoveToEnd();
703       }
704       return *this;
705     }
706 
707     uint64_t operator*() const {
708       uint64_t value = 0;
709       for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
710         value = value | (*iters_[i]) << fp_values_->offsets_[i];
711       }
712       return value;
713     }
714 
715     const BitChunks::iterator& GetBitChunksIter(int i) { return iters_[i]; }
716 
717     std::string ToString() const {
718       return absl::StrJoin(iters_, ",",
719                            AppendStringifyNum<BitChunks::iterator>);
720     }
721 
722    private:
723     // Moves the iterator for the ith BitChunks to the next value, and
724     // returns true if the new state is not the end of the iterator.
725     bool Next(int i = 0) {
726       iters_[i].Next();
727       if (iters_[i] == iters_[i].GetBitChunks()->end()) {
728         if (i == FpValues::kTotalBitChunks - 1) {
729           return false;
730         }
731         if (Next(i + 1)) {
732           iters_[i].Reset();
733           return true;
734         }
735         return false;
736       }
737       return true;
738     }
739 
740     std::array<BitChunks::iterator, FpValues::kTotalBitChunks> iters_;
741     const FpValues* fp_values_;
742   };
743 
744   FpValues() : bit_chunks_(), offsets_() {}
745   FpValues(absl::Span<const BitChunks> chunks, absl::Span<const int> offsets) {
746     CHECK_EQ(chunks.size(), offsets.size() - 1);
747     CHECK_EQ(chunks.size(), kTotalBitChunks);
748     std::copy_n(chunks.begin(), kTotalBitChunks, bit_chunks_.begin());
749     std::copy_n(offsets.begin(), kTotalBitChunks, offsets_.begin());
750 
751     // The last value in `offsets` is the total number of bits.
752     offsets_[kTotalBitChunks] = offsets[kTotalBitChunks];
753     // Validate the input values.
754     for (int i = 0; i < kTotalBitChunks; ++i) {
755       int total_bits = offsets[i + 1] - offsets[i];
756       if (total_bits < 64) {
757         uint64_t bound = 1ull << total_bits;
758         CHECK_LT(chunks[i].start_, bound);
759         CHECK_LT(chunks[i].end_, bound);
760       } else {
761         CHECK_EQ(total_bits, 64);
762       }
763     }
764   }
765 
766   iterator begin() const { return iterator(this); }
767 
768   iterator end() const {
769     iterator end(this);
770     return end.MoveToEnd();
771   }
772 
773   int64_t GetTotalNumValues() const {
774     int64_t total = 1;
775     absl::c_for_each(bit_chunks_, [&](const BitChunks& chunks) {
776       total *= chunks.GetTotalBitChunks();
777     });
778     return total;
779   }
780 
781   const BitChunks& GetBitChunks(int i) const { return bit_chunks_[i]; }
782 
783   std::string ToString() const {
784     return absl::StrCat(
785         "[", absl::StrJoin(bit_chunks_, ",", AppendStringifyNum<BitChunks>),
786         "]");
787   }
788 
789   std::array<BitChunks, kTotalBitChunks> bit_chunks_;
790   std::array<int, kTotalBitChunks + 1> offsets_;
791 };
792 
793 template <typename T, typename std::enable_if<
794                           std::is_same<T, float>::value ||
795                           std::is_same<T, double>::value>::type* = nullptr>
796 int GetMantissaTotalBits() {
797   return std::numeric_limits<T>::digits - 1;
798 }
799 
800 template <typename T>
801 int GetFpTotalBits() {
802   return sizeof(T) * 8;
803 }
804 
805 template <typename T>
806 int GetExponentTotalBits() {
807   return GetFpTotalBits<T>() - GetMantissaTotalBits<T>() - 1;
808 }
809 
810 template <typename T>
811 uint64_t GetAllOneMantissa() {
812   return (1ull << GetMantissaTotalBits<T>()) - 1ull;
813 }
814 
815 template <typename T>
816 uint64_t GetAllOneExponent() {
817   return (1ull << GetExponentTotalBits<T>()) - 1ull;
818 }
819 
820 template <typename T, typename std::enable_if<
821                           std::is_same<T, float>::value ||
822                           std::is_same<T, double>::value>::type* = nullptr>
823 FpValues GetFpValues(BitChunks mantissa, BitChunks exponent, BitChunks sign) {
824   int total_bits = GetFpTotalBits<T>();
825   return FpValues({mantissa, exponent, sign},
826                   {0, GetMantissaTotalBits<T>(), total_bits - 1, total_bits});
827 }
828 
829 template <typename T>
830 FpValues GetZeros() {
831   return GetFpValues<T>(BitChunks(0, 0, 1), BitChunks(0, 0, 1),
832                         BitChunks(0, 1, 1));
833 }
834 
835 template <typename T>
836 FpValues GetSubnormals(int approx_num_values) {
837   int mantissa = GetMantissaTotalBits<T>();
838   uint64_t mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2);
839   return GetFpValues<T>(
840       BitChunks(0x1, GetAllOneMantissa<T>(), mantissa_spacing),
841       BitChunks(0, 0, 1), BitChunks(0, 1, 1));
842 }
843 
844 template <typename T>
845 FpValues GetInfinites() {
846   uint64_t all_one_exp = GetAllOneExponent<T>();
847   return GetFpValues<T>(BitChunks(0, 0, 1),
848                         BitChunks(all_one_exp, all_one_exp, 1),
849                         BitChunks(0, 1, 1));
850 }
851 
852 template <typename T>
853 FpValues GetNans(int approx_num_values) {
854   int mantissa = GetMantissaTotalBits<T>();
855   uint64_t mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2);
856   uint64_t all_one_exp = GetAllOneExponent<T>();
857   return GetFpValues<T>(
858       BitChunks(0x1, GetAllOneMantissa<T>(), mantissa_spacing),
859       BitChunks(all_one_exp, all_one_exp, 1), BitChunks(0, 1, 1));
860 }
861 
862 template <typename T>
863 FpValues GetNormals(int approx_num_values) {
864   float component_total = std::sqrt(static_cast<float>(approx_num_values));
865   return GetFpValues<T>(
866       BitChunks(0x1, GetAllOneMantissa<T>(),
867                 (1ull << (GetMantissaTotalBits<T>() + 1)) / component_total),
868       BitChunks(0x1, GetAllOneExponent<T>() - 1,
869                 (1ull << (GetExponentTotalBits<T>() + 1)) / component_total),
870       BitChunks(0, 1, 1));
871 }
872 
873 // Returns a vector of FpValues, which together represent about
874 // `approx_num_values` floating point values of type `T`, with each FpValues
875 // represents about `num_values_per_group` floating point values.
876 template <typename T>
877 std::vector<FpValues> GetFpValuesWithExponents(uint64_t first_exponent,
878                                                uint64_t exponent_spacing,
879                                                uint64_t num_exponents,
880                                                uint64_t approx_num_values,
881                                                uint64_t num_values_per_group) {
882   const uint64_t num_signs = 2;
883   uint64_t approx_num_mantissa =
884       approx_num_values / (num_exponents * num_signs);
885   uint64_t num_mantissa_per_group =
886       num_values_per_group / (num_exponents * num_signs);
887   CHECK_GT(approx_num_mantissa, 0);
888   CHECK_GT(num_mantissa_per_group, 0);
889 
890   CHECK_LT(first_exponent + num_exponents - 1ull, GetAllOneExponent<T>());
891   int mantissa = GetMantissaTotalBits<T>();
892   uint64_t mantissa_spacing = (1ull << mantissa) / approx_num_mantissa;
893 
894   std::vector<FpValues> result;
895   for (uint64_t group_start = 0; group_start < GetAllOneMantissa<T>();
896        group_start += mantissa_spacing * num_mantissa_per_group) {
897     uint64_t group_end =
898         group_start + (num_mantissa_per_group - 1) * mantissa_spacing;
899     if (group_end > GetAllOneMantissa<T>()) {
900       group_end = GetAllOneMantissa<T>();
901     }
902     result.push_back(GetFpValues<T>(
903         BitChunks(group_start, group_end, mantissa_spacing),
904         BitChunks(first_exponent, first_exponent + num_exponents - 1, 1),
905         BitChunks(0, 1, 1)));
906   }
907   return result;
908 }
909 
910 // Returns a vector of FpValues together represent about `approx_num_values`
911 // "very large" floating point values and `approx_num_values` "very small"
912 // floating point values of type `T`, which each FpValues represent about
913 // `num_values_per_group` floating point values. Because we use FpValues as
914 // a parameter for parameterized testing, the number of floating values
915 // represented by each FpValues affects the input size for each sub-test and
916 // the hence the peak memory usage of the test.
917 template <typename T>
918 std::vector<FpValues> GetFpValuesForMagnitudeExtremeNormals(
919     uint64_t approx_num_values = 40000, uint64_t num_values_per_group = 4000) {
920   std::vector<FpValues> large =
921       GetFpValuesWithExponents<T>(GetAllOneExponent<T>() - 5, 1, 5,
922                                   approx_num_values / 2, num_values_per_group);
923   std::vector<FpValues> small = GetFpValuesWithExponents<T>(
924       1, 1, 5, approx_num_values / 2, num_values_per_group);
925   large.insert(large.end(), small.begin(), small.end());
926   return large;
927 }
928 
929 template <typename T>
930 std::vector<FpValues> CreateFpValuesForBoundaryTest() {
931   return {GetZeros<T>(), GetSubnormals<T>(1000), GetInfinites<T>(),
932           GetNans<T>(1000)};
933 }
934 
935 inline std::vector<std::pair<int64_t, int64_t>> CreateExhaustiveF32Ranges() {
936   // We break up the 2^32-element space into small'ish chunks to keep peak
937   // memory usage low.
938   std::vector<std::pair<int64_t, int64_t>> result;
939   const int64_t step = 1 << 25;
940   for (int64_t i = 0; i < (1l << 32); i += step) {
941     result.push_back({i, i + step});
942   }
943   return result;
944 }
945 
946 template <PrimitiveType T, size_t N>
947 inline ErrorSpec DefaultSpecGenerator(
948     typename ExhaustiveOpTestBase<T, N>::NativeT) {
949   LOG(FATAL) << "Unhandled Type";
950 }
951 
952 template <PrimitiveType T, size_t N>
953 inline ErrorSpec DefaultSpecGenerator(
954     typename ExhaustiveOpTestBase<T, N>::NativeT,
955     typename ExhaustiveOpTestBase<T, N>::NativeT) {
956   LOG(FATAL) << "Unhandled Type";
957 }
958 
959 template <>
960 inline ErrorSpec DefaultSpecGenerator<C128, 1>(complex128) {
961   return ErrorSpec{0.0001, 0.0001};
962 }
963 
964 template <>
965 inline ErrorSpec DefaultSpecGenerator<C64, 1>(complex64) {
966   return ErrorSpec{0.0001, 0.0001};
967 }
968 
969 template <>
970 inline ErrorSpec DefaultSpecGenerator<F64, 1>(double) {
971   return ErrorSpec{0.0001, 0.0001};
972 }
973 
974 template <>
975 inline ErrorSpec DefaultSpecGenerator<F32, 1>(float) {
976   return ErrorSpec{0.0001, 0.0001};
977 }
978 
979 template <>
980 inline ErrorSpec DefaultSpecGenerator<F16, 1>(Eigen::half) {
981   return ErrorSpec{0.001, 0.001};
982 }
983 
984 template <>
985 inline ErrorSpec DefaultSpecGenerator<BF16, 1>(bfloat16) {
986   return ErrorSpec{0.002, 0.02};
987 }
988 
989 template <>
990 inline ErrorSpec DefaultSpecGenerator<F64, 2>(double, double) {
991   return ErrorSpec{0.001, 0.001};
992 }
993 
994 template <>
995 inline ErrorSpec DefaultSpecGenerator<F32, 2>(float, float) {
996   return ErrorSpec{0.001, 0.001};
997 }
998 
999 template <>
1000 inline ErrorSpec DefaultSpecGenerator<F16, 2>(Eigen::half, Eigen::half) {
1001   return ErrorSpec{0.001, 0.001};
1002 }
1003 
1004 template <>
1005 inline ErrorSpec DefaultSpecGenerator<BF16, 2>(bfloat16, bfloat16) {
1006   return ErrorSpec{0.002, 0.02};
1007 }
1008 
1009 template <PrimitiveType T, size_t N>
1010 typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator() {
1011   return DefaultSpecGenerator<T, N>;
1012 }
1013 
1014 template <typename T, typename std::enable_if<
1015                           std::is_same<T, float>::value ||
1016                           std::is_same<T, double>::value>::type* = nullptr>
1017 T ReferenceMax(T x, T y) {
1018   // We need to propagate NAN here because std::max may not propagate NAN.
1019   if (std::fpclassify(x) == FP_NAN) {
1020     return x;
1021   }
1022   if (std::fpclassify(y) == FP_NAN) {
1023     return y;
1024   }
1025 
1026   return std::max<T>(x, y);
1027 }
1028 
1029 template <typename T, typename std::enable_if<
1030                           std::is_same<T, float>::value ||
1031                           std::is_same<T, double>::value>::type* = nullptr>
1032 T ReferenceMin(T x, T y) {
1033   // We need to propagate NAN here because std::max may not propagate NAN.
1034   if (std::fpclassify(x) == FP_NAN) {
1035     return x;
1036   }
1037   if (std::fpclassify(y) == FP_NAN) {
1038     return y;
1039   }
1040 
1041   return std::min<T>(x, y);
1042 }
1043 
1044 // Returns a wrapper of the given build method, which build an HLO operation
1045 // with an empty broadcast dimension.
1046 inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension(
1047     std::function<XlaOp(XlaOp, XlaOp, absl::Span<const int64_t>)>
1048         build_method) {
1049   return [&](XlaOp src0, XlaOp src1) -> XlaOp {
1050     return build_method(src0, src1, {});
1051   };
1052 }
1053 
1054 template <PrimitiveType T>
1055 class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> {
1056  public:
1057   using typename ExhaustiveOpTestBase<T, 1>::ErrorSpecGen;
1058   static ErrorSpecGen GetDefaultSpecGenerator() {
1059     return exhaustive_op_test::GetDefaultSpecGenerator<T, 1>();
1060   }
1061 };
1062 
1063 template <PrimitiveType T>
1064 using ExhaustiveBinaryTest = ExhaustiveOpTestBase<T, 2>;
1065 
1066 }  // namespace exhaustive_op_test
1067 }  // namespace xla
1068 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
1069