xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec_half.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/intrinsics.h>
4 
5 namespace at::vec {
6 // See Note [CPU_CAPABILITY namespace]
7 inline namespace CPU_CAPABILITY {
8 
9 #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
10     !defined(__APPLE__)
float2half_scalar(float val)11 static inline uint16_t float2half_scalar(float val) {
12 #if defined(CPU_CAPABILITY_AVX2)
13 #if defined(_MSC_VER)
14   __m256 v = _mm256_set1_ps(val);
15   __m128i o =
16       _mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
17   return static_cast<std::uint16_t>(_mm_cvtsi128_si32(o));
18 #else
19   return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
20 #endif
21 #elif defined(CPU_CAPABILITY_AVX512)
22   __m512 v = _mm512_set1_ps(val);
23   __m256i o =
24       _mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
25   return static_cast<std::uint16_t>(
26       _mm_cvtsi128_si32(_mm256_castsi256_si128(o)));
27 #endif
28 }
29 
half2float_scalar(uint16_t val)30 static inline float half2float_scalar(uint16_t val) {
31 #if defined(CPU_CAPABILITY_AVX2)
32 #if defined(_MSC_VER)
33   __m128i v = _mm_cvtsi32_si128(val);
34   __m256 o = _mm256_cvtph_ps(v);
35   return _mm256_cvtss_f32(o);
36 #else
37   return _cvtsh_ss(val);
38 #endif
39 #elif defined(CPU_CAPABILITY_AVX512)
40   __m256i v =
41       _mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
42   __m512 o = _mm512_cvtph_ps(v);
43   return _mm512_cvtss_f32(o);
44 #endif
45 }
46 
47 #endif
48 
49 } // namespace CPU_CAPABILITY
50 } // namespace at::vec
51