1 #![allow(dead_code, unused_imports)]
2 use crate::leading_zeros::leading_zeros_u16;
3 use core::mem;
4
5 macro_rules! convert_fn {
6 (fn $name:ident($($var:ident : $vartype:ty),+) -> $restype:ty {
7 if feature("f16c") { $f16c:expr }
8 else { $fallback:expr }}) => {
9 #[inline]
10 pub(crate) fn $name($($var: $vartype),+) -> $restype {
11 // Use CPU feature detection if using std
12 #[cfg(all(
13 feature = "use-intrinsics",
14 feature = "std",
15 any(target_arch = "x86", target_arch = "x86_64"),
16 not(target_feature = "f16c")
17 ))]
18 {
19 if is_x86_feature_detected!("f16c") {
20 $f16c
21 } else {
22 $fallback
23 }
24 }
25 // Use intrinsics directly when a compile target or using no_std
26 #[cfg(all(
27 feature = "use-intrinsics",
28 any(target_arch = "x86", target_arch = "x86_64"),
29 target_feature = "f16c"
30 ))]
31 {
32 $f16c
33 }
34 // Fallback to software
35 #[cfg(any(
36 not(feature = "use-intrinsics"),
37 not(any(target_arch = "x86", target_arch = "x86_64")),
38 all(not(feature = "std"), not(target_feature = "f16c"))
39 ))]
40 {
41 $fallback
42 }
43 }
44 };
45 }
46
47 convert_fn! {
48 fn f32_to_f16(f: f32) -> u16 {
49 if feature("f16c") {
50 unsafe { x86::f32_to_f16_x86_f16c(f) }
51 } else {
52 f32_to_f16_fallback(f)
53 }
54 }
55 }
56
57 convert_fn! {
58 fn f64_to_f16(f: f64) -> u16 {
59 if feature("f16c") {
60 unsafe { x86::f32_to_f16_x86_f16c(f as f32) }
61 } else {
62 f64_to_f16_fallback(f)
63 }
64 }
65 }
66
67 convert_fn! {
68 fn f16_to_f32(i: u16) -> f32 {
69 if feature("f16c") {
70 unsafe { x86::f16_to_f32_x86_f16c(i) }
71 } else {
72 f16_to_f32_fallback(i)
73 }
74 }
75 }
76
77 convert_fn! {
78 fn f16_to_f64(i: u16) -> f64 {
79 if feature("f16c") {
80 unsafe { x86::f16_to_f32_x86_f16c(i) as f64 }
81 } else {
82 f16_to_f64_fallback(i)
83 }
84 }
85 }
86
87 convert_fn! {
88 fn f32x4_to_f16x4(f: &[f32; 4]) -> [u16; 4] {
89 if feature("f16c") {
90 unsafe { x86::f32x4_to_f16x4_x86_f16c(f) }
91 } else {
92 f32x4_to_f16x4_fallback(f)
93 }
94 }
95 }
96
97 convert_fn! {
98 fn f16x4_to_f32x4(i: &[u16; 4]) -> [f32; 4] {
99 if feature("f16c") {
100 unsafe { x86::f16x4_to_f32x4_x86_f16c(i) }
101 } else {
102 f16x4_to_f32x4_fallback(i)
103 }
104 }
105 }
106
107 convert_fn! {
108 fn f64x4_to_f16x4(f: &[f64; 4]) -> [u16; 4] {
109 if feature("f16c") {
110 unsafe { x86::f64x4_to_f16x4_x86_f16c(f) }
111 } else {
112 f64x4_to_f16x4_fallback(f)
113 }
114 }
115 }
116
117 convert_fn! {
118 fn f16x4_to_f64x4(i: &[u16; 4]) -> [f64; 4] {
119 if feature("f16c") {
120 unsafe { x86::f16x4_to_f64x4_x86_f16c(i) }
121 } else {
122 f16x4_to_f64x4_fallback(i)
123 }
124 }
125 }
126
127 convert_fn! {
128 fn f32x8_to_f16x8(f: &[f32; 8]) -> [u16; 8] {
129 if feature("f16c") {
130 unsafe { x86::f32x8_to_f16x8_x86_f16c(f) }
131 } else {
132 f32x8_to_f16x8_fallback(f)
133 }
134 }
135 }
136
137 convert_fn! {
138 fn f16x8_to_f32x8(i: &[u16; 8]) -> [f32; 8] {
139 if feature("f16c") {
140 unsafe { x86::f16x8_to_f32x8_x86_f16c(i) }
141 } else {
142 f16x8_to_f32x8_fallback(i)
143 }
144 }
145 }
146
147 convert_fn! {
148 fn f64x8_to_f16x8(f: &[f64; 8]) -> [u16; 8] {
149 if feature("f16c") {
150 unsafe { x86::f64x8_to_f16x8_x86_f16c(f) }
151 } else {
152 f64x8_to_f16x8_fallback(f)
153 }
154 }
155 }
156
157 convert_fn! {
158 fn f16x8_to_f64x8(i: &[u16; 8]) -> [f64; 8] {
159 if feature("f16c") {
160 unsafe { x86::f16x8_to_f64x8_x86_f16c(i) }
161 } else {
162 f16x8_to_f64x8_fallback(i)
163 }
164 }
165 }
166
167 convert_fn! {
168 fn f32_to_f16_slice(src: &[f32], dst: &mut [u16]) -> () {
169 if feature("f16c") {
170 convert_chunked_slice_8(src, dst, x86::f32x8_to_f16x8_x86_f16c,
171 x86::f32x4_to_f16x4_x86_f16c)
172 } else {
173 slice_fallback(src, dst, f32_to_f16_fallback)
174 }
175 }
176 }
177
178 convert_fn! {
179 fn f16_to_f32_slice(src: &[u16], dst: &mut [f32]) -> () {
180 if feature("f16c") {
181 convert_chunked_slice_8(src, dst, x86::f16x8_to_f32x8_x86_f16c,
182 x86::f16x4_to_f32x4_x86_f16c)
183 } else {
184 slice_fallback(src, dst, f16_to_f32_fallback)
185 }
186 }
187 }
188
189 convert_fn! {
190 fn f64_to_f16_slice(src: &[f64], dst: &mut [u16]) -> () {
191 if feature("f16c") {
192 convert_chunked_slice_8(src, dst, x86::f64x8_to_f16x8_x86_f16c,
193 x86::f64x4_to_f16x4_x86_f16c)
194 } else {
195 slice_fallback(src, dst, f64_to_f16_fallback)
196 }
197 }
198 }
199
200 convert_fn! {
201 fn f16_to_f64_slice(src: &[u16], dst: &mut [f64]) -> () {
202 if feature("f16c") {
203 convert_chunked_slice_8(src, dst, x86::f16x8_to_f64x8_x86_f16c,
204 x86::f16x4_to_f64x4_x86_f16c)
205 } else {
206 slice_fallback(src, dst, f16_to_f64_fallback)
207 }
208 }
209 }
210
211 /// Chunks sliced into x8 or x4 arrays
212 #[inline]
convert_chunked_slice_8<S: Copy + Default, D: Copy>( src: &[S], dst: &mut [D], fn8: unsafe fn(&[S; 8]) -> [D; 8], fn4: unsafe fn(&[S; 4]) -> [D; 4], )213 fn convert_chunked_slice_8<S: Copy + Default, D: Copy>(
214 src: &[S],
215 dst: &mut [D],
216 fn8: unsafe fn(&[S; 8]) -> [D; 8],
217 fn4: unsafe fn(&[S; 4]) -> [D; 4],
218 ) {
219 assert_eq!(src.len(), dst.len());
220
221 // TODO: Can be further optimized with array_chunks when it becomes stabilized
222
223 let src_chunks = src.chunks_exact(8);
224 let mut dst_chunks = dst.chunks_exact_mut(8);
225 let src_remainder = src_chunks.remainder();
226 for (s, d) in src_chunks.zip(&mut dst_chunks) {
227 let chunk: &[S; 8] = s.try_into().unwrap();
228 d.copy_from_slice(unsafe { &fn8(chunk) });
229 }
230
231 // Process remainder
232 if src_remainder.len() > 4 {
233 let mut buf: [S; 8] = Default::default();
234 buf[..src_remainder.len()].copy_from_slice(src_remainder);
235 let vec = unsafe { fn8(&buf) };
236 let dst_remainder = dst_chunks.into_remainder();
237 dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]);
238 } else if !src_remainder.is_empty() {
239 let mut buf: [S; 4] = Default::default();
240 buf[..src_remainder.len()].copy_from_slice(src_remainder);
241 let vec = unsafe { fn4(&buf) };
242 let dst_remainder = dst_chunks.into_remainder();
243 dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]);
244 }
245 }
246
247 /// Chunks sliced into x4 arrays
248 #[inline]
convert_chunked_slice_4<S: Copy + Default, D: Copy>( src: &[S], dst: &mut [D], f: unsafe fn(&[S; 4]) -> [D; 4], )249 fn convert_chunked_slice_4<S: Copy + Default, D: Copy>(
250 src: &[S],
251 dst: &mut [D],
252 f: unsafe fn(&[S; 4]) -> [D; 4],
253 ) {
254 assert_eq!(src.len(), dst.len());
255
256 // TODO: Can be further optimized with array_chunks when it becomes stabilized
257
258 let src_chunks = src.chunks_exact(4);
259 let mut dst_chunks = dst.chunks_exact_mut(4);
260 let src_remainder = src_chunks.remainder();
261 for (s, d) in src_chunks.zip(&mut dst_chunks) {
262 let chunk: &[S; 4] = s.try_into().unwrap();
263 d.copy_from_slice(unsafe { &f(chunk) });
264 }
265
266 // Process remainder
267 if !src_remainder.is_empty() {
268 let mut buf: [S; 4] = Default::default();
269 buf[..src_remainder.len()].copy_from_slice(src_remainder);
270 let vec = unsafe { f(&buf) };
271 let dst_remainder = dst_chunks.into_remainder();
272 dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]);
273 }
274 }
275
276 /////////////// Fallbacks ////////////////
277
278 // In the below functions, round to nearest, with ties to even.
279 // Let us call the most significant bit that will be shifted out the round_bit.
280 //
281 // Round up if either
282 // a) Removed part > tie.
283 // (mantissa & round_bit) != 0 && (mantissa & (round_bit - 1)) != 0
284 // b) Removed part == tie, and retained part is odd.
285 // (mantissa & round_bit) != 0 && (mantissa & (2 * round_bit)) != 0
286 // (If removed part == tie and retained part is even, do not round up.)
287 // These two conditions can be combined into one:
288 // (mantissa & round_bit) != 0 && (mantissa & ((round_bit - 1) | (2 * round_bit))) != 0
289 // which can be simplified into
290 // (mantissa & round_bit) != 0 && (mantissa & (3 * round_bit - 1)) != 0
291
292 #[inline]
f32_to_f16_fallback(value: f32) -> u16293 pub(crate) const fn f32_to_f16_fallback(value: f32) -> u16 {
294 // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized
295 // Convert to raw bytes
296 let x: u32 = unsafe { mem::transmute(value) };
297
298 // Extract IEEE754 components
299 let sign = x & 0x8000_0000u32;
300 let exp = x & 0x7F80_0000u32;
301 let man = x & 0x007F_FFFFu32;
302
303 // Check for all exponent bits being set, which is Infinity or NaN
304 if exp == 0x7F80_0000u32 {
305 // Set mantissa MSB for NaN (and also keep shifted mantissa bits)
306 let nan_bit = if man == 0 { 0 } else { 0x0200u32 };
307 return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16;
308 }
309
310 // The number is normalized, start assembling half precision version
311 let half_sign = sign >> 16;
312 // Unbias the exponent, then bias for half precision
313 let unbiased_exp = ((exp >> 23) as i32) - 127;
314 let half_exp = unbiased_exp + 15;
315
316 // Check for exponent overflow, return +infinity
317 if half_exp >= 0x1F {
318 return (half_sign | 0x7C00u32) as u16;
319 }
320
321 // Check for underflow
322 if half_exp <= 0 {
323 // Check mantissa for what we can do
324 if 14 - half_exp > 24 {
325 // No rounding possibility, so this is a full underflow, return signed zero
326 return half_sign as u16;
327 }
328 // Don't forget about hidden leading mantissa bit when assembling mantissa
329 let man = man | 0x0080_0000u32;
330 let mut half_man = man >> (14 - half_exp);
331 // Check for rounding (see comment above functions)
332 let round_bit = 1 << (13 - half_exp);
333 if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
334 half_man += 1;
335 }
336 // No exponent for subnormals
337 return (half_sign | half_man) as u16;
338 }
339
340 // Rebias the exponent
341 let half_exp = (half_exp as u32) << 10;
342 let half_man = man >> 13;
343 // Check for rounding (see comment above functions)
344 let round_bit = 0x0000_1000u32;
345 if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
346 // Round it
347 ((half_sign | half_exp | half_man) + 1) as u16
348 } else {
349 (half_sign | half_exp | half_man) as u16
350 }
351 }
352
353 #[inline]
f64_to_f16_fallback(value: f64) -> u16354 pub(crate) const fn f64_to_f16_fallback(value: f64) -> u16 {
355 // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always
356 // be lost on half-precision.
357 // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized
358 let val: u64 = unsafe { mem::transmute(value) };
359 let x = (val >> 32) as u32;
360
361 // Extract IEEE754 components
362 let sign = x & 0x8000_0000u32;
363 let exp = x & 0x7FF0_0000u32;
364 let man = x & 0x000F_FFFFu32;
365
366 // Check for all exponent bits being set, which is Infinity or NaN
367 if exp == 0x7FF0_0000u32 {
368 // Set mantissa MSB for NaN (and also keep shifted mantissa bits).
369 // We also have to check the last 32 bits.
370 let nan_bit = if man == 0 && (val as u32 == 0) {
371 0
372 } else {
373 0x0200u32
374 };
375 return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 10)) as u16;
376 }
377
378 // The number is normalized, start assembling half precision version
379 let half_sign = sign >> 16;
380 // Unbias the exponent, then bias for half precision
381 let unbiased_exp = ((exp >> 20) as i64) - 1023;
382 let half_exp = unbiased_exp + 15;
383
384 // Check for exponent overflow, return +infinity
385 if half_exp >= 0x1F {
386 return (half_sign | 0x7C00u32) as u16;
387 }
388
389 // Check for underflow
390 if half_exp <= 0 {
391 // Check mantissa for what we can do
392 if 10 - half_exp > 21 {
393 // No rounding possibility, so this is a full underflow, return signed zero
394 return half_sign as u16;
395 }
396 // Don't forget about hidden leading mantissa bit when assembling mantissa
397 let man = man | 0x0010_0000u32;
398 let mut half_man = man >> (11 - half_exp);
399 // Check for rounding (see comment above functions)
400 let round_bit = 1 << (10 - half_exp);
401 if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
402 half_man += 1;
403 }
404 // No exponent for subnormals
405 return (half_sign | half_man) as u16;
406 }
407
408 // Rebias the exponent
409 let half_exp = (half_exp as u32) << 10;
410 let half_man = man >> 10;
411 // Check for rounding (see comment above functions)
412 let round_bit = 0x0000_0200u32;
413 if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
414 // Round it
415 ((half_sign | half_exp | half_man) + 1) as u16
416 } else {
417 (half_sign | half_exp | half_man) as u16
418 }
419 }
420
421 #[inline]
f16_to_f32_fallback(i: u16) -> f32422 pub(crate) const fn f16_to_f32_fallback(i: u16) -> f32 {
423 // Check for signed zero
424 // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized
425 if i & 0x7FFFu16 == 0 {
426 return unsafe { mem::transmute((i as u32) << 16) };
427 }
428
429 let half_sign = (i & 0x8000u16) as u32;
430 let half_exp = (i & 0x7C00u16) as u32;
431 let half_man = (i & 0x03FFu16) as u32;
432
433 // Check for an infinity or NaN when all exponent bits set
434 if half_exp == 0x7C00u32 {
435 // Check for signed infinity if mantissa is zero
436 if half_man == 0 {
437 return unsafe { mem::transmute((half_sign << 16) | 0x7F80_0000u32) };
438 } else {
439 // NaN, keep current mantissa but also set most significiant mantissa bit
440 return unsafe {
441 mem::transmute((half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13))
442 };
443 }
444 }
445
446 // Calculate single-precision components with adjusted exponent
447 let sign = half_sign << 16;
448 // Unbias exponent
449 let unbiased_exp = ((half_exp as i32) >> 10) - 15;
450
451 // Check for subnormals, which will be normalized by adjusting exponent
452 if half_exp == 0 {
453 // Calculate how much to adjust the exponent by
454 let e = leading_zeros_u16(half_man as u16) - 6;
455
456 // Rebias and adjust exponent
457 let exp = (127 - 15 - e) << 23;
458 let man = (half_man << (14 + e)) & 0x7F_FF_FFu32;
459 return unsafe { mem::transmute(sign | exp | man) };
460 }
461
462 // Rebias exponent for a normalized normal
463 let exp = ((unbiased_exp + 127) as u32) << 23;
464 let man = (half_man & 0x03FFu32) << 13;
465 unsafe { mem::transmute(sign | exp | man) }
466 }
467
468 #[inline]
f16_to_f64_fallback(i: u16) -> f64469 pub(crate) const fn f16_to_f64_fallback(i: u16) -> f64 {
470 // Check for signed zero
471 // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized
472 if i & 0x7FFFu16 == 0 {
473 return unsafe { mem::transmute((i as u64) << 48) };
474 }
475
476 let half_sign = (i & 0x8000u16) as u64;
477 let half_exp = (i & 0x7C00u16) as u64;
478 let half_man = (i & 0x03FFu16) as u64;
479
480 // Check for an infinity or NaN when all exponent bits set
481 if half_exp == 0x7C00u64 {
482 // Check for signed infinity if mantissa is zero
483 if half_man == 0 {
484 return unsafe { mem::transmute((half_sign << 48) | 0x7FF0_0000_0000_0000u64) };
485 } else {
486 // NaN, keep current mantissa but also set most significiant mantissa bit
487 return unsafe {
488 mem::transmute((half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 42))
489 };
490 }
491 }
492
493 // Calculate double-precision components with adjusted exponent
494 let sign = half_sign << 48;
495 // Unbias exponent
496 let unbiased_exp = ((half_exp as i64) >> 10) - 15;
497
498 // Check for subnormals, which will be normalized by adjusting exponent
499 if half_exp == 0 {
500 // Calculate how much to adjust the exponent by
501 let e = leading_zeros_u16(half_man as u16) - 6;
502
503 // Rebias and adjust exponent
504 let exp = ((1023 - 15 - e) as u64) << 52;
505 let man = (half_man << (43 + e)) & 0xF_FFFF_FFFF_FFFFu64;
506 return unsafe { mem::transmute(sign | exp | man) };
507 }
508
509 // Rebias exponent for a normalized normal
510 let exp = ((unbiased_exp + 1023) as u64) << 52;
511 let man = (half_man & 0x03FFu64) << 42;
512 unsafe { mem::transmute(sign | exp | man) }
513 }
514
515 #[inline]
f16x4_to_f32x4_fallback(v: &[u16; 4]) -> [f32; 4]516 fn f16x4_to_f32x4_fallback(v: &[u16; 4]) -> [f32; 4] {
517 [
518 f16_to_f32_fallback(v[0]),
519 f16_to_f32_fallback(v[1]),
520 f16_to_f32_fallback(v[2]),
521 f16_to_f32_fallback(v[3]),
522 ]
523 }
524
525 #[inline]
f32x4_to_f16x4_fallback(v: &[f32; 4]) -> [u16; 4]526 fn f32x4_to_f16x4_fallback(v: &[f32; 4]) -> [u16; 4] {
527 [
528 f32_to_f16_fallback(v[0]),
529 f32_to_f16_fallback(v[1]),
530 f32_to_f16_fallback(v[2]),
531 f32_to_f16_fallback(v[3]),
532 ]
533 }
534
535 #[inline]
f16x4_to_f64x4_fallback(v: &[u16; 4]) -> [f64; 4]536 fn f16x4_to_f64x4_fallback(v: &[u16; 4]) -> [f64; 4] {
537 [
538 f16_to_f64_fallback(v[0]),
539 f16_to_f64_fallback(v[1]),
540 f16_to_f64_fallback(v[2]),
541 f16_to_f64_fallback(v[3]),
542 ]
543 }
544
545 #[inline]
f64x4_to_f16x4_fallback(v: &[f64; 4]) -> [u16; 4]546 fn f64x4_to_f16x4_fallback(v: &[f64; 4]) -> [u16; 4] {
547 [
548 f64_to_f16_fallback(v[0]),
549 f64_to_f16_fallback(v[1]),
550 f64_to_f16_fallback(v[2]),
551 f64_to_f16_fallback(v[3]),
552 ]
553 }
554
555 #[inline]
f16x8_to_f32x8_fallback(v: &[u16; 8]) -> [f32; 8]556 fn f16x8_to_f32x8_fallback(v: &[u16; 8]) -> [f32; 8] {
557 [
558 f16_to_f32_fallback(v[0]),
559 f16_to_f32_fallback(v[1]),
560 f16_to_f32_fallback(v[2]),
561 f16_to_f32_fallback(v[3]),
562 f16_to_f32_fallback(v[4]),
563 f16_to_f32_fallback(v[5]),
564 f16_to_f32_fallback(v[6]),
565 f16_to_f32_fallback(v[7]),
566 ]
567 }
568
569 #[inline]
f32x8_to_f16x8_fallback(v: &[f32; 8]) -> [u16; 8]570 fn f32x8_to_f16x8_fallback(v: &[f32; 8]) -> [u16; 8] {
571 [
572 f32_to_f16_fallback(v[0]),
573 f32_to_f16_fallback(v[1]),
574 f32_to_f16_fallback(v[2]),
575 f32_to_f16_fallback(v[3]),
576 f32_to_f16_fallback(v[4]),
577 f32_to_f16_fallback(v[5]),
578 f32_to_f16_fallback(v[6]),
579 f32_to_f16_fallback(v[7]),
580 ]
581 }
582
583 #[inline]
f16x8_to_f64x8_fallback(v: &[u16; 8]) -> [f64; 8]584 fn f16x8_to_f64x8_fallback(v: &[u16; 8]) -> [f64; 8] {
585 [
586 f16_to_f64_fallback(v[0]),
587 f16_to_f64_fallback(v[1]),
588 f16_to_f64_fallback(v[2]),
589 f16_to_f64_fallback(v[3]),
590 f16_to_f64_fallback(v[4]),
591 f16_to_f64_fallback(v[5]),
592 f16_to_f64_fallback(v[6]),
593 f16_to_f64_fallback(v[7]),
594 ]
595 }
596
597 #[inline]
f64x8_to_f16x8_fallback(v: &[f64; 8]) -> [u16; 8]598 fn f64x8_to_f16x8_fallback(v: &[f64; 8]) -> [u16; 8] {
599 [
600 f64_to_f16_fallback(v[0]),
601 f64_to_f16_fallback(v[1]),
602 f64_to_f16_fallback(v[2]),
603 f64_to_f16_fallback(v[3]),
604 f64_to_f16_fallback(v[4]),
605 f64_to_f16_fallback(v[5]),
606 f64_to_f16_fallback(v[6]),
607 f64_to_f16_fallback(v[7]),
608 ]
609 }
610
611 #[inline]
slice_fallback<S: Copy, D>(src: &[S], dst: &mut [D], f: fn(S) -> D)612 fn slice_fallback<S: Copy, D>(src: &[S], dst: &mut [D], f: fn(S) -> D) {
613 assert_eq!(src.len(), dst.len());
614 for (s, d) in src.iter().copied().zip(dst.iter_mut()) {
615 *d = f(s);
616 }
617 }
618
619 /////////////// x86/x86_64 f16c ////////////////
620 #[cfg(all(
621 feature = "use-intrinsics",
622 any(target_arch = "x86", target_arch = "x86_64")
623 ))]
624 mod x86 {
625 use core::{mem::MaybeUninit, ptr};
626
627 #[cfg(target_arch = "x86")]
628 use core::arch::x86::{
629 __m128, __m128i, __m256, _mm256_cvtph_ps, _mm256_cvtps_ph, _mm_cvtph_ps,
630 _MM_FROUND_TO_NEAREST_INT,
631 };
632 #[cfg(target_arch = "x86_64")]
633 use core::arch::x86_64::{
634 __m128, __m128i, __m256, _mm256_cvtph_ps, _mm256_cvtps_ph, _mm_cvtph_ps, _mm_cvtps_ph,
635 _MM_FROUND_TO_NEAREST_INT,
636 };
637
638 use super::convert_chunked_slice_8;
639
640 #[target_feature(enable = "f16c")]
641 #[inline]
f16_to_f32_x86_f16c(i: u16) -> f32642 pub(super) unsafe fn f16_to_f32_x86_f16c(i: u16) -> f32 {
643 let mut vec = MaybeUninit::<__m128i>::zeroed();
644 vec.as_mut_ptr().cast::<u16>().write(i);
645 let retval = _mm_cvtph_ps(vec.assume_init());
646 *(&retval as *const __m128).cast()
647 }
648
649 #[target_feature(enable = "f16c")]
650 #[inline]
f32_to_f16_x86_f16c(f: f32) -> u16651 pub(super) unsafe fn f32_to_f16_x86_f16c(f: f32) -> u16 {
652 let mut vec = MaybeUninit::<__m128>::zeroed();
653 vec.as_mut_ptr().cast::<f32>().write(f);
654 let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
655 *(&retval as *const __m128i).cast()
656 }
657
658 #[target_feature(enable = "f16c")]
659 #[inline]
f16x4_to_f32x4_x86_f16c(v: &[u16; 4]) -> [f32; 4]660 pub(super) unsafe fn f16x4_to_f32x4_x86_f16c(v: &[u16; 4]) -> [f32; 4] {
661 let mut vec = MaybeUninit::<__m128i>::zeroed();
662 ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
663 let retval = _mm_cvtph_ps(vec.assume_init());
664 *(&retval as *const __m128).cast()
665 }
666
667 #[target_feature(enable = "f16c")]
668 #[inline]
f32x4_to_f16x4_x86_f16c(v: &[f32; 4]) -> [u16; 4]669 pub(super) unsafe fn f32x4_to_f16x4_x86_f16c(v: &[f32; 4]) -> [u16; 4] {
670 let mut vec = MaybeUninit::<__m128>::uninit();
671 ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
672 let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
673 *(&retval as *const __m128i).cast()
674 }
675
676 #[target_feature(enable = "f16c")]
677 #[inline]
f16x4_to_f64x4_x86_f16c(v: &[u16; 4]) -> [f64; 4]678 pub(super) unsafe fn f16x4_to_f64x4_x86_f16c(v: &[u16; 4]) -> [f64; 4] {
679 let array = f16x4_to_f32x4_x86_f16c(v);
680 // Let compiler vectorize this regular cast for now.
681 // TODO: investigate auto-detecting sse2/avx convert features
682 [
683 array[0] as f64,
684 array[1] as f64,
685 array[2] as f64,
686 array[3] as f64,
687 ]
688 }
689
690 #[target_feature(enable = "f16c")]
691 #[inline]
f64x4_to_f16x4_x86_f16c(v: &[f64; 4]) -> [u16; 4]692 pub(super) unsafe fn f64x4_to_f16x4_x86_f16c(v: &[f64; 4]) -> [u16; 4] {
693 // Let compiler vectorize this regular cast for now.
694 // TODO: investigate auto-detecting sse2/avx convert features
695 let v = [v[0] as f32, v[1] as f32, v[2] as f32, v[3] as f32];
696 f32x4_to_f16x4_x86_f16c(&v)
697 }
698
699 #[target_feature(enable = "f16c")]
700 #[inline]
f16x8_to_f32x8_x86_f16c(v: &[u16; 8]) -> [f32; 8]701 pub(super) unsafe fn f16x8_to_f32x8_x86_f16c(v: &[u16; 8]) -> [f32; 8] {
702 let mut vec = MaybeUninit::<__m128i>::zeroed();
703 ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 8);
704 let retval = _mm256_cvtph_ps(vec.assume_init());
705 *(&retval as *const __m256).cast()
706 }
707
708 #[target_feature(enable = "f16c")]
709 #[inline]
f32x8_to_f16x8_x86_f16c(v: &[f32; 8]) -> [u16; 8]710 pub(super) unsafe fn f32x8_to_f16x8_x86_f16c(v: &[f32; 8]) -> [u16; 8] {
711 let mut vec = MaybeUninit::<__m256>::uninit();
712 ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 8);
713 let retval = _mm256_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
714 *(&retval as *const __m128i).cast()
715 }
716
717 #[target_feature(enable = "f16c")]
718 #[inline]
f16x8_to_f64x8_x86_f16c(v: &[u16; 8]) -> [f64; 8]719 pub(super) unsafe fn f16x8_to_f64x8_x86_f16c(v: &[u16; 8]) -> [f64; 8] {
720 let array = f16x8_to_f32x8_x86_f16c(v);
721 // Let compiler vectorize this regular cast for now.
722 // TODO: investigate auto-detecting sse2/avx convert features
723 [
724 array[0] as f64,
725 array[1] as f64,
726 array[2] as f64,
727 array[3] as f64,
728 array[4] as f64,
729 array[5] as f64,
730 array[6] as f64,
731 array[7] as f64,
732 ]
733 }
734
735 #[target_feature(enable = "f16c")]
736 #[inline]
f64x8_to_f16x8_x86_f16c(v: &[f64; 8]) -> [u16; 8]737 pub(super) unsafe fn f64x8_to_f16x8_x86_f16c(v: &[f64; 8]) -> [u16; 8] {
738 // Let compiler vectorize this regular cast for now.
739 // TODO: investigate auto-detecting sse2/avx convert features
740 let v = [
741 v[0] as f32,
742 v[1] as f32,
743 v[2] as f32,
744 v[3] as f32,
745 v[4] as f32,
746 v[5] as f32,
747 v[6] as f32,
748 v[7] as f32,
749 ];
750 f32x8_to_f16x8_x86_f16c(&v)
751 }
752 }
753