1 #pragma once 2 #include <ATen/cpu/vec/vec.h> 3 #include <ATen/cpu/vec/functional.h> 4 #include <c10/util/bit_cast.h> 5 #include <c10/util/irange.h> 6 #include <gtest/gtest.h> 7 #include <chrono> 8 #include <exception> 9 #include <functional> 10 #include <iostream> 11 #include <limits> 12 #include <random> 13 #include <vector> 14 #include <complex> 15 #include <math.h> 16 #include <float.h> 17 #include <algorithm> 18 19 #if defined(CPU_CAPABILITY_AVX512) 20 #define CACHE_LINE 64 21 #else 22 #define CACHE_LINE 32 23 #endif 24 25 #if defined(__GNUC__) 26 #define CACHE_ALIGN __attribute__((aligned(CACHE_LINE))) 27 #define not_inline __attribute__((noinline)) 28 #elif defined(_WIN32) 29 #define CACHE_ALIGN __declspec(align(CACHE_LINE)) 30 #define not_inline __declspec(noinline) 31 #else 32 CACHE_ALIGN #define 33 #define not_inline 34 #endif 35 #if defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) 36 #define TEST_AGAINST_DEFAULT 1 37 #elif !defined(CPU_CAPABILITY_AVX512) && !defined(CPU_CAPABILITY_AVX2) && !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_ZVECTOR) 38 #define TEST_AGAINST_DEFAULT 1 39 #else 40 #undef TEST_AGAINST_DEFAULT 41 #endif 42 #undef NAME_INFO 43 #define STRINGIFY(x) #x 44 #define TOSTRING(x) STRINGIFY(x) 45 #define NAME_INFO(name) TOSTRING(name) " " TOSTRING(__FILE__) ":" TOSTRING(__LINE__) 46 47 #define RESOLVE_OVERLOAD(...) \ 48 [](auto&&... args) -> decltype(auto) { \ 49 return __VA_ARGS__(std::forward<decltype(args)>(args)...); \ 50 } 51 52 #if defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_AVX2) || \ 53 defined(CPU_CAPABILITY_AVX512) && (defined(__GNUC__) || defined(__GNUG__)) 54 #undef CHECK_DEQUANT_WITH_LOW_PRECISION 55 #define CHECK_WITH_FMA 1 56 #elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2) 57 #undef CHECK_DEQUANT_WITH_LOW_PRECISION 58 #undef CHECK_WITH_FMA 59 #else 60 #define CHECK_DEQUANT_WITH_LOW_PRECISION 1 61 #undef CHECK_WITH_FMA 62 #endif 63 64 template<typename T> 65 using Complex = typename c10::complex<T>; 66 67 template <typename T> 68 using VecType = typename at::vec::Vectorized<T>; 69 70 using vfloat = VecType<float>; 71 using vdouble = VecType<double>; 72 using vcomplex = VecType<Complex<float>>; 73 using vcomplexDbl = VecType<Complex<double>>; 74 using vlong = VecType<int64_t>; 75 using vint = VecType<int32_t>; 76 using vshort = VecType<int16_t>; 77 using vqint8 = VecType<c10::qint8>; 78 using vquint8 = VecType<c10::quint8>; 79 using vqint = VecType<c10::qint32>; 80 using vBFloat16 = VecType<c10::BFloat16>; 81 using vHalf = VecType<c10::Half>; 82 83 template <typename T> 84 using ValueType = typename T::value_type; 85 86 template <int N> 87 struct BitStr 88 { 89 using type = uintmax_t; 90 }; 91 92 template <> 93 struct BitStr<8> 94 { 95 using type = uint64_t; 96 }; 97 98 template <> 99 struct BitStr<4> 100 { 101 using type = uint32_t; 102 }; 103 104 template <> 105 struct BitStr<2> 106 { 107 using type = uint16_t; 108 }; 109 110 template <> 111 struct BitStr<1> 112 { 113 using type = uint8_t; 114 }; 115 116 template <typename T> 117 using BitType = typename BitStr<sizeof(T)>::type; 118 119 template<typename T> 120 struct VecTypeHelper { 121 using holdType = typename T::value_type; 122 using memStorageType = typename T::value_type; 123 static constexpr int holdCount = T::size(); 124 static constexpr int unitStorageCount = 1; 125 }; 126 127 template<> 128 struct VecTypeHelper<vcomplex> { 129 using holdType = Complex<float>; 130 using memStorageType = float; 131 static constexpr int holdCount = vcomplex::size(); 132 static constexpr int unitStorageCount = 2; 133 }; 134 135 template<> 136 struct VecTypeHelper<vcomplexDbl> { 137 using holdType = Complex<double>; 138 using memStorageType = double; 139 static constexpr int holdCount = vcomplexDbl::size(); 140 static constexpr int unitStorageCount = 2; 141 }; 142 143 template<> 144 struct VecTypeHelper<vqint8> { 145 using holdType = c10::qint8; 146 using memStorageType = typename c10::qint8::underlying; 147 static constexpr int holdCount = vqint8::size(); 148 static constexpr int unitStorageCount = 1; 149 }; 150 151 template<> 152 struct VecTypeHelper<vquint8> { 153 using holdType = c10::quint8; 154 using memStorageType = typename c10::quint8::underlying; 155 static constexpr int holdCount = vquint8::size(); 156 static constexpr int unitStorageCount = 1; 157 }; 158 159 template<> 160 struct VecTypeHelper<vqint> { 161 using holdType = c10::qint32; 162 using memStorageType = typename c10::qint32::underlying; 163 static constexpr int holdCount = vqint::size(); 164 static constexpr int unitStorageCount = 1; 165 }; 166 167 template<> 168 struct VecTypeHelper<vBFloat16> { 169 using holdType = c10::BFloat16; 170 using memStorageType = typename vBFloat16::value_type; 171 static constexpr int holdCount = vBFloat16::size(); 172 static constexpr int unitStorageCount = 1; 173 }; 174 175 template<> 176 struct VecTypeHelper<vHalf> { 177 using holdType = c10::Half; 178 using memStorageType = typename vHalf::value_type; 179 static constexpr int holdCount = vHalf::size(); 180 static constexpr int unitStorageCount = 1; 181 }; 182 183 template <typename T> 184 using UholdType = typename VecTypeHelper<T>::holdType; 185 186 template <typename T> 187 using UvalueType = typename VecTypeHelper<T>::memStorageType; 188 189 template <class T, size_t N> 190 constexpr size_t size(T(&)[N]) { 191 return N; 192 } 193 194 template <typename Filter, typename T> 195 typename std::enable_if_t<std::is_same_v<Filter, std::nullptr_t>, void> 196 call_filter(Filter filter, T& val) {} 197 198 template <typename Filter, typename T> 199 typename std::enable_if_t< std::is_same_v<Filter, std::nullptr_t>, void> 200 call_filter(Filter filter, T& first, T& second) { } 201 202 template <typename Filter, typename T> 203 typename std::enable_if_t< std::is_same_v<Filter, std::nullptr_t>, void> 204 call_filter(Filter filter, T& first, T& second, T& third) { } 205 206 template <typename Filter, typename T> 207 typename std::enable_if_t< 208 !std::is_same_v<Filter, std::nullptr_t>, void> 209 call_filter(Filter filter, T& val) { 210 return filter(val); 211 } 212 213 template <typename Filter, typename T> 214 typename std::enable_if_t< 215 !std::is_same_v<Filter, std::nullptr_t>, void> 216 call_filter(Filter filter, T& first, T& second) { 217 return filter(first, second); 218 } 219 220 template <typename Filter, typename T> 221 typename std::enable_if_t< 222 !std::is_same_v<Filter, std::nullptr_t>, void> 223 call_filter(Filter filter, T& first, T& second, T& third) { 224 return filter(first, second, third); 225 } 226 227 template <typename T> 228 struct DomainRange { 229 T start; // start [ 230 T end; // end is not included. one could use nextafter for including his end case for tests 231 }; 232 233 template <typename T> 234 struct CustomCheck { 235 std::vector<UholdType<T>> Args; 236 UholdType<T> expectedResult; 237 }; 238 239 template <typename T> 240 struct CheckWithinDomains { 241 // each argument takes domain Range 242 std::vector<DomainRange<T>> ArgsDomain; 243 // check with error tolerance 244 bool CheckWithTolerance = false; 245 T ToleranceError = (T)0; 246 }; 247 248 template <typename T> 249 std::ostream& operator<<(std::ostream& stream, const CheckWithinDomains<T>& dmn) { 250 stream << "Domain: "; 251 if (dmn.ArgsDomain.size() > 0) { 252 for (const DomainRange<T>& x : dmn.ArgsDomain) { 253 if constexpr (std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>) { 254 stream << "\n{ " << static_cast<int>(x.start) << ", " << static_cast<int>(x.end) << " }"; 255 } 256 else { 257 stream << "\n{ " << x.start << ", " << x.end << " }"; 258 } 259 } 260 } 261 else { 262 stream << "default range"; 263 } 264 if (dmn.CheckWithTolerance) { 265 stream << "\nError tolerance: " << dmn.ToleranceError; 266 } 267 return stream; 268 } 269 270 template <typename T> 271 bool check_both_nan(T x, T y) { 272 if constexpr (std::is_floating_point_v<T>) { 273 return std::isnan(x) && std::isnan(y); 274 } 275 return false; 276 } 277 278 template <typename T> 279 bool check_both_inf(T x, T y) { 280 if constexpr (std::is_floating_point_v<T>) { 281 return std::isinf(x) && std::isinf(y); 282 } 283 return false; 284 } 285 286 template<typename T> 287 std::enable_if_t<!std::is_floating_point_v<T>, bool> check_both_big(T x, T y) { 288 return false; 289 } 290 291 template<typename T> 292 std::enable_if_t<std::is_floating_point_v<T>, bool> check_both_big(T x, T y) { 293 T cmax = std::is_same_v<T, float> ? static_cast<T>(1e+30) : static_cast<T>(1e+300); 294 T cmin = std::is_same_v<T, float> ? static_cast<T>(-1e+30) : static_cast<T>(-1e+300); 295 //only allow when one is inf 296 bool x_inf = std::isinf(x); 297 bool y_inf = std::isinf(y); 298 bool px = x > 0; 299 bool py = y > 0; 300 return (px && x_inf && y >= cmax) || (py && y_inf && x >= cmax) || 301 (!px && x_inf && y <= cmin) || (!py && y_inf && x <= cmin); 302 } 303 304 template<class T> struct is_complex : std::false_type {}; 305 306 template<class T> struct is_complex<Complex<T>> : std::true_type {}; 307 308 template<typename T> 309 T safe_fpt_division(T f1, T f2) 310 { 311 //code was taken from boost 312 // Avoid overflow. 313 if ((f2 < static_cast<T>(1)) && (f1 > f2 * std::numeric_limits<T>::max())) { 314 return std::numeric_limits<T>::max(); 315 } 316 // Avoid underflow. 317 if ((f1 == static_cast<T>(0)) || 318 ((f2 > static_cast<T>(1)) && (f1 < f2 * std::numeric_limits<T>::min()))) { 319 return static_cast<T>(0); 320 } 321 return f1 / f2; 322 } 323 324 template<class T> 325 std::enable_if_t<std::is_floating_point_v<T>, bool> 326 nearlyEqual(T a, T b, T tolerance) { 327 if (check_both_nan<T>(a, b)) return true; 328 if (check_both_big(a, b)) return true; 329 T absA = std::abs(a); 330 T absB = std::abs(b); 331 T diff = std::abs(a - b); 332 if (diff <= tolerance) { 333 return true; 334 } 335 T d1 = safe_fpt_division<T>(diff, absB); 336 T d2 = safe_fpt_division<T>(diff, absA); 337 return (d1 <= tolerance) || (d2 <= tolerance); 338 } 339 340 template<class T> 341 std::enable_if_t<!std::is_floating_point_v<T>, bool> 342 nearlyEqual(T a, T b, T tolerance) { 343 return a == b; 344 } 345 346 template <typename T> 347 T reciprocal(T x) { 348 return 1 / x; 349 } 350 351 template <typename T> 352 T rsqrt(T x) { 353 return 1 / std::sqrt(x); 354 } 355 356 template <typename T> 357 T frac(T x) { 358 return x - std::trunc(x); 359 } 360 361 template <class T> 362 T maximum(const T& a, const T& b) { 363 return (a > b) ? a : b; 364 } 365 366 template <class T> 367 T minimum(const T& a, const T& b) { 368 return (a < b) ? a : b; 369 } 370 371 template <class T> 372 T clamp(const T& a, const T& min, const T& max) { 373 return a < min ? min : (a > max ? max : a); 374 } 375 376 template <class T> 377 T clamp_max(const T& a, const T& max) { 378 return a > max ? max : a; 379 } 380 381 template <class T> 382 T clamp_min(const T& a, const T& min) { 383 return a < min ? min : a; 384 } 385 386 template <class VT, size_t N> 387 void copy_interleave(VT(&vals)[N], VT(&interleaved)[N]) { 388 static_assert(N % 2 == 0, "should be even"); 389 auto ptr1 = vals; 390 auto ptr2 = vals + N / 2; 391 for (size_t i = 0; i < N; i += 2) { 392 interleaved[i] = *ptr1++; 393 interleaved[i + 1] = *ptr2++; 394 } 395 } 396 397 template <typename T> 398 bool is_zero(T val) { 399 if constexpr (std::is_floating_point_v<T>) { 400 return std::fpclassify(val) == FP_ZERO; 401 } else { 402 return val == 0; 403 } 404 } 405 406 template <typename T> 407 void filter_clamp(T& f, T& s, T& t) { 408 if (t < s) { 409 std::swap(s, t); 410 } 411 } 412 413 template <typename T> 414 std::enable_if_t<std::is_floating_point_v<T>, void> filter_fmod(T& a, T& b) { 415 // This is to make sure fmod won't cause overflow when doing the div 416 if (std::abs(b) < (T)1) { 417 b = b < (T)0 ? (T)-1 : T(1); 418 } 419 } 420 421 template <typename T> 422 std::enable_if_t<std::is_floating_point_v<T>, void> filter_fmadd(T& a, T& b, T& c) { 423 // This is to setup a limit to make sure fmadd (a * b + c) won't overflow 424 T max = std::sqrt(std::numeric_limits<T>::max()) / T(2.0); 425 T min = ((T)0 - max); 426 427 if (a > max) a = max; 428 else if (a < min) a = min; 429 430 if (b > max) b = max; 431 else if (b < min) b = min; 432 433 if (c > max) c = max; 434 else if (c < min) c = min; 435 } 436 437 template <typename T> 438 void filter_zero(T& val) { 439 val = is_zero(val) ? (T)1 : val; 440 } 441 template <typename T> 442 std::enable_if_t<is_complex<Complex<T>>::value, void> filter_zero(Complex<T>& val) { 443 T rr = val.real(); 444 T ii = val.imag(); 445 rr = is_zero(rr) ? (T)1 : rr; 446 ii = is_zero(ii) ? (T)1 : ii; 447 val = Complex<T>(rr, ii); 448 } 449 450 template <typename T> 451 void filter_int_minimum(T& val) { 452 if constexpr (!std::is_integral_v<T>) return; 453 if (val == std::numeric_limits<T>::min()) { 454 val = 0; 455 } 456 } 457 458 template <typename T> 459 std::enable_if_t<is_complex<T>::value, void> filter_add_overflow(T& a, T& b) 460 { 461 //missing for complex 462 } 463 464 template <typename T> 465 std::enable_if_t<is_complex<T>::value, void> filter_sub_overflow(T& a, T& b) 466 { 467 //missing for complex 468 } 469 470 template <typename T> 471 std::enable_if_t < !is_complex<T>::value, void> filter_add_overflow(T& a, T& b) { 472 if constexpr (std::is_integral_v<T> == false) return; 473 T max = std::numeric_limits<T>::max(); 474 T min = std::numeric_limits<T>::min(); 475 // min <= (a +b) <= max; 476 // min - b <= a <= max - b 477 if (b < 0) { 478 if (a < min - b) { 479 a = min - b; 480 } 481 } 482 else { 483 if (a > max - b) { 484 a = max - b; 485 } 486 } 487 } 488 489 template <typename T> 490 std::enable_if_t < !is_complex<T>::value, void> filter_sub_overflow(T& a, T& b) { 491 if constexpr (std::is_integral_v<T> == false) return; 492 T max = std::numeric_limits<T>::max(); 493 T min = std::numeric_limits<T>::min(); 494 // min <= (a-b) <= max; 495 // min + b <= a <= max +b 496 if (b < 0) { 497 if (a > max + b) { 498 a = max + b; 499 } 500 } 501 else { 502 if (a < min + b) { 503 a = min + b; 504 } 505 } 506 } 507 508 template <typename T> 509 std::enable_if_t<is_complex<T>::value, void> 510 filter_mult_overflow(T& val1, T& val2) { 511 //missing 512 } 513 514 template <typename T> 515 std::enable_if_t<is_complex<T>::value, void> 516 filter_div_ub(T& val1, T& val2) { 517 //missing 518 //at least consdier zero division 519 auto ret = std::abs(val2); 520 if (ret == 0) { 521 val2 = T(1, 2); 522 } 523 } 524 525 template <typename T> 526 std::enable_if_t<!is_complex<T>::value, void> 527 filter_mult_overflow(T& val1, T& val2) { 528 if constexpr (std::is_integral_v<T> == false) return; 529 if (!is_zero(val2)) { 530 T c = (std::numeric_limits<T>::max() - 1) / val2; 531 if (std::abs(val1) >= c) { 532 // correct first; 533 val1 = c; 534 } 535 } // is_zero 536 } 537 538 template <typename T> 539 std::enable_if_t<!is_complex<T>::value, void> 540 filter_div_ub(T& val1, T& val2) { 541 if (is_zero(val2)) { 542 val2 = 1; 543 } 544 else if (std::is_integral_v<T> && val1 == std::numeric_limits<T>::min() && val2 == -1) { 545 val2 = 1; 546 } 547 } 548 549 struct TestSeed { 550 TestSeed() : seed(std::chrono::high_resolution_clock::now().time_since_epoch().count()) { 551 } 552 TestSeed(uint64_t seed) : seed(seed) { 553 } 554 uint64_t getSeed() { 555 return seed; 556 } 557 operator uint64_t () const { 558 return seed; 559 } 560 561 TestSeed add(uint64_t index) { 562 return TestSeed(seed + index); 563 } 564 private: 565 uint64_t seed; 566 }; 567 568 template <typename T, bool is_floating_point = std::is_floating_point_v<T>, bool is_complex = is_complex<T>::value> 569 struct ValueGen 570 { 571 std::uniform_int_distribution<int64_t> dis; 572 std::mt19937 gen; 573 ValueGen() : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max()) 574 { 575 } 576 ValueGen(uint64_t seed) : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max(), seed) 577 { 578 } 579 ValueGen(T start, T stop, uint64_t seed = TestSeed()) 580 { 581 gen = std::mt19937(seed); 582 dis = std::uniform_int_distribution<int64_t>(start, stop); 583 } 584 T get() 585 { 586 return static_cast<T>(dis(gen)); 587 } 588 }; 589 590 template <typename T> 591 struct ValueGen<T, true, false> 592 { 593 std::mt19937 gen; 594 std::normal_distribution<T> normal; 595 std::uniform_int_distribution<int> roundChance; 596 T _start; 597 T _stop; 598 bool use_sign_change = false; 599 bool use_round = true; 600 ValueGen() : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max()) 601 { 602 } 603 ValueGen(uint64_t seed) : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max(), seed) 604 { 605 } 606 ValueGen(T start, T stop, uint64_t seed = TestSeed()) 607 { 608 gen = std::mt19937(seed); 609 T mean = start * static_cast<T>(0.5) + stop * static_cast<T>(0.5); 610 //make it normal +-3sigma 611 T divRange = static_cast<T>(6.0); 612 T stdev = std::abs(stop / divRange - start / divRange); 613 normal = std::normal_distribution<T>{ mean, stdev }; 614 // in real its hard to get rounded value 615 // so we will force it by uniform chance 616 roundChance = std::uniform_int_distribution<int>(0, 5); 617 _start = start; 618 _stop = stop; 619 } 620 T get() 621 { 622 T a = normal(gen); 623 //make rounded value ,too 624 auto rChoice = roundChance(gen); 625 if (rChoice == 1) 626 a = std::round(a); 627 if (a < _start) 628 return nextafter(_start, _stop); 629 if (a >= _stop) 630 return nextafter(_stop, _start); 631 return a; 632 } 633 }; 634 635 template <typename T> 636 struct ValueGen<Complex<T>, false, true> 637 { 638 std::mt19937 gen; 639 std::normal_distribution<T> normal; 640 std::uniform_int_distribution<int> roundChance; 641 T _start; 642 T _stop; 643 bool use_sign_change = false; 644 bool use_round = true; 645 ValueGen() : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max()) 646 { 647 } 648 ValueGen(uint64_t seed) : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max(), seed) 649 { 650 } 651 ValueGen(T start, T stop, uint64_t seed = TestSeed()) 652 { 653 gen = std::mt19937(seed); 654 T mean = start * static_cast<T>(0.5) + stop * static_cast<T>(0.5); 655 //make it normal +-3sigma 656 T divRange = static_cast<T>(6.0); 657 T stdev = std::abs(stop / divRange - start / divRange); 658 normal = std::normal_distribution<T>{ mean, stdev }; 659 // in real its hard to get rounded value 660 // so we will force it by uniform chance 661 roundChance = std::uniform_int_distribution<int>(0, 5); 662 _start = start; 663 _stop = stop; 664 } 665 Complex<T> get() 666 { 667 T a = normal(gen); 668 T b = normal(gen); 669 //make rounded value ,too 670 auto rChoice = roundChance(gen); 671 rChoice = rChoice & 3; 672 if (rChoice & 1) 673 a = std::round(a); 674 if (rChoice & 2) 675 b = std::round(b); 676 if (a < _start) 677 a = nextafter(_start, _stop); 678 else if (a >= _stop) 679 a = nextafter(_stop, _start); 680 if (b < _start) 681 b = nextafter(_start, _stop); 682 else if (b >= _stop) 683 b = nextafter(_stop, _start); 684 return Complex<T>(a, b); 685 } 686 }; 687 688 template<class T> 689 int getTrialCount(int test_trials, int domains_size) { 690 int trialCount; 691 int trial_default = 1; 692 if (sizeof(T) <= 2) { 693 //half coverage for byte 694 trial_default = 128; 695 } 696 else { 697 //2*65536 698 trial_default = 2 * std::numeric_limits<uint16_t>::max(); 699 } 700 trialCount = test_trials < 1 ? trial_default : test_trials; 701 if (domains_size > 1) { 702 trialCount = trialCount / domains_size; 703 trialCount = trialCount < 1 ? 1 : trialCount; 704 } 705 return trialCount; 706 } 707 708 template <typename T, typename U = UvalueType<T>> 709 class TestCaseBuilder; 710 711 template <typename T, typename U = UvalueType<T>> 712 class TestingCase { 713 public: 714 friend class TestCaseBuilder<T, U>; 715 static TestCaseBuilder<T, U> getBuilder() { return TestCaseBuilder<T, U>{}; } 716 bool checkSpecialValues() const { 717 //this will be used to check nan, infs, and other special cases 718 return specialCheck; 719 } 720 size_t getTrialCount() const { return trials; } 721 bool isBitwise() const { return bitwise; } 722 const std::vector<CheckWithinDomains<U>>& getDomains() const { 723 return domains; 724 } 725 const std::vector<CustomCheck<T>>& getCustomChecks() const { 726 return customCheck; 727 } 728 TestSeed getTestSeed() const { 729 return testSeed; 730 } 731 private: 732 // if domains is empty we will test default 733 std::vector<CheckWithinDomains<U>> domains; 734 std::vector<CustomCheck<T>> customCheck; 735 // its not used for now 736 bool specialCheck = false; 737 bool bitwise = false; // test bitlevel 738 size_t trials = 0; 739 TestSeed testSeed; 740 }; 741 742 template <typename T, typename U > 743 class TestCaseBuilder { 744 private: 745 TestingCase<T, U> _case; 746 public: 747 TestCaseBuilder<T, U>& set(bool bitwise, bool checkSpecialValues) { 748 _case.bitwise = bitwise; 749 _case.specialCheck = checkSpecialValues; 750 return *this; 751 } 752 TestCaseBuilder<T, U>& setTestSeed(TestSeed seed) { 753 _case.testSeed = seed; 754 return *this; 755 } 756 TestCaseBuilder<T, U>& setTrialCount(size_t trial_count) { 757 _case.trials = trial_count; 758 return *this; 759 } 760 TestCaseBuilder<T, U>& addDomain(const CheckWithinDomains<U>& domainCheck) { 761 _case.domains.emplace_back(domainCheck); 762 return *this; 763 } 764 TestCaseBuilder<T, U>& addCustom(const CustomCheck<T>& customArgs) { 765 _case.customCheck.emplace_back(customArgs); 766 return *this; 767 } 768 TestCaseBuilder<T, U>& checkSpecialValues() { 769 _case.specialCheck = true; 770 return *this; 771 } 772 TestCaseBuilder<T, U>& compareBitwise() { 773 _case.bitwise = true; 774 return *this; 775 } 776 operator TestingCase<T, U> && () { return std::move(_case); } 777 }; 778 779 template <typename T> 780 typename std::enable_if_t<!is_complex<T>::value&& std::is_unsigned<T>::value, T> 781 correctEpsilon(const T& eps) 782 { 783 return eps; 784 } 785 template <typename T> 786 typename std::enable_if_t<!is_complex<T>::value && !std::is_unsigned<T>::value, T> 787 correctEpsilon(const T& eps) 788 { 789 return std::abs(eps); 790 } 791 template <typename T> 792 typename std::enable_if_t<is_complex<Complex<T>>::value, T> 793 correctEpsilon(const Complex<T>& eps) 794 { 795 return std::abs(eps); 796 } 797 798 template <typename T> 799 class AssertVectorized 800 { 801 public: 802 AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0) 803 : additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), argSize(1) 804 { 805 } 806 AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0, const T& input1) 807 : additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), arg1(input1), argSize(2) 808 { 809 } 810 AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0, const T& input1, const T& input2) 811 : additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), arg1(input1), arg2(input2), argSize(3) 812 { 813 } 814 AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual) : additionalInfo(info), testSeed(seed), exp(expected), act(actual) 815 { 816 } 817 AssertVectorized(const std::string& info, const T& expected, const T& actual) : additionalInfo(info), exp(expected), act(actual), hasSeed(false) 818 { 819 } 820 821 std::string getDetail(int index) const 822 { 823 using UVT = UvalueType<T>; 824 std::stringstream stream; 825 stream.precision(std::numeric_limits<UVT>::max_digits10); 826 stream << "Failure Details:\n"; 827 stream << additionalInfo << "\n"; 828 if (hasSeed) 829 { 830 stream << "Test Seed to reproduce: " << testSeed << "\n"; 831 } 832 if (argSize > 0) 833 { 834 stream << "Arguments:\n"; 835 stream << "#\t " << arg0 << "\n"; 836 if (argSize == 2) 837 { 838 stream << "#\t " << arg1 << "\n"; 839 } 840 if (argSize == 3) 841 { 842 stream << "#\t " << arg2 << "\n"; 843 } 844 } 845 stream << "Expected:\n#\t" << exp << "\nActual:\n#\t" << act; 846 stream << "\nFirst mismatch Index: " << index; 847 return stream.str(); 848 } 849 850 bool check(bool bitwise = false, bool checkWithTolerance = false, ValueType<T> toleranceEps = {}) const 851 { 852 using UVT = UvalueType<T>; 853 using BVT = BitType<UVT>; 854 UVT absErr = correctEpsilon(toleranceEps); 855 constexpr int sizeX = VecTypeHelper<T>::holdCount * VecTypeHelper<T>::unitStorageCount; 856 constexpr int unitStorageCount = VecTypeHelper<T>::unitStorageCount; 857 CACHE_ALIGN UVT expArr[sizeX]; 858 CACHE_ALIGN UVT actArr[sizeX]; 859 exp.store(expArr); 860 act.store(actArr); 861 if (bitwise) 862 { 863 for (const auto i : c10::irange(sizeX)) { 864 BVT b_exp = c10::bit_cast<BVT>(expArr[i]); 865 BVT b_act = c10::bit_cast<BVT>(actArr[i]); 866 EXPECT_EQ(b_exp, b_act) << getDetail(i / unitStorageCount); 867 if (::testing::Test::HasFailure()) 868 return true; 869 } 870 } 871 else if (checkWithTolerance) 872 { 873 for (const auto i : c10::irange(sizeX)) { 874 EXPECT_EQ(nearlyEqual<UVT>(expArr[i], actArr[i], absErr), true) << expArr[i] << "!=" << actArr[i] << "\n" << getDetail(i / unitStorageCount); 875 if (::testing::Test::HasFailure()) 876 return true; 877 } 878 } 879 else 880 { 881 for (const auto i : c10::irange(sizeX)) { 882 if constexpr (std::is_same_v<UVT, float>) 883 { 884 if (!check_both_nan(expArr[i], actArr[i])) { 885 EXPECT_FLOAT_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount); 886 } 887 } 888 else if constexpr (std::is_same_v<UVT, double>) 889 { 890 if (!check_both_nan(expArr[i], actArr[i])) 891 { 892 EXPECT_DOUBLE_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount); 893 } 894 } 895 else 896 { 897 EXPECT_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount); 898 } 899 if (::testing::Test::HasFailure()) 900 return true; 901 } 902 } 903 return false; 904 } 905 906 private: 907 std::string additionalInfo; 908 TestSeed testSeed; 909 T exp; 910 T act; 911 T arg0; 912 T arg1; 913 T arg2; 914 int argSize = 0; 915 bool hasSeed = true; 916 }; 917 918 template< typename T, typename Op1, typename Op2, typename Filter = std::nullptr_t> 919 void test_unary( 920 std::string testNameInfo, 921 Op1 expectedFunction, 922 Op2 actualFunction, const TestingCase<T>& testCase, Filter filter = {}) { 923 using vec_type = T; 924 using VT = ValueType<T>; 925 using UVT = UvalueType<T>; 926 constexpr int el_count = vec_type::size(); 927 CACHE_ALIGN VT vals[el_count]; 928 CACHE_ALIGN VT expected[el_count]; 929 bool bitwise = testCase.isBitwise(); 930 UVT default_start = std::is_floating_point_v<UVT> ? std::numeric_limits<UVT>::lowest() : std::numeric_limits<UVT>::min(); 931 UVT default_end = std::numeric_limits<UVT>::max(); 932 auto domains = testCase.getDomains(); 933 auto domains_size = domains.size(); 934 auto test_trials = testCase.getTrialCount(); 935 int trialCount = getTrialCount<UVT>(test_trials, domains_size); 936 TestSeed seed = testCase.getTestSeed(); 937 uint64_t changeSeedBy = 0; 938 for (const CheckWithinDomains<UVT>& dmn : domains) { 939 size_t dmn_argc = dmn.ArgsDomain.size(); 940 UVT start = dmn_argc > 0 ? dmn.ArgsDomain[0].start : default_start; 941 UVT end = dmn_argc > 0 ? dmn.ArgsDomain[0].end : default_end; 942 ValueGen<VT> generator(start, end, seed.add(changeSeedBy)); 943 for (C10_UNUSED const auto trial : c10::irange(trialCount)) { 944 for (const auto k : c10::irange(el_count)) { 945 vals[k] = generator.get(); 946 call_filter(filter, vals[k]); 947 //map operator 948 expected[k] = expectedFunction(vals[k]); 949 } 950 // test 951 auto input = vec_type::loadu(vals); 952 auto actual = actualFunction(input); 953 auto vec_expected = vec_type::loadu(expected); 954 AssertVectorized<vec_type> vecAssert(testNameInfo, seed, vec_expected, actual, input); 955 if (vecAssert.check(bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) return; 956 957 }// trial 958 //inrease Seed 959 changeSeedBy += 1; 960 } 961 for (auto& custom : testCase.getCustomChecks()) { 962 auto args = custom.Args; 963 if (args.size() > 0) { 964 auto input = vec_type{ args[0] }; 965 auto actual = actualFunction(input); 966 auto vec_expected = vec_type{ custom.expectedResult }; 967 AssertVectorized<vec_type> vecAssert(testNameInfo, seed, vec_expected, actual, input); 968 if (vecAssert.check()) return; 969 } 970 } 971 } 972 973 template< typename T, typename Op1, typename Op2, typename Filter = std::nullptr_t> 974 void test_binary( 975 std::string testNameInfo, 976 Op1 expectedFunction, 977 Op2 actualFunction, const TestingCase<T>& testCase, Filter filter = {}) { 978 using vec_type = T; 979 using VT = ValueType<T>; 980 using UVT = UvalueType<T>; 981 constexpr int el_count = vec_type::size(); 982 CACHE_ALIGN VT vals0[el_count]; 983 CACHE_ALIGN VT vals1[el_count]; 984 CACHE_ALIGN VT expected[el_count]; 985 bool bitwise = testCase.isBitwise(); 986 UVT default_start = std::is_floating_point_v<UVT> ? std::numeric_limits<UVT>::lowest() : std::numeric_limits<UVT>::min(); 987 UVT default_end = std::numeric_limits<UVT>::max(); 988 auto domains = testCase.getDomains(); 989 auto domains_size = domains.size(); 990 auto test_trials = testCase.getTrialCount(); 991 int trialCount = getTrialCount<UVT>(test_trials, domains_size); 992 TestSeed seed = testCase.getTestSeed(); 993 uint64_t changeSeedBy = 0; 994 for (const CheckWithinDomains<UVT>& dmn : testCase.getDomains()) { 995 size_t dmn_argc = dmn.ArgsDomain.size(); 996 UVT start0 = dmn_argc > 0 ? dmn.ArgsDomain[0].start : default_start; 997 UVT end0 = dmn_argc > 0 ? dmn.ArgsDomain[0].end : default_end; 998 UVT start1 = dmn_argc > 1 ? dmn.ArgsDomain[1].start : default_start; 999 UVT end1 = dmn_argc > 1 ? dmn.ArgsDomain[1].end : default_end; 1000 ValueGen<VT> generator0(start0, end0, seed.add(changeSeedBy)); 1001 ValueGen<VT> generator1(start1, end1, seed.add(changeSeedBy + 1)); 1002 for (C10_UNUSED const auto trial : c10::irange(trialCount)) { 1003 for (const auto k : c10::irange(el_count)) { 1004 vals0[k] = generator0.get(); 1005 vals1[k] = generator1.get(); 1006 call_filter(filter, vals0[k], vals1[k]); 1007 //map operator 1008 expected[k] = expectedFunction(vals0[k], vals1[k]); 1009 } 1010 // test 1011 auto input0 = vec_type::loadu(vals0); 1012 auto input1 = vec_type::loadu(vals1); 1013 auto actual = actualFunction(input0, input1); 1014 auto vec_expected = vec_type::loadu(expected); 1015 AssertVectorized<vec_type> vecAssert(testNameInfo, seed, vec_expected, actual, input0, input1); 1016 if (vecAssert.check(bitwise, dmn.CheckWithTolerance, dmn.ToleranceError))return; 1017 }// trial 1018 changeSeedBy += 1; 1019 } 1020 for (auto& custom : testCase.getCustomChecks()) { 1021 auto args = custom.Args; 1022 if (args.size() > 0) { 1023 auto input0 = vec_type{ args[0] }; 1024 auto input1 = args.size() > 1 ? vec_type{ args[1] } : vec_type{ args[0] }; 1025 auto actual = actualFunction(input0, input1); 1026 auto vec_expected = vec_type(custom.expectedResult); 1027 AssertVectorized<vec_type> vecAssert(testNameInfo, seed, vec_expected, actual, input0, input1); 1028 if (vecAssert.check()) return; 1029 } 1030 } 1031 } 1032 1033 template< typename T, typename Op1, typename Op2, typename Filter = std::nullptr_t> 1034 void test_ternary( 1035 std::string testNameInfo, 1036 Op1 expectedFunction, 1037 Op2 actualFunction, const TestingCase<T>& testCase, Filter filter = {}) { 1038 using vec_type = T; 1039 using VT = ValueType<T>; 1040 using UVT = UvalueType<T>; 1041 constexpr int el_count = vec_type::size(); 1042 CACHE_ALIGN VT vals0[el_count]; 1043 CACHE_ALIGN VT vals1[el_count]; 1044 CACHE_ALIGN VT vals2[el_count]; 1045 CACHE_ALIGN VT expected[el_count]; 1046 bool bitwise = testCase.isBitwise(); 1047 UVT default_start = std::is_floating_point_v<UVT> ? std::numeric_limits<UVT>::lowest() : std::numeric_limits<UVT>::min(); 1048 UVT default_end = std::numeric_limits<UVT>::max(); 1049 auto domains = testCase.getDomains(); 1050 auto domains_size = domains.size(); 1051 auto test_trials = testCase.getTrialCount(); 1052 int trialCount = getTrialCount<UVT>(test_trials, domains_size); 1053 TestSeed seed = testCase.getTestSeed(); 1054 uint64_t changeSeedBy = 0; 1055 for (const CheckWithinDomains<UVT>& dmn : testCase.getDomains()) { 1056 size_t dmn_argc = dmn.ArgsDomain.size(); 1057 UVT start0 = dmn_argc > 0 ? dmn.ArgsDomain[0].start : default_start; 1058 UVT end0 = dmn_argc > 0 ? dmn.ArgsDomain[0].end : default_end; 1059 UVT start1 = dmn_argc > 1 ? dmn.ArgsDomain[1].start : default_start; 1060 UVT end1 = dmn_argc > 1 ? dmn.ArgsDomain[1].end : default_end; 1061 UVT start2 = dmn_argc > 2 ? dmn.ArgsDomain[2].start : default_start; 1062 UVT end2 = dmn_argc > 2 ? dmn.ArgsDomain[2].end : default_end; 1063 ValueGen<VT> generator0(start0, end0, seed.add(changeSeedBy)); 1064 ValueGen<VT> generator1(start1, end1, seed.add(changeSeedBy + 1)); 1065 ValueGen<VT> generator2(start2, end2, seed.add(changeSeedBy + 2)); 1066 1067 for (C10_UNUSED const auto trial : c10::irange(trialCount)) { 1068 for (const auto k : c10::irange(el_count)) { 1069 vals0[k] = generator0.get(); 1070 vals1[k] = generator1.get(); 1071 vals2[k] = generator2.get(); 1072 call_filter(filter, vals0[k], vals1[k], vals2[k]); 1073 //map operator 1074 expected[k] = expectedFunction(vals0[k], vals1[k], vals2[k]); 1075 } 1076 // test 1077 auto input0 = vec_type::loadu(vals0); 1078 auto input1 = vec_type::loadu(vals1); 1079 auto input2 = vec_type::loadu(vals2); 1080 auto actual = actualFunction(input0, input1, input2); 1081 auto vec_expected = vec_type::loadu(expected); 1082 AssertVectorized<vec_type> vecAssert(testNameInfo, seed, vec_expected, actual, input0, input1, input2); 1083 if (vecAssert.check(bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) return; 1084 }// trial 1085 changeSeedBy += 1; 1086 } 1087 } 1088 1089 template <typename T, typename Op> 1090 T func_cmp(Op call, T v0, T v1) { 1091 using bit_rep = BitType<T>; 1092 constexpr bit_rep mask = std::numeric_limits<bit_rep>::max(); 1093 bit_rep ret = call(v0, v1) ? mask : 0; 1094 return c10::bit_cast<T>(ret); 1095 } 1096 1097 struct PreventFma 1098 { 1099 not_inline float sub(float a, float b) 1100 { 1101 return a - b; 1102 } 1103 not_inline double sub(double a, double b) 1104 { 1105 return a - b; 1106 } 1107 not_inline float add(float a, float b) 1108 { 1109 return a + b; 1110 } 1111 not_inline double add(double a, double b) 1112 { 1113 return a + b; 1114 } 1115 }; 1116 1117 template <typename T> 1118 std::enable_if_t<!is_complex<T>::value, T> local_log2(T x) { 1119 return std::log2(x); 1120 } 1121 1122 template <typename T> 1123 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_log2(Complex<T> x) { 1124 T ret = std::log(x); 1125 T real = ret.real() / std::log(static_cast<T>(2)); 1126 T imag = ret.imag() / std::log(static_cast<T>(2)); 1127 return Complex<T>(real, imag); 1128 } 1129 1130 template <typename T> 1131 std::enable_if_t<!is_complex<T>::value, T> local_abs(T x) { 1132 return std::abs(x); 1133 } 1134 1135 template <typename T> 1136 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_abs(Complex<T> x) { 1137 #if defined(TEST_AGAINST_DEFAULT) 1138 return std::abs(x); 1139 #else 1140 PreventFma noFma; 1141 T real = x.real(); 1142 T imag = x.imag(); 1143 T rr = real * real; 1144 T ii = imag * imag; 1145 T abs = std::sqrt(noFma.add(rr, ii)); 1146 return Complex<T>(abs, 0); 1147 #endif 1148 } 1149 1150 template <typename T> 1151 std::enable_if_t<!is_complex<T>::value, T> local_multiply(T x, T y) { 1152 return x * y; 1153 } 1154 1155 template <typename T> 1156 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_multiply(Complex<T> x, Complex<T> y) { 1157 #if defined(TEST_AGAINST_DEFAULT) 1158 return x * y; 1159 #else 1160 //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i 1161 T x_real = x.real(); 1162 T x_imag = x.imag(); 1163 T y_real = y.real(); 1164 T y_imag = y.imag(); 1165 #if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR) 1166 //check multiplication considerin swap and fma 1167 T rr = x_real * y_real; 1168 T ii = x_imag * y_real; 1169 T neg_imag = -y_imag; 1170 rr = fma(x_imag, neg_imag, rr); 1171 ii = fma(x_real, y_imag, ii); 1172 #else 1173 // replicate order 1174 PreventFma noFma; 1175 T ac = x_real * y_real; 1176 T bd = x_imag * y_imag; 1177 T ad = x_real * y_imag; 1178 T bc = x_imag * (-y_real); 1179 T rr = noFma.sub(ac, bd); 1180 T ii = noFma.sub(ad, bc); 1181 #endif 1182 return Complex<T>(rr, ii); 1183 #endif 1184 } 1185 1186 1187 1188 template <typename T> 1189 std::enable_if_t<!is_complex<T>::value, T> local_division(T x, T y) { 1190 return x / y; 1191 } 1192 1193 template <typename T> 1194 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_division(Complex<T> x, Complex<T> y) { 1195 #if defined(TEST_AGAINST_DEFAULT) 1196 return x / y; 1197 #else /* defined(TEST_AGAINST_DEFAULT) */ 1198 //re = (ac + bd)/abs_2() 1199 //im = (bc - ad)/abs_2() 1200 T x_real = x.real(); 1201 T x_imag = x.imag(); 1202 T y_real = y.real(); 1203 T y_imag = y.imag(); 1204 PreventFma noFma; 1205 #if defined(CPU_CAPABILITY_ZVECTOR) 1206 T abs_c = std::abs(y_real); 1207 T abs_d = std::abs(y_imag); 1208 T scale = 1.0 / std::max(abs_c, abs_d); 1209 1210 T a_sc = x_real * scale; // a/sc 1211 T b_sc = x_imag * scale; // b/sc 1212 T c_sc = y_real * scale; // c/sc 1213 T d_sc = y_imag * scale; // d/sc 1214 1215 T ac_sc2 = a_sc * c_sc; // ac/sc^2 1216 T bd_sc2 = b_sc * d_sc; // bd/sc^2 1217 1218 T neg_d_sc = -1.0 * d_sc; // -d/sc^2 1219 1220 T neg_ad_sc2 = a_sc * neg_d_sc; // -ad/sc^2 1221 T bc_sc2 = b_sc * c_sc; // bc/sc^2 1222 1223 T ac_bd_sc2 = noFma.add(ac_sc2, bd_sc2); // (ac+bd)/sc^2 1224 T bc_ad_sc2 = noFma.add(bc_sc2, neg_ad_sc2); // (bc-ad)/sc^2 1225 1226 T c2_sc2 = c_sc * c_sc; // c^2/sc^2 1227 T d2_sc2 = d_sc * d_sc; // d^2/sc^2 1228 1229 T c2_d2_sc2 = noFma.add(c2_sc2, d2_sc2); // (c^2+d^2)/sc^2 1230 1231 T rr = ac_bd_sc2 / c2_d2_sc2; // (ac+bd)/(c^2+d^2) 1232 T ii = bc_ad_sc2 / c2_d2_sc2; // (bc-ad)/(c^2+d^2) 1233 1234 return Complex<T>(rr, ii); 1235 #else /* defined(CPU_CAPABILITY_ZVECTOR) */ 1236 #if defined(CPU_CAPABILITY_VSX) 1237 //check multiplication considerin swap and fma 1238 T rr = x_real * y_real; 1239 T ii = x_imag * y_real; 1240 T neg_imag = -y_imag; 1241 rr = fma(x_imag, y_imag, rr); 1242 ii = fma(x_real, neg_imag, ii); 1243 //b.abs_2 1244 #else /* defined(CPU_CAPABILITY_VSX) */ 1245 T ac = x_real * y_real; 1246 T bd = x_imag * y_imag; 1247 T ad = x_real * y_imag; 1248 T bc = x_imag * y_real; 1249 T rr = noFma.add(ac, bd); 1250 T ii = noFma.sub(bc, ad); 1251 #endif /* defined(CPU_CAPABILITY_VSX) */ 1252 //b.abs_2() 1253 T abs_rr = y_real * y_real; 1254 T abs_ii = y_imag * y_imag; 1255 T abs_2 = noFma.add(abs_rr, abs_ii); 1256 rr = rr / abs_2; 1257 ii = ii / abs_2; 1258 return Complex<T>(rr, ii); 1259 #endif /* defined(CPU_CAPABILITY_ZVECTOR) */ 1260 #endif /* defined(TEST_AGAINST_DEFAULT) */ 1261 } 1262 1263 1264 template <typename T> 1265 std::enable_if_t<!is_complex<T>::value, T> local_fmadd(T a, T b, T c) { 1266 PreventFma noFma; 1267 T ab = a * b; 1268 return noFma.add(ab, c); 1269 } 1270 1271 template <typename T> 1272 std::enable_if_t<!is_complex<T>::value, T> local_sqrt(T x) { 1273 return std::sqrt(x); 1274 } 1275 1276 template <typename T> 1277 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_sqrt(Complex<T> x) { 1278 return std::sqrt(x); 1279 } 1280 1281 template <typename T> 1282 std::enable_if_t<!is_complex<T>::value, T> local_asin(T x) { 1283 return std::asin(x); 1284 } 1285 1286 template <typename T> 1287 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_asin(Complex<T> x) { 1288 return std::asin(x); 1289 } 1290 1291 template <typename T> 1292 std::enable_if_t<!is_complex<T>::value, T> local_acos(T x) { 1293 return std::acos(x); 1294 } 1295 1296 template <typename T> 1297 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_acos(Complex<T> x) { 1298 return std::acos(x); 1299 } 1300 1301 template<typename T> 1302 std::enable_if_t<!is_complex<T>::value, T> 1303 local_and(const T& val0, const T& val1) { 1304 using bit_rep = BitType<T>; 1305 bit_rep ret = c10::bit_cast<bit_rep>(val0) & c10::bit_cast<bit_rep>(val1); 1306 return c10::bit_cast<T> (ret); 1307 } 1308 1309 template <typename T> 1310 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> 1311 local_and(const Complex<T>& val0, const Complex<T>& val1) 1312 { 1313 using bit_rep = BitType<T>; 1314 T real1 = val0.real(); 1315 T imag1 = val0.imag(); 1316 T real2 = val1.real(); 1317 T imag2 = val1.imag(); 1318 bit_rep real_ret = c10::bit_cast<bit_rep>(real1) & c10::bit_cast<bit_rep>(real2); 1319 bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) & c10::bit_cast<bit_rep>(imag2); 1320 return Complex<T>(c10::bit_cast<T>(real_ret), c10::bit_cast<T>(imag_ret)); 1321 } 1322 1323 template<typename T> 1324 std::enable_if_t<!is_complex<T>::value, T> 1325 local_or(const T& val0, const T& val1) { 1326 using bit_rep = BitType<T>; 1327 bit_rep ret = c10::bit_cast<bit_rep>(val0) | c10::bit_cast<bit_rep>(val1); 1328 return c10::bit_cast<T> (ret); 1329 } 1330 1331 template<typename T> 1332 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> 1333 local_or(const Complex<T>& val0, const Complex<T>& val1) { 1334 using bit_rep = BitType<T>; 1335 T real1 = val0.real(); 1336 T imag1 = val0.imag(); 1337 T real2 = val1.real(); 1338 T imag2 = val1.imag(); 1339 bit_rep real_ret = c10::bit_cast<bit_rep>(real1) | c10::bit_cast<bit_rep>(real2); 1340 bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) | c10::bit_cast<bit_rep>(imag2); 1341 return Complex<T>(c10::bit_cast<T> (real_ret), c10::bit_cast<T>(imag_ret)); 1342 } 1343 1344 template<typename T> 1345 std::enable_if_t<!is_complex<T>::value, T> 1346 local_xor(const T& val0, const T& val1) { 1347 using bit_rep = BitType<T>; 1348 bit_rep ret = c10::bit_cast<bit_rep>(val0) ^ c10::bit_cast<bit_rep>(val1); 1349 return c10::bit_cast<T> (ret); 1350 } 1351 1352 template<typename T> 1353 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> 1354 local_xor(const Complex<T>& val0, const Complex<T>& val1) { 1355 using bit_rep = BitType<T>; 1356 T real1 = val0.real(); 1357 T imag1 = val0.imag(); 1358 T real2 = val1.real(); 1359 T imag2 = val1.imag(); 1360 bit_rep real_ret = c10::bit_cast<bit_rep>(real1) ^ c10::bit_cast<bit_rep>(real2); 1361 bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) ^ c10::bit_cast<bit_rep>(imag2); 1362 return Complex<T>(c10::bit_cast<T> (real_ret), c10::bit_cast<T>(imag_ret)); 1363 } 1364 1365 template <typename T> 1366 T quantize_val(float scale, int64_t zero_point, float value) { 1367 int64_t qvalue; 1368 constexpr int64_t qmin = std::numeric_limits<T>::min(); 1369 constexpr int64_t qmax = std::numeric_limits<T>::max(); 1370 float inv_scale = 1.0f / scale; 1371 qvalue = static_cast<int64_t>(zero_point + at::native::round_impl<float>(value * inv_scale)); 1372 qvalue = std::max<int64_t>(qvalue, qmin); 1373 qvalue = std::min<int64_t>(qvalue, qmax); 1374 return static_cast<T>(qvalue); 1375 } 1376 1377 template <typename T> 1378 #if defined(TEST_AGAINST_DEFAULT) 1379 T requantize_from_int(float multiplier, int32_t zero_point, int32_t src) { 1380 auto xx = static_cast<float>(src) * multiplier; 1381 double xx2 = nearbyint(xx); 1382 int32_t quantize_down = xx2 + zero_point; 1383 #else 1384 T requantize_from_int(float multiplier, int64_t zero_point, int64_t src) { 1385 int64_t quantize_down = static_cast<int64_t>(zero_point + std::lrintf(src * multiplier)); 1386 #endif 1387 constexpr int64_t min = std::numeric_limits<T>::min(); 1388 constexpr int64_t max = std::numeric_limits<T>::max(); 1389 auto ret = static_cast<T>(std::min<int64_t>(std::max<int64_t>(quantize_down, min), max)); 1390 return ret; 1391 } 1392 1393 template <typename T> 1394 float dequantize_val(float scale, int64_t zero_point, T value) { 1395 //when negated scale is used as addition 1396 #if defined(CHECK_WITH_FMA) 1397 float neg_p = -(zero_point * scale); 1398 float v = static_cast<float>(value); 1399 float ret = fma(v, scale, neg_p); 1400 #else 1401 float ret = (static_cast<float>(value) - zero_point) * scale; 1402 #endif 1403 return ret; 1404 } 1405 1406 template<typename T> 1407 T relu(const T & val, const T & zero_point) { 1408 return std::max(val, zero_point); 1409 } 1410 1411 template<typename T> 1412 T relu6(T val, T zero_point, T q_six) { 1413 return std::min<T>(std::max<T>(val, zero_point), q_six); 1414 } 1415 1416 template<typename T> 1417 int32_t widening_subtract(T val, T b) { 1418 return static_cast<int32_t>(val) - static_cast<int32_t>(b); 1419 } 1420 1421 //default testing case 1422 template<typename T> 1423 T getDefaultTolerance() { 1424 return static_cast<T>(0.0); 1425 } 1426 1427 template<> 1428 float getDefaultTolerance() { 1429 return 5.e-5f; 1430 } 1431 1432 template<> 1433 double getDefaultTolerance() { 1434 return 1.e-9; 1435 } 1436 1437 template<typename T, int N = 1> 1438 at::vec::VecMask<T, N> create_vec_mask(uint64_t bitmask) { 1439 constexpr auto size = at::vec::Vectorized<T>::size(); 1440 std::array<int, N * size> mask; 1441 for (int n = 0; n < N; n++) { 1442 for (int i = 0; i < size; i++) { 1443 mask[n * size + i] = (bitmask >> i) & 1; 1444 } 1445 } 1446 return at::vec::VecMask<T, N>::from(mask.data()); 1447 } 1448 1449 template<typename T, int N = 1> 1450 at::vec::VecMask<T, N> generate_vec_mask(int seed) { 1451 constexpr auto size = at::vec::Vectorized<T>::size(); 1452 ValueGen<uint64_t> generator(0, (1ULL << size) - 1, seed); 1453 auto bitmask = generator.get(); 1454 return create_vec_mask<T, N>(bitmask); 1455 } 1456 1457 template<typename T> 1458 TestingCase<T> createDefaultUnaryTestCase(TestSeed seed = TestSeed(), bool bitwise = false, bool checkWithTolerance = false, size_t trials = 0) { 1459 using UVT = UvalueType<T>; 1460 TestingCase<T> testCase; 1461 if (!bitwise && std::is_floating_point_v<UVT>) { 1462 //for float types lets add manual ranges 1463 UVT tolerance = getDefaultTolerance<UVT>(); 1464 testCase = TestingCase<T>::getBuilder() 1465 .set(bitwise, false) 1466 .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-10, (UVT)10}}, checkWithTolerance, tolerance}) 1467 .addDomain(CheckWithinDomains<UVT>{ { {(UVT)10, (UVT)100 }}, checkWithTolerance, tolerance}) 1468 .addDomain(CheckWithinDomains<UVT>{ { {(UVT)100, (UVT)1000 }}, checkWithTolerance, tolerance}) 1469 .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-100, (UVT)-10 }}, checkWithTolerance, tolerance}) 1470 .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-1000, (UVT)-100 }}, checkWithTolerance, tolerance}) 1471 .addDomain(CheckWithinDomains<UVT>{ {}, checkWithTolerance, tolerance}) 1472 .setTrialCount(trials) 1473 .setTestSeed(seed); 1474 } 1475 else { 1476 testCase = TestingCase<T>::getBuilder() 1477 .set(bitwise, false) 1478 .addDomain(CheckWithinDomains<UVT>{}) 1479 .setTrialCount(trials) 1480 .setTestSeed(seed); 1481 } 1482 return testCase; 1483 } 1484 1485 template<typename T> 1486 TestingCase<T> createDefaultBinaryTestCase(TestSeed seed = TestSeed(), bool bitwise = false, bool checkWithTolerance = false, size_t trials = 0) { 1487 using UVT = UvalueType<T>; 1488 TestingCase<T> testCase; 1489 if (!bitwise && std::is_floating_point_v<UVT>) { 1490 //for float types lets add manual ranges 1491 UVT tolerance = getDefaultTolerance<UVT>(); 1492 testCase = TestingCase<T>::getBuilder() 1493 .set(bitwise, false) 1494 .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-10, (UVT)10}, { (UVT)-10, (UVT)10 }}, checkWithTolerance, tolerance}) 1495 .addDomain(CheckWithinDomains<UVT>{ { {(UVT)10, (UVT)100 }, { (UVT)-10, (UVT)100 }}, checkWithTolerance, tolerance}) 1496 .addDomain(CheckWithinDomains<UVT>{ { {(UVT)100, (UVT)1000 }, { (UVT)-100, (UVT)1000 }}, checkWithTolerance, tolerance}) 1497 .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-100, (UVT)-10 }, { (UVT)-100, (UVT)10 }}, checkWithTolerance, tolerance}) 1498 .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-1000, (UVT)-100 }, { (UVT)-1000, (UVT)100 }}, checkWithTolerance, tolerance}) 1499 .addDomain(CheckWithinDomains<UVT>{ {}, checkWithTolerance, tolerance}) 1500 .setTrialCount(trials) 1501 .setTestSeed(seed); 1502 } 1503 else { 1504 testCase = TestingCase<T>::getBuilder() 1505 .set(bitwise, false) 1506 .addDomain(CheckWithinDomains<UVT>{}) 1507 .setTrialCount(trials) 1508 .setTestSeed(seed); 1509 } 1510 return testCase; 1511 } 1512 1513 template<typename T> 1514 TestingCase<T> createDefaultTernaryTestCase(TestSeed seed = TestSeed(), bool bitwise = false, bool checkWithTolerance = false, size_t trials = 0) { 1515 TestingCase<T> testCase = TestingCase<T>::getBuilder() 1516 .set(bitwise, false) 1517 .addDomain(CheckWithinDomains<UvalueType<T>>{}) 1518 .setTrialCount(trials) 1519 .setTestSeed(seed); 1520 return testCase; 1521 } 1522