xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/vec_test_all_types.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/test/vec_test_all_types.h>
2 #include <c10/util/irange.h>
3 namespace {
4 #if GTEST_HAS_TYPED_TEST
5     template <typename T>
6     class Memory : public ::testing::Test {};
7     template <typename T>
8     class Arithmetics : public ::testing::Test {};
9     template <typename T>
10     class Comparison : public ::testing::Test {};
11     template <typename T>
12     class Bitwise : public ::testing::Test {};
13     template <typename T>
14     class MinMax : public ::testing::Test {};
15     template <typename T>
16     class Nan : public ::testing::Test {};
17     template <typename T>
18     class Interleave : public ::testing::Test {};
19     template <typename T>
20     class SignManipulation : public ::testing::Test {};
21     template <typename T>
22     class SignManipulationHalfPrecision : public ::testing::Test {};
23     template <typename T>
24     class Rounding : public ::testing::Test {};
25     template <typename T>
26     class SqrtAndReciprocal : public ::testing::Test {};
27     template <typename T>
28     class SqrtAndReciprocalReal : public ::testing::Test {};
29     template <typename T>
30     class FractionAndRemainderReal : public ::testing::Test {};
31     template <typename T>
32     class Trigonometric : public ::testing::Test {};
33     template <typename T>
34     class ErrorFunctions : public ::testing::Test {};
35     template <typename T>
36     class Exponents : public ::testing::Test {};
37     template <typename T>
38     class Hyperbolic : public ::testing::Test {};
39     template <typename T>
40     class InverseTrigonometric : public ::testing::Test {};
41     template <typename T>
42     class InverseTrigonometricReal : public ::testing::Test {};
43     template <typename T>
44     class LGamma : public ::testing::Test {};
45     template <typename T>
46     class Logarithm : public ::testing::Test {};
47     template <typename T>
48     class LogarithmReals : public ::testing::Test {};
49     template <typename T>
50     class Pow : public ::testing::Test {};
51     template <typename T>
52     class RangeFactories : public ::testing::Test {};
53     template <typename T>
54     class BitwiseFloatsAdditional : public ::testing::Test {};
55     template <typename T>
56     class BitwiseFloatsAdditional2 : public ::testing::Test {};
57     template <typename T>
58     class RealTests : public ::testing::Test {};
59     template <typename T>
60     class ComplexTests : public ::testing::Test {};
61     template <typename T>
62     class QuantizationTests : public ::testing::Test {};
63     template <typename T>
64     class Quantization8BitWithTailTests : public ::testing::Test {};
65     template <typename T>
66     class FunctionalTests : public ::testing::Test {};
67     template <typename T>
68     class FunctionalTestsReducedFloat : public ::testing::Test {};
69     template <typename T>
70     class InfiniteTests : public ::testing::Test {};
71     template <typename T>
72     class VecConvertTests : public ::testing::Test {};
73     template <typename T>
74     class VecMaskTests : public ::testing::Test {};
75     using RealFloatTestedTypes = ::testing::Types<vfloat, vdouble>;
76     using FloatTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vcomplexDbl>;
77     using ALLTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vlong, vint, vshort, vqint8, vquint8, vqint>;
78     using QuantTestedTypes = ::testing::Types<vqint8, vquint8, vqint>;
79 #if (defined(CPU_CAPABILITY_AVX2) ||  defined(CPU_CAPABILITY_AVX512))  && !defined(_MSC_VER)
80     using Quantization8BitWithTailTestedTypes =
81         ::testing::Types<vqint8, vquint8>;
82 #endif
83     using RealFloatIntTestedTypes = ::testing::Types<vfloat, vdouble, vlong, vint, vshort>;
84     using FloatIntTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vcomplexDbl, vlong, vint, vshort>;
85     using ComplexTypes = ::testing::Types<vcomplex, vcomplexDbl>;
86     using ReducedFloatTestedTypes = ::testing::Types<vBFloat16, vHalf>;
87     TYPED_TEST_SUITE(Memory, ALLTestedTypes);
88     TYPED_TEST_SUITE(Arithmetics, FloatIntTestedTypes);
89     TYPED_TEST_SUITE(Comparison, RealFloatIntTestedTypes);
90     TYPED_TEST_SUITE(Bitwise, FloatIntTestedTypes);
91     TYPED_TEST_SUITE(MinMax, RealFloatIntTestedTypes);
92     TYPED_TEST_SUITE(Nan, RealFloatTestedTypes);
93     TYPED_TEST_SUITE(Interleave, RealFloatIntTestedTypes);
94     TYPED_TEST_SUITE(SignManipulation, FloatIntTestedTypes);
95     TYPED_TEST_SUITE(SignManipulationHalfPrecision, ReducedFloatTestedTypes);
96     TYPED_TEST_SUITE(Rounding, RealFloatTestedTypes);
97     TYPED_TEST_SUITE(SqrtAndReciprocal, FloatTestedTypes);
98     TYPED_TEST_SUITE(SqrtAndReciprocalReal, RealFloatTestedTypes);
99     TYPED_TEST_SUITE(FractionAndRemainderReal, RealFloatTestedTypes);
100     TYPED_TEST_SUITE(Trigonometric, RealFloatTestedTypes);
101     TYPED_TEST_SUITE(ErrorFunctions, RealFloatTestedTypes);
102     TYPED_TEST_SUITE(Exponents, RealFloatTestedTypes);
103     TYPED_TEST_SUITE(Hyperbolic, RealFloatTestedTypes);
104     TYPED_TEST_SUITE(InverseTrigonometricReal, RealFloatTestedTypes);
105     TYPED_TEST_SUITE(InverseTrigonometric, FloatTestedTypes);
106     TYPED_TEST_SUITE(LGamma, RealFloatTestedTypes);
107     TYPED_TEST_SUITE(Logarithm, FloatTestedTypes);
108     TYPED_TEST_SUITE(LogarithmReals, RealFloatTestedTypes);
109     TYPED_TEST_SUITE(Pow, RealFloatTestedTypes);
110     TYPED_TEST_SUITE(RealTests, RealFloatTestedTypes);
111     TYPED_TEST_SUITE(RangeFactories, FloatIntTestedTypes);
112     TYPED_TEST_SUITE(BitwiseFloatsAdditional, RealFloatTestedTypes);
113     TYPED_TEST_SUITE(BitwiseFloatsAdditional2, FloatTestedTypes);
114     TYPED_TEST_SUITE(QuantizationTests, QuantTestedTypes);
115     TYPED_TEST_SUITE(InfiniteTests, RealFloatTestedTypes);
116 #if (defined(CPU_CAPABILITY_AVX2) ||  defined(CPU_CAPABILITY_AVX512))  && !defined(_MSC_VER)
117     TYPED_TEST_SUITE(
118         Quantization8BitWithTailTests,
119         Quantization8BitWithTailTestedTypes);
120 #endif
121     TYPED_TEST_SUITE(FunctionalTests, RealFloatIntTestedTypes);
122     TYPED_TEST_SUITE(FunctionalTestsReducedFloat, ReducedFloatTestedTypes);
123     TYPED_TEST_SUITE(VecConvertTests, RealFloatIntTestedTypes);
124     TYPED_TEST_SUITE(VecMaskTests, RealFloatIntTestedTypes);
TYPED_TEST(Memory,UnAlignedLoadStore)125     TYPED_TEST(Memory, UnAlignedLoadStore) {
126         using vec = TypeParam;
127         using VT = ValueType<TypeParam>;
128         constexpr size_t b_size = vec::size() * sizeof(VT);
129         // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
130         CACHE_ALIGN unsigned char ref_storage[128 * b_size];
131         // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
132         CACHE_ALIGN unsigned char storage[128 * b_size];
133         auto seed = TestSeed();
134         ValueGen<unsigned char> generator(seed);
135         for (auto& x : ref_storage) {
136             x = generator.get();
137         }
138         // test counted load stores
139 #if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR)
140         for (int i = 1; i < 2 * vec::size(); i++) {
141             vec v = vec::loadu(ref_storage, i);
142             v.store(storage);
143             size_t count = std::min(i * sizeof(VT), b_size);
144             bool cmp = (std::memcmp(ref_storage, storage, count) == 0);
145             ASSERT_TRUE(cmp) << "Failure Details:\nTest Seed to reproduce: " << seed
146                 << "\nCount: " << i;
147             if (::testing::Test::HasFailure()) {
148                 break;
149             }
150             // clear storage
151             std::memset(storage, 0, b_size);
152         }
153 #endif
154         // testing unaligned load store
155         for (size_t offset = 0; offset < b_size; offset += 1) {
156             unsigned char* p1 = ref_storage + offset;
157             unsigned char* p2 = storage + offset;
158             for (; p1 + b_size <= std::end(ref_storage); p1 += b_size, p2 += b_size) {
159                 vec v = vec::loadu(p1);
160                 v.store(p2);
161             }
162             size_t written = p1 - ref_storage - offset;
163             bool cmp = (std::memcmp(ref_storage + offset, storage + offset, written) == 0);
164             ASSERT_TRUE(cmp) << "Failure Details:\nTest Seed to reproduce: " << seed
165                 << "\nMismatch at unaligned offset: " << offset;
166             if (::testing::Test::HasFailure()) {
167                 break;
168             }
169             // clear storage
170             std::memset(storage, 0, sizeof storage);
171         }
172     }
TYPED_TEST(SignManipulation,Absolute)173     TYPED_TEST(SignManipulation, Absolute) {
174         using vec = TypeParam;
175         bool checkRelativeErr = is_complex<ValueType<TypeParam>>();
176         test_unary<vec>(
177             NAME_INFO(absolute), RESOLVE_OVERLOAD(local_abs),
178             [](vec v) { return v.abs(); },
179             createDefaultUnaryTestCase<vec>(TestSeed(), false, checkRelativeErr),
180             RESOLVE_OVERLOAD(filter_int_minimum));
181     }
TYPED_TEST(SignManipulation,Negate)182     TYPED_TEST(SignManipulation, Negate) {
183         using vec = TypeParam;
184         // negate overflows for minimum on int and long
185         test_unary<vec>(
186             NAME_INFO(negate), std::negate<ValueType<vec>>(),
187             [](vec v) { return v.neg(); },
188             createDefaultUnaryTestCase<vec>(TestSeed()),
189             RESOLVE_OVERLOAD(filter_int_minimum));
190     }
TYPED_TEST(SignManipulationHalfPrecision,AbsNegate)191     TYPED_TEST(SignManipulationHalfPrecision, AbsNegate) {
192       typedef enum  {
193         ABS,
194         NEGATE
195       } SignOpType;
196       using vec = TypeParam;
197       using VT = UholdType<TypeParam>;
198       using RT = float; // reference
199       float atol = 0.01f;
200       float rtol = 0.01f;
201 
202       auto cmp = [&](RT ref, VT val) {
203         return std::abs(ref - RT(val)) <= atol + rtol * std::abs(val);
204       };
205 
206 #define APPLY_FN_AND_STORE(VEC_TYPE)                            \
207       [&](SignOpType op_type, VEC_TYPE& x_fp_vec, void *x_fp) { \
208         if (op_type == SignOpType::NEGATE) {                    \
209           x_fp_vec.neg().store(x_fp);                           \
210         } else {                                                \
211           x_fp_vec.abs().store(x_fp);                           \
212         }                                                       \
213       }
214 
215       auto apply_fn_and_store_ref = APPLY_FN_AND_STORE(vfloat);
216       auto apply_fn_and_store_half = APPLY_FN_AND_STORE(vec);
217 
218       auto half_precision_ut = [&](SignOpType op_type) {
219         constexpr auto N = vec::size();
220         CACHE_ALIGN RT x_fp[N];
221         CACHE_ALIGN VT x_hp[N];
222         auto seed = TestSeed();
223         ValueGen<RT> generator(RT(-1), RT(1), seed);
224         for (const auto i : c10::irange(N)) {
225             x_fp[i] = generator.get();
226             x_hp[i] = VT(x_fp[i]);
227         }
228         auto x_fp_vec = vfloat::loadu(x_fp);
229         apply_fn_and_store_ref(op_type, x_fp_vec, x_fp);
230         x_fp_vec = vfloat::loadu(x_fp + vfloat::size());
231         apply_fn_and_store_ref(op_type, x_fp_vec, x_fp + vfloat::size());
232 
233         auto x_hp_vec = vec::loadu(x_hp);
234         apply_fn_and_store_half(op_type, x_hp_vec, x_hp);
235 
236         for (int64_t len = 0; len < N; len++) {
237             ASSERT_TRUE(cmp(x_fp[len], x_hp[len])) << "Failure Details:\nTest Seed to reproduce: " << seed
238                 << "\nabs/negate, Length: " << len << "; fp32: " << x_fp[len] << "; bf16/fp16: " << RT(x_hp[len]);
239         }
240       };
241 
242       half_precision_ut(SignOpType::ABS);
243       half_precision_ut(SignOpType::NEGATE);
244     }
TYPED_TEST(Rounding,Round)245     TYPED_TEST(Rounding, Round) {
246         using vec = TypeParam;
247         using UVT = UvalueType<TypeParam>;
248         UVT case1 = -658.5f;
249         UVT exp1 = -658.f;
250         UVT case2 = -657.5f;
251         UVT exp2 = -658.f;
252         auto test_case = TestingCase<vec>::getBuilder()
253             .addDomain(CheckWithinDomains<UVT>{ { {-1000, 1000}} })
254             .addCustom({ {case1},exp1 })
255             .addCustom({ {case2},exp2 })
256             .setTrialCount(64000)
257             .setTestSeed(TestSeed());
258         test_unary<vec>(
259             NAME_INFO(round),
260             RESOLVE_OVERLOAD(at::native::round_impl),
261             [](vec v) { return v.round(); },
262             test_case);
263     }
TYPED_TEST(Rounding,Ceil)264     TYPED_TEST(Rounding, Ceil) {
265         using vec = TypeParam;
266         test_unary<vec>(
267             NAME_INFO(ceil),
268             RESOLVE_OVERLOAD(std::ceil),
269             [](vec v) { return v.ceil(); },
270             createDefaultUnaryTestCase<vec>(TestSeed()));
271     }
TYPED_TEST(Rounding,Floor)272     TYPED_TEST(Rounding, Floor) {
273         using vec = TypeParam;
274         test_unary<vec>(
275             NAME_INFO(floor),
276             RESOLVE_OVERLOAD(std::floor),
277             [](vec v) { return v.floor(); },
278             createDefaultUnaryTestCase<vec>(TestSeed()));
279     }
TYPED_TEST(Rounding,Trunc)280     TYPED_TEST(Rounding, Trunc) {
281         using vec = TypeParam;
282         test_unary<vec>(
283             NAME_INFO(trunc),
284             RESOLVE_OVERLOAD(std::trunc),
285             [](vec v) { return v.trunc(); },
286             createDefaultUnaryTestCase<vec>(TestSeed()));
287     }
TYPED_TEST(SqrtAndReciprocal,Sqrt)288     TYPED_TEST(SqrtAndReciprocal, Sqrt) {
289         using vec = TypeParam;
290         test_unary<vec>(
291             NAME_INFO(sqrt),
292             RESOLVE_OVERLOAD(local_sqrt),
293             [](vec v) { return v.sqrt(); },
294             createDefaultUnaryTestCase<vec>(TestSeed(), false, true));
295     }
TYPED_TEST(SqrtAndReciprocalReal,RSqrt)296     TYPED_TEST(SqrtAndReciprocalReal, RSqrt) {
297         using vec = TypeParam;
298         test_unary<vec>(
299             NAME_INFO(rsqrt),
300             rsqrt<ValueType<vec>>,
301             [](vec v) { return v.rsqrt(); },
302             createDefaultUnaryTestCase<vec>(TestSeed()),
303             RESOLVE_OVERLOAD(filter_zero));
304     }
TYPED_TEST(SqrtAndReciprocalReal,Reciprocal)305     TYPED_TEST(SqrtAndReciprocalReal, Reciprocal) {
306         using vec = TypeParam;
307         test_unary<vec>(
308             NAME_INFO(reciprocal),
309             reciprocal<ValueType<vec>>,
310             [](vec v) { return v.reciprocal(); },
311             createDefaultUnaryTestCase<vec>(TestSeed()),
312             RESOLVE_OVERLOAD(filter_zero));
313     }
TYPED_TEST(FractionAndRemainderReal,Frac)314     TYPED_TEST(FractionAndRemainderReal, Frac) {
315       using vec = TypeParam;
316       test_unary<vec>(
317           NAME_INFO(frac),
318           RESOLVE_OVERLOAD(frac),
319           [](vec v) { return v.frac(); },
320           createDefaultUnaryTestCase<vec>(TestSeed(), false, true));
321     }
TYPED_TEST(FractionAndRemainderReal,Fmod)322     TYPED_TEST(FractionAndRemainderReal, Fmod) {
323       using vec = TypeParam;
324       test_binary<vec>(
325           NAME_INFO(fmod),
326           RESOLVE_OVERLOAD(std::fmod),
327           [](vec v0, vec v1) { return v0.fmod(v1); },
328           createDefaultBinaryTestCase<vec>(TestSeed()),
329           RESOLVE_OVERLOAD(filter_fmod));
330     }
TYPED_TEST(Trigonometric,Sin)331     TYPED_TEST(Trigonometric, Sin) {
332         using vec = TypeParam;
333         using UVT = UvalueType<TypeParam>;
334         auto test_case = TestingCase<vec>::getBuilder()
335             .addDomain(CheckWithinDomains<UVT>{ { {-4096, 4096}}, true, 1.2e-7f})
336             .addDomain(CheckWithinDomains<UVT>{ { {-8192, 8192}}, true, 3.0e-7f})
337             .setTrialCount(8000)
338             .setTestSeed(TestSeed());
339         test_unary<vec>(
340             NAME_INFO(sin),
341             RESOLVE_OVERLOAD(std::sin),
342             [](vec v) { return v.sin(); },
343             test_case);
344     }
TYPED_TEST(Trigonometric,Cos)345     TYPED_TEST(Trigonometric, Cos) {
346         using vec = TypeParam;
347         using UVT = UvalueType<TypeParam>;
348         auto test_case = TestingCase<vec>::getBuilder()
349             .addDomain(CheckWithinDomains<UVT>{ { {-4096, 4096}}, true, 1.2e-7f})
350             .addDomain(CheckWithinDomains<UVT>{ { {-8192, 8192}}, true, 3.0e-7f})
351             .setTrialCount(8000)
352             .setTestSeed(TestSeed());
353         test_unary<vec>(
354             NAME_INFO(cos),
355             RESOLVE_OVERLOAD(std::cos),
356             [](vec v) { return v.cos(); },
357             test_case);
358     }
TYPED_TEST(Trigonometric,Tan)359     TYPED_TEST(Trigonometric, Tan) {
360         using vec = TypeParam;
361         test_unary<vec>(
362             NAME_INFO(tan),
363             RESOLVE_OVERLOAD(std::tan),
364             [](vec v) { return v.tan(); },
365             createDefaultUnaryTestCase<vec>(TestSeed()));
366     }
TYPED_TEST(Hyperbolic,Tanh)367     TYPED_TEST(Hyperbolic, Tanh) {
368         using vec = TypeParam;
369         test_unary<vec>(
370             NAME_INFO(tanH),
371             RESOLVE_OVERLOAD(std::tanh),
372             [](vec v) { return v.tanh(); },
373             createDefaultUnaryTestCase<vec>(TestSeed()));
374     }
TYPED_TEST(Hyperbolic,Sinh)375     TYPED_TEST(Hyperbolic, Sinh) {
376         using vec = TypeParam;
377         using UVT = UvalueType<TypeParam>;
378         auto test_case =
379             TestingCase<vec>::getBuilder()
380             .addDomain(CheckWithinDomains<UVT>{ { {-88, 88}}, true, getDefaultTolerance<UVT>()})
381             .setTrialCount(65536)
382             .setTestSeed(TestSeed());
383         test_unary<vec>(
384             NAME_INFO(sinh),
385             RESOLVE_OVERLOAD(std::sinh),
386             [](vec v) { return v.sinh(); },
387             test_case);
388     }
TYPED_TEST(Hyperbolic,Cosh)389     TYPED_TEST(Hyperbolic, Cosh) {
390         using vec = TypeParam;
391         using UVT = UvalueType<TypeParam>;
392         auto test_case =
393             TestingCase<vec>::getBuilder()
394             .addDomain(CheckWithinDomains<UVT>{ { {-88, 88}}, true, getDefaultTolerance<UVT>()})
395             .setTrialCount(65536)
396             .setTestSeed(TestSeed());
397         test_unary<vec>(
398             NAME_INFO(cosh),
399             RESOLVE_OVERLOAD(std::cosh),
400             [](vec v) { return v.cosh(); },
401             test_case);
402     }
TYPED_TEST(InverseTrigonometric,Asin)403     TYPED_TEST(InverseTrigonometric, Asin) {
404         using vec = TypeParam;
405         using UVT = UvalueType<TypeParam>;
406         bool checkRelativeErr = is_complex<ValueType<TypeParam>>();
407         auto test_case =
408             TestingCase<vec>::getBuilder()
409             .addDomain(CheckWithinDomains<UVT>{ { {-10, 10}}, checkRelativeErr, getDefaultTolerance<UVT>() })
410             .setTrialCount(125536)
411             .setTestSeed(TestSeed());
412         test_unary<vec>(
413             NAME_INFO(asin),
414             RESOLVE_OVERLOAD(local_asin),
415             [](vec v) { return v.asin(); },
416             test_case);
417     }
TYPED_TEST(InverseTrigonometric,ACos)418     TYPED_TEST(InverseTrigonometric, ACos) {
419         using vec = TypeParam;
420         using UVT = UvalueType<TypeParam>;
421         bool checkRelativeErr = is_complex<ValueType<TypeParam>>();
422         auto test_case =
423             TestingCase<vec>::getBuilder()
424             .addDomain(CheckWithinDomains<UVT>{ { {-10, 10}}, checkRelativeErr, getDefaultTolerance<UVT>() })
425             .setTrialCount(125536)
426             .setTestSeed(TestSeed());
427         test_unary<vec>(
428             NAME_INFO(acos),
429             RESOLVE_OVERLOAD(local_acos),
430             [](vec v) { return v.acos(); },
431             test_case);
432     }
TYPED_TEST(InverseTrigonometric,ATan)433     TYPED_TEST(InverseTrigonometric, ATan) {
434         bool checkRelativeErr = is_complex<ValueType<TypeParam>>();
435         using vec = TypeParam;
436         using UVT = UvalueType<TypeParam>;
437         auto test_case =
438             TestingCase<vec>::getBuilder()
439             .addDomain(CheckWithinDomains<UVT>{ { {-100, 100}}, checkRelativeErr, getDefaultTolerance<UVT>()})
440             .setTrialCount(65536)
441             .setTestSeed(TestSeed());
442         test_unary<vec>(
443             NAME_INFO(atan),
444             RESOLVE_OVERLOAD(std::atan),
445             [](vec v) { return v.atan(); },
446             test_case,
447             RESOLVE_OVERLOAD(filter_zero));
448     }
TYPED_TEST(Logarithm,Log)449     TYPED_TEST(Logarithm, Log) {
450         using vec = TypeParam;
451         test_unary<vec>(
452             NAME_INFO(log),
453             RESOLVE_OVERLOAD(std::log),
454             [](const vec& v) { return v.log(); },
455             createDefaultUnaryTestCase<vec>(TestSeed()));
456     }
TYPED_TEST(LogarithmReals,Log2)457     TYPED_TEST(LogarithmReals, Log2) {
458         using vec = TypeParam;
459         test_unary<vec>(
460             NAME_INFO(log2),
461             RESOLVE_OVERLOAD(local_log2),
462             [](const vec& v) { return v.log2(); },
463             createDefaultUnaryTestCase<vec>(TestSeed()));
464     }
TYPED_TEST(Logarithm,Log10)465     TYPED_TEST(Logarithm, Log10) {
466         using vec = TypeParam;
467         test_unary<vec>(
468             NAME_INFO(log10),
469             RESOLVE_OVERLOAD(std::log10),
470             [](const vec& v) { return v.log10(); },
471             createDefaultUnaryTestCase<vec>(TestSeed()));
472     }
TYPED_TEST(LogarithmReals,Log1p)473     TYPED_TEST(LogarithmReals, Log1p) {
474         using vec = TypeParam;
475         using UVT = UvalueType<TypeParam>;
476         auto test_case =
477             TestingCase<vec>::getBuilder()
478             .addDomain(CheckWithinDomains<UVT>{ { {-1, 1000}}, true, getDefaultTolerance<UVT>()})
479             .addDomain(CheckWithinDomains<UVT>{ { {1000, 1.e+30}}, true, getDefaultTolerance<UVT>()})
480             .setTrialCount(65536)
481             .setTestSeed(TestSeed());
482         test_unary<vec>(
483             NAME_INFO(log1p),
484             RESOLVE_OVERLOAD(std::log1p),
485             [](const vec& v) { return v.log1p(); },
486             test_case);
487     }
TYPED_TEST(Exponents,Exp)488     TYPED_TEST(Exponents, Exp) {
489         using vec = TypeParam;
490         test_unary<vec>(
491             NAME_INFO(exp),
492             RESOLVE_OVERLOAD(std::exp),
493             [](const vec& v) { return v.exp(); },
494             createDefaultUnaryTestCase<vec>(TestSeed()));
495     }
TYPED_TEST(Exponents,Expm1)496     TYPED_TEST(Exponents, Expm1) {
497         using vec = TypeParam;
498         test_unary<vec>(
499             NAME_INFO(expm1),
500             RESOLVE_OVERLOAD(std::expm1),
501             [](const vec& v) { return v.expm1(); },
502             createDefaultUnaryTestCase<vec>(TestSeed(), false, true));
503     }
TYPED_TEST(ErrorFunctions,Erf)504     TYPED_TEST(ErrorFunctions, Erf) {
505         using vec = TypeParam;
506         test_unary<vec>(
507             NAME_INFO(erf),
508             RESOLVE_OVERLOAD(std::erf),
509             [](const vec& v) { return v.erf(); },
510             createDefaultUnaryTestCase<vec>(TestSeed(), false, true));
511     }
TYPED_TEST(ErrorFunctions,Erfc)512     TYPED_TEST(ErrorFunctions, Erfc) {
513         using vec = TypeParam;
514         test_unary<vec>(
515             NAME_INFO(erfc),
516             RESOLVE_OVERLOAD(std::erfc),
517             [](const vec& v) { return v.erfc(); },
518             createDefaultUnaryTestCase<vec>(TestSeed(), false, true));
519     }
TYPED_TEST(ErrorFunctions,Erfinv)520     TYPED_TEST(ErrorFunctions, Erfinv) {
521         using vec = TypeParam;
522         test_unary<vec>(
523             NAME_INFO(erfinv),
524             RESOLVE_OVERLOAD(calc_erfinv),
525             [](const vec& v) { return v.erfinv(); },
526             createDefaultUnaryTestCase<vec>(TestSeed(), false, true));
527     }
TYPED_TEST(Nan,IsNan)528     TYPED_TEST(Nan, IsNan) {
529         using vec = TypeParam;
530         using VT = ValueType<TypeParam>;
531         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
532         CACHE_ALIGN VT test_vals[vec::size()];
533         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
534         CACHE_ALIGN VT expected_vals[vec::size()];
535         auto vals = 1 << (vec::size());
536         for (const auto val : c10::irange(vals)) {
537           for (int i = 0; i < vec::size(); ++i) {
538             if (val & (1 << i)) {
539               test_vals[i] = std::numeric_limits<VT>::quiet_NaN();
540               // All bits are set to 1 if true, otherwise 0.
541               // same rule as at::Vectorized<T>::binary_pred.
542               std::memset(static_cast<void*>(&expected_vals[i]), 0xFF, sizeof(VT));
543             } else {
544               test_vals[i] = (VT)0.123;
545               std::memset(static_cast<void*>(&expected_vals[i]), 0, sizeof(VT));
546             }
547           }
548           vec actual = vec::loadu(test_vals).isnan();
549           vec expected = vec::loadu(expected_vals);
550           AssertVectorized<vec>(NAME_INFO(isnan), expected, actual).check();
551         }
552     }
TYPED_TEST(LGamma,LGamma)553     TYPED_TEST(LGamma, LGamma) {
554         using vec = TypeParam;
555         using UVT = UvalueType<vec>;
556         UVT tolerance = getDefaultTolerance<UVT>();
557         // double: 2e+305  float: 4e+36 (https://sleef.org/purec.xhtml#eg)
558         UVT maxCorrect = std::is_same_v<UVT, float> ? (UVT)4e+36 : (UVT)2e+305;
559         TestingCase<vec> testCase = TestingCase<vec>::getBuilder()
560             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)-100, (UVT)0}}, true, tolerance})
561             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)0, (UVT)1000 }}, true, tolerance})
562             .addDomain(CheckWithinDomains<UVT>{ { {(UVT)1000, maxCorrect }}, true, tolerance})
563             .setTestSeed(TestSeed());
564         test_unary<vec>(
565             NAME_INFO(lgamma),
566             RESOLVE_OVERLOAD(std::lgamma),
567             [](vec v) { return v.lgamma(); },
568             testCase);
569     }
TYPED_TEST(InverseTrigonometricReal,ATan2)570     TYPED_TEST(InverseTrigonometricReal, ATan2) {
571         using vec = TypeParam;
572         test_binary<vec>(
573             NAME_INFO(atan2),
574             RESOLVE_OVERLOAD(std::atan2),
575             [](vec v0, vec v1) {
576                 return v0.atan2(v1);
577             },
578             createDefaultBinaryTestCase<vec>(TestSeed()));
579     }
TYPED_TEST(Pow,Pow)580     TYPED_TEST(Pow, Pow) {
581         using vec = TypeParam;
582         test_binary<vec>(
583             NAME_INFO(pow),
584             RESOLVE_OVERLOAD(std::pow),
585             [](vec v0, vec v1) { return v0.pow(v1); },
586             createDefaultBinaryTestCase<vec>(TestSeed(), false, true));
587     }
TYPED_TEST(RealTests,Hypot)588     TYPED_TEST(RealTests, Hypot) {
589         using vec = TypeParam;
590         test_binary<vec>(
591             NAME_INFO(hypot),
592             RESOLVE_OVERLOAD(std::hypot),
593             [](vec v0, vec v1) { return v0.hypot(v1); },
594             createDefaultBinaryTestCase<vec>(TestSeed(), false, true));
595     }
TYPED_TEST(RealTests,NextAfter)596     TYPED_TEST(RealTests, NextAfter) {
597         using vec = TypeParam;
598         test_binary<vec>(
599             NAME_INFO(nextafter),
600             RESOLVE_OVERLOAD(std::nextafter),
601             [](vec v0, vec v1) { return v0.nextafter(v1); },
602             createDefaultBinaryTestCase<vec>(TestSeed(), false, true));
603     }
TYPED_TEST(Interleave,Interleave)604     TYPED_TEST(Interleave, Interleave) {
605         using vec = TypeParam;
606         using VT = ValueType<TypeParam>;
607         constexpr auto N = vec::size() * 2LL;
608         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
609         CACHE_ALIGN VT vals[N];
610         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
611         CACHE_ALIGN VT interleaved[N];
612         auto seed = TestSeed();
613         ValueGen<VT> generator(seed);
614         for (VT& v : vals) {
615             v = generator.get();
616         }
617         copy_interleave(vals, interleaved);
618         auto a = vec::loadu(vals);
619         auto b = vec::loadu(vals + vec::size());
620         auto cc = interleave2(a, b);
621         AssertVectorized<vec>(NAME_INFO(Interleave FirstHalf), std::get<0>(cc), vec::loadu(interleaved)).check(true);
622         AssertVectorized<vec>(NAME_INFO(Interleave SecondHalf), std::get<1>(cc), vec::loadu(interleaved + vec::size())).check(true);
623     }
TYPED_TEST(Interleave,DeInterleave)624     TYPED_TEST(Interleave, DeInterleave) {
625         using vec = TypeParam;
626         using VT = ValueType<TypeParam>;
627         constexpr auto N = vec::size() * 2LL;
628         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
629         CACHE_ALIGN VT vals[N];
630         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
631         CACHE_ALIGN VT interleaved[N];
632         auto seed = TestSeed();
633         ValueGen<VT> generator(seed);
634         for (VT& v : vals) {
635             v = generator.get();
636         }
637         copy_interleave(vals, interleaved);
638         // test interleaved with vals this time
639         auto a = vec::loadu(interleaved);
640         auto b = vec::loadu(interleaved + vec::size());
641         auto cc = deinterleave2(a, b);
642         AssertVectorized<vec>(NAME_INFO(DeInterleave FirstHalf), std::get<0>(cc), vec::loadu(vals)).check(true);
643         AssertVectorized<vec>(NAME_INFO(DeInterleave SecondHalf), std::get<1>(cc), vec::loadu(vals + vec::size())).check(true);
644     }
TYPED_TEST(Arithmetics,Plus)645     TYPED_TEST(Arithmetics, Plus) {
646         using vec = TypeParam;
647         using VT = ValueType<TypeParam>;
648         test_binary<vec>(
649             NAME_INFO(plus),
650             std::plus<VT>(),
651             [](const vec& v0, const vec& v1) -> vec {
652                 return v0 + v1;
653             },
654             createDefaultBinaryTestCase<vec>(TestSeed()),
655                 RESOLVE_OVERLOAD(filter_add_overflow));
656     }
TYPED_TEST(Arithmetics,Minus)657     TYPED_TEST(Arithmetics, Minus) {
658         using vec = TypeParam;
659         using VT = ValueType<TypeParam>;
660         test_binary<vec>(
661             NAME_INFO(minus),
662             std::minus<VT>(),
663             [](const vec& v0, const vec& v1) -> vec {
664                 return v0 - v1;
665             },
666             createDefaultBinaryTestCase<vec>(TestSeed()),
667                 RESOLVE_OVERLOAD(filter_sub_overflow));
668     }
TYPED_TEST(Arithmetics,Multiplication)669     TYPED_TEST(Arithmetics, Multiplication) {
670         using vec = TypeParam;
671         test_binary<vec>(
672             NAME_INFO(mult),
673             RESOLVE_OVERLOAD(local_multiply),
674             [](const vec& v0, const vec& v1) { return v0 * v1; },
675             createDefaultBinaryTestCase<vec>(TestSeed(), false, true),
676             RESOLVE_OVERLOAD(filter_mult_overflow));
677     }
TYPED_TEST(Arithmetics,Division)678     TYPED_TEST(Arithmetics, Division) {
679         using vec = TypeParam;
680         TestSeed seed;
681         test_binary<vec>(
682             NAME_INFO(division),
683             RESOLVE_OVERLOAD(local_division),
684             [](const vec& v0, const vec& v1) { return v0 / v1; },
685             createDefaultBinaryTestCase<vec>(seed),
686             RESOLVE_OVERLOAD(filter_div_ub));
687     }
TYPED_TEST(Bitwise,BitAnd)688     TYPED_TEST(Bitwise, BitAnd) {
689         using vec = TypeParam;
690         test_binary<vec>(
691             NAME_INFO(bit_and),
692             RESOLVE_OVERLOAD(local_and),
693             [](const vec& v0, const vec& v1) { return v0 & v1; },
694             createDefaultBinaryTestCase<vec>(TestSeed(), true));
695     }
TYPED_TEST(Bitwise,BitOr)696     TYPED_TEST(Bitwise, BitOr) {
697         using vec = TypeParam;
698         test_binary<vec>(
699             NAME_INFO(bit_or),
700             RESOLVE_OVERLOAD(local_or),
701             [](const vec& v0, const vec& v1) { return v0 | v1; },
702             createDefaultBinaryTestCase<vec>(TestSeed(), true));
703     }
TYPED_TEST(Bitwise,BitXor)704     TYPED_TEST(Bitwise, BitXor) {
705         using vec = TypeParam;
706         test_binary<vec>(
707             NAME_INFO(bit_xor),
708             RESOLVE_OVERLOAD(local_xor),
709             [](const vec& v0, const vec& v1) { return v0 ^ v1; },
710             createDefaultBinaryTestCase<vec>(TestSeed(), true));
711     }
TYPED_TEST(Comparison,Equal)712     TYPED_TEST(Comparison, Equal) {
713         using vec = TypeParam;
714         using VT = ValueType<TypeParam>;
715         test_binary<vec>(
716             NAME_INFO(== ),
717             [](const VT& v1, const VT& v2) {return func_cmp(std::equal_to<VT>(), v1, v2); },
718             [](const vec& v0, const vec& v1) { return v0 == v1; },
719             createDefaultBinaryTestCase<vec>(TestSeed(), true));
720     }
TYPED_TEST(Comparison,NotEqual)721     TYPED_TEST(Comparison, NotEqual) {
722         using vec = TypeParam;
723         using VT = ValueType<TypeParam>;
724         test_binary<vec>(
725             NAME_INFO(!= ),
726             [](const VT& v1, const VT& v2) {return func_cmp(std::not_equal_to<VT>(), v1, v2); },
727             [](const vec& v0, const vec& v1) { return v0 != v1; },
728             createDefaultBinaryTestCase<vec>(TestSeed(), true));
729     }
TYPED_TEST(Comparison,Greater)730     TYPED_TEST(Comparison, Greater) {
731         using vec = TypeParam;
732         using VT = ValueType<TypeParam>;
733         test_binary<vec>(
734             NAME_INFO(> ),
735             [](const VT& v1, const VT& v2) {return func_cmp(std::greater<VT>(), v1, v2); },
736             [](const vec& v0, const vec& v1) { return v0 > v1; },
737             createDefaultBinaryTestCase<vec>(TestSeed(), true));
738     }
TYPED_TEST(Comparison,Less)739     TYPED_TEST(Comparison, Less) {
740         using vec = TypeParam;
741         using VT = ValueType<TypeParam>;
742         test_binary<vec>(
743             NAME_INFO(< ),
744             [](const VT& v1, const VT& v2) {return func_cmp(std::less<VT>(), v1, v2); },
745             [](const vec& v0, const vec& v1) { return v0 < v1; },
746             createDefaultBinaryTestCase<vec>(TestSeed(), true));
747     }
TYPED_TEST(Comparison,GreaterEqual)748     TYPED_TEST(Comparison, GreaterEqual) {
749         using vec = TypeParam;
750         using VT = ValueType<TypeParam>;
751         test_binary<vec>(
752             NAME_INFO(>= ),
753             [](const VT& v1, const VT& v2) {return func_cmp(std::greater_equal<VT>(), v1, v2); },
754             [](const vec& v0, const vec& v1) { return v0 >= v1; },
755             createDefaultBinaryTestCase<vec>(TestSeed(), true));
756     }
TYPED_TEST(Comparison,LessEqual)757     TYPED_TEST(Comparison, LessEqual) {
758         using vec = TypeParam;
759         using VT = ValueType<TypeParam>;
760         test_binary<vec>(
761             NAME_INFO(<= ),
762             [](const VT& v1, const VT& v2) {return func_cmp(std::less_equal<VT>(), v1, v2); },
763             [](const vec& v0, const vec& v1) { return v0 <= v1; },
764             createDefaultBinaryTestCase<vec>(TestSeed(), true));
765     }
TYPED_TEST(MinMax,Minimum)766     TYPED_TEST(MinMax, Minimum) {
767         using vec = TypeParam;
768         using VT = ValueType<TypeParam>;
769         test_binary<vec>(
770             NAME_INFO(minimum),
771             minimum<VT>,
772             [](const vec& v0, const vec& v1) {
773                 return minimum(v0, v1);
774             },
775             createDefaultBinaryTestCase<vec>(TestSeed()));
776     }
TYPED_TEST(MinMax,Maximum)777     TYPED_TEST(MinMax, Maximum) {
778         using vec = TypeParam;
779         using VT = ValueType<TypeParam>;
780         test_binary<vec>(
781             NAME_INFO(maximum),
782             maximum<VT>,
783             [](const vec& v0, const vec& v1) {
784                 return maximum(v0, v1);
785             },
786             createDefaultBinaryTestCase<vec>(TestSeed()));
787     }
TYPED_TEST(MinMax,ClampMin)788     TYPED_TEST(MinMax, ClampMin) {
789         using vec = TypeParam;
790         using VT = ValueType<TypeParam>;
791         test_binary<vec>(
792             NAME_INFO(clamp min),
793             clamp_min<VT>,
794             [](const vec& v0, const vec& v1) {
795                 return clamp_min(v0, v1);
796             },
797             createDefaultBinaryTestCase<vec>(TestSeed()));
798     }
TYPED_TEST(MinMax,ClampMax)799     TYPED_TEST(MinMax, ClampMax) {
800         using vec = TypeParam;
801         using VT = ValueType<TypeParam>;
802         test_binary<vec>(
803             NAME_INFO(clamp max),
804             clamp_max<VT>,
805             [](const vec& v0, const vec& v1) {
806                 return clamp_max(v0, v1);
807             },
808             createDefaultBinaryTestCase<vec>(TestSeed()));
809     }
TYPED_TEST(MinMax,Clamp)810     TYPED_TEST(MinMax, Clamp) {
811         using vec = TypeParam;
812         using VT = ValueType<TypeParam>;
813         test_ternary<vec>(
814             NAME_INFO(clamp), clamp<VT>,
815             [](const vec& v0, const vec& v1, const vec& v2) {
816                 return clamp(v0, v1, v2);
817             },
818             createDefaultTernaryTestCase<vec>(TestSeed()),
819                 RESOLVE_OVERLOAD(filter_clamp));
820     }
TYPED_TEST(BitwiseFloatsAdditional,ZeroMask)821     TYPED_TEST(BitwiseFloatsAdditional, ZeroMask) {
822         using vec = TypeParam;
823         using VT = ValueType<TypeParam>;
824         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
825         CACHE_ALIGN VT test_vals[vec::size()];
826         //all sets will be within 0  2^(n-1)
827         auto power_sets = 1 << (vec::size());
828         for (const auto expected : c10::irange(power_sets)) {
829             // generate test_val based on expected
830             for (int i = 0; i < vec::size(); ++i)
831             {
832                 if (expected & (1 << i)) {
833                     test_vals[i] = (VT)0;
834                 }
835                 else {
836                     test_vals[i] = (VT)0.897;
837                 }
838             }
839             int actual = vec::loadu(test_vals).zero_mask();
840             ASSERT_EQ(expected, actual) << "Failure Details:\n"
841                 << std::hex << "Expected:\n#\t" << expected
842                 << "\nActual:\n#\t" << actual;
843         }
844     }
TYPED_TEST(BitwiseFloatsAdditional,Convert)845     TYPED_TEST(BitwiseFloatsAdditional, Convert) {
846         using vec = TypeParam;
847         using VT = ValueType<TypeParam>;
848         using IntVT = at::vec::int_same_size_t<VT>;
849 
850         // verify float to int
851         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
852         CACHE_ALIGN VT input1[vec::size()];
853         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
854         CACHE_ALIGN IntVT expected_vals1[vec::size()];
855         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
856         CACHE_ALIGN IntVT actual_vals1[vec::size()];
857         for (int64_t i = 0; i < vec::size(); i++) {
858             input1[i] = (VT)i * (VT)2.1 + (VT)0.5;
859             expected_vals1[i] = static_cast<IntVT>(input1[i]);
860         }
861         at::vec::convert(input1, actual_vals1, vec::size());
862         auto expected1 = VecType<IntVT>::loadu(expected_vals1);
863         auto actual1 = VecType<IntVT>::loadu(actual_vals1);
864         if (AssertVectorized<VecType<IntVT>>(NAME_INFO(test_convert_to_int), expected1, actual1).check()) {
865           return;
866         }
867 
868         // verify int to float
869         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
870         CACHE_ALIGN IntVT input2[vec::size()];
871         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
872         CACHE_ALIGN VT expected_vals2[vec::size()];
873         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
874         CACHE_ALIGN VT actual_vals2[vec::size()];
875         for (int64_t i = 0; i < vec::size(); i++) {
876             input2[i] = (IntVT)i * (IntVT)2 + (IntVT)1;
877             expected_vals2[i] = (VT)input2[i];
878         }
879         at::vec::convert(input2, actual_vals2, vec::size());
880         auto expected2 = vec::loadu(expected_vals2);
881         auto actual2 = vec::loadu(actual_vals2);
882         AssertVectorized<vec>(NAME_INFO(test_convert_to_float), expected2, actual2).check();
883     }
TYPED_TEST(BitwiseFloatsAdditional,Fmadd)884     TYPED_TEST(BitwiseFloatsAdditional, Fmadd) {
885         using vec = TypeParam;
886         using VT = ValueType<TypeParam>;
887 
888         auto test_case = TestingCase<vec>::getBuilder()
889           .addDomain(CheckWithinDomains<VT>{
890               {{(VT)-1000, (VT)1000}, {(VT)-1000, (VT)1000}, {(VT)-1000, (VT)1000}},
891               true, getDefaultTolerance<VT>()})
892           .setTestSeed(TestSeed());
893 
894         test_ternary<vec>(
895             NAME_INFO(clamp), RESOLVE_OVERLOAD(local_fmadd),
896             [](const vec& v0, const vec& v1, const vec& v2) {
897                 return at::vec::fmadd(v0, v1, v2);
898             },
899             test_case,
900             RESOLVE_OVERLOAD(filter_fmadd));
901     }
902     template<typename vec, typename VT, int64_t mask>
903     typename std::enable_if_t<(mask < 0 || mask> 255), void>
904     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
test_blend(VT expected_val[vec::size ()],VT a[vec::size ()],VT b[vec::size ()])905     test_blend(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()])
906     {
907     }
908     template<typename vec, typename VT, int64_t mask>
909     typename std::enable_if_t<(mask >= 0 && mask <= 255), void>
910     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
test_blend(VT expected_val[vec::size ()],VT a[vec::size ()],VT b[vec::size ()])911     test_blend(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()]) {
912         // generate expected_val
913         int64_t m = mask;
914         for (int64_t i = 0; i < vec::size(); i++) {
915             expected_val[i] = (m & 0x01) ? b[i] : a[i];
916             m = m >> 1;
917         }
918         // test with blend
919         auto vec_a = vec::loadu(a);
920         auto vec_b = vec::loadu(b);
921         auto expected = vec::loadu(expected_val);
922         auto actual = vec::template blend<mask>(vec_a, vec_b);
923         auto mask_str = std::string("\nblend mask: ") + std::to_string(mask);
924         if (AssertVectorized<vec>(std::string(NAME_INFO(test_blend)) + mask_str, expected, actual).check()) return;
925         test_blend<vec, VT, mask - 1>(expected_val, a, b);
926     }
927     template<typename vec, typename VT, int64_t idx, int64_t N>
928     std::enable_if_t<(!is_complex<VT>::value && idx == N), bool>
929     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
test_blendv(VT expected_val[vec::size ()],VT a[vec::size ()],VT b[vec::size ()],VT mask[vec::size ()])930     test_blendv(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], VT mask[vec::size()]) {
931         using bit_rep = BitType<VT>;
932         // generate expected_val
933         for (int64_t i = 0; i < vec::size(); i++) {
934             bit_rep hex_mask = 0;
935             hex_mask=c10::bit_cast<bit_rep>(mask[i]);
936             expected_val[i] = (hex_mask & 0x01) ? b[i] : a[i];
937         }
938         // test with blendv
939         auto vec_a = vec::loadu(a);
940         auto vec_b = vec::loadu(b);
941         auto vec_m = vec::loadu(mask);
942         auto expected = vec::loadu(expected_val);
943         auto actual = vec::blendv(vec_a, vec_b, vec_m);
944         auto mask_str = std::string("\nblendv mask: ");
945         for (int64_t i = 0; i < vec::size(); i++) {
946             mask_str += std::to_string(mask[i]) + " ";
947         }
948         if (AssertVectorized<vec>(std::string(NAME_INFO(test_blendv)) + mask_str, expected, actual).check()) {
949             return false;
950         }
951         return true;
952     }
953     template<typename vec, typename VT, int64_t idx, int64_t N>
954     std::enable_if_t<(!is_complex<VT>::value && idx != N), bool>
955     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
test_blendv(VT expected_val[vec::size ()],VT a[vec::size ()],VT b[vec::size ()],VT mask[vec::size ()])956     test_blendv(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], VT mask[vec::size()]) {
957         // shuffle mask and do blendv test
958         VT m = mask[idx];
959         if (!test_blendv<vec, VT, idx+1, N>(expected_val, a, b, mask)) return false;
960         if (m != (VT)0) {
961           mask[idx] = (VT)0;
962         }
963         else {
964           int64_t hex_mask = 0xFFFFFFFFFFFFFFFF;
965           std::memcpy(&mask[idx], &hex_mask, sizeof(VT));
966         }
967         if (!test_blendv<vec, VT, idx+1, N>(expected_val, a, b, mask)) return false;
968         mask[idx] = m;
969         return true;
970     }
971     template<typename T, int N>
972     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
blend_init(T (& a)[N],T (& b)[N])973     void blend_init(T(&a)[N], T(&b)[N]) {
974         a[0] = (T)1.0;
975         b[0] = a[0] + (T)N;
976         for (const auto i : c10::irange(1, N)) {
977             a[i] = a[i - 1] + (T)(1.0);
978             b[i] = b[i - 1] + (T)(1.0);
979         }
980     }
TYPED_TEST(BitwiseFloatsAdditional,Blendv)981     TYPED_TEST(BitwiseFloatsAdditional, Blendv) {
982         using vec = TypeParam;
983         using VT = ValueType<TypeParam>;
984         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
985         CACHE_ALIGN VT a[vec::size()];
986         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
987         CACHE_ALIGN VT b[vec::size()];
988         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
989         CACHE_ALIGN VT mask[vec::size()] = {0};
990         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
991         CACHE_ALIGN VT expected_val[vec::size()];
992         blend_init(a, b);
993         test_blendv<vec, VT, 0, vec::size()>(expected_val, a, b, mask);
994     }
TYPED_TEST(BitwiseFloatsAdditional2,Blend)995     TYPED_TEST(BitwiseFloatsAdditional2, Blend) {
996         using vec = TypeParam;
997         using VT = ValueType<TypeParam>;
998         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
999         CACHE_ALIGN VT a[vec::size()];
1000         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1001         CACHE_ALIGN VT b[vec::size()];
1002         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1003         CACHE_ALIGN VT expected_val[vec::size()];
1004         blend_init(a, b);
1005         constexpr int64_t power_sets = 1LL << (vec::size());
1006         test_blend<vec, VT, power_sets - 1>(expected_val, a, b);
1007     }
1008     template<typename vec, typename VT>
1009     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
test_set(VT expected_val[vec::size ()],VT a[vec::size ()],VT b[vec::size ()],int64_t count)1010     void test_set(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], int64_t count){
1011         if (count < 0) return;
1012         //generate expected_val
1013         for (int64_t i = 0; i < vec::size(); i++) {
1014             expected_val[i] = (i < count) ? b[i] : a[i];
1015         }
1016         // test with set
1017         auto vec_a = vec::loadu(a);
1018         auto vec_b = vec::loadu(b);
1019         auto expected = vec::loadu(expected_val);
1020         auto actual = vec::set(vec_a, vec_b, count);
1021 
1022         auto count_str = std::string("\ncount: ") + std::to_string(count);
1023         if (AssertVectorized<vec>(std::string(NAME_INFO(test_set)) + count_str, expected, actual).check()) {
1024           return;
1025         }
1026         test_set<vec, VT>(expected_val, a, b, (count == 0 ? -1 : count / 2));
1027     }
TYPED_TEST(BitwiseFloatsAdditional2,Set)1028     TYPED_TEST(BitwiseFloatsAdditional2, Set) {
1029         using vec = TypeParam;
1030         using VT = ValueType<TypeParam>;
1031         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1032         CACHE_ALIGN VT a[vec::size()];
1033         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1034         CACHE_ALIGN VT b[vec::size()];
1035         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1036         CACHE_ALIGN VT expected_val[vec::size()];
1037         blend_init(a, b);
1038         test_set<vec, VT>(expected_val, a, b, vec::size());
1039     }
1040     template<typename T>
1041     std::enable_if_t<!is_complex<T>::value, void>
arange_init(T & base,T & step)1042     arange_init(T& base, T& step) {
1043         base = (T)5.0;
1044         step = (T)2.0;
1045     }
1046     template<typename T>
1047     std::enable_if_t<is_complex<T>::value, void>
arange_init(T & base,T & step)1048     arange_init(T& base, T& step) {
1049        base = T(5.0, 5.0);
1050        step = T(2.0, 3.0);
1051     }
TYPED_TEST(RangeFactories,Arange)1052     TYPED_TEST(RangeFactories, Arange) {
1053         using vec = TypeParam;
1054         using VT = ValueType<TypeParam>;
1055         using UVT = UvalueType<TypeParam>;
1056         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1057         CACHE_ALIGN VT expected_val[vec::size()];
1058         VT base, step;
1059         arange_init(base, step);
1060         for (int64_t i = 0; i < vec::size(); i++) {
1061             expected_val[i] = base + VT((UVT)i) * step;
1062         }
1063         auto expected = vec::loadu(expected_val);
1064         auto actual = vec::arange(base, step);
1065         AssertVectorized<vec>(NAME_INFO(test_arange), expected, actual).check();
1066     }
TEST(ComplexTests,TestComplexFloatImagRealConj)1067     TEST(ComplexTests, TestComplexFloatImagRealConj) {
1068         // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1069         float aa[] = { 1.5488e-28,2.5488e-28,3.5488e-28,4.5488e-28,5.5488e-28,6.5488e-28,7.5488e-28,8.5488e-28,
1070                        9.5488e-28,10.5488e-28,11.5488e-28,12.5488e-28,13.5488e-28,14.5488e-28,15.5488e-28,16.5488e-28};
1071         // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1072         float exp[] = { aa[0],0,aa[2],0,aa[4],0,aa[6],0,aa[8],0,aa[10],0,aa[12],0,aa[14],0 };
1073         // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1074         float exp3[] = { aa[1],0,aa[3],0,aa[5],0,aa[7],0,aa[9],0,aa[11],0,aa[13],0,aa[15],0 };
1075         // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1076         float exp4[] = { 1.5488e-28, -2.5488e-28,3.5488e-28,-4.5488e-28,
1077                          5.5488e-28,-6.5488e-28,7.5488e-28,-8.5488e-28,
1078                          9.5488e-28,-10.5488e-28,11.5488e-28,-12.5488e-28,
1079                          13.5488e-28,-14.5488e-28,15.5488e-28,-16.5488e-28 };
1080         auto a = vcomplex::loadu(aa);
1081         auto actual1 = a.real();
1082         auto actual3 = a.imag();
1083         auto actual4 = a.conj();
1084         auto expected1 = vcomplex::loadu(exp);
1085         auto expected3 = vcomplex::loadu(exp3);
1086         auto expected4 = vcomplex::loadu(exp4);
1087         AssertVectorized<vcomplex>(NAME_INFO(complex real), expected1, actual1).check();
1088         AssertVectorized<vcomplex>(NAME_INFO(complex imag), expected3, actual3).check();
1089         AssertVectorized<vcomplex>(NAME_INFO(complex conj), expected4, actual4).check();
1090     }
TEST(ComplexTests,TestComplexConstructor)1091     TEST(ComplexTests, TestComplexConstructor) {
1092         auto actual1 = vcomplex(1.0);
1093         auto expected1 = vcomplex(Complex<float>(1.0));
1094         AssertVectorized<vcomplex>(NAME_INFO(complex constructor), expected1, actual1).check();
1095     }
TYPED_TEST(QuantizationTests,Quantize)1096     TYPED_TEST(QuantizationTests, Quantize) {
1097         using vec = TypeParam;
1098         using underlying = ValueType<vec>;
1099         constexpr int trials = 4000;
1100         // NOLINTNEXTLINE(bugprone-signed-char-misuse)
1101         constexpr int min_val = std::numeric_limits<underlying>::min();
1102         constexpr int max_val = std::numeric_limits<underlying>::max();
1103         constexpr int el_count = vfloat::size();
1104         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1105         CACHE_ALIGN float unit_float_vec[el_count];
1106         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1107         CACHE_ALIGN underlying expected_qint_vals[vec::size()];
1108         typename vec::float_vec_return_type  float_ret;
1109         auto seed = TestSeed();
1110         //zero point
1111         ValueGen<int> generator_zp(min_val, max_val, seed);
1112         //scale
1113         ValueGen<float> generator_sc(1.f, 15.f, seed.add(1));
1114         //value
1115         float minv = static_cast<float>(static_cast<double>(min_val) * 2.0);
1116         float maxv = static_cast<float>(static_cast<double>(max_val) * 2.0);
1117         ValueGen<float> gen(minv, maxv, seed.add(2));
1118         for (C10_UNUSED const auto i : c10::irange(trials)) {
1119             float scale = generator_sc.get();
1120             float inv_scale = 1.0f / static_cast<float>(scale);
1121             auto zero_point_val = generator_zp.get();
1122             int index = 0;
1123             for (int j = 0; j < vec::float_num_vecs(); j++) {
1124                 //generate vals
1125                 for (auto& v : unit_float_vec) {
1126                     v = gen.get();
1127                     expected_qint_vals[index] = quantize_val<underlying>(scale, zero_point_val, v);
1128                     index++;
1129                 }
1130                 float_ret[j] = vfloat::loadu(unit_float_vec);
1131             }
1132             auto expected = vec::loadu(expected_qint_vals);
1133             auto actual = vec::quantize(float_ret, scale, zero_point_val, inv_scale);
1134             if (AssertVectorized<vec>(NAME_INFO(Quantize), expected, actual).check()) return;
1135         } //trials;
1136     }
1137 #if (defined(CPU_CAPABILITY_AVX2) ||  defined(CPU_CAPABILITY_AVX512))  && !defined(_MSC_VER)
1138     // This test case aims to test at::vec::QuantizeAvx512 and
1139     // at::vec::QuantizeAVX2 which do not support CPU_CAPABILITY_DEFAULT case
TYPED_TEST(Quantization8BitWithTailTests,QuantizeTile)1140     TYPED_TEST(Quantization8BitWithTailTests, QuantizeTile) {
1141       using vec = TypeParam;
1142       using underlying = ValueType<vec>;
1143       constexpr int trials = 4000;
1144       // NOLINTNEXTLINE(bugprone-signed-char-misuse)
1145       constexpr int min_val = std::numeric_limits<underlying>::min();
1146       constexpr int max_val = std::numeric_limits<underlying>::max();
1147       constexpr int el_count = vfloat::size();
1148       // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1149       CACHE_ALIGN float unit_float_vec[el_count];
1150       // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1151       CACHE_ALIGN underlying expected_qint_vals[vec::size()];
1152       CACHE_ALIGN underlying actual_qint_vals[vec::size()];
1153       constexpr int tile_size = vec::size() - 1;
1154       typename vec::float_vec_return_type float_ret;
1155       auto seed = TestSeed();
1156       // zero point
1157       ValueGen<int> generator_zp(min_val, max_val, seed);
1158       // scale
1159       ValueGen<float> generator_sc(1.f, 15.f, seed.add(1));
1160       // value
1161       float minv = static_cast<float>(static_cast<double>(min_val) * 2.0);
1162       float maxv = static_cast<float>(static_cast<double>(max_val) * 2.0);
1163       ValueGen<float> gen(minv, maxv, seed.add(2));
1164       for (C10_UNUSED const auto i : c10::irange(trials)) {
1165         float scale = generator_sc.get();
1166         float inv_scale = 1.0f / static_cast<float>(scale);
1167         auto zero_point_val = generator_zp.get();
1168         int index = 0;
1169         for (int j = 0; j < vec::float_num_vecs(); j++) {
1170           // generate vals
1171           for (auto& v : unit_float_vec) {
1172             v = gen.get();
1173             expected_qint_vals[index] =
1174                 quantize_val<underlying>(scale, zero_point_val, v);
1175             index++;
1176           }
1177           float_ret[j] = vfloat::loadu(unit_float_vec);
1178         }
1179 #if defined(CPU_CAPABILITY_AVX512)
1180         at::vec::QuantizeAvx512(
1181             (float*)float_ret.data(),
1182             actual_qint_vals,
1183             tile_size,
1184             inv_scale,
1185             zero_point_val);
1186 #endif
1187 #if defined(CPU_CAPABILITY_AVX2)
1188         at::vec::QuantizeAvx2(
1189             (float*)float_ret.data(),
1190             actual_qint_vals,
1191             tile_size,
1192             inv_scale,
1193             zero_point_val);
1194 #endif
1195         expected_qint_vals[tile_size] = 0;
1196         actual_qint_vals[tile_size] = 0;
1197         auto expected = vec::loadu(expected_qint_vals);
1198         auto actual = vec::loadu(actual_qint_vals);
1199         if (AssertVectorized<vec>(NAME_INFO(QuantizeTile), expected, actual)
1200                 .check())
1201           return;
1202       } // trials;
1203     }
1204 #endif
TYPED_TEST(QuantizationTests,DeQuantize)1205     TYPED_TEST(QuantizationTests, DeQuantize) {
1206         using vec = TypeParam;
1207         using underlying = ValueType<vec>;
1208         constexpr bool is_large = sizeof(underlying) > 1;
1209         constexpr int trials = is_large ? 4000 : std::numeric_limits<underlying>::max() / 2;
1210         constexpr int min_val = is_large ? -2190 : std::numeric_limits<underlying>::min();
1211         constexpr int max_val = is_large ? 2199 : std::numeric_limits<underlying>::max();
1212         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1213         CACHE_ALIGN float unit_exp_vals[vfloat::size()];
1214         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1215         CACHE_ALIGN underlying qint_vals[vec::size()];
1216 #if  defined(CHECK_DEQUANT_WITH_LOW_PRECISION)
1217         std::cout << "Dequant will be tested with relative error " << 1.e-3f << std::endl;
1218 #endif
1219         auto seed = TestSeed();
1220         ValueGen<int> generator(min_val, max_val, seed.add(1));
1221         //scale
1222         ValueGen<float> generator_sc(1.f, 15.f, seed.add(2));
1223         for (C10_UNUSED const auto i : c10::irange(trials)) {
1224             float scale = generator_sc.get();
1225             int32_t zero_point_val = generator.get();
1226             float scale_zp_premul = -(scale * zero_point_val);
1227             vfloat vf_scale = vfloat{ scale };
1228             vfloat vf_zp = vfloat{ static_cast<float>(zero_point_val) };
1229             vfloat vf_scale_zp = vfloat{ scale_zp_premul };
1230             //generate vals
1231             for (auto& x : qint_vals) {
1232                 x = generator.get();
1233             }
1234             //get expected
1235             int index = 0;
1236             auto qint_vec = vec::loadu(qint_vals);
1237             auto actual_float_ret = qint_vec.dequantize(vf_scale, vf_zp, vf_scale_zp);
1238             for (int j = 0; j < vec::float_num_vecs(); j++) {
1239                 for (auto& v : unit_exp_vals) {
1240                     v = dequantize_val(scale, zero_point_val, qint_vals[index]);
1241                     index++;
1242                 }
1243                 vfloat expected = vfloat::loadu(unit_exp_vals);
1244                 const auto& actual = actual_float_ret[j];
1245 #if  defined(CHECK_DEQUANT_WITH_LOW_PRECISION)
1246                 if (AssertVectorized<vfloat>(NAME_INFO(DeQuantize), seed, expected, actual).check(false, true, 1.e-3f)) return;
1247 #else
1248                 if (AssertVectorized<vfloat>(NAME_INFO(DeQuantize), seed, expected, actual).check()) return;
1249 #endif
1250             }
1251         } //trials;
1252     }
TYPED_TEST(QuantizationTests,ReQuantizeFromInt)1253     TYPED_TEST(QuantizationTests, ReQuantizeFromInt) {
1254         using vec = TypeParam;
1255         using underlying = ValueType<vec>;
1256         constexpr int trials = 4000;
1257         constexpr int min_val = -65535;
1258         constexpr int max_val = 65535;
1259         constexpr int el_count = vint::size();
1260         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1261         CACHE_ALIGN c10::qint32 unit_int_vec[el_count];
1262         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1263         CACHE_ALIGN underlying expected_qint_vals[vec::size()];
1264         typename vec::int_vec_return_type  int_ret;
1265         auto seed = TestSeed();
1266         //zero point and value
1267         ValueGen<int32_t> generator(min_val, max_val, seed);
1268         //scale
1269         ValueGen<float> generator_sc(1.f, 15.f, seed.add(1));
1270         for (C10_UNUSED const auto i : c10::irange(trials)) {
1271             float multiplier = 1.f / (generator_sc.get());
1272             auto zero_point_val = generator.get();
1273             int index = 0;
1274             for (int j = 0; j < vec::float_num_vecs(); j++) {
1275                 //generate vals
1276                 for (auto& v : unit_int_vec) {
1277                     v = c10::qint32(generator.get());
1278                     expected_qint_vals[index] = requantize_from_int<underlying>(multiplier, zero_point_val, v.val_);
1279                     index++;
1280                 }
1281                 int_ret[j] = vqint::loadu(unit_int_vec);
1282             }
1283             auto expected = vec::loadu(expected_qint_vals);
1284             auto actual = vec::requantize_from_int(int_ret, multiplier, zero_point_val);
1285             if (AssertVectorized<vec>(NAME_INFO(ReQuantizeFromInt), seed, expected, actual).check()) {
1286                 return;
1287             }
1288         } //trials;
1289     }
TYPED_TEST(QuantizationTests,WideningSubtract)1290     TYPED_TEST(QuantizationTests, WideningSubtract) {
1291         using vec = TypeParam;
1292         using underlying = ValueType<vec>;
1293         constexpr bool is_large = sizeof(underlying) > 1;
1294         constexpr int trials = is_large ? 4000 : std::numeric_limits<underlying>::max() / 2;
1295         // NOLINTNEXTLINE(bugprone-signed-char-misuse)
1296         constexpr int min_val = std::numeric_limits<underlying>::min();
1297         constexpr int max_val = std::numeric_limits<underlying>::max();
1298         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1299         CACHE_ALIGN int32_t unit_exp_vals[vfloat::size()];
1300         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1301         CACHE_ALIGN underlying qint_vals[vec::size()];
1302         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1303         CACHE_ALIGN underlying qint_b[vec::size()];
1304         typename vec::int_vec_return_type  expected_int_ret;
1305         auto seed = TestSeed();
1306         ValueGen<underlying> generator(min_val, max_val, seed);
1307         for (C10_UNUSED const auto i : c10::irange(trials)) {
1308             //generate vals
1309             for (int j = 0; j < vec::size(); j++) {
1310                 qint_vals[j] = generator.get();
1311                 qint_b[j] = generator.get();
1312                 if constexpr (std::is_same_v<underlying, int>) {
1313                     //filter overflow cases
1314                     filter_sub_overflow(qint_vals[j], qint_b[j]);
1315                 }
1316             }
1317             int index = 0;
1318             auto qint_vec = vec::loadu(qint_vals);
1319             auto qint_vec_b = vec::loadu(qint_b);
1320             auto actual_int_ret = qint_vec.widening_subtract(qint_vec_b);
1321             for (int j = 0; j < vec::float_num_vecs(); j++) {
1322                 for (auto& v : unit_exp_vals) {
1323                     v = widening_subtract(qint_vals[index], qint_b[index]);
1324                     index++;
1325                 }
1326                 auto expected = vqint::loadu(unit_exp_vals);
1327                 const auto& actual = actual_int_ret[j];
1328                 if (AssertVectorized<vqint>(NAME_INFO(WideningSubtract), seed, expected, actual).check()) return;
1329             }
1330         } //trials;
1331     }
TYPED_TEST(QuantizationTests,Relu)1332     TYPED_TEST(QuantizationTests, Relu) {
1333         using vec = TypeParam;
1334         using VT = ValueType<TypeParam>;
1335         constexpr VT min_val = std::numeric_limits<VT>::min();
1336         constexpr VT max_val = std::numeric_limits<VT>::max();
1337         constexpr VT fake_zp = sizeof(VT) > 1 ? static_cast<VT>(65535) : static_cast<VT>(47);
1338         auto test_case = TestingCase<vec>::getBuilder()
1339             .addDomain(CheckWithinDomains<VT>{ { DomainRange<VT>{min_val, max_val}, DomainRange<VT>{(VT)0, (VT)fake_zp}} })
1340             .setTestSeed(TestSeed());
1341         test_binary<vec>(
1342             NAME_INFO(relu),
1343             RESOLVE_OVERLOAD(relu),
1344             [](const vec& v0, const vec& v1) {
1345                 return v0.relu(v1);
1346             },
1347             test_case);
1348     }
TYPED_TEST(QuantizationTests,Relu6)1349     TYPED_TEST(QuantizationTests, Relu6) {
1350         using vec = TypeParam;
1351         using VT = ValueType<TypeParam>;
1352         constexpr VT min_val = std::numeric_limits<VT>::min();
1353         constexpr VT max_val = std::numeric_limits<VT>::max();
1354         constexpr VT fake_zp = sizeof(VT) > 1 ? static_cast<VT>(65535) : static_cast<VT>(47);
1355         constexpr VT temp = sizeof(VT) > 1 ? static_cast<VT>(12345) : static_cast<VT>(32);
1356         constexpr VT fake_qsix = fake_zp + temp;
1357         auto test_case = TestingCase<vec>::getBuilder()
1358             .addDomain(CheckWithinDomains<VT>{
1359                 {
1360                     DomainRange<VT>{min_val, max_val},
1361                         DomainRange<VT>{(VT)0, (VT)fake_zp},
1362                         DomainRange<VT>{(VT)fake_zp, (VT)fake_qsix}
1363                 }})
1364             .setTestSeed(TestSeed());
1365         test_ternary<vec>(
1366             NAME_INFO(relu6),
1367             RESOLVE_OVERLOAD(relu6),
1368             [](/*const*/ vec& v0, const vec& v1, const vec& v2) {
1369                 return  v0.relu6(v1, v2);
1370             },
1371             test_case);
1372     }
TYPED_TEST(FunctionalTests,Map)1373     TYPED_TEST(FunctionalTests, Map) {
1374         using vec = TypeParam;
1375         using VT = ValueType<TypeParam>;
1376         constexpr auto R = 2LL; // residual
1377         constexpr auto N = vec::size() + R;
1378         CACHE_ALIGN VT x1[N];
1379         CACHE_ALIGN VT x2[N];
1380         CACHE_ALIGN VT x3[N];
1381         CACHE_ALIGN VT x4[N];
1382         CACHE_ALIGN VT y[N];
1383         CACHE_ALIGN VT ref_y[N];
1384         auto seed = TestSeed();
1385         ValueGen<VT> generator(VT(-100), VT(100), seed);
1386         for (const auto i : c10::irange(N)) {
1387           x1[i] = generator.get();
1388           x2[i] = generator.get();
1389           x3[i] = generator.get();
1390           x4[i] = generator.get();
1391         }
1392         auto cmp = [&](VT* y, VT* ref_y) {
1393           AssertVectorized<vec>(NAME_INFO(Map), vec::loadu(y), vec::loadu(ref_y)).check(true);
1394           AssertVectorized<vec>(NAME_INFO(Map), vec::loadu(y + vec::size(), R), vec::loadu(ref_y + vec::size(), R)).check(true);
1395         };
1396         // test map: y = x1
1397         at::vec::map<VT>([](vec x) { return x; }, y, x1, N);
1398         for (const auto i : c10::irange(N)) { ref_y[i] = x1[i]; }
1399         cmp(y, ref_y);
1400         // test map2: y = x1 + x2
1401         at::vec::map2<VT>([](vec x1, vec x2) { return x1 + x2; }, y, x1, x2, N);
1402         for (const auto i : c10::irange(N)) { ref_y[i] = x1[i] + x2[i]; }
1403         cmp(y, ref_y);
1404         // test map3: y = x1 + x2 + x3
1405         at::vec::map3<VT>([](vec x1, vec x2, vec x3) { return x1 + x2 + x3; }, y, x1, x2, x3, N);
1406         for (const auto i : c10::irange(N)) { ref_y[i] = x1[i] + x2[i] + x3[i]; }
1407         cmp(y, ref_y);
1408         // test map4: y = x1 + x2 + x3 + x4
1409         at::vec::map4<VT>([](vec x1, vec x2, vec x3, vec x4) { return x1 + x2 + x3 + x4; }, y, x1, x2, x3, x4, N);
1410         for (const auto i : c10::irange(N)) { ref_y[i] = x1[i] + x2[i] + x3[i] + x4[i]; }
1411         cmp(y, ref_y);
1412     }
TYPED_TEST(FunctionalTestsReducedFloat,Reduce)1413       TYPED_TEST(FunctionalTestsReducedFloat, Reduce) {
1414       using vec = TypeParam;
1415       // Can't use ValueType<TypeParam> here:
1416       // Vectorized<BFloat16>::value_type returns uint16_t on AVX2/AVX512
1417       using VT = UholdType<TypeParam>;
1418       using RT = float; // reference
1419       constexpr auto R = 2LL; // residual
1420       constexpr auto N = vec::size() * 2 + R;
1421       CACHE_ALIGN RT x_f1[N];
1422       CACHE_ALIGN RT x_f2[N];
1423       CACHE_ALIGN RT x_f3[N];
1424       CACHE_ALIGN VT x_b1[N];
1425       CACHE_ALIGN VT x_b2[N];
1426       CACHE_ALIGN VT x_b3[N];
1427       auto seed = TestSeed();
1428       ValueGen<RT> generator(RT(-1), RT(1), seed);
1429       for (const auto i : c10::irange(N)) {
1430         x_f1[i] = generator.get();
1431         x_f2[i] = generator.get();
1432         x_f3[i] = generator.get();
1433         x_b1[i] = VT(x_f1[i]);
1434         x_b2[i] = VT(x_f2[i]);
1435         x_b3[i] = VT(x_f3[i]);
1436       }
1437       float atol = 0.01f;
1438       float rtol = 0.01f;
1439       auto cmp = [=](RT ref, VT val) { return std::abs(ref - val) <= atol + rtol * std::abs(val); };
1440       auto sum = [](auto& x, auto& y) { return x + y; };
1441       auto max = [](auto& x, auto& y) { return at::vec::maximum(x, y); };
1442       // ReduceAll
1443       for (int64_t len = 1; len <= N; len++) {
1444         auto y1 = at::vec::reduce_all<RT>(sum, x_f1, len);
1445         auto y2 = at::vec::reduce_all<VT>(sum, x_b1, len);
1446         ASSERT_TRUE(cmp(y1, y2)) << "Failure Details:\nTest Seed to reproduce: " << seed
1447             << "\nreduce_all, Length: " << len << "; fp32: " << y1 << "; bf16: " << RT(y2);
1448       }
1449       // Reduce2All
1450       for (int64_t len = 1; len <= N; len++) {
1451         auto y1 = at::vec::reduce2_all<RT>(sum, max, x_f1, len);
1452         auto y2 = at::vec::reduce2_all<VT>(sum, max, x_b1, len);
1453         ASSERT_TRUE(cmp(y1.first, y2.first) && cmp(y1.second, y2.second)) << "Failure Details:\nTest Seed to reproduce: " << seed
1454             << "\nreduce2_all, Length: " << len << "; fp32(fun1): " << y1.first << "; bf16(fun1): " << RT(y2.first)
1455             << "; fp32(fun2): " << y1.second << "; bf16(fun2): " << y2.second;
1456       }
1457       // MapReduceAll
1458       for (int64_t len = 1; len <= N; len++) {
1459         auto y1 = at::vec::map_reduce_all<RT>([](auto x) { return x - x.exp(); }, sum, x_f1, len);
1460         auto y2 = at::vec::map_reduce_all<VT>([](auto x) { return x - x.exp(); }, sum, x_b1, len);
1461         ASSERT_TRUE(cmp(y1, y2)) << "Failure Details:\nTest Seed to reproduce: " << seed
1462             << "\nmap_reduce_all, Length: " << len << "; fp32: " << y1 << "; bf16: " << RT(y2);
1463       }
1464       // Map2ReduceAll
1465       for (int64_t len = 1; len <= N; len++) {
1466         auto y1 = at::vec::map2_reduce_all<RT>([](auto x, auto y) { return x * y; }, sum, x_f1, x_f2, len);
1467         auto y2 = at::vec::map2_reduce_all<VT>([](auto x, auto y) { return x * y; }, sum, x_b1, x_b2, len);
1468         ASSERT_TRUE(cmp(y1, y2)) << "Failure Details:\nTest Seed to reproduce: " << seed
1469             << "\nmap2_reduce_all, Length: " << len << "; fp32: " << y1 << "; bf16: " << RT(y2);
1470       }
1471       // Map3ReduceAll
1472       for (int64_t len = 1; len <= N; len++) {
1473         auto y1 = at::vec::map3_reduce_all<RT>([](auto x, auto y, auto z) { return x * y + z; }, sum, x_f1, x_f2, x_f3, len);
1474         auto y2 = at::vec::map3_reduce_all<VT>([](auto x, auto y, auto z) { return x * y + z; }, sum, x_b1, x_b2, x_b3, len);
1475         ASSERT_TRUE(cmp(y1, y2)) << "Failure Details:\nTest Seed to reproduce: " << seed
1476             << "\nmap3_reduce_all, Length: " << len << "; fp32: " << y1 << "; bf16: " << RT(y2);
1477       }
1478     }
TYPED_TEST(FunctionalTestsReducedFloat,Map)1479     TYPED_TEST(FunctionalTestsReducedFloat, Map) {
1480       using vec = TypeParam;
1481       using VT = UholdType<TypeParam>;
1482       using RT = float; // reference
1483       constexpr auto R = 2LL; // residual
1484       constexpr auto N = vec::size() * 2 + R;
1485       CACHE_ALIGN RT x_f1[N];
1486       CACHE_ALIGN RT x_f2[N];
1487       CACHE_ALIGN RT x_f3[N];
1488       CACHE_ALIGN RT x_f4[N];
1489       CACHE_ALIGN VT x_b1[N];
1490       CACHE_ALIGN VT x_b2[N];
1491       CACHE_ALIGN VT x_b3[N];
1492       CACHE_ALIGN VT x_b4[N];
1493       CACHE_ALIGN RT y_f[N];
1494       CACHE_ALIGN VT y_b[N];
1495       auto seed = TestSeed();
1496       ValueGen<RT> generator(RT(-1), RT(1), seed);
1497       for (const auto i : c10::irange(N)) {
1498         x_f1[i] = generator.get();
1499         x_f2[i] = generator.get();
1500         x_f3[i] = generator.get();
1501         x_f4[i] = generator.get();
1502         x_b1[i] = VT(x_f1[i]);
1503         x_b2[i] = VT(x_f2[i]);
1504         x_b3[i] = VT(x_f3[i]);
1505         x_b4[i] = VT(x_f4[i]);
1506       }
1507       float atol = 0.01f;
1508       float rtol = 0.01f;
1509       auto cmp = [=](RT ref, VT val) { return std::abs(ref - val) <= atol + rtol * std::abs(val); };
1510       // Map
1511       for (int64_t len = 1; len <= N; len++) {
1512         at::vec::map<RT>([](auto x) { return x; }, y_f, x_f1, len);
1513         at::vec::map<VT>([](auto x) { return x; }, y_b, x_b1, len);
1514         for (const auto i : c10::irange(len)) {
1515           ASSERT_TRUE(cmp(y_f[i], y_b[i])) << "Failure Details:\nTest Seed to reproduce: " << seed
1516               << "\nmap, Length: " << len << "; index: " << i << "; fp32 reference: " << y_f[i] << "; bf16 value: " << RT(y_b[i]);
1517         }
1518       }
1519       // Map - For float32 in, reduced floating points out
1520       for (int64_t len = 1; len <= N; len++) {
1521         at::vec::map<RT>([](auto x) { return x; }, y_f, x_f1, len);
1522         at::vec::map<VT>([](auto x) { return x; }, y_b, x_f1, len);
1523         for (const auto i : c10::irange(len)) {
1524           ASSERT_TRUE(cmp(y_f[i], y_b[i])) << "Failure Details:\nTest Seed to reproduce: " << seed
1525               << "\nmap, Length: " << len << "; index: " << i << "; fp32 reference: " << y_f[i] << "; bf16 value: " << RT(y_b[i]);
1526         }
1527       }
1528       // Map2
1529       for (int64_t len = 1; len <= N; len++) {
1530         at::vec::map2<RT>([](auto x, auto y) { return x + y; }, y_f, x_f1, x_f2, len);
1531         at::vec::map2<VT>([](auto x, auto y) { return x + y; }, y_b, x_b1, x_b2, len);
1532         for (const auto i : c10::irange(len)) {
1533           ASSERT_TRUE(cmp(y_f[i], y_b[i])) << "Failure Details:\nTest Seed to reproduce: " << seed
1534               << "\nmap2, Length: " << len << "; index: " << i << "; fp32 reference: " << y_f[i] << "; bf16 value: " << RT(y_b[i]);
1535         }
1536       }
1537       // Map3
1538       for (int64_t len = 1; len <= N; len++) {
1539         at::vec::map3<RT>([](auto x, auto y, auto z) { return x + y * z; }, y_f, x_f1, x_f2, x_f3, len);
1540         at::vec::map3<VT>([](auto x, auto y, auto z) { return x + y * z; }, y_b, x_b1, x_b2, x_b3, len);
1541         for (const auto i : c10::irange(len)) {
1542           ASSERT_TRUE(cmp(y_f[i], y_b[i])) << "Failure Details:\nTest Seed to reproduce: " << seed
1543               << "\nmap3, Length: " << len << "; index: " << i << "; fp32 reference: " << y_f[i] << "; bf16 value: " << RT(y_b[i]);
1544         }
1545       }
1546       // Map4
1547       for (int64_t len = 1; len <= N; len++) {
1548          at::vec::map4<RT>([](auto x, auto y, auto z, auto w) { return x + y * z - w; }, y_f, x_f1, x_f2, x_f3, x_f4, len);
1549          at::vec::map4<VT>([](auto x, auto y, auto z, auto w) { return x + y * z - w; }, y_b, x_b1, x_b2, x_b3, x_b4, len);
1550          for (const auto i : c10::irange(len)) {
1551            ASSERT_TRUE(cmp(y_f[i], y_b[i])) << "Failure Details:\nTest Seed to reproduce: " << seed
1552                << "\nmap4, Length: " << len << "; index: " << i << "; fp32 reference: " << y_f[i] << "; bf16 value: " << RT(y_b[i]);
1553          }
1554       }
1555     }
TEST(HalfConversionTest,HalfFloat)1556     TEST(HalfConversionTest, HalfFloat) {
1557       float f32s[100];
1558       for (const auto i : c10::irange(100)) {
1559         f32s[i] = i + 0.3;
1560       }
1561       uint16_t u16;
1562       float x;
1563       for (const auto i : c10::irange(100)) {
1564       #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
1565           !defined(__APPLE__)
1566         u16 = at::vec::float2half_scalar(f32s[i]);
1567         x = at::vec::half2float_scalar(u16);
1568       #else
1569         u16 = c10::detail::fp16_ieee_from_fp32_value(f32s[i]);
1570         x = c10::detail::fp16_ieee_to_fp32_value(u16);
1571       #endif
1572 
1573         EXPECT_EQ(u16, c10::detail::fp16_ieee_from_fp32_value(f32s[i]))
1574             << "Test failed for float to uint16 " << f32s[i] << "\n";
1575         EXPECT_EQ(x, c10::detail::fp16_ieee_to_fp32_value(u16))
1576             << "Test failed for uint16 to float " << u16 << "\n";
1577       }
1578     }
TYPED_TEST(InfiniteTests,HasInfNan)1579     TYPED_TEST(InfiniteTests, HasInfNan) {
1580       using vec = TypeParam;
1581       using VT = UholdType<TypeParam>;
1582       auto vec_size = vec::size();
1583       VT values[20];
1584       for (const auto i : c10::irange(20)) {
1585         values[i] = i + 0.3;
1586       }
1587       auto vec_val = vec::loadu(values);
1588       auto seed = TestSeed();
1589       ValueGen<int> generator(int(0), int(vec_size - 1), seed);
1590       int index = generator.get();
1591       int nanBits = 0x7FC00000;
1592       VT v_nan = static_cast<VT>(*(float *)&nanBits);
1593       values[index] = v_nan;
1594       auto vec_nan = vec::loadu(values);
1595       int infBits = 0x7F800000;
1596       VT v_pinf = static_cast<VT>(*(float *)&infBits);
1597       values[index] = v_pinf;
1598       auto vec_pinf = vec::loadu(values);
1599       int negInfBits = 0xFF800000;
1600       VT v_ninf  = static_cast<VT>(*(float *)&negInfBits);
1601       values[index] = v_ninf;
1602       auto vec_ninf = vec::loadu(values);
1603 
1604       ASSERT_TRUE(!(vec_val.has_inf_nan())) << "Test failed for normal value\n";
1605       ASSERT_TRUE(vec_nan.has_inf_nan()) << "Test failed for NAN\n";
1606       ASSERT_TRUE(vec_pinf.has_inf_nan()) << "Test failed for positive Infinity\n";
1607       ASSERT_TRUE(vec_ninf.has_inf_nan()) << "Test failed for negative Infinity\n";
1608     }
TYPED_TEST(VecConvertTests,Convert)1609     TYPED_TEST(VecConvertTests, Convert) {
1610       using vec = TypeParam;
1611       using src_t = ValueType<TypeParam>;
1612       constexpr auto N = vec::size();
1613     #define TEST_CONVERT_TO(dst_t)                                     \
1614       do {                                                             \
1615         CACHE_ALIGN src_t x[N];                                        \
1616         CACHE_ALIGN dst_t y[N];                                        \
1617         CACHE_ALIGN dst_t ref[N];                                      \
1618         auto seed = TestSeed();                                        \
1619         auto low = std::is_signed_v<dst_t> ? src_t(-100) : 0;          \
1620         ValueGen<src_t> generator(low, src_t(100), seed);              \
1621         for (const auto i : c10::irange(N)) {                          \
1622           x[i] = generator.get();                                      \
1623         }                                                              \
1624         for (const auto i : c10::irange(N)) {                          \
1625           ref[i] = static_cast<dst_t>(x[i]);                           \
1626         }                                                              \
1627         auto x_vec = vec::loadu(x);                                    \
1628         auto y_vec = at::vec::convert<dst_t>(x_vec);                   \
1629         constexpr int num_dst_elements =                               \
1630             std::min(N, at::vec::Vectorized<dst_t>::size());           \
1631         y_vec.store(y, num_dst_elements);                              \
1632         for (const auto i : c10::irange(num_dst_elements)) {           \
1633           ASSERT_EQ(y[i], ref[i])                                      \
1634               << "Failure Details:\nTest Seed to reproduce: " << seed  \
1635               << " x[" << i << "]=" << x[i] << " dst_t=" #dst_t;       \
1636         }                                                              \
1637         constexpr int dst_n = N / num_dst_elements;                    \
1638         auto y_vec_n = at::vec::convert<dst_t, dst_n, src_t, 1>(       \
1639             at::vec::VectorizedN<src_t, 1>(x_vec));                    \
1640         y_vec_n.store(y, N);                                           \
1641         for (const auto i : c10::irange(N)) {                          \
1642           ASSERT_EQ(y[i], ref[i])                                      \
1643               << "Failure Details:\nTest Seed to reproduce: " << seed  \
1644               << " x[" << i << "]=" << x[i] << " dst_t=" #dst_t;       \
1645         }                                                              \
1646       } while (0)
1647       TEST_CONVERT_TO(int8_t);
1648       TEST_CONVERT_TO(uint8_t);
1649       TEST_CONVERT_TO(int16_t);
1650       TEST_CONVERT_TO(uint16_t);
1651       TEST_CONVERT_TO(int32_t);
1652       TEST_CONVERT_TO(uint32_t);
1653       TEST_CONVERT_TO(int64_t);
1654       TEST_CONVERT_TO(uint64_t);
1655       TEST_CONVERT_TO(c10::BFloat16);
1656       TEST_CONVERT_TO(c10::Half);
1657       TEST_CONVERT_TO(float);
1658       TEST_CONVERT_TO(double);
1659     #undef TEST_CONVERT_TO
1660     }
TYPED_TEST(VecMaskTests,MaskedLoad)1661     TYPED_TEST(VecMaskTests, MaskedLoad) {
1662       using vec = TypeParam;
1663       using src_t = ValueType<TypeParam>;
1664       constexpr auto size = vec::size();
1665 
1666     #define TEST_MASK_LOAD(dst_t, mask_t, mask_n)                           \
1667       do {                                                                  \
1668         CACHE_ALIGN dst_t x[mask_n * size];                                 \
1669         CACHE_ALIGN dst_t y[mask_n * size];                                 \
1670         CACHE_ALIGN dst_t ref[mask_n * size];                               \
1671         auto seed = TestSeed();                                             \
1672         ValueGen<dst_t> generator(dst_t(-100), dst_t(100), seed);           \
1673         for (const auto i : c10::irange(mask_n * size)) {                   \
1674           x[i] = generator.get();                                           \
1675         }                                                                   \
1676         auto vec_mask = generate_vec_mask<mask_t, mask_n>(seed);            \
1677         constexpr int dst_size = at::vec::Vectorized<dst_t>::size();        \
1678         constexpr int dst_n = mask_n * size / dst_size;                     \
1679         constexpr int rnd_n = (mask_n * size + dst_size - 1) / dst_size;    \
1680         if constexpr(dst_n * dst_size >= mask_n * size) {                   \
1681             auto x_vec = vec_mask.template loadu<dst_t, rnd_n>(x);          \
1682             x_vec.store(y);                                                 \
1683             for (const auto i : c10::irange(mask_n * size)) {               \
1684                 if (vec_mask.is_masked(i)) {                                \
1685                     ref[i] = x[i];                                          \
1686                 } else {                                                    \
1687                     ref[i] = 0;                                             \
1688                 }                                                           \
1689             }                                                               \
1690             for (const auto i : c10::irange(mask_n * size)) {               \
1691             ASSERT_EQ(y[i], ref[i])                                         \
1692                 << "Failure Details:\nTest Seed to reproduce: " << seed;    \
1693             }                                                               \
1694         }                                                                   \
1695       } while (0)
1696 
1697 
1698     #define TEST_MASK_LOAD_N(N)                                      \
1699       TEST_MASK_LOAD(int8_t, src_t, N);                              \
1700       TEST_MASK_LOAD(uint8_t, src_t, N);                             \
1701       TEST_MASK_LOAD(int16_t, src_t, N);                             \
1702       TEST_MASK_LOAD(uint16_t, src_t, N);                            \
1703       TEST_MASK_LOAD(int32_t, src_t, N);                             \
1704       TEST_MASK_LOAD(uint32_t, src_t, N);                            \
1705       TEST_MASK_LOAD(int64_t, src_t, N);                             \
1706       TEST_MASK_LOAD(uint64_t, src_t, N);                            \
1707       TEST_MASK_LOAD(c10::BFloat16, src_t, N);                       \
1708       TEST_MASK_LOAD(c10::Half, src_t, N);                           \
1709       TEST_MASK_LOAD(float, src_t, N);                               \
1710       TEST_MASK_LOAD(double, src_t, N);
1711 
1712       TEST_MASK_LOAD_N(1)
1713       TEST_MASK_LOAD_N(2)
1714       TEST_MASK_LOAD_N(4)
1715 
1716     #undef TEST_MASK_LOAD
1717     #undef TEST_MASK_LOAD_N
1718     }
TYPED_TEST(VecMaskTests,MaskedCheck)1719     TYPED_TEST(VecMaskTests, MaskedCheck) {
1720       using VT = ValueType<TypeParam>;
1721       using vec = TypeParam;
1722       constexpr auto size = vec::size();
1723     #define TEST_MASK_CHECK_N(N)                                                           \
1724       do {                                                                                 \
1725         auto vec_mask = create_vec_mask<VT, N>(0);                                         \
1726         ASSERT_TRUE(vec_mask.all_zero()) << "all_zero check failed";                       \
1727         vec_mask = create_vec_mask<VT, N>(-1);                                             \
1728         ASSERT_TRUE(vec_mask.all_masked()) << "all_masked check failed";                   \
1729         vec_mask = create_vec_mask<VT, N>(2);                                              \
1730         for (int i = 0; i < N; i ++) {                                                     \
1731           ASSERT_TRUE(vec_mask.is_masked(1 + i * size)) << "is_masked(1) check failed";    \
1732           ASSERT_TRUE(!vec_mask.is_masked(0 + i * size)) << "!is_masked(0) check failed";  \
1733         }                                                                                  \
1734       } while (0)
1735 
1736       TEST_MASK_CHECK_N(1);
1737       TEST_MASK_CHECK_N(2);
1738       TEST_MASK_CHECK_N(4);
1739 
1740     #undef TEST_MASK_CHECK_N
1741     }
TYPED_TEST(VecMaskTests,ToFrom)1742     TYPED_TEST(VecMaskTests, ToFrom) {
1743       using vec = TypeParam;
1744       using VT = ValueType<TypeParam>;
1745       constexpr auto N = vec::size();
1746       auto vec_mask = at::vec::VecMask<VT, 1>::from(1);
1747       ASSERT_TRUE(vec_mask.all_masked()) << "expect all_masked with from(1)";
1748       vec_mask = at::vec::VecMask<VT, 1>::from(0);
1749       ASSERT_TRUE(vec_mask.all_zero()) << "expect all_zero with from(0)";
1750 
1751       CACHE_ALIGN VT x[N];
1752       CACHE_ALIGN VT y[N];
1753       auto seed = TestSeed();
1754       ValueGen<VT> generator(VT(0), VT(2), seed);
1755       for (const auto i : c10::irange(N)) {
1756         x[i] = generator.get();
1757       }
1758       auto x_vec = vec::loadu(x);
1759       vec_mask = at::vec::VecMask<VT, 1>::template from<VT, 1>(x_vec);
1760       auto y_vec = vec_mask.template to<VT, 1>();
1761       y_vec.store(y);
1762       for (const auto i : c10::irange(N)) {
1763         ASSERT_EQ(y[i] != 0, x[i] != 0)
1764             << "Failure Details:\nTest Seed to reproduce: " << seed;
1765       }
1766     }
TYPED_TEST(VecMaskTests,Cast)1767     TYPED_TEST(VecMaskTests, Cast) {
1768       using vec = TypeParam;
1769       using src_t = ValueType<TypeParam>;
1770       constexpr auto size = vec::size();
1771 
1772     #define TEST_MASK_CAST(dst_t, mask_t, mask_n)                      \
1773       do {                                                             \
1774         CACHE_ALIGN mask_t x[mask_n * size];                           \
1775         CACHE_ALIGN dst_t y[mask_n * size];                            \
1776         auto seed = TestSeed();                                        \
1777         auto vec_mask = generate_vec_mask<mask_t, mask_n>(seed);       \
1778         constexpr int num_dst_elements =                               \
1779             std::min(size, at::vec::Vectorized<dst_t>::size());        \
1780         constexpr int dst_n = mask_n * size / num_dst_elements;        \
1781         auto vec_mask_new = vec_mask.template cast<dst_t, dst_n>();    \
1782         vec_mask.template to<mask_t, mask_n>().store(x);               \
1783         vec_mask_new.template to<dst_t, dst_n>().store(y);             \
1784         for (const auto i : c10::irange(mask_n * size)) {              \
1785           ASSERT_EQ(y[i], x[i])                                        \
1786               << "Failure Details:\nTest Seed to reproduce: " << seed; \
1787         }                                                              \
1788       } while (0)
1789 
1790     #define TEST_MASK_CAST_N(N)                                      \
1791       TEST_MASK_CAST(int8_t, src_t, N);                              \
1792       TEST_MASK_CAST(uint8_t, src_t, N);                             \
1793       TEST_MASK_CAST(int16_t, src_t, N);                             \
1794       TEST_MASK_CAST(uint16_t, src_t, N);                            \
1795       TEST_MASK_CAST(int32_t, src_t, N);                             \
1796       TEST_MASK_CAST(uint32_t, src_t, N);                            \
1797       TEST_MASK_CAST(int64_t, src_t, N);                             \
1798       TEST_MASK_CAST(uint64_t, src_t, N);                            \
1799       TEST_MASK_CAST(c10::BFloat16, src_t, N);                       \
1800       TEST_MASK_CAST(c10::Half, src_t, N);                           \
1801       TEST_MASK_CAST(float, src_t, N);                               \
1802       TEST_MASK_CAST(double, src_t, N);
1803 
1804       TEST_MASK_CAST_N(1)
1805       TEST_MASK_CAST_N(2)
1806       TEST_MASK_CAST_N(4)
1807 
1808     #undef TEST_MASK_CAST
1809     #undef TEST_MASK_CAST_N
1810     }
1811 #else
1812 #error GTEST does not have TYPED_TEST
1813 #endif
1814 }  // namespace
1815