xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/vec_test_all_types.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/cpu/vec/vec.h>
3 #include <ATen/cpu/vec/functional.h>
4 #include <c10/util/bit_cast.h>
5 #include <c10/util/irange.h>
6 #include <gtest/gtest.h>
7 #include <chrono>
8 #include <exception>
9 #include <functional>
10 #include <iostream>
11 #include <limits>
12 #include <random>
13 #include <vector>
14 #include <complex>
15 #include <math.h>
16 #include <float.h>
17 #include <algorithm>
18 
19 #if defined(CPU_CAPABILITY_AVX512)
20 #define CACHE_LINE 64
21 #else
22 #define CACHE_LINE 32
23 #endif
24 
25 #if defined(__GNUC__)
26 #define CACHE_ALIGN __attribute__((aligned(CACHE_LINE)))
27 #define not_inline __attribute__((noinline))
28 #elif defined(_WIN32)
29 #define CACHE_ALIGN __declspec(align(CACHE_LINE))
30 #define not_inline __declspec(noinline)
31 #else
32 CACHE_ALIGN #define
33 #define not_inline
34 #endif
35 #if defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER)
36 #define TEST_AGAINST_DEFAULT 1
37 #elif !defined(CPU_CAPABILITY_AVX512) && !defined(CPU_CAPABILITY_AVX2) && !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_ZVECTOR)
38 #define TEST_AGAINST_DEFAULT 1
39 #else
40 #undef TEST_AGAINST_DEFAULT
41 #endif
42 #undef NAME_INFO
43 #define STRINGIFY(x) #x
44 #define TOSTRING(x) STRINGIFY(x)
45 #define NAME_INFO(name) TOSTRING(name) " " TOSTRING(__FILE__) ":" TOSTRING(__LINE__)
46 
47 #define RESOLVE_OVERLOAD(...)                                  \
48   [](auto&&... args) -> decltype(auto) {                       \
49     return __VA_ARGS__(std::forward<decltype(args)>(args)...); \
50   }
51 
52 #if defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_AVX2) || \
53   defined(CPU_CAPABILITY_AVX512) && (defined(__GNUC__) || defined(__GNUG__))
54 #undef CHECK_DEQUANT_WITH_LOW_PRECISION
55 #define CHECK_WITH_FMA 1
56 #elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2)
57 #undef CHECK_DEQUANT_WITH_LOW_PRECISION
58 #undef CHECK_WITH_FMA
59 #else
60 #define CHECK_DEQUANT_WITH_LOW_PRECISION 1
61 #undef CHECK_WITH_FMA
62 #endif
63 
64 template<typename T>
65 using Complex = typename c10::complex<T>;
66 
67 template <typename T>
68 using VecType = typename at::vec::Vectorized<T>;
69 
70 using vfloat = VecType<float>;
71 using vdouble = VecType<double>;
72 using vcomplex = VecType<Complex<float>>;
73 using vcomplexDbl = VecType<Complex<double>>;
74 using vlong = VecType<int64_t>;
75 using vint = VecType<int32_t>;
76 using vshort = VecType<int16_t>;
77 using vqint8 = VecType<c10::qint8>;
78 using vquint8 = VecType<c10::quint8>;
79 using vqint = VecType<c10::qint32>;
80 using vBFloat16 = VecType<c10::BFloat16>;
81 using vHalf = VecType<c10::Half>;
82 
83 template <typename T>
84 using ValueType = typename T::value_type;
85 
86 template <int N>
87 struct BitStr
88 {
89     using type = uintmax_t;
90 };
91 
92 template <>
93 struct BitStr<8>
94 {
95     using type = uint64_t;
96 };
97 
98 template <>
99 struct BitStr<4>
100 {
101     using type = uint32_t;
102 };
103 
104 template <>
105 struct BitStr<2>
106 {
107     using type = uint16_t;
108 };
109 
110 template <>
111 struct BitStr<1>
112 {
113     using type = uint8_t;
114 };
115 
116 template <typename T>
117 using BitType = typename BitStr<sizeof(T)>::type;
118 
119 template<typename T>
120 struct VecTypeHelper {
121     using holdType = typename T::value_type;
122     using memStorageType = typename T::value_type;
123     static constexpr int holdCount = T::size();
124     static constexpr int unitStorageCount = 1;
125 };
126 
127 template<>
128 struct VecTypeHelper<vcomplex> {
129     using holdType = Complex<float>;
130     using memStorageType = float;
131     static constexpr int holdCount = vcomplex::size();
132     static constexpr int unitStorageCount = 2;
133 };
134 
135 template<>
136 struct VecTypeHelper<vcomplexDbl> {
137     using holdType = Complex<double>;
138     using memStorageType = double;
139     static constexpr int holdCount = vcomplexDbl::size();
140     static constexpr int unitStorageCount = 2;
141 };
142 
143 template<>
144 struct VecTypeHelper<vqint8> {
145     using holdType = c10::qint8;
146     using memStorageType = typename c10::qint8::underlying;
147     static constexpr int holdCount = vqint8::size();
148     static constexpr int unitStorageCount = 1;
149 };
150 
151 template<>
152 struct VecTypeHelper<vquint8> {
153     using holdType = c10::quint8;
154     using memStorageType = typename c10::quint8::underlying;
155     static constexpr int holdCount = vquint8::size();
156     static constexpr int unitStorageCount = 1;
157 };
158 
159 template<>
160 struct VecTypeHelper<vqint> {
161     using holdType = c10::qint32;
162     using memStorageType = typename c10::qint32::underlying;
163     static constexpr int holdCount = vqint::size();
164     static constexpr int unitStorageCount = 1;
165 };
166 
167 template<>
168 struct VecTypeHelper<vBFloat16> {
169     using holdType = c10::BFloat16;
170     using memStorageType = typename vBFloat16::value_type;
171     static constexpr int holdCount = vBFloat16::size();
172     static constexpr int unitStorageCount = 1;
173 };
174 
175 template<>
176 struct VecTypeHelper<vHalf> {
177     using holdType = c10::Half;
178     using memStorageType = typename vHalf::value_type;
179     static constexpr int holdCount = vHalf::size();
180     static constexpr int unitStorageCount = 1;
181 };
182 
183 template <typename T>
184 using UholdType = typename VecTypeHelper<T>::holdType;
185 
186 template <typename T>
187 using UvalueType = typename VecTypeHelper<T>::memStorageType;
188 
189 template <class T, size_t N>
190 constexpr size_t size(T(&)[N]) {
191     return N;
192 }
193 
194 template <typename Filter, typename T>
195 typename std::enable_if_t<std::is_same_v<Filter, std::nullptr_t>, void>
196 call_filter(Filter filter, T& val) {}
197 
198 template <typename Filter, typename T>
199 typename std::enable_if_t< std::is_same_v<Filter, std::nullptr_t>, void>
200 call_filter(Filter filter, T& first, T& second) { }
201 
202 template <typename Filter, typename T>
203 typename std::enable_if_t< std::is_same_v<Filter, std::nullptr_t>, void>
204 call_filter(Filter filter, T& first, T& second, T& third) {  }
205 
206 template <typename Filter, typename T>
207 typename std::enable_if_t<
208     !std::is_same_v<Filter, std::nullptr_t>, void>
209     call_filter(Filter filter, T& val) {
210     return filter(val);
211 }
212 
213 template <typename Filter, typename T>
214 typename std::enable_if_t<
215     !std::is_same_v<Filter, std::nullptr_t>, void>
216     call_filter(Filter filter, T& first, T& second) {
217     return filter(first, second);
218 }
219 
220 template <typename Filter, typename T>
221 typename std::enable_if_t<
222     !std::is_same_v<Filter, std::nullptr_t>, void>
223     call_filter(Filter filter, T& first, T& second, T& third) {
224     return filter(first, second, third);
225 }
226 
227 template <typename T>
228 struct DomainRange {
229     T start;  // start [
230     T end;    // end is not included. one could use  nextafter for including his end case for tests
231 };
232 
233 template <typename T>
234 struct CustomCheck {
235     std::vector<UholdType<T>> Args;
236     UholdType<T> expectedResult;
237 };
238 
239 template <typename T>
240 struct CheckWithinDomains {
241     // each argument takes domain Range
242     std::vector<DomainRange<T>> ArgsDomain;
243     // check with error tolerance
244     bool CheckWithTolerance = false;
245     T ToleranceError = (T)0;
246 };
247 
248 template <typename T>
249 std::ostream& operator<<(std::ostream& stream, const CheckWithinDomains<T>& dmn) {
250     stream << "Domain: ";
251     if (dmn.ArgsDomain.size() > 0) {
252         for (const DomainRange<T>& x : dmn.ArgsDomain) {
253             if constexpr (std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>) {
254                 stream << "\n{ " << static_cast<int>(x.start) << ", " << static_cast<int>(x.end) << " }";
255             }
256             else {
257                 stream << "\n{ " << x.start << ", " << x.end << " }";
258             }
259         }
260     }
261     else {
262         stream << "default range";
263     }
264     if (dmn.CheckWithTolerance) {
265         stream << "\nError tolerance: " << dmn.ToleranceError;
266     }
267     return stream;
268 }
269 
270 template <typename T>
271 bool check_both_nan(T x, T y) {
272     if constexpr (std::is_floating_point_v<T>) {
273         return std::isnan(x) && std::isnan(y);
274     }
275     return false;
276 }
277 
278 template <typename T>
279 bool check_both_inf(T x, T y) {
280     if constexpr (std::is_floating_point_v<T>) {
281         return std::isinf(x) && std::isinf(y);
282     }
283     return false;
284 }
285 
286 template<typename T>
287 std::enable_if_t<!std::is_floating_point_v<T>, bool> check_both_big(T x, T y) {
288     return false;
289 }
290 
291 template<typename T>
292 std::enable_if_t<std::is_floating_point_v<T>, bool> check_both_big(T x, T y) {
293     T cmax = std::is_same_v<T, float> ? static_cast<T>(1e+30) : static_cast<T>(1e+300);
294     T cmin = std::is_same_v<T, float> ? static_cast<T>(-1e+30) : static_cast<T>(-1e+300);
295     //only allow when one is inf
296     bool x_inf = std::isinf(x);
297     bool y_inf = std::isinf(y);
298     bool px = x > 0;
299     bool py = y > 0;
300     return (px && x_inf && y >= cmax) || (py && y_inf && x >= cmax) ||
301         (!px && x_inf && y <= cmin) || (!py && y_inf && x <= cmin);
302 }
303 
304 template<class T> struct is_complex : std::false_type {};
305 
306 template<class T> struct is_complex<Complex<T>> : std::true_type {};
307 
308 template<typename T>
309 T safe_fpt_division(T f1, T f2)
310 {
311     //code was taken from boost
312     // Avoid overflow.
313     if ((f2 < static_cast<T>(1)) && (f1 > f2 * std::numeric_limits<T>::max())) {
314         return std::numeric_limits<T>::max();
315     }
316     // Avoid underflow.
317     if ((f1 == static_cast<T>(0)) ||
318         ((f2 > static_cast<T>(1)) && (f1 < f2 * std::numeric_limits<T>::min()))) {
319         return static_cast<T>(0);
320     }
321     return f1 / f2;
322 }
323 
324 template<class T>
325 std::enable_if_t<std::is_floating_point_v<T>, bool>
326 nearlyEqual(T a, T b, T tolerance) {
327     if (check_both_nan<T>(a, b)) return true;
328     if (check_both_big(a, b)) return true;
329     T absA = std::abs(a);
330     T absB = std::abs(b);
331     T diff = std::abs(a - b);
332     if (diff <= tolerance) {
333         return true;
334     }
335     T d1 = safe_fpt_division<T>(diff, absB);
336     T d2 = safe_fpt_division<T>(diff, absA);
337     return (d1 <= tolerance) || (d2 <= tolerance);
338 }
339 
340 template<class T>
341 std::enable_if_t<!std::is_floating_point_v<T>, bool>
342 nearlyEqual(T a, T b, T tolerance) {
343     return a == b;
344 }
345 
346 template <typename T>
347 T reciprocal(T x) {
348     return 1 / x;
349 }
350 
351 template <typename T>
352 T rsqrt(T x) {
353     return 1 / std::sqrt(x);
354 }
355 
356 template <typename T>
357 T frac(T x) {
358   return x - std::trunc(x);
359 }
360 
361 template <class T>
362 T maximum(const T& a, const T& b) {
363     return (a > b) ? a : b;
364 }
365 
366 template <class T>
367 T minimum(const T& a, const T& b) {
368     return (a < b) ? a : b;
369 }
370 
371 template <class T>
372 T clamp(const T& a, const T& min, const T& max) {
373     return a < min ? min : (a > max ? max : a);
374 }
375 
376 template <class T>
377 T clamp_max(const T& a, const T& max) {
378     return a > max ? max : a;
379 }
380 
381 template <class T>
382 T clamp_min(const T& a, const T& min) {
383     return a < min ? min : a;
384 }
385 
386 template <class VT, size_t N>
387 void copy_interleave(VT(&vals)[N], VT(&interleaved)[N]) {
388     static_assert(N % 2 == 0, "should be even");
389     auto ptr1 = vals;
390     auto ptr2 = vals + N / 2;
391     for (size_t i = 0; i < N; i += 2) {
392         interleaved[i] = *ptr1++;
393         interleaved[i + 1] = *ptr2++;
394     }
395 }
396 
397 template <typename T>
398 bool is_zero(T val) {
399     if constexpr (std::is_floating_point_v<T>) {
400         return std::fpclassify(val) == FP_ZERO;
401     } else {
402         return val == 0;
403     }
404 }
405 
406 template <typename T>
407 void filter_clamp(T& f, T& s, T& t) {
408     if (t < s) {
409         std::swap(s, t);
410     }
411 }
412 
413 template <typename T>
414 std::enable_if_t<std::is_floating_point_v<T>, void> filter_fmod(T& a, T& b) {
415     // This is to make sure fmod won't cause overflow when doing the div
416     if (std::abs(b) < (T)1) {
417       b = b < (T)0 ? (T)-1 : T(1);
418     }
419 }
420 
421 template <typename T>
422 std::enable_if_t<std::is_floating_point_v<T>, void> filter_fmadd(T& a, T& b, T& c) {
423     // This is to setup a limit to make sure fmadd (a * b + c) won't overflow
424     T max = std::sqrt(std::numeric_limits<T>::max()) / T(2.0);
425     T min = ((T)0 - max);
426 
427     if (a > max) a = max;
428     else if (a < min) a = min;
429 
430     if (b > max) b = max;
431     else if (b < min) b = min;
432 
433     if (c > max) c = max;
434     else if (c < min) c = min;
435 }
436 
437 template <typename T>
438 void filter_zero(T& val) {
439     val = is_zero(val) ? (T)1 : val;
440 }
441 template <typename T>
442 std::enable_if_t<is_complex<Complex<T>>::value, void> filter_zero(Complex<T>& val) {
443     T rr = val.real();
444     T ii = val.imag();
445     rr = is_zero(rr) ? (T)1 : rr;
446     ii = is_zero(ii) ? (T)1 : ii;
447     val = Complex<T>(rr, ii);
448 }
449 
450 template <typename T>
451 void filter_int_minimum(T& val) {
452     if constexpr (!std::is_integral_v<T>) return;
453     if (val == std::numeric_limits<T>::min()) {
454         val = 0;
455     }
456 }
457 
458 template <typename T>
459 std::enable_if_t<is_complex<T>::value, void> filter_add_overflow(T& a, T& b)
460 {
461     //missing for complex
462 }
463 
464 template <typename T>
465 std::enable_if_t<is_complex<T>::value, void> filter_sub_overflow(T& a, T& b)
466 {
467     //missing for complex
468 }
469 
470 template <typename T>
471 std::enable_if_t < !is_complex<T>::value, void> filter_add_overflow(T& a, T& b) {
472     if constexpr (std::is_integral_v<T> == false) return;
473     T max = std::numeric_limits<T>::max();
474     T min = std::numeric_limits<T>::min();
475     // min <= (a +b) <= max;
476     // min - b <= a  <= max - b
477     if (b < 0) {
478         if (a < min - b) {
479             a = min - b;
480         }
481     }
482     else {
483         if (a > max - b) {
484             a = max - b;
485         }
486     }
487 }
488 
489 template <typename T>
490 std::enable_if_t < !is_complex<T>::value, void> filter_sub_overflow(T& a, T& b) {
491     if constexpr (std::is_integral_v<T> == false) return;
492     T max = std::numeric_limits<T>::max();
493     T min = std::numeric_limits<T>::min();
494     // min <= (a-b) <= max;
495     // min + b <= a  <= max +b
496     if (b < 0) {
497         if (a > max + b) {
498             a = max + b;
499         }
500     }
501     else {
502         if (a < min + b) {
503             a = min + b;
504         }
505     }
506 }
507 
508 template <typename T>
509 std::enable_if_t<is_complex<T>::value, void>
510 filter_mult_overflow(T& val1, T& val2) {
511     //missing
512 }
513 
514 template <typename T>
515 std::enable_if_t<is_complex<T>::value, void>
516 filter_div_ub(T& val1, T& val2) {
517     //missing
518     //at least consdier zero division
519     auto ret = std::abs(val2);
520     if (ret == 0) {
521         val2 = T(1, 2);
522     }
523 }
524 
525 template <typename T>
526 std::enable_if_t<!is_complex<T>::value, void>
527 filter_mult_overflow(T& val1, T& val2) {
528     if constexpr (std::is_integral_v<T> == false) return;
529     if (!is_zero(val2)) {
530         T c = (std::numeric_limits<T>::max() - 1) / val2;
531         if (std::abs(val1) >= c) {
532             // correct first;
533             val1 = c;
534         }
535     }  // is_zero
536 }
537 
538 template <typename T>
539 std::enable_if_t<!is_complex<T>::value, void>
540 filter_div_ub(T& val1, T& val2) {
541     if (is_zero(val2)) {
542         val2 = 1;
543     }
544     else if (std::is_integral_v<T> && val1 == std::numeric_limits<T>::min() && val2 == -1) {
545         val2 = 1;
546     }
547 }
548 
549 struct TestSeed {
550     TestSeed() : seed(std::chrono::high_resolution_clock::now().time_since_epoch().count()) {
551     }
552     TestSeed(uint64_t seed) : seed(seed) {
553     }
554     uint64_t getSeed() {
555         return seed;
556     }
557     operator uint64_t () const {
558         return seed;
559     }
560 
561     TestSeed add(uint64_t index) {
562         return TestSeed(seed + index);
563     }
564 private:
565     uint64_t seed;
566 };
567 
568 template <typename T, bool is_floating_point = std::is_floating_point_v<T>, bool is_complex = is_complex<T>::value>
569 struct ValueGen
570 {
571     std::uniform_int_distribution<int64_t> dis;
572     std::mt19937 gen;
573     ValueGen() : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max())
574     {
575     }
576     ValueGen(uint64_t seed) : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max(), seed)
577     {
578     }
579     ValueGen(T start, T stop, uint64_t seed = TestSeed())
580     {
581         gen = std::mt19937(seed);
582         dis = std::uniform_int_distribution<int64_t>(start, stop);
583     }
584     T get()
585     {
586         return static_cast<T>(dis(gen));
587     }
588 };
589 
590 template <typename T>
591 struct ValueGen<T, true, false>
592 {
593     std::mt19937 gen;
594     std::normal_distribution<T> normal;
595     std::uniform_int_distribution<int> roundChance;
596     T _start;
597     T _stop;
598     bool use_sign_change = false;
599     bool use_round = true;
600     ValueGen() : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max())
601     {
602     }
603     ValueGen(uint64_t seed) : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max(), seed)
604     {
605     }
606     ValueGen(T start, T stop, uint64_t seed = TestSeed())
607     {
608         gen = std::mt19937(seed);
609         T mean = start * static_cast<T>(0.5) + stop * static_cast<T>(0.5);
610         //make it  normal +-3sigma
611         T divRange = static_cast<T>(6.0);
612         T stdev = std::abs(stop / divRange - start / divRange);
613         normal = std::normal_distribution<T>{ mean, stdev };
614         // in real its hard to get rounded value
615         // so we will force it by  uniform chance
616         roundChance = std::uniform_int_distribution<int>(0, 5);
617         _start = start;
618         _stop = stop;
619     }
620     T get()
621     {
622         T a = normal(gen);
623         //make rounded value ,too
624         auto rChoice = roundChance(gen);
625         if (rChoice == 1)
626             a = std::round(a);
627         if (a < _start)
628             return nextafter(_start, _stop);
629         if (a >= _stop)
630             return nextafter(_stop, _start);
631         return a;
632     }
633 };
634 
635 template <typename T>
636 struct ValueGen<Complex<T>, false, true>
637 {
638     std::mt19937 gen;
639     std::normal_distribution<T> normal;
640     std::uniform_int_distribution<int> roundChance;
641     T _start;
642     T _stop;
643     bool use_sign_change = false;
644     bool use_round = true;
645     ValueGen() : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max())
646     {
647     }
648     ValueGen(uint64_t seed) : ValueGen(std::numeric_limits<T>::min(), std::numeric_limits<T>::max(), seed)
649     {
650     }
651     ValueGen(T start, T stop, uint64_t seed = TestSeed())
652     {
653         gen = std::mt19937(seed);
654         T mean = start * static_cast<T>(0.5) + stop * static_cast<T>(0.5);
655         //make it  normal +-3sigma
656         T divRange = static_cast<T>(6.0);
657         T stdev = std::abs(stop / divRange - start / divRange);
658         normal = std::normal_distribution<T>{ mean, stdev };
659         // in real its hard to get rounded value
660         // so we will force it by  uniform chance
661         roundChance = std::uniform_int_distribution<int>(0, 5);
662         _start = start;
663         _stop = stop;
664     }
665     Complex<T> get()
666     {
667         T a = normal(gen);
668         T b = normal(gen);
669         //make rounded value ,too
670         auto rChoice = roundChance(gen);
671         rChoice = rChoice & 3;
672         if (rChoice & 1)
673             a = std::round(a);
674         if (rChoice & 2)
675             b = std::round(b);
676         if (a < _start)
677             a = nextafter(_start, _stop);
678         else if (a >= _stop)
679             a = nextafter(_stop, _start);
680         if (b < _start)
681             b = nextafter(_start, _stop);
682         else if (b >= _stop)
683             b = nextafter(_stop, _start);
684         return Complex<T>(a, b);
685     }
686 };
687 
688 template<class T>
689 int getTrialCount(int test_trials, int domains_size) {
690     int trialCount;
691     int trial_default = 1;
692     if (sizeof(T) <= 2) {
693         //half coverage for byte
694         trial_default = 128;
695     }
696     else {
697         //2*65536
698         trial_default = 2 * std::numeric_limits<uint16_t>::max();
699     }
700     trialCount = test_trials < 1 ? trial_default : test_trials;
701     if (domains_size > 1) {
702         trialCount = trialCount / domains_size;
703         trialCount = trialCount < 1 ? 1 : trialCount;
704     }
705     return trialCount;
706 }
707 
708 template <typename T, typename U = UvalueType<T>>
709 class TestCaseBuilder;
710 
711 template <typename T, typename U = UvalueType<T>>
712 class TestingCase {
713 public:
714     friend class TestCaseBuilder<T, U>;
715     static TestCaseBuilder<T, U> getBuilder() { return TestCaseBuilder<T, U>{}; }
716     bool checkSpecialValues() const {
717         //this will be used to check nan, infs, and other special cases
718         return specialCheck;
719     }
720     size_t getTrialCount() const { return trials; }
721     bool isBitwise() const { return bitwise; }
722     const std::vector<CheckWithinDomains<U>>& getDomains() const {
723         return domains;
724     }
725     const std::vector<CustomCheck<T>>& getCustomChecks() const {
726         return customCheck;
727     }
728     TestSeed getTestSeed() const {
729         return testSeed;
730     }
731 private:
732     // if domains is empty we will test default
733     std::vector<CheckWithinDomains<U>> domains;
734     std::vector<CustomCheck<T>> customCheck;
735     // its not used for now
736     bool specialCheck = false;
737     bool bitwise = false;  // test bitlevel
738     size_t trials = 0;
739     TestSeed testSeed;
740 };
741 
742 template <typename T, typename U >
743 class TestCaseBuilder {
744 private:
745     TestingCase<T, U> _case;
746 public:
747     TestCaseBuilder<T, U>& set(bool bitwise, bool checkSpecialValues) {
748         _case.bitwise = bitwise;
749         _case.specialCheck = checkSpecialValues;
750         return *this;
751     }
752     TestCaseBuilder<T, U>& setTestSeed(TestSeed seed) {
753         _case.testSeed = seed;
754         return *this;
755     }
756     TestCaseBuilder<T, U>& setTrialCount(size_t trial_count) {
757         _case.trials = trial_count;
758         return *this;
759     }
760     TestCaseBuilder<T, U>& addDomain(const CheckWithinDomains<U>& domainCheck) {
761         _case.domains.emplace_back(domainCheck);
762         return *this;
763     }
764     TestCaseBuilder<T, U>& addCustom(const CustomCheck<T>& customArgs) {
765         _case.customCheck.emplace_back(customArgs);
766         return *this;
767     }
768     TestCaseBuilder<T, U>& checkSpecialValues() {
769         _case.specialCheck = true;
770         return *this;
771     }
772     TestCaseBuilder<T, U>& compareBitwise() {
773         _case.bitwise = true;
774         return *this;
775     }
776     operator TestingCase<T, U> && () { return std::move(_case); }
777 };
778 
779 template <typename T>
780 typename std::enable_if_t<!is_complex<T>::value&& std::is_unsigned<T>::value, T>
781 correctEpsilon(const T& eps)
782 {
783     return eps;
784 }
785 template <typename T>
786 typename std::enable_if_t<!is_complex<T>::value && !std::is_unsigned<T>::value, T>
787 correctEpsilon(const T& eps)
788 {
789     return std::abs(eps);
790 }
791 template <typename T>
792 typename std::enable_if_t<is_complex<Complex<T>>::value, T>
793 correctEpsilon(const Complex<T>& eps)
794 {
795     return std::abs(eps);
796 }
797 
798 template <typename T>
799 class AssertVectorized
800 {
801 public:
802     AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0)
803         : additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), argSize(1)
804     {
805     }
806     AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0, const T& input1)
807         : additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), arg1(input1), argSize(2)
808     {
809     }
810     AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0, const T& input1, const T& input2)
811         : additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), arg1(input1), arg2(input2), argSize(3)
812     {
813     }
814     AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual) : additionalInfo(info), testSeed(seed), exp(expected), act(actual)
815     {
816     }
817     AssertVectorized(const std::string& info, const T& expected, const T& actual) : additionalInfo(info), exp(expected), act(actual), hasSeed(false)
818     {
819     }
820 
821     std::string getDetail(int index) const
822     {
823         using UVT = UvalueType<T>;
824         std::stringstream stream;
825         stream.precision(std::numeric_limits<UVT>::max_digits10);
826         stream << "Failure Details:\n";
827         stream << additionalInfo << "\n";
828         if (hasSeed)
829         {
830             stream << "Test Seed to reproduce: " << testSeed << "\n";
831         }
832         if (argSize > 0)
833         {
834             stream << "Arguments:\n";
835             stream << "#\t " << arg0 << "\n";
836             if (argSize == 2)
837             {
838                 stream << "#\t " << arg1 << "\n";
839             }
840             if (argSize == 3)
841             {
842                 stream << "#\t " << arg2 << "\n";
843             }
844         }
845         stream << "Expected:\n#\t" << exp << "\nActual:\n#\t" << act;
846         stream << "\nFirst mismatch Index: " << index;
847         return stream.str();
848     }
849 
850     bool check(bool bitwise = false, bool checkWithTolerance = false, ValueType<T> toleranceEps = {}) const
851     {
852         using UVT = UvalueType<T>;
853         using BVT = BitType<UVT>;
854         UVT absErr = correctEpsilon(toleranceEps);
855         constexpr int sizeX = VecTypeHelper<T>::holdCount * VecTypeHelper<T>::unitStorageCount;
856         constexpr int unitStorageCount = VecTypeHelper<T>::unitStorageCount;
857         CACHE_ALIGN UVT expArr[sizeX];
858         CACHE_ALIGN UVT actArr[sizeX];
859         exp.store(expArr);
860         act.store(actArr);
861         if (bitwise)
862         {
863             for (const auto i : c10::irange(sizeX)) {
864                 BVT b_exp = c10::bit_cast<BVT>(expArr[i]);
865                 BVT b_act = c10::bit_cast<BVT>(actArr[i]);
866                 EXPECT_EQ(b_exp, b_act) << getDetail(i / unitStorageCount);
867                 if (::testing::Test::HasFailure())
868                     return true;
869             }
870         }
871         else if (checkWithTolerance)
872         {
873             for (const auto i : c10::irange(sizeX)) {
874                 EXPECT_EQ(nearlyEqual<UVT>(expArr[i], actArr[i], absErr), true) << expArr[i] << "!=" << actArr[i] << "\n" << getDetail(i / unitStorageCount);
875                 if (::testing::Test::HasFailure())
876                     return true;
877             }
878         }
879         else
880         {
881             for (const auto i : c10::irange(sizeX)) {
882                 if constexpr (std::is_same_v<UVT, float>)
883                 {
884                     if (!check_both_nan(expArr[i], actArr[i])) {
885                         EXPECT_FLOAT_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount);
886                     }
887                 }
888                 else if constexpr (std::is_same_v<UVT, double>)
889                 {
890                     if (!check_both_nan(expArr[i], actArr[i]))
891                     {
892                         EXPECT_DOUBLE_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount);
893                     }
894                 }
895                 else
896                 {
897                     EXPECT_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount);
898                 }
899                 if (::testing::Test::HasFailure())
900                     return true;
901             }
902         }
903         return false;
904     }
905 
906 private:
907     std::string additionalInfo;
908     TestSeed testSeed;
909     T exp;
910     T act;
911     T arg0;
912     T arg1;
913     T arg2;
914     int argSize = 0;
915     bool hasSeed = true;
916 };
917 
918 template< typename T, typename Op1, typename Op2, typename Filter = std::nullptr_t>
919 void test_unary(
920     std::string testNameInfo,
921     Op1 expectedFunction,
922     Op2 actualFunction, const TestingCase<T>& testCase, Filter filter = {}) {
923     using vec_type = T;
924     using VT = ValueType<T>;
925     using UVT = UvalueType<T>;
926     constexpr int el_count = vec_type::size();
927     CACHE_ALIGN VT vals[el_count];
928     CACHE_ALIGN VT expected[el_count];
929     bool bitwise = testCase.isBitwise();
930     UVT default_start = std::is_floating_point_v<UVT> ? std::numeric_limits<UVT>::lowest() : std::numeric_limits<UVT>::min();
931     UVT default_end = std::numeric_limits<UVT>::max();
932     auto domains = testCase.getDomains();
933     auto domains_size = domains.size();
934     auto test_trials = testCase.getTrialCount();
935     int trialCount = getTrialCount<UVT>(test_trials, domains_size);
936     TestSeed seed = testCase.getTestSeed();
937     uint64_t changeSeedBy = 0;
938     for (const CheckWithinDomains<UVT>& dmn : domains) {
939         size_t dmn_argc = dmn.ArgsDomain.size();
940         UVT start = dmn_argc > 0 ? dmn.ArgsDomain[0].start : default_start;
941         UVT end = dmn_argc > 0 ? dmn.ArgsDomain[0].end : default_end;
942         ValueGen<VT> generator(start, end, seed.add(changeSeedBy));
943         for (C10_UNUSED const auto trial : c10::irange(trialCount)) {
944             for (const auto k : c10::irange(el_count)) {
945                 vals[k] = generator.get();
946                 call_filter(filter, vals[k]);
947                 //map operator
948                 expected[k] = expectedFunction(vals[k]);
949             }
950             // test
951             auto input = vec_type::loadu(vals);
952             auto actual = actualFunction(input);
953             auto vec_expected = vec_type::loadu(expected);
954             AssertVectorized<vec_type> vecAssert(testNameInfo, seed, vec_expected, actual, input);
955             if (vecAssert.check(bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) return;
956 
957         }// trial
958         //inrease Seed
959         changeSeedBy += 1;
960     }
961     for (auto& custom : testCase.getCustomChecks()) {
962         auto args = custom.Args;
963         if (args.size() > 0) {
964             auto input = vec_type{ args[0] };
965             auto actual = actualFunction(input);
966             auto vec_expected = vec_type{ custom.expectedResult };
967             AssertVectorized<vec_type> vecAssert(testNameInfo, seed, vec_expected, actual, input);
968             if (vecAssert.check()) return;
969         }
970     }
971 }
972 
973 template< typename T, typename Op1, typename Op2, typename Filter = std::nullptr_t>
974 void test_binary(
975     std::string testNameInfo,
976     Op1 expectedFunction,
977     Op2 actualFunction, const TestingCase<T>& testCase, Filter filter = {}) {
978     using vec_type = T;
979     using VT = ValueType<T>;
980     using UVT = UvalueType<T>;
981     constexpr int el_count = vec_type::size();
982     CACHE_ALIGN VT vals0[el_count];
983     CACHE_ALIGN VT vals1[el_count];
984     CACHE_ALIGN VT expected[el_count];
985     bool bitwise = testCase.isBitwise();
986     UVT default_start = std::is_floating_point_v<UVT> ? std::numeric_limits<UVT>::lowest() : std::numeric_limits<UVT>::min();
987     UVT default_end = std::numeric_limits<UVT>::max();
988     auto domains = testCase.getDomains();
989     auto domains_size = domains.size();
990     auto test_trials = testCase.getTrialCount();
991     int trialCount = getTrialCount<UVT>(test_trials, domains_size);
992     TestSeed seed = testCase.getTestSeed();
993     uint64_t changeSeedBy = 0;
994     for (const CheckWithinDomains<UVT>& dmn : testCase.getDomains()) {
995         size_t dmn_argc = dmn.ArgsDomain.size();
996         UVT start0 = dmn_argc > 0 ? dmn.ArgsDomain[0].start : default_start;
997         UVT end0 = dmn_argc > 0 ? dmn.ArgsDomain[0].end : default_end;
998         UVT start1 = dmn_argc > 1 ? dmn.ArgsDomain[1].start : default_start;
999         UVT end1 = dmn_argc > 1 ? dmn.ArgsDomain[1].end : default_end;
1000         ValueGen<VT> generator0(start0, end0, seed.add(changeSeedBy));
1001         ValueGen<VT> generator1(start1, end1, seed.add(changeSeedBy + 1));
1002         for (C10_UNUSED const auto trial : c10::irange(trialCount)) {
1003             for (const auto k : c10::irange(el_count)) {
1004                 vals0[k] = generator0.get();
1005                 vals1[k] = generator1.get();
1006                 call_filter(filter, vals0[k], vals1[k]);
1007                 //map operator
1008                 expected[k] = expectedFunction(vals0[k], vals1[k]);
1009             }
1010             // test
1011             auto input0 = vec_type::loadu(vals0);
1012             auto input1 = vec_type::loadu(vals1);
1013             auto actual = actualFunction(input0, input1);
1014             auto vec_expected = vec_type::loadu(expected);
1015             AssertVectorized<vec_type> vecAssert(testNameInfo, seed, vec_expected, actual, input0, input1);
1016             if (vecAssert.check(bitwise, dmn.CheckWithTolerance, dmn.ToleranceError))return;
1017         }// trial
1018         changeSeedBy += 1;
1019     }
1020     for (auto& custom : testCase.getCustomChecks()) {
1021         auto args = custom.Args;
1022         if (args.size() > 0) {
1023             auto input0 = vec_type{ args[0] };
1024             auto input1 = args.size() > 1 ? vec_type{ args[1] } : vec_type{ args[0] };
1025             auto actual = actualFunction(input0, input1);
1026             auto vec_expected = vec_type(custom.expectedResult);
1027             AssertVectorized<vec_type> vecAssert(testNameInfo, seed, vec_expected, actual, input0, input1);
1028             if (vecAssert.check()) return;
1029         }
1030     }
1031 }
1032 
1033 template< typename T, typename Op1, typename Op2, typename Filter = std::nullptr_t>
1034 void test_ternary(
1035     std::string testNameInfo,
1036     Op1 expectedFunction,
1037     Op2 actualFunction, const TestingCase<T>& testCase, Filter filter = {}) {
1038     using vec_type = T;
1039     using VT = ValueType<T>;
1040     using UVT = UvalueType<T>;
1041     constexpr int el_count = vec_type::size();
1042     CACHE_ALIGN VT vals0[el_count];
1043     CACHE_ALIGN VT vals1[el_count];
1044     CACHE_ALIGN VT vals2[el_count];
1045     CACHE_ALIGN VT expected[el_count];
1046     bool bitwise = testCase.isBitwise();
1047     UVT default_start = std::is_floating_point_v<UVT> ? std::numeric_limits<UVT>::lowest() : std::numeric_limits<UVT>::min();
1048     UVT default_end = std::numeric_limits<UVT>::max();
1049     auto domains = testCase.getDomains();
1050     auto domains_size = domains.size();
1051     auto test_trials = testCase.getTrialCount();
1052     int trialCount = getTrialCount<UVT>(test_trials, domains_size);
1053     TestSeed seed = testCase.getTestSeed();
1054     uint64_t changeSeedBy = 0;
1055     for (const CheckWithinDomains<UVT>& dmn : testCase.getDomains()) {
1056         size_t dmn_argc = dmn.ArgsDomain.size();
1057         UVT start0 = dmn_argc > 0 ? dmn.ArgsDomain[0].start : default_start;
1058         UVT end0 = dmn_argc > 0 ? dmn.ArgsDomain[0].end : default_end;
1059         UVT start1 = dmn_argc > 1 ? dmn.ArgsDomain[1].start : default_start;
1060         UVT end1 = dmn_argc > 1 ? dmn.ArgsDomain[1].end : default_end;
1061         UVT start2 = dmn_argc > 2 ? dmn.ArgsDomain[2].start : default_start;
1062         UVT end2 = dmn_argc > 2 ? dmn.ArgsDomain[2].end : default_end;
1063         ValueGen<VT> generator0(start0, end0, seed.add(changeSeedBy));
1064         ValueGen<VT> generator1(start1, end1, seed.add(changeSeedBy + 1));
1065         ValueGen<VT> generator2(start2, end2, seed.add(changeSeedBy + 2));
1066 
1067         for (C10_UNUSED const auto trial : c10::irange(trialCount)) {
1068             for (const auto k : c10::irange(el_count)) {
1069                 vals0[k] = generator0.get();
1070                 vals1[k] = generator1.get();
1071                 vals2[k] = generator2.get();
1072                 call_filter(filter, vals0[k], vals1[k], vals2[k]);
1073                 //map operator
1074                 expected[k] = expectedFunction(vals0[k], vals1[k], vals2[k]);
1075             }
1076             // test
1077             auto input0 = vec_type::loadu(vals0);
1078             auto input1 = vec_type::loadu(vals1);
1079             auto input2 = vec_type::loadu(vals2);
1080             auto actual = actualFunction(input0, input1, input2);
1081             auto vec_expected = vec_type::loadu(expected);
1082             AssertVectorized<vec_type> vecAssert(testNameInfo, seed, vec_expected, actual, input0, input1, input2);
1083             if (vecAssert.check(bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) return;
1084         }// trial
1085         changeSeedBy += 1;
1086     }
1087 }
1088 
1089 template <typename T, typename Op>
1090 T func_cmp(Op call, T v0, T v1) {
1091     using bit_rep = BitType<T>;
1092     constexpr bit_rep mask = std::numeric_limits<bit_rep>::max();
1093     bit_rep  ret = call(v0, v1) ? mask : 0;
1094     return c10::bit_cast<T>(ret);
1095 }
1096 
1097 struct PreventFma
1098 {
1099     not_inline float sub(float a, float b)
1100     {
1101         return a - b;
1102     }
1103     not_inline double sub(double a, double b)
1104     {
1105         return a - b;
1106     }
1107     not_inline float add(float a, float b)
1108     {
1109         return a + b;
1110     }
1111     not_inline double add(double a, double b)
1112     {
1113         return a + b;
1114     }
1115 };
1116 
1117 template <typename T>
1118 std::enable_if_t<!is_complex<T>::value, T> local_log2(T x) {
1119     return std::log2(x);
1120 }
1121 
1122 template <typename T>
1123 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_log2(Complex<T> x) {
1124     T ret = std::log(x);
1125     T real = ret.real() / std::log(static_cast<T>(2));
1126     T imag = ret.imag() / std::log(static_cast<T>(2));
1127     return Complex<T>(real, imag);
1128 }
1129 
1130 template <typename T>
1131 std::enable_if_t<!is_complex<T>::value, T> local_abs(T x) {
1132     return std::abs(x);
1133 }
1134 
1135 template <typename T>
1136 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_abs(Complex<T> x) {
1137 #if defined(TEST_AGAINST_DEFAULT)
1138     return std::abs(x);
1139 #else
1140     PreventFma noFma;
1141     T real = x.real();
1142     T imag = x.imag();
1143     T rr = real * real;
1144     T ii = imag * imag;
1145     T abs = std::sqrt(noFma.add(rr, ii));
1146     return Complex<T>(abs, 0);
1147 #endif
1148 }
1149 
1150 template <typename T>
1151 std::enable_if_t<!is_complex<T>::value, T> local_multiply(T x, T y) {
1152     return x * y;
1153 }
1154 
1155 template <typename T>
1156 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_multiply(Complex<T> x, Complex<T> y) {
1157 #if defined(TEST_AGAINST_DEFAULT)
1158     return x * y;
1159 #else
1160     //(a + bi)  * (c + di) = (ac - bd) + (ad + bc)i
1161     T x_real = x.real();
1162     T x_imag = x.imag();
1163     T y_real = y.real();
1164     T y_imag = y.imag();
1165 #if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR)
1166     //check multiplication considerin swap and fma
1167     T rr = x_real * y_real;
1168     T ii = x_imag * y_real;
1169     T neg_imag = -y_imag;
1170     rr = fma(x_imag, neg_imag, rr);
1171     ii = fma(x_real, y_imag, ii);
1172 #else
1173     // replicate order
1174     PreventFma noFma;
1175     T ac = x_real * y_real;
1176     T bd = x_imag * y_imag;
1177     T ad = x_real * y_imag;
1178     T bc = x_imag * (-y_real);
1179     T rr = noFma.sub(ac, bd);
1180     T ii = noFma.sub(ad, bc);
1181 #endif
1182     return Complex<T>(rr, ii);
1183 #endif
1184 }
1185 
1186 
1187 
1188 template <typename T>
1189 std::enable_if_t<!is_complex<T>::value, T> local_division(T x, T y) {
1190     return x / y;
1191 }
1192 
1193 template <typename T>
1194 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_division(Complex<T> x, Complex<T> y) {
1195 #if defined(TEST_AGAINST_DEFAULT)
1196     return x / y;
1197 #else /* defined(TEST_AGAINST_DEFAULT) */
1198     //re = (ac + bd)/abs_2()
1199     //im = (bc - ad)/abs_2()
1200     T x_real = x.real();
1201     T x_imag = x.imag();
1202     T y_real = y.real();
1203     T y_imag = y.imag();
1204     PreventFma noFma;
1205 #if defined(CPU_CAPABILITY_ZVECTOR)
1206     T abs_c = std::abs(y_real);
1207     T abs_d = std::abs(y_imag);
1208     T scale = 1.0 / std::max(abs_c, abs_d);
1209 
1210     T a_sc = x_real * scale; // a/sc
1211     T b_sc = x_imag * scale; // b/sc
1212     T c_sc = y_real * scale; // c/sc
1213     T d_sc = y_imag * scale; // d/sc
1214 
1215     T ac_sc2 = a_sc * c_sc; // ac/sc^2
1216     T bd_sc2 = b_sc * d_sc; // bd/sc^2
1217 
1218     T neg_d_sc = -1.0 * d_sc; // -d/sc^2
1219 
1220     T neg_ad_sc2 = a_sc * neg_d_sc; // -ad/sc^2
1221     T bc_sc2 = b_sc * c_sc; // bc/sc^2
1222 
1223     T ac_bd_sc2 = noFma.add(ac_sc2, bd_sc2); // (ac+bd)/sc^2
1224     T bc_ad_sc2 = noFma.add(bc_sc2, neg_ad_sc2); // (bc-ad)/sc^2
1225 
1226     T c2_sc2 = c_sc * c_sc; // c^2/sc^2
1227     T d2_sc2 = d_sc * d_sc; // d^2/sc^2
1228 
1229     T c2_d2_sc2 = noFma.add(c2_sc2, d2_sc2); // (c^2+d^2)/sc^2
1230 
1231     T rr = ac_bd_sc2 / c2_d2_sc2; // (ac+bd)/(c^2+d^2)
1232     T ii = bc_ad_sc2 / c2_d2_sc2; // (bc-ad)/(c^2+d^2)
1233 
1234     return Complex<T>(rr, ii);
1235 #else /* defined(CPU_CAPABILITY_ZVECTOR) */
1236 #if defined(CPU_CAPABILITY_VSX)
1237     //check multiplication considerin swap and fma
1238     T rr = x_real * y_real;
1239     T ii = x_imag * y_real;
1240     T neg_imag = -y_imag;
1241     rr = fma(x_imag, y_imag, rr);
1242     ii = fma(x_real, neg_imag, ii);
1243     //b.abs_2
1244 #else /* defined(CPU_CAPABILITY_VSX) */
1245     T ac = x_real * y_real;
1246     T bd = x_imag * y_imag;
1247     T ad = x_real * y_imag;
1248     T bc = x_imag * y_real;
1249     T rr = noFma.add(ac, bd);
1250     T ii = noFma.sub(bc, ad);
1251 #endif /* defined(CPU_CAPABILITY_VSX) */
1252     //b.abs_2()
1253     T abs_rr = y_real * y_real;
1254     T abs_ii = y_imag * y_imag;
1255     T abs_2 = noFma.add(abs_rr, abs_ii);
1256     rr = rr / abs_2;
1257     ii = ii / abs_2;
1258     return Complex<T>(rr, ii);
1259 #endif /* defined(CPU_CAPABILITY_ZVECTOR) */
1260 #endif /* defined(TEST_AGAINST_DEFAULT) */
1261 }
1262 
1263 
1264 template <typename T>
1265 std::enable_if_t<!is_complex<T>::value, T> local_fmadd(T a, T b, T c) {
1266     PreventFma noFma;
1267     T ab = a * b;
1268     return noFma.add(ab, c);
1269 }
1270 
1271 template <typename T>
1272 std::enable_if_t<!is_complex<T>::value, T> local_sqrt(T x) {
1273     return std::sqrt(x);
1274 }
1275 
1276 template <typename T>
1277 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_sqrt(Complex<T> x) {
1278     return std::sqrt(x);
1279 }
1280 
1281 template <typename T>
1282 std::enable_if_t<!is_complex<T>::value, T> local_asin(T x) {
1283     return std::asin(x);
1284 }
1285 
1286 template <typename T>
1287 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_asin(Complex<T> x) {
1288     return std::asin(x);
1289 }
1290 
1291 template <typename T>
1292 std::enable_if_t<!is_complex<T>::value, T> local_acos(T x) {
1293     return std::acos(x);
1294 }
1295 
1296 template <typename T>
1297 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_acos(Complex<T> x) {
1298     return std::acos(x);
1299 }
1300 
1301 template<typename T>
1302 std::enable_if_t<!is_complex<T>::value, T>
1303 local_and(const T& val0, const T& val1) {
1304     using bit_rep = BitType<T>;
1305     bit_rep ret = c10::bit_cast<bit_rep>(val0) & c10::bit_cast<bit_rep>(val1);
1306     return c10::bit_cast<T> (ret);
1307 }
1308 
1309 template <typename T>
1310 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>>
1311 local_and(const Complex<T>& val0, const Complex<T>& val1)
1312 {
1313     using bit_rep = BitType<T>;
1314     T real1 = val0.real();
1315     T imag1 = val0.imag();
1316     T real2 = val1.real();
1317     T imag2 = val1.imag();
1318     bit_rep real_ret = c10::bit_cast<bit_rep>(real1) & c10::bit_cast<bit_rep>(real2);
1319     bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) & c10::bit_cast<bit_rep>(imag2);
1320     return Complex<T>(c10::bit_cast<T>(real_ret), c10::bit_cast<T>(imag_ret));
1321 }
1322 
1323 template<typename T>
1324 std::enable_if_t<!is_complex<T>::value, T>
1325 local_or(const T& val0, const T& val1) {
1326     using bit_rep = BitType<T>;
1327     bit_rep ret = c10::bit_cast<bit_rep>(val0) | c10::bit_cast<bit_rep>(val1);
1328     return c10::bit_cast<T> (ret);
1329 }
1330 
1331 template<typename T>
1332 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>>
1333 local_or(const Complex<T>& val0, const Complex<T>& val1) {
1334     using bit_rep = BitType<T>;
1335     T real1 = val0.real();
1336     T imag1 = val0.imag();
1337     T real2 = val1.real();
1338     T imag2 = val1.imag();
1339     bit_rep real_ret = c10::bit_cast<bit_rep>(real1) | c10::bit_cast<bit_rep>(real2);
1340     bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) | c10::bit_cast<bit_rep>(imag2);
1341     return Complex<T>(c10::bit_cast<T> (real_ret), c10::bit_cast<T>(imag_ret));
1342 }
1343 
1344 template<typename T>
1345 std::enable_if_t<!is_complex<T>::value, T>
1346 local_xor(const T& val0, const T& val1) {
1347     using bit_rep = BitType<T>;
1348     bit_rep ret = c10::bit_cast<bit_rep>(val0) ^ c10::bit_cast<bit_rep>(val1);
1349     return c10::bit_cast<T> (ret);
1350 }
1351 
1352 template<typename T>
1353 std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>>
1354 local_xor(const Complex<T>& val0, const Complex<T>& val1) {
1355     using bit_rep = BitType<T>;
1356     T real1 = val0.real();
1357     T imag1 = val0.imag();
1358     T real2 = val1.real();
1359     T imag2 = val1.imag();
1360     bit_rep real_ret = c10::bit_cast<bit_rep>(real1) ^ c10::bit_cast<bit_rep>(real2);
1361     bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) ^ c10::bit_cast<bit_rep>(imag2);
1362     return Complex<T>(c10::bit_cast<T> (real_ret), c10::bit_cast<T>(imag_ret));
1363 }
1364 
1365 template <typename T>
1366 T quantize_val(float scale, int64_t zero_point, float value) {
1367     int64_t qvalue;
1368     constexpr int64_t qmin = std::numeric_limits<T>::min();
1369     constexpr int64_t qmax = std::numeric_limits<T>::max();
1370     float inv_scale = 1.0f / scale;
1371     qvalue = static_cast<int64_t>(zero_point + at::native::round_impl<float>(value * inv_scale));
1372     qvalue = std::max<int64_t>(qvalue, qmin);
1373     qvalue = std::min<int64_t>(qvalue, qmax);
1374     return static_cast<T>(qvalue);
1375 }
1376 
1377 template <typename T>
1378 #if defined(TEST_AGAINST_DEFAULT)
1379 T requantize_from_int(float multiplier, int32_t zero_point, int32_t src) {
1380     auto xx = static_cast<float>(src) * multiplier;
1381     double xx2 = nearbyint(xx);
1382     int32_t quantize_down = xx2 + zero_point;
1383 #else
1384 T requantize_from_int(float multiplier, int64_t zero_point, int64_t src) {
1385     int64_t quantize_down = static_cast<int64_t>(zero_point + std::lrintf(src * multiplier));
1386 #endif
1387     constexpr int64_t min = std::numeric_limits<T>::min();
1388     constexpr int64_t max = std::numeric_limits<T>::max();
1389     auto ret = static_cast<T>(std::min<int64_t>(std::max<int64_t>(quantize_down, min), max));
1390     return ret;
1391 }
1392 
1393 template <typename T>
1394 float dequantize_val(float scale, int64_t zero_point, T value) {
1395     //when negated scale is used as addition
1396 #if defined(CHECK_WITH_FMA)
1397     float neg_p = -(zero_point * scale);
1398     float v = static_cast<float>(value);
1399     float ret = fma(v, scale, neg_p);
1400 #else
1401     float ret = (static_cast<float>(value) - zero_point) * scale;
1402 #endif
1403     return ret;
1404 }
1405 
1406 template<typename T>
1407 T relu(const T & val, const T & zero_point) {
1408     return std::max(val, zero_point);
1409 }
1410 
1411 template<typename T>
1412 T relu6(T val, T zero_point, T q_six) {
1413     return std::min<T>(std::max<T>(val, zero_point), q_six);
1414 }
1415 
1416 template<typename T>
1417 int32_t widening_subtract(T val, T b) {
1418     return static_cast<int32_t>(val) - static_cast<int32_t>(b);
1419 }
1420 
1421 //default testing case
1422 template<typename T>
1423 T getDefaultTolerance() {
1424     return static_cast<T>(0.0);
1425 }
1426 
1427 template<>
1428 float getDefaultTolerance() {
1429     return 5.e-5f;
1430 }
1431 
1432 template<>
1433 double getDefaultTolerance() {
1434     return 1.e-9;
1435 }
1436 
1437 template<typename T, int N = 1>
1438 at::vec::VecMask<T, N> create_vec_mask(uint64_t bitmask) {
1439   constexpr auto size = at::vec::Vectorized<T>::size();
1440   std::array<int, N * size> mask;
1441   for (int n = 0; n < N; n++) {
1442       for (int i = 0; i < size; i++) {
1443         mask[n * size + i] = (bitmask >> i) & 1;
1444       }
1445   }
1446   return at::vec::VecMask<T, N>::from(mask.data());
1447 }
1448 
1449 template<typename T, int N = 1>
1450 at::vec::VecMask<T, N> generate_vec_mask(int seed) {
1451   constexpr auto size = at::vec::Vectorized<T>::size();
1452   ValueGen<uint64_t> generator(0, (1ULL << size) - 1, seed);
1453   auto bitmask = generator.get();
1454   return create_vec_mask<T, N>(bitmask);
1455 }
1456 
1457 template<typename T>
1458 TestingCase<T> createDefaultUnaryTestCase(TestSeed seed = TestSeed(), bool bitwise = false, bool checkWithTolerance = false, size_t trials = 0) {
1459     using UVT = UvalueType<T>;
1460     TestingCase<T> testCase;
1461     if (!bitwise && std::is_floating_point_v<UVT>) {
1462         //for float types lets add manual ranges
1463         UVT tolerance = getDefaultTolerance<UVT>();
1464         testCase = TestingCase<T>::getBuilder()
1465             .set(bitwise, false)
1466             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-10, (UVT)10}}, checkWithTolerance, tolerance})
1467             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)10, (UVT)100 }}, checkWithTolerance, tolerance})
1468             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)100, (UVT)1000 }}, checkWithTolerance, tolerance})
1469             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-100, (UVT)-10 }}, checkWithTolerance, tolerance})
1470             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-1000, (UVT)-100 }}, checkWithTolerance, tolerance})
1471             .addDomain(CheckWithinDomains<UVT>{ {}, checkWithTolerance, tolerance})
1472             .setTrialCount(trials)
1473             .setTestSeed(seed);
1474     }
1475     else {
1476         testCase = TestingCase<T>::getBuilder()
1477             .set(bitwise, false)
1478             .addDomain(CheckWithinDomains<UVT>{})
1479             .setTrialCount(trials)
1480             .setTestSeed(seed);
1481     }
1482     return testCase;
1483 }
1484 
1485 template<typename T>
1486 TestingCase<T> createDefaultBinaryTestCase(TestSeed seed = TestSeed(), bool bitwise = false, bool checkWithTolerance = false, size_t trials = 0) {
1487     using UVT = UvalueType<T>;
1488     TestingCase<T> testCase;
1489     if (!bitwise && std::is_floating_point_v<UVT>) {
1490         //for float types lets add manual ranges
1491         UVT tolerance = getDefaultTolerance<UVT>();
1492         testCase = TestingCase<T>::getBuilder()
1493             .set(bitwise, false)
1494             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-10, (UVT)10}, { (UVT)-10, (UVT)10 }}, checkWithTolerance, tolerance})
1495             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)10, (UVT)100 }, { (UVT)-10, (UVT)100 }}, checkWithTolerance, tolerance})
1496             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)100, (UVT)1000 }, { (UVT)-100, (UVT)1000 }}, checkWithTolerance, tolerance})
1497             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-100, (UVT)-10 }, { (UVT)-100, (UVT)10 }}, checkWithTolerance, tolerance})
1498             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-1000, (UVT)-100 }, { (UVT)-1000, (UVT)100 }}, checkWithTolerance, tolerance})
1499             .addDomain(CheckWithinDomains<UVT>{ {}, checkWithTolerance, tolerance})
1500             .setTrialCount(trials)
1501             .setTestSeed(seed);
1502     }
1503     else {
1504         testCase = TestingCase<T>::getBuilder()
1505             .set(bitwise, false)
1506             .addDomain(CheckWithinDomains<UVT>{})
1507             .setTrialCount(trials)
1508             .setTestSeed(seed);
1509     }
1510     return testCase;
1511 }
1512 
1513 template<typename T>
1514 TestingCase<T> createDefaultTernaryTestCase(TestSeed seed = TestSeed(), bool bitwise = false, bool checkWithTolerance = false, size_t trials = 0) {
1515     TestingCase<T> testCase = TestingCase<T>::getBuilder()
1516         .set(bitwise, false)
1517         .addDomain(CheckWithinDomains<UvalueType<T>>{})
1518         .setTrialCount(trials)
1519         .setTestSeed(seed);
1520     return testCase;
1521 }
1522