1 #pragma once
2
3 #include <ATen/cpu/vec/intrinsics.h>
4
5 namespace at::vec {
6 // See Note [CPU_CAPABILITY namespace]
7 inline namespace CPU_CAPABILITY {
8
9 #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
10 !defined(__APPLE__)
float2half_scalar(float val)11 static inline uint16_t float2half_scalar(float val) {
12 #if defined(CPU_CAPABILITY_AVX2)
13 #if defined(_MSC_VER)
14 __m256 v = _mm256_set1_ps(val);
15 __m128i o =
16 _mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
17 return static_cast<std::uint16_t>(_mm_cvtsi128_si32(o));
18 #else
19 return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
20 #endif
21 #elif defined(CPU_CAPABILITY_AVX512)
22 __m512 v = _mm512_set1_ps(val);
23 __m256i o =
24 _mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
25 return static_cast<std::uint16_t>(
26 _mm_cvtsi128_si32(_mm256_castsi256_si128(o)));
27 #endif
28 }
29
half2float_scalar(uint16_t val)30 static inline float half2float_scalar(uint16_t val) {
31 #if defined(CPU_CAPABILITY_AVX2)
32 #if defined(_MSC_VER)
33 __m128i v = _mm_cvtsi32_si128(val);
34 __m256 o = _mm256_cvtph_ps(v);
35 return _mm256_cvtss_f32(o);
36 #else
37 return _cvtsh_ss(val);
38 #endif
39 #elif defined(CPU_CAPABILITY_AVX512)
40 __m256i v =
41 _mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
42 __m512 o = _mm512_cvtph_ps(v);
43 return _mm512_cvtss_f32(o);
44 #endif
45 }
46
47 #endif
48
49 } // namespace CPU_CAPABILITY
50 } // namespace at::vec
51