xref: /aosp_15_r20/external/FP16/include/fp16/avx2.py (revision 5f32b7105932ea8520a0e8811c640f936367d707)
1*5f32b710SXin Lifrom peachpy import *
2*5f32b710SXin Lifrom peachpy.x86_64 import *
3*5f32b710SXin Li
4*5f32b710SXin Li
5*5f32b710SXin Lidef fp16_alt_xmm_to_fp32_ymm(xmm_half):
6*5f32b710SXin Li	ymm_half = YMMRegister()
7*5f32b710SXin Li	VPERMQ(ymm_half, xmm_half.as_ymm, 0b01010000)
8*5f32b710SXin Li
9*5f32b710SXin Li	ymm_zero = YMMRegister()
10*5f32b710SXin Li	VPXOR(ymm_zero.as_xmm, ymm_zero.as_xmm, ymm_zero.as_xmm)
11*5f32b710SXin Li
12*5f32b710SXin Li	ymm_word = YMMRegister()
13*5f32b710SXin Li	VPUNPCKLWD(ymm_word, ymm_zero, ymm_half)
14*5f32b710SXin Li
15*5f32b710SXin Li	ymm_shl1_half = YMMRegister()
16*5f32b710SXin Li	VPADDW(ymm_shl1_half, ymm_half, ymm_half)
17*5f32b710SXin Li
18*5f32b710SXin Li	ymm_shl1_nonsign = YMMRegister()
19*5f32b710SXin Li	VPADDD(ymm_shl1_nonsign, ymm_word, ymm_word)
20*5f32b710SXin Li
21*5f32b710SXin Li	sign_mask = Constant.float32x8(-0.0)
22*5f32b710SXin Li
23*5f32b710SXin Li	ymm_sign = YMMRegister()
24*5f32b710SXin Li	VANDPS(ymm_sign, ymm_word, sign_mask)
25*5f32b710SXin Li
26*5f32b710SXin Li	ymm_shr3_nonsign = YMMRegister()
27*5f32b710SXin Li	VPSRLD(ymm_shr3_nonsign, ymm_shl1_nonsign, 4)
28*5f32b710SXin Li
29*5f32b710SXin Li	exp_offset = Constant.uint32x8(0x38000000)
30*5f32b710SXin Li
31*5f32b710SXin Li	ymm_norm_nonsign = YMMRegister()
32*5f32b710SXin Li	VPADDD(ymm_norm_nonsign, ymm_shr3_nonsign, exp_offset)
33*5f32b710SXin Li
34*5f32b710SXin Li	magic_mask = Constant.uint16x16(0x3E80)
35*5f32b710SXin Li	ymm_denorm_nonsign = YMMRegister()
36*5f32b710SXin Li	VPUNPCKLWD(ymm_denorm_nonsign, ymm_shl1_half, magic_mask)
37*5f32b710SXin Li
38*5f32b710SXin Li	magic_bias = Constant.float32x8(0.25)
39*5f32b710SXin Li	VSUBPS(ymm_denorm_nonsign, ymm_denorm_nonsign, magic_bias)
40*5f32b710SXin Li
41*5f32b710SXin Li	ymm_denorm_cutoff = YMMRegister()
42*5f32b710SXin Li	VMOVDQA(ymm_denorm_cutoff, Constant.uint32x8(0x00800000))
43*5f32b710SXin Li
44*5f32b710SXin Li	ymm_denorm_mask = YMMRegister()
45*5f32b710SXin Li	VPCMPGTD(ymm_denorm_mask, ymm_denorm_cutoff, ymm_shr3_nonsign)
46*5f32b710SXin Li
47*5f32b710SXin Li	ymm_nonsign = YMMRegister()
48*5f32b710SXin Li	VBLENDVPS(ymm_nonsign, ymm_norm_nonsign, ymm_denorm_nonsign, ymm_denorm_mask)
49*5f32b710SXin Li
50*5f32b710SXin Li	ymm_float = YMMRegister()
51*5f32b710SXin Li	VORPS(ymm_float, ymm_nonsign, ymm_sign)
52*5f32b710SXin Li
53*5f32b710SXin Li	return ymm_float
54