1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_ 17 #define TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_ 18 19 #include <algorithm> 20 #include <array> 21 #include <cmath> 22 #include <cstdint> 23 #include <functional> 24 #include <iterator> 25 #include <string> 26 #include <utility> 27 28 #include "tensorflow/compiler/xla/bit_cast.h" 29 #include "tensorflow/compiler/xla/client/lib/constants.h" 30 #include "tensorflow/compiler/xla/client/lib/math.h" 31 #include "tensorflow/compiler/xla/client/xla_builder.h" 32 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 33 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 34 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 35 #include "tensorflow/compiler/xla/tests/test_macros.h" 36 37 namespace xla { 38 namespace exhaustive_op_test { 39 40 struct ErrorSpec { 41 float abs_err; 42 float rel_err; 43 44 // If true, will consider -0 not near to +0 and vice versa. Note that 45 // +epsilon may still be considered close to -0, depending on the error 46 // spec; this only covers the case when both `expected` and `actual` are 47 // equal to 0. 48 bool strict_signed_zeros = false; 49 ErrorSpecErrorSpec50 ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {} 51 }; 52 53 // Representations of the reference function passed in by the user. 54 template <typename NativeRefT, size_t K> 55 struct EvaluateOpWrapper {}; 56 template <typename NativeRefT> 57 struct EvaluateOpWrapper<NativeRefT, 1> { 58 using type = NativeRefT (*)(NativeRefT); 59 }; 60 template <typename NativeRefT> 61 struct EvaluateOpWrapper<NativeRefT, 2> { 62 using type = NativeRefT (*)(NativeRefT, NativeRefT); 63 }; 64 65 // Representations of the reference function passed in by the user. 66 template <typename XlaInputs, size_t K> 67 struct EnqueueOpWrapper {}; 68 template <typename XlaInputs> 69 struct EnqueueOpWrapper<XlaInputs, 1> { 70 using type = std::function<XlaOp(XlaOp)>; 71 static XlaOp BuildFromInputs(XlaInputs inputs, type ty) { 72 return ty(inputs[0]); 73 } 74 }; 75 template <typename XlaInputs> 76 struct EnqueueOpWrapper<XlaInputs, 2> { 77 using type = std::function<XlaOp(XlaOp, XlaOp)>; 78 static XlaOp BuildFromInputs(XlaInputs inputs, type ty) { 79 return ty(inputs[0], inputs[1]); 80 } 81 }; 82 83 // Representations of the ErrorSpecGen function passed in by the user. 84 template <PrimitiveType T, size_t K> 85 struct ErrorSpecGenWrapper {}; 86 template <PrimitiveType T> 87 struct ErrorSpecGenWrapper<T, 1> { 88 using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type; 89 using type = ErrorSpec (*)(NativeT); 90 }; 91 template <PrimitiveType T> 92 struct ErrorSpecGenWrapper<T, 2> { 93 using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type; 94 using type = ErrorSpec (*)(NativeT, NativeT); 95 }; 96 97 template <PrimitiveType T, size_t N> 98 typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator(); 99 100 // T: The primitive type being tested. 101 // N: The number of operands that the function being tested takes. 102 template <PrimitiveType T, size_t N> 103 class ExhaustiveOpTestBase : public ClientLibraryTestBase { 104 public: 105 // Definitions depending on the primitive type T. 106 107 static constexpr bool kIsComplex = (T == C128 || T == C64); 108 109 // The primitive type used to compute the reference output. 110 struct RefT { 111 static constexpr PrimitiveType value = (T == F16 || T == BF16) ? F32 : T; 112 }; 113 114 // The primitive type of the component of T. If T is not complex, then 115 // ComponentT = T. 116 struct ComponentT { 117 static constexpr PrimitiveType value = !kIsComplex ? T 118 : T == C128 ? F64 119 : T == C64 ? F32 120 : PRIMITIVE_TYPE_INVALID; 121 }; 122 123 // Same as ComponentT, but for the RefT primitive type. 124 struct ComponentRefT { 125 static constexpr PrimitiveType value = !kIsComplex ? RefT::value 126 : RefT::value == C128 ? F64 127 : RefT::value == C64 128 ? F32 129 : PRIMITIVE_TYPE_INVALID; 130 }; 131 132 // The primitive type of an unsigned integer that can be bitcasted to and from 133 // ComponentT. 134 struct ComponentIntegralT { 135 static constexpr PrimitiveType value = (T == C128 || T == F64) ? U64 136 : (T == C64 || T == F32) ? U32 137 : (T == F16 || T == BF16) 138 ? U16 139 : PRIMITIVE_TYPE_INVALID; 140 }; 141 142 // Native types that correspond to the primitive types above. 143 using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type; 144 using NativeRefT = 145 typename primitive_util::PrimitiveTypeToNative<RefT::value>::type; 146 using ComponentNativeT = 147 typename primitive_util::PrimitiveTypeToNative<ComponentT::value>::type; 148 using ComponentNativeRefT = typename primitive_util::PrimitiveTypeToNative< 149 ComponentRefT::value>::type; 150 using ComponentIntegralNativeT = 151 typename primitive_util::PrimitiveTypeToNative< 152 ComponentIntegralT::value>::type; 153 154 using InputLiterals = std::array<Literal, N>; 155 156 private: 157 // N spans corresponding to the list of literal data values. 158 using NativeInputsList = std::array<absl::Span<const NativeT>, N>; 159 160 // N data items representing a single input to an XLA function. 161 using NativeInputs = std::array<NativeT, N>; 162 163 // N data items representing a single input to an interpreter backend 164 // function. 165 using NativeRefInputs = std::array<NativeRefT, N>; 166 167 // N data items representing a single input to an XLA function. 168 using XlaInputs = std::array<XlaOp, N>; 169 170 public: 171 using ErrorSpecGen = typename ErrorSpecGenWrapper<T, N>::type; 172 using EvaluateOp = typename EvaluateOpWrapper<NativeRefT, N>::type; 173 using EnqueueOp = typename EnqueueOpWrapper<XlaInputs, N>::type; 174 175 explicit ExhaustiveOpTestBase() 176 : ty_(T), platform_(client_->platform()->Name()) { 177 SetFastMathDisabled(true); 178 179 // Run all HLO passes. In particular, constant folding is disabled by 180 // default for tests, but we need to run it in order to tickle some bugs. 181 mutable_debug_options()->clear_xla_disable_hlo_passes(); 182 } 183 184 void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op) { 185 Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator<T, N>()); 186 } 187 188 // A helper for implementing the Run method for exhaustive op tests. It 189 // constructs the HLO module, compiles and runs the module and checks the 190 // result. 191 // 192 // We use a function pointer for evaluate_op for performance because it is 193 // called each time an output element is compared inside a loop in routine 194 // ExpectNear. 195 void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op, 196 ErrorSpecGen error_spec_gen) { 197 InputLiterals input_literals = CreateInputLiterals(); 198 FillInput(&input_literals); 199 200 XlaBuilder builder(TestName()); 201 XlaInputs xla_inputs; 202 for (int i = 0; i < N; ++i) { 203 xla_inputs[i] = 204 Parameter(&builder, i, input_literals[i].shape(), "input"); 205 } 206 EnqueueOpWrapper<XlaInputs, N>::BuildFromInputs(xla_inputs, enqueue_op); 207 208 TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); 209 TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, 210 RunComputationHelper(comp, input_literals)); 211 ExpectNear(input_literals, result_literal, evaluate_op, error_spec_gen); 212 } 213 214 StatusOr<Literal> RunComputationHelper(const XlaComputation& comp, 215 const Literal& literal) { 216 return RunComputation(comp, {&literal}); 217 } 218 219 StatusOr<Literal> RunComputationHelper( 220 const XlaComputation& comp, const std::array<Literal, N>& literals) { 221 std::array<const Literal*, N> lit_ptrs; 222 for (int i = 0; i < N; ++i) { 223 lit_ptrs[i] = &literals[i]; 224 } 225 return RunComputation(comp, lit_ptrs); 226 } 227 228 // We essentially reimplement LiteralTestUtil::Near here because 229 // a) this streamlined implementation is much faster, and 230 // b) we can print out better error messages (namely, we can print out 231 // which floating-point value input failed, while LiteralTestUtil::Near 232 // can only print out the input index that failed). 233 // c) we need special handling of certain inputs. For example, we say that 234 // a denormal input has multiple correct outputs (namely, f(x) and f(0)) 235 // and just needs to be close to one of them. 236 void ExpectNear(const InputLiterals& input_literals, 237 const Literal& result_literal, EvaluateOp evaluate_op, 238 ErrorSpecGen error_spec_gen); 239 240 // Builds and runs the computation using the LocalClient API, rather than the 241 // plain Client API, which is used by ClientLibraryTestBase. This is because 242 // the plain Client API results does more memcpys to/from Literals, and that's 243 // slow given that we're touching a lot of data here. 244 StatusOr<Literal> RunComputation( 245 const XlaComputation& computation, 246 absl::Span<const Literal* const> input_literals) { 247 // Copy debug options from ClientLibraryTestBase. In particular, we're 248 // interested in disabling constant folding. 249 ExecutableBuildOptions build_opts; 250 *build_opts.mutable_debug_options() = *mutable_debug_options(); 251 252 std::vector<ScopedShapedBuffer> input_buffers; 253 absl::c_transform(input_literals, std::back_inserter(input_buffers), 254 [&](const Literal* input_literal) { 255 return client_ 256 ->LiteralToShapedBuffer(*input_literal, 257 /*device_ordinal=*/0) 258 .value(); 259 }); 260 std::vector<const Shape*> input_shapes; 261 absl::c_transform(input_buffers, std::back_inserter(input_shapes), 262 [&](const ScopedShapedBuffer& buffer) { 263 return &buffer.on_device_shape(); 264 }); 265 266 TF_ASSIGN_OR_RETURN( 267 auto executables, 268 client_->Compile(computation, input_shapes, build_opts)); 269 270 std::vector<const ShapedBuffer*> input_buffer_pointers; 271 absl::c_transform( 272 input_buffers, std::back_inserter(input_buffer_pointers), 273 [&](const ScopedShapedBuffer& buffer) { return &buffer; }); 274 275 ExecutableRunOptions run_opts; 276 run_opts.set_allocator(client_->backend().memory_allocator()); 277 run_opts.set_intra_op_thread_pool( 278 client_->backend().eigen_intra_op_thread_pool_device()); 279 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, 280 executables[0]->Run(input_buffer_pointers, run_opts)); 281 282 TF_ASSIGN_OR_RETURN(Literal result_literal, 283 client_->ShapedBufferToLiteral(result)); 284 return std::move(result_literal); 285 } 286 287 const std::string& Platform() { return platform_; } 288 289 // Returns the number of elements in each input literal. 290 virtual int64_t GetInputSize() = 0; 291 292 // Fills the literals with values to test for. 293 virtual void FillInput(InputLiterals* literals) = 0; 294 295 // Replace infinites with max value to help compute errors. 296 static ComponentNativeRefT ReplaceInfWithMax(ComponentNativeRefT value) { 297 if (std::isinf(value)) { 298 return std::copysign(std::numeric_limits<ComponentNativeRefT>::max(), 299 value); 300 } 301 return value; 302 } 303 304 // Returns true if both components are 0, but their sign bits differ. 305 static bool CheckSignedZeroError(ComponentNativeRefT expected, 306 ComponentNativeRefT actual) { 307 return expected == 0 && actual == 0 && 308 std::signbit(expected) != std::signbit(actual); 309 } 310 311 // Sets the components to 0 if both are NaNs. 312 static void RemoveCorrespondingNaNs(ComponentNativeRefT* expected, 313 ComponentNativeRefT* actual) { 314 if (std::isnan(*expected) && std::isnan(*actual)) { 315 *expected = 0; 316 *actual = 0; 317 } 318 } 319 320 // The Implementation of the functions above, except for complex inputs. 321 322 static std::complex<ComponentNativeRefT> ReplaceInfWithMax( 323 std::complex<ComponentNativeRefT> value) { 324 value.real(ReplaceInfWithMax(value.real())); 325 value.imag(ReplaceInfWithMax(value.imag())); 326 return value; 327 } 328 329 static bool CheckSignedZeroError(std::complex<ComponentNativeRefT> expected, 330 std::complex<ComponentNativeRefT> actual) { 331 return CheckSignedZeroError(expected.real(), actual.real()) || 332 CheckSignedZeroError(expected.imag(), actual.imag()); 333 } 334 335 static void RemoveCorrespondingNaNs( 336 std::complex<ComponentNativeRefT>* expected, 337 std::complex<ComponentNativeRefT>* actual) { 338 ComponentNativeRefT expected_real = expected->real(); 339 ComponentNativeRefT expected_imag = expected->imag(); 340 ComponentNativeRefT actual_real = actual->real(); 341 ComponentNativeRefT actual_imag = actual->imag(); 342 RemoveCorrespondingNaNs(&expected_real, &actual_real); 343 RemoveCorrespondingNaNs(&expected_imag, &actual_imag); 344 expected->real(expected_real); 345 expected->imag(expected_imag); 346 actual->real(actual_real); 347 actual->imag(actual_imag); 348 } 349 350 // Returns a list of inputs that should be tested for closeness given some 351 // original input values. 352 // 353 // For denormal component inputs, we accept answers that are close to any of: 354 // 355 // - evaluate_op(input) 356 // - evaluate_op(+/-0), where the sign of 0 equal to the sign of 357 // `input`, 358 // - evaluate_op(+/-min_normal_float), where the sign of 359 // min_normal_float matches `input`. 360 // - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of 361 // 0 is the opposite of `input`. 362 // 363 // (In particular, the XLA:CPU implementation of log flushes positive 364 // denormals to min-normal-float. This seems kind of reasonable if our 365 // goal is to avoid infinities because they cause nans?) 366 std::vector<ComponentNativeRefT> GetTestValuesWithSubnormalSubstitutions( 367 ComponentNativeRefT value) { 368 std::vector<ComponentNativeRefT> test_values; 369 if (std::fpclassify(value) == FP_SUBNORMAL) { 370 test_values.reserve(relaxed_denormal_signs_ ? 3 : 2); 371 test_values.push_back(std::copysign(0, value)); 372 test_values.push_back(std::copysign( 373 std::numeric_limits<ComponentNativeRefT>::min(), value)); 374 if (relaxed_denormal_signs_) { 375 test_values.push_back(std::copysign(0, -value)); 376 } 377 } else { 378 test_values.push_back(value); 379 } 380 return test_values; 381 } 382 383 // Similar to complex numbers, we only need to test the components that are 384 // subnormal. We can find the subnormal testing values for each component, 385 // then take the Cartesian product of each set of component values. 386 std::vector<std::complex<ComponentNativeRefT>> 387 GetTestValuesWithSubnormalSubstitutions( 388 std::complex<ComponentNativeRefT> value) { 389 using complex = std::complex<ComponentNativeRefT>; 390 391 auto real_values = GetTestValuesWithSubnormalSubstitutions(value.real()); 392 auto imag_values = GetTestValuesWithSubnormalSubstitutions(value.imag()); 393 394 std::vector<complex> test_values; 395 test_values.reserve(real_values.size() * imag_values.size()); 396 for (auto real : real_values) { 397 for (auto imag : imag_values) { 398 test_values.push_back(complex(real, imag)); 399 } 400 } 401 402 return test_values; 403 } 404 405 // The test values for an XLA function with N operands are the Cartesian 406 // product of the test values for each of the N operands. 407 std::vector<std::array<NativeRefT, N>> 408 GetTestValuesWithSubnormalSubstitutions( 409 const std::array<NativeRefT, N>& value) { 410 std::vector<std::array<NativeRefT, N>> test_values; 411 412 std::array<std::vector<NativeRefT>, N> component_test_values; 413 int total = 1; 414 for (int i = 0; i < N; ++i) { 415 component_test_values[i] = 416 GetTestValuesWithSubnormalSubstitutions(value[i]); 417 if (!component_test_values.empty()) { 418 total *= component_test_values[i].size(); 419 } 420 } 421 422 // If total == 1, then value has no subnormal components, so we can just 423 // return a vector with value in it. 424 if (total == 1) { 425 test_values.push_back(value); 426 return test_values; 427 } 428 429 test_values.reserve(total); 430 431 // Perform a Cartesian product of the vectors in component_test_values. 432 // We can calculate this by uniquely mapping each integer from 0 to 433 // (total - 1) to a list of component indices. The function that maps an 434 // integer z to the index of component j is: 435 // component_index(j) = (i / NumValues(0, j-1)) % NumValues(j, j) 436 // and NumIndices(x, y) is the number of values in the Cartesian product of 437 // component_test_values[x], component_test_values[x+1], ... 438 // component_test_values[y]. 439 for (int i = 0; i < total; ++i) { 440 int accumulated_num_values = 1; 441 std::array<NativeRefT, N> test_value; 442 for (int j = 0; j < N; ++j) { 443 int num_indices = component_test_values[j].size(); 444 int component_index = (i / accumulated_num_values) % num_indices; 445 test_value[j] = component_test_values[j][component_index]; 446 accumulated_num_values *= num_indices; 447 } 448 test_values.push_back(std::move(test_value)); 449 } 450 return test_values; 451 } 452 453 InputLiterals CreateInputLiterals() { 454 InputLiterals literals; 455 for (int i = 0; i < N; ++i) { 456 literals[i] = LiteralUtil::CreateFromDimensions(T, {GetInputSize()}); 457 } 458 return std::move(literals); 459 } 460 461 // Determines if two output values are sufficiently close to each other based 462 // on an error spec. 463 bool IsClose(NativeRefT expected, NativeRefT actual, ErrorSpec spec) { 464 // When two corresponding values are a NaN, they can be considered to have 465 // the same value, so the values are just set to 0. 466 RemoveCorrespondingNaNs(&expected, &actual); 467 468 if (spec.strict_signed_zeros) { 469 if (CheckSignedZeroError(expected, actual)) { 470 return false; 471 } 472 } 473 474 // Replace Inf with Max when calculating absolute or relative errors. This 475 // allows the test to pass when another value are close to Inf and the 476 // specified absolute or relative errors are not zero. 477 double abs_err = 478 std::abs(ReplaceInfWithMax(expected) - ReplaceInfWithMax(actual)); 479 double rel_err = abs_err / std::abs(ReplaceInfWithMax(expected)); 480 481 return abs_err <= spec.abs_err || rel_err <= spec.rel_err; 482 } 483 484 // Converts part or all bits in an uint64_t to the value of the floating point 485 // data type being tested. 486 // 487 // When trying to exhaustive test for an operation of data type T, we always 488 // use an integral I with the same number of bits at T to exhaustive the input 489 // bit patterns for T. This bit pattern is zero extended and stored as 490 // uint64_t. This function is used to convert such a bit pattern stored as 491 // uint64_t to the input value for T. 492 static ComponentNativeT ConvertValue(uint64_t bits) { 493 using I = ComponentIntegralNativeT; 494 I used_bits = static_cast<I>(bits); 495 return BitCast<ComponentNativeT>(used_bits); 496 } 497 498 ComponentNativeT ConvertAndReplaceKnownIncorrectValueWith( 499 uint64_t bits, int replacement_value = 0) { 500 if (known_incorrect_fn_ && known_incorrect_fn_(bits)) { 501 return static_cast<ComponentNativeT>(replacement_value); 502 } 503 return ConvertValue(bits); 504 } 505 506 protected: 507 // The primitive type being tested. 508 const PrimitiveType ty_; 509 510 // The platform under test. 511 const std::string platform_; 512 513 // Testing will ignore inputs for which known_incorrect_fn_ returns true. The 514 // argument to the function is the raw bits for the data being test, zero 515 // extended to 64 bits if the data type is less than 64 bits. 516 std::function<bool(int64_t)> known_incorrect_fn_; 517 518 // If true, allows denormals to be flushed to non-sign-preserving 0. 519 // 520 // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of 521 // a negative number) or -inf (flush the denormal to sign-perserving zero, 522 // then sqrt(-0)). But with this as true, we'll also accept 0 (sqrt(0)). 523 // 524 // XLA:GPU preserves denormal signs, but other backends don't. 525 bool relaxed_denormal_signs_ = platform_ != "CUDA"; 526 }; 527 528 // Represents a set of 64 bit chunks by representing the starting bit chunk, 529 // the last bit chunk, and the spacing between two adjacent bit chunks, without 530 // actually storing all the bit chunks being generated. The bit chunk iterator 531 // is provided to retrieve all the bit chunks. 532 // 533 // This data structure is used to generate the bit representation to test 534 // operations that requires more than 64 bit input data. In this case, 535 // truly exhaustive testing is not possible and we want to test a value every 536 // n values, where n == spacing_. 537 // 538 // Currently, the iterator of BitChunks adds the `spacing_` to a bit chunk to 539 // compute the next bit chunk. We can change this to use values generated 540 // by a random number generator that can achieve the average spacing 541 // statistically, if we will find this is necessary. 542 class BitChunks { 543 public: 544 class iterator { 545 public: 546 using iterator_category = std::input_iterator_tag; 547 using value_type = uint64_t; 548 using difference_type = uint64_t; 549 using pointer = const uint64_t*; 550 using reference = uint64_t; 551 552 iterator() = default; 553 554 explicit iterator(const BitChunks* bit_chunks) 555 : bit_chunks_(bit_chunks), next_bit_chunk_(bit_chunks->start_) {} 556 557 iterator& operator++() { 558 Next(); 559 return *this; 560 } 561 562 iterator operator++(int) { 563 iterator retval = *this; 564 Next(); 565 return retval; 566 } 567 568 bool operator==(iterator other) const { 569 return bit_chunks_ == other.bit_chunks_ && 570 next_bit_chunk_ == other.next_bit_chunk_; 571 } 572 573 bool operator!=(iterator other) const { return !(*this == other); } 574 575 iterator MoveToEnd() { 576 MoveNextBitChunkToOnePassEnd(); 577 return *this; 578 } 579 580 reference operator*() const { 581 CHECK(*this != this->bit_chunks_->end()); 582 return next_bit_chunk_; 583 } 584 585 const BitChunks* GetBitChunks() const { return bit_chunks_; } 586 587 void Reset() { next_bit_chunk_ = bit_chunks_->start_; } 588 589 void Next() { 590 CHECK(*this != this->bit_chunks_->end()); 591 if (next_bit_chunk_ == bit_chunks_->end_) { 592 MoveNextBitChunkToOnePassEnd(); 593 } else { 594 next_bit_chunk_ += bit_chunks_->spacing_; 595 if (next_bit_chunk_ > bit_chunks_->end_) { 596 next_bit_chunk_ = bit_chunks_->end_; 597 } 598 } 599 } 600 601 std::string ToString() const { 602 return absl::StrFormat("0x%08x", next_bit_chunk_); 603 } 604 605 private: 606 // Move next_bit_chunk_ to 1 pass the bit_chunks_->end, to mark that the 607 // iterator has reached the end. When spacing_ is not one, or if we will 608 // change to use a random value instead of spacing_ in function Next(), 609 // normalizing the representation of the iterator ending this way can 610 // can simplify the checking for iterator ending. 611 void MoveNextBitChunkToOnePassEnd() { 612 next_bit_chunk_ = bit_chunks_->end_ + 1; 613 } 614 615 const BitChunks* bit_chunks_; 616 uint64_t next_bit_chunk_; 617 }; 618 619 iterator begin() const { return iterator(this); } 620 iterator end() const { 621 iterator end(this); 622 return end.MoveToEnd(); 623 } 624 625 explicit BitChunks(uint64_t start = 0, uint64_t end = 0, uint64_t spacing = 1) 626 : start_(start), end_(end), spacing_(spacing) { 627 CHECK_GE(end_, start_); 628 CHECK_NE(spacing, 0) << ToString(); 629 } 630 631 int64_t GetTotalBitChunks() const { 632 if (start_ == end_) { 633 return 1; 634 } 635 636 return 1 + (end_ - start_ + spacing_ - 1) / spacing_; 637 } 638 639 std::string ToString() const { 640 return absl::StrFormat("(0x%08x, 0x%08x, 0x%08x)", start_, end_, spacing_); 641 } 642 643 uint64_t start_; 644 uint64_t end_; 645 uint64_t spacing_; 646 }; 647 648 inline std::string StringifyNum(BitChunks c) { return c.ToString(); } 649 650 inline std::string StringifyNum(BitChunks::iterator c) { return c.ToString(); } 651 652 template <typename T> 653 void AppendStringifyNum(std::string* s, T x) { 654 absl::StrAppend(s, StringifyNum(x)); 655 } 656 657 // Represents a set of floating point values through the possible values for 658 // the three components: mantissa, exponent, and sign. Also implements an 659 // iterator for retrieving all the represented floating point values. 660 class FpValues { 661 public: 662 static constexpr int kTotalBitChunks = 3; 663 664 class iterator { 665 public: 666 using iterator_category = std::input_iterator_tag; 667 using value_type = uint64_t; 668 using difference_type = uint64_t; 669 using pointer = const uint64_t*; 670 using reference = uint64_t; 671 672 explicit iterator(const FpValues* fp_values) : fp_values_(fp_values) { 673 for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { 674 iters_[i] = BitChunks::iterator(&fp_values->GetBitChunks(i)); 675 } 676 } 677 678 iterator& operator++() { 679 Next(); 680 return *this; 681 } 682 683 iterator operator++(int) { 684 iterator retval = *this; 685 Next(); 686 return retval; 687 } 688 689 bool operator==(iterator other) const { 690 for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { 691 if (iters_[i] != other.GetBitChunksIter(i)) { 692 return false; 693 } 694 } 695 return true; 696 } 697 698 bool operator!=(iterator other) const { return !(*this == other); } 699 700 iterator MoveToEnd() { 701 for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { 702 iters_[i].MoveToEnd(); 703 } 704 return *this; 705 } 706 707 uint64_t operator*() const { 708 uint64_t value = 0; 709 for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { 710 value = value | (*iters_[i]) << fp_values_->offsets_[i]; 711 } 712 return value; 713 } 714 715 const BitChunks::iterator& GetBitChunksIter(int i) { return iters_[i]; } 716 717 std::string ToString() const { 718 return absl::StrJoin(iters_, ",", 719 AppendStringifyNum<BitChunks::iterator>); 720 } 721 722 private: 723 // Moves the iterator for the ith BitChunks to the next value, and 724 // returns true if the new state is not the end of the iterator. 725 bool Next(int i = 0) { 726 iters_[i].Next(); 727 if (iters_[i] == iters_[i].GetBitChunks()->end()) { 728 if (i == FpValues::kTotalBitChunks - 1) { 729 return false; 730 } 731 if (Next(i + 1)) { 732 iters_[i].Reset(); 733 return true; 734 } 735 return false; 736 } 737 return true; 738 } 739 740 std::array<BitChunks::iterator, FpValues::kTotalBitChunks> iters_; 741 const FpValues* fp_values_; 742 }; 743 744 FpValues() : bit_chunks_(), offsets_() {} 745 FpValues(absl::Span<const BitChunks> chunks, absl::Span<const int> offsets) { 746 CHECK_EQ(chunks.size(), offsets.size() - 1); 747 CHECK_EQ(chunks.size(), kTotalBitChunks); 748 std::copy_n(chunks.begin(), kTotalBitChunks, bit_chunks_.begin()); 749 std::copy_n(offsets.begin(), kTotalBitChunks, offsets_.begin()); 750 751 // The last value in `offsets` is the total number of bits. 752 offsets_[kTotalBitChunks] = offsets[kTotalBitChunks]; 753 // Validate the input values. 754 for (int i = 0; i < kTotalBitChunks; ++i) { 755 int total_bits = offsets[i + 1] - offsets[i]; 756 if (total_bits < 64) { 757 uint64_t bound = 1ull << total_bits; 758 CHECK_LT(chunks[i].start_, bound); 759 CHECK_LT(chunks[i].end_, bound); 760 } else { 761 CHECK_EQ(total_bits, 64); 762 } 763 } 764 } 765 766 iterator begin() const { return iterator(this); } 767 768 iterator end() const { 769 iterator end(this); 770 return end.MoveToEnd(); 771 } 772 773 int64_t GetTotalNumValues() const { 774 int64_t total = 1; 775 absl::c_for_each(bit_chunks_, [&](const BitChunks& chunks) { 776 total *= chunks.GetTotalBitChunks(); 777 }); 778 return total; 779 } 780 781 const BitChunks& GetBitChunks(int i) const { return bit_chunks_[i]; } 782 783 std::string ToString() const { 784 return absl::StrCat( 785 "[", absl::StrJoin(bit_chunks_, ",", AppendStringifyNum<BitChunks>), 786 "]"); 787 } 788 789 std::array<BitChunks, kTotalBitChunks> bit_chunks_; 790 std::array<int, kTotalBitChunks + 1> offsets_; 791 }; 792 793 template <typename T, typename std::enable_if< 794 std::is_same<T, float>::value || 795 std::is_same<T, double>::value>::type* = nullptr> 796 int GetMantissaTotalBits() { 797 return std::numeric_limits<T>::digits - 1; 798 } 799 800 template <typename T> 801 int GetFpTotalBits() { 802 return sizeof(T) * 8; 803 } 804 805 template <typename T> 806 int GetExponentTotalBits() { 807 return GetFpTotalBits<T>() - GetMantissaTotalBits<T>() - 1; 808 } 809 810 template <typename T> 811 uint64_t GetAllOneMantissa() { 812 return (1ull << GetMantissaTotalBits<T>()) - 1ull; 813 } 814 815 template <typename T> 816 uint64_t GetAllOneExponent() { 817 return (1ull << GetExponentTotalBits<T>()) - 1ull; 818 } 819 820 template <typename T, typename std::enable_if< 821 std::is_same<T, float>::value || 822 std::is_same<T, double>::value>::type* = nullptr> 823 FpValues GetFpValues(BitChunks mantissa, BitChunks exponent, BitChunks sign) { 824 int total_bits = GetFpTotalBits<T>(); 825 return FpValues({mantissa, exponent, sign}, 826 {0, GetMantissaTotalBits<T>(), total_bits - 1, total_bits}); 827 } 828 829 template <typename T> 830 FpValues GetZeros() { 831 return GetFpValues<T>(BitChunks(0, 0, 1), BitChunks(0, 0, 1), 832 BitChunks(0, 1, 1)); 833 } 834 835 template <typename T> 836 FpValues GetSubnormals(int approx_num_values) { 837 int mantissa = GetMantissaTotalBits<T>(); 838 uint64_t mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2); 839 return GetFpValues<T>( 840 BitChunks(0x1, GetAllOneMantissa<T>(), mantissa_spacing), 841 BitChunks(0, 0, 1), BitChunks(0, 1, 1)); 842 } 843 844 template <typename T> 845 FpValues GetInfinites() { 846 uint64_t all_one_exp = GetAllOneExponent<T>(); 847 return GetFpValues<T>(BitChunks(0, 0, 1), 848 BitChunks(all_one_exp, all_one_exp, 1), 849 BitChunks(0, 1, 1)); 850 } 851 852 template <typename T> 853 FpValues GetNans(int approx_num_values) { 854 int mantissa = GetMantissaTotalBits<T>(); 855 uint64_t mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2); 856 uint64_t all_one_exp = GetAllOneExponent<T>(); 857 return GetFpValues<T>( 858 BitChunks(0x1, GetAllOneMantissa<T>(), mantissa_spacing), 859 BitChunks(all_one_exp, all_one_exp, 1), BitChunks(0, 1, 1)); 860 } 861 862 template <typename T> 863 FpValues GetNormals(int approx_num_values) { 864 float component_total = std::sqrt(static_cast<float>(approx_num_values)); 865 return GetFpValues<T>( 866 BitChunks(0x1, GetAllOneMantissa<T>(), 867 (1ull << (GetMantissaTotalBits<T>() + 1)) / component_total), 868 BitChunks(0x1, GetAllOneExponent<T>() - 1, 869 (1ull << (GetExponentTotalBits<T>() + 1)) / component_total), 870 BitChunks(0, 1, 1)); 871 } 872 873 // Returns a vector of FpValues, which together represent about 874 // `approx_num_values` floating point values of type `T`, with each FpValues 875 // represents about `num_values_per_group` floating point values. 876 template <typename T> 877 std::vector<FpValues> GetFpValuesWithExponents(uint64_t first_exponent, 878 uint64_t exponent_spacing, 879 uint64_t num_exponents, 880 uint64_t approx_num_values, 881 uint64_t num_values_per_group) { 882 const uint64_t num_signs = 2; 883 uint64_t approx_num_mantissa = 884 approx_num_values / (num_exponents * num_signs); 885 uint64_t num_mantissa_per_group = 886 num_values_per_group / (num_exponents * num_signs); 887 CHECK_GT(approx_num_mantissa, 0); 888 CHECK_GT(num_mantissa_per_group, 0); 889 890 CHECK_LT(first_exponent + num_exponents - 1ull, GetAllOneExponent<T>()); 891 int mantissa = GetMantissaTotalBits<T>(); 892 uint64_t mantissa_spacing = (1ull << mantissa) / approx_num_mantissa; 893 894 std::vector<FpValues> result; 895 for (uint64_t group_start = 0; group_start < GetAllOneMantissa<T>(); 896 group_start += mantissa_spacing * num_mantissa_per_group) { 897 uint64_t group_end = 898 group_start + (num_mantissa_per_group - 1) * mantissa_spacing; 899 if (group_end > GetAllOneMantissa<T>()) { 900 group_end = GetAllOneMantissa<T>(); 901 } 902 result.push_back(GetFpValues<T>( 903 BitChunks(group_start, group_end, mantissa_spacing), 904 BitChunks(first_exponent, first_exponent + num_exponents - 1, 1), 905 BitChunks(0, 1, 1))); 906 } 907 return result; 908 } 909 910 // Returns a vector of FpValues together represent about `approx_num_values` 911 // "very large" floating point values and `approx_num_values` "very small" 912 // floating point values of type `T`, which each FpValues represent about 913 // `num_values_per_group` floating point values. Because we use FpValues as 914 // a parameter for parameterized testing, the number of floating values 915 // represented by each FpValues affects the input size for each sub-test and 916 // the hence the peak memory usage of the test. 917 template <typename T> 918 std::vector<FpValues> GetFpValuesForMagnitudeExtremeNormals( 919 uint64_t approx_num_values = 40000, uint64_t num_values_per_group = 4000) { 920 std::vector<FpValues> large = 921 GetFpValuesWithExponents<T>(GetAllOneExponent<T>() - 5, 1, 5, 922 approx_num_values / 2, num_values_per_group); 923 std::vector<FpValues> small = GetFpValuesWithExponents<T>( 924 1, 1, 5, approx_num_values / 2, num_values_per_group); 925 large.insert(large.end(), small.begin(), small.end()); 926 return large; 927 } 928 929 template <typename T> 930 std::vector<FpValues> CreateFpValuesForBoundaryTest() { 931 return {GetZeros<T>(), GetSubnormals<T>(1000), GetInfinites<T>(), 932 GetNans<T>(1000)}; 933 } 934 935 inline std::vector<std::pair<int64_t, int64_t>> CreateExhaustiveF32Ranges() { 936 // We break up the 2^32-element space into small'ish chunks to keep peak 937 // memory usage low. 938 std::vector<std::pair<int64_t, int64_t>> result; 939 const int64_t step = 1 << 25; 940 for (int64_t i = 0; i < (1l << 32); i += step) { 941 result.push_back({i, i + step}); 942 } 943 return result; 944 } 945 946 template <PrimitiveType T, size_t N> 947 inline ErrorSpec DefaultSpecGenerator( 948 typename ExhaustiveOpTestBase<T, N>::NativeT) { 949 LOG(FATAL) << "Unhandled Type"; 950 } 951 952 template <PrimitiveType T, size_t N> 953 inline ErrorSpec DefaultSpecGenerator( 954 typename ExhaustiveOpTestBase<T, N>::NativeT, 955 typename ExhaustiveOpTestBase<T, N>::NativeT) { 956 LOG(FATAL) << "Unhandled Type"; 957 } 958 959 template <> 960 inline ErrorSpec DefaultSpecGenerator<C128, 1>(complex128) { 961 return ErrorSpec{0.0001, 0.0001}; 962 } 963 964 template <> 965 inline ErrorSpec DefaultSpecGenerator<C64, 1>(complex64) { 966 return ErrorSpec{0.0001, 0.0001}; 967 } 968 969 template <> 970 inline ErrorSpec DefaultSpecGenerator<F64, 1>(double) { 971 return ErrorSpec{0.0001, 0.0001}; 972 } 973 974 template <> 975 inline ErrorSpec DefaultSpecGenerator<F32, 1>(float) { 976 return ErrorSpec{0.0001, 0.0001}; 977 } 978 979 template <> 980 inline ErrorSpec DefaultSpecGenerator<F16, 1>(Eigen::half) { 981 return ErrorSpec{0.001, 0.001}; 982 } 983 984 template <> 985 inline ErrorSpec DefaultSpecGenerator<BF16, 1>(bfloat16) { 986 return ErrorSpec{0.002, 0.02}; 987 } 988 989 template <> 990 inline ErrorSpec DefaultSpecGenerator<F64, 2>(double, double) { 991 return ErrorSpec{0.001, 0.001}; 992 } 993 994 template <> 995 inline ErrorSpec DefaultSpecGenerator<F32, 2>(float, float) { 996 return ErrorSpec{0.001, 0.001}; 997 } 998 999 template <> 1000 inline ErrorSpec DefaultSpecGenerator<F16, 2>(Eigen::half, Eigen::half) { 1001 return ErrorSpec{0.001, 0.001}; 1002 } 1003 1004 template <> 1005 inline ErrorSpec DefaultSpecGenerator<BF16, 2>(bfloat16, bfloat16) { 1006 return ErrorSpec{0.002, 0.02}; 1007 } 1008 1009 template <PrimitiveType T, size_t N> 1010 typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator() { 1011 return DefaultSpecGenerator<T, N>; 1012 } 1013 1014 template <typename T, typename std::enable_if< 1015 std::is_same<T, float>::value || 1016 std::is_same<T, double>::value>::type* = nullptr> 1017 T ReferenceMax(T x, T y) { 1018 // We need to propagate NAN here because std::max may not propagate NAN. 1019 if (std::fpclassify(x) == FP_NAN) { 1020 return x; 1021 } 1022 if (std::fpclassify(y) == FP_NAN) { 1023 return y; 1024 } 1025 1026 return std::max<T>(x, y); 1027 } 1028 1029 template <typename T, typename std::enable_if< 1030 std::is_same<T, float>::value || 1031 std::is_same<T, double>::value>::type* = nullptr> 1032 T ReferenceMin(T x, T y) { 1033 // We need to propagate NAN here because std::max may not propagate NAN. 1034 if (std::fpclassify(x) == FP_NAN) { 1035 return x; 1036 } 1037 if (std::fpclassify(y) == FP_NAN) { 1038 return y; 1039 } 1040 1041 return std::min<T>(x, y); 1042 } 1043 1044 // Returns a wrapper of the given build method, which build an HLO operation 1045 // with an empty broadcast dimension. 1046 inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension( 1047 std::function<XlaOp(XlaOp, XlaOp, absl::Span<const int64_t>)> 1048 build_method) { 1049 return [&](XlaOp src0, XlaOp src1) -> XlaOp { 1050 return build_method(src0, src1, {}); 1051 }; 1052 } 1053 1054 template <PrimitiveType T> 1055 class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> { 1056 public: 1057 using typename ExhaustiveOpTestBase<T, 1>::ErrorSpecGen; 1058 static ErrorSpecGen GetDefaultSpecGenerator() { 1059 return exhaustive_op_test::GetDefaultSpecGenerator<T, 1>(); 1060 } 1061 }; 1062 1063 template <PrimitiveType T> 1064 using ExhaustiveBinaryTest = ExhaustiveOpTestBase<T, 2>; 1065 1066 } // namespace exhaustive_op_test 1067 } // namespace xla 1068 #endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_ 1069