xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse_matmul_op.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
18 
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/core/platform/byte_order.h"
21 #include "tensorflow/core/platform/types.h"
22 
23 #if defined(PLATFORM_WINDOWS)
24 #include "tensorflow/tsl/platform/windows/intrinsics_port.h"
25 #endif
26 
27 namespace Eigen {
28 namespace internal {
29 
30 // Return the float representation of the bfloat16 value
31 // in the lower 16-bits of input
32 template <typename Packet>
pexpand_bf16_l(const Packet & from)33 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) {
34   tensorflow::uint32 tmp;
35 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
36   tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
37 #else
38   tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
39 #endif
40   return reinterpret_cast<const float&>(tmp);
41 }
42 
43 // Return the float representation of the bfloat16 value
44 // in the upper 16-bits of input
45 template <typename Packet>
pexpand_bf16_u(const Packet & from)46 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) {
47   tensorflow::uint32 tmp;
48 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
49   tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
50 #else
51   tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
52 #endif
53   return reinterpret_cast<const float&>(tmp);
54 }
55 
56 // Specialization non-scalar version on non-sse.
57 // Enable vectorization on z13 and higher
58 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
59     defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
60 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)61 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
62   float r[4];
63   tensorflow::uint32 p[4];
64   pstoreu(r, from);
65   tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
66   p[0] = (ir[0] << 16) & 0xffff0000;
67   p[1] = ir[0] & 0xffff0000;
68   p[2] = (ir[1] << 16) & 0xffff0000;
69   p[3] = ir[1] & 0xffff0000;
70   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
71 }
72 
73 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)74 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
75   float r[4];
76   tensorflow::uint32 p[4];
77   pstoreu(r, from);
78   tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
79   p[0] = (ir[2] << 16) & 0xffff0000;
80   p[1] = ir[2] & 0xffff0000;
81   p[2] = (ir[3] << 16) & 0xffff0000;
82   p[3] = ir[3] & 0xffff0000;
83   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
84 }
85 #endif
86 
87 template <typename Packet>
pinterleave4x64(const Packet & from)88 EIGEN_DEVICE_FUNC inline Packet pinterleave4x64(const Packet& from) {
89   return from;
90 }
91 
92 template <typename Packet>
pbroadcast_first(const Packet & a)93 EIGEN_DEVICE_FUNC inline Packet pbroadcast_first(const Packet& a) {
94   return a;
95 }
96 
97 template <typename Packet>
pbroadcast_second(const Packet & a)98 EIGEN_DEVICE_FUNC inline Packet pbroadcast_second(const Packet& a) {
99   assert(false && "Not applicable to Scalar Values");
100   return a;
101 }
102 
103 template <typename Packet>
pbroadcast_third(const Packet & a)104 EIGEN_DEVICE_FUNC inline Packet pbroadcast_third(const Packet& a) {
105   assert(false && "Not applicable to Scalar Values");
106   return a;
107 }
108 
109 template <typename Packet>
pbroadcast_fourth(const Packet & a)110 EIGEN_DEVICE_FUNC inline Packet pbroadcast_fourth(const Packet& a) {
111   assert(false && "Not applicable to Scalar Values");
112   return a;
113 }
114 
115 template <typename Packet>
pload4bf16(const typename unpacket_traits<Packet>::type * from)116 EIGEN_DEVICE_FUNC inline Packet pload4bf16(
117     const typename unpacket_traits<Packet>::type* from) {
118   assert(false && "Not applicable to Scalar Values");
119   return Packet();
120 }
121 
122 template <typename Packet>
pload2bf16(const typename unpacket_traits<Packet>::type * from)123 EIGEN_DEVICE_FUNC inline Packet pload2bf16(
124     const typename unpacket_traits<Packet>::type* from) {
125   assert(false && "Not applicable to Scalar Values");
126   return Packet();
127 }
128 
129 // Specialization for pload4bf16 and pload2bf16 for non-sse.
130 // Enable vectorization on z13 and higher.
131 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
132     defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
133 template <>
134 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
135   tensorflow::uint32 p[4];
136   const tensorflow::uint32* ir =
137       reinterpret_cast<const tensorflow::uint32*>(from);
138   p[0] = (ir[0] << 16) & 0xffff0000;
139   p[1] = ir[0] & 0xffff0000;
140   p[2] = (ir[1] << 16) & 0xffff0000;
141   p[3] = ir[1] & 0xffff0000;
142   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
143 }
144 
145 template <>
146 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
147   tensorflow::uint32 p[4];
148   const tensorflow::uint32* ir =
149       reinterpret_cast<const tensorflow::uint32*>(from);
150   p[0] = (ir[0] << 16) & 0xffff0000;
151   p[1] = ir[0] & 0xffff0000;
152   p[2] = (ir[0] << 16) & 0xffff0000;
153   p[3] = ir[0] & 0xffff0000;
154   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
155 }
156 #endif
157 
158 #if defined(EIGEN_VECTORIZE_NEON)
159 // Return a packet with the first value of the input Packet replicated
160 template <>
161 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
162   return pset1<Packet4f>(pfirst(a));
163 }
164 template <>
165 EIGEN_STRONG_INLINE Packet2f pbroadcast_first<Packet2f>(const Packet2f& a) {
166   return pset1<Packet2f>(pfirst(a));
167 }
168 
169 // Return a packet with the second value of the input Packet replicated
170 template <>
171 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
172   return pset1<Packet4f>(vgetq_lane_f32(a, 1));
173 }
174 template <>
175 EIGEN_STRONG_INLINE Packet2f pbroadcast_second<Packet2f>(const Packet2f& a) {
176   return pset1<Packet2f>(vget_lane_f32(a, 1));
177 }
178 
179 // Return a packet with the third value of the input Packet replicated
180 template <>
181 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
182   return pset1<Packet4f>(vgetq_lane_f32(a, 2));
183 }
184 
185 // Return a packet with the fourth value of the input Packet replicated
186 template <>
187 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
188   return pset1<Packet4f>(vgetq_lane_f32(a, 3));
189 }
190 #endif
191 
192 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
193 // Return a packet with the first value of the input Packet replicated
194 template <>
195 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
196   return vec_splat(a, 0);
197 }
198 
199 // Return a packet with the second value of the input Packet replicated
200 template <>
201 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
202   return vec_splat(a, 1);
203 }
204 
205 // Return a packet with the third value of the input Packet replicated
206 template <>
207 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
208   return vec_splat(a, 2);
209 }
210 
211 // Return a packet with the fourth value of the input Packet replicated
212 template <>
213 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
214   return vec_splat(a, 3);
215 }
216 #endif
217 
218 #ifdef EIGEN_VECTORIZE_SSE2
219 // For PacketSize of 4 floats the Packet is not modified
220 template <>
221 EIGEN_STRONG_INLINE Packet4f pinterleave4x64<Packet4f>(const Packet4f& from) {
222   return from;
223 }
224 
225 // Return a Packet with 4 floats loaded from 4 bfloat16 values
226 template <>
227 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
228   __m128i zero = _mm_setzero_si128();
229   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
230   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
231 }
232 
233 // Return a Packet with 2 floats loaded from 2 bfloat16 values
234 template <>
235 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
236   __m128i zero = _mm_setzero_si128();
237   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
238   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
239 }
240 
241 // Return a Packet with 4 floats expanded from 4 bfloat16 values
242 // in the lower half of the 128-bit lane
243 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)244 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
245   __m128i zero = _mm_setzero_si128();
246   __m128i tmp = _mm_castps_si128(from);
247   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
248 }
249 
250 // Return a Packet with 4 floats expanded from 4 bfloat16 values
251 // in the upper half of the 128-bit lane
252 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)253 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
254   __m128i zero = _mm_setzero_si128();
255   __m128i tmp = _mm_castps_si128(from);
256   return _mm_castsi128_ps(_mm_unpackhi_epi16(zero, tmp));
257 }
258 
259 // Return a packet with the first value of the input Packet replicated
260 template <>
261 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
262   return _mm_set1_ps(pfirst<Packet4f>(a));
263 }
264 
265 // Return a packet with the second value of the input Packet replicated
266 template <>
267 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
268   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 1)));
269 }
270 
271 // Return a packet with the third value of the input Packet replicated
272 template <>
273 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
274   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 2)));
275 }
276 
277 // Return a packet with the fourth value of the input Packet replicated
278 template <>
279 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
280   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 3)));
281 }
282 
283 #endif
284 
285 #ifdef EIGEN_VECTORIZE_AVX512
286 template <>
287 EIGEN_STRONG_INLINE Packet16f
288 pbroadcast_first<Packet16f>(const Packet16f& a_in) {
289   Packet4f a = _mm512_castps512_ps128(a_in);
290   return _mm512_broadcastss_ps(a);
291 }
292 template <>
293 EIGEN_STRONG_INLINE Packet16f
294 pbroadcast_second<Packet16f>(const Packet16f& a_in) {
295   Packet4f a = _mm512_castps512_ps128(a_in);
296   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1)));
297 }
298 template <>
299 EIGEN_STRONG_INLINE Packet16f
300 pbroadcast_third<Packet16f>(const Packet16f& a_in) {
301   Packet4f a = _mm512_castps512_ps128(a_in);
302   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2)));
303 }
304 template <>
305 EIGEN_STRONG_INLINE Packet16f
306 pbroadcast_fourth<Packet16f>(const Packet16f& a_in) {
307   Packet4f a = _mm512_castps512_ps128(a_in);
308   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3)));
309 }
310 template <>
311 EIGEN_STRONG_INLINE Packet8d pbroadcast_first<Packet8d>(const Packet8d& a_in) {
312   Packet2d a = _mm512_castpd512_pd128(a_in);
313   return _mm512_broadcastsd_pd(a);
314 }
315 template <>
316 EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) {
317   Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3);
318   return _mm512_broadcastsd_pd(a);
319 }
320 template <>
321 EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) {
322   Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1);
323   return _mm512_broadcastsd_pd(a);
324 }
325 template <>
326 EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) {
327   Packet2d a =
328       _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3);
329   return _mm512_broadcastsd_pd(a);
330 }
331 template <>
332 EIGEN_STRONG_INLINE Packet16i
333 pbroadcast_first<Packet16i>(const Packet16i& a_in) {
334   Packet4i a = _mm512_castsi512_si128(a_in);
335   return _mm512_broadcastd_epi32(a);
336 }
337 template <>
338 EIGEN_STRONG_INLINE Packet16i
339 pbroadcast_second<Packet16i>(const Packet16i& a_in) {
340   Packet4i a = _mm512_castsi512_si128(a_in);
341   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1)));
342 }
343 template <>
344 EIGEN_STRONG_INLINE Packet16i
345 pbroadcast_third<Packet16i>(const Packet16i& a_in) {
346   Packet4i a = _mm512_castsi512_si128(a_in);
347   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2)));
348 }
349 template <>
350 EIGEN_STRONG_INLINE Packet16i
351 pbroadcast_fourth<Packet16i>(const Packet16i& a_in) {
352   Packet4i a = _mm512_castsi512_si128(a_in);
353   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3)));
354 }
355 #endif
356 
357 #ifdef EIGEN_VECTORIZE_AVX
358 // For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords
359 template <>
360 EIGEN_STRONG_INLINE Packet8f pinterleave4x64<Packet8f>(const Packet8f& from) {
361 #ifdef EIGEN_VECTORIZE_AVX2
362   return _mm256_castsi256_ps(_mm256_permute4x64_epi64(_mm256_castps_si256(from),
363                                                       _MM_SHUFFLE(3, 1, 2, 0)));
364 #else
365   auto tmp1 = _mm256_extract_epi32(_mm256_castps_si256(from), 2);
366   auto tmp2 = _mm256_extract_epi32(_mm256_castps_si256(from), 3);
367   auto tmp3 = _mm256_extract_epi32(_mm256_castps_si256(from), 4);
368   auto tmp4 = _mm256_extract_epi32(_mm256_castps_si256(from), 5);
369   auto tmp5 = _mm256_insert_epi32(_mm256_castps_si256(from), tmp1, 4);
370   tmp5 = _mm256_insert_epi32(tmp5, tmp2, 5);
371   tmp5 = _mm256_insert_epi32(tmp5, tmp3, 2);
372   tmp5 = _mm256_insert_epi32(tmp5, tmp4, 3);
373   return _mm256_castsi256_ps(tmp5);
374 #endif
375 }
376 // Return a Packet with 4 floats loaded from 4 bfloat16 values
377 template <>
378 EIGEN_STRONG_INLINE Packet8f pload4bf16<Packet8f>(const float* from) {
379   __m128i zero = _mm_setzero_si128();
380   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
381   return _mm256_castps128_ps256(
382       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
383 }
384 // Return a Packet with 2 floats loaded from 2 bfloat16 values
385 template <>
386 EIGEN_STRONG_INLINE Packet8f pload2bf16<Packet8f>(const float* from) {
387   __m128i zero = _mm_setzero_si128();
388   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
389   return _mm256_castps128_ps256(
390       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
391 }
392 
393 #ifdef EIGEN_VECTORIZE_AVX512
394 // Return a Packet with 4 floats loaded from 4 bfloat16 values
395 template <>
396 EIGEN_STRONG_INLINE Packet16f pload4bf16<Packet16f>(const float* from) {
397   __m128i zero = _mm_setzero_si128();
398   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
399   return _mm512_castps128_ps512(
400       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
401 }
402 // Return a Packet with 2 floats loaded from 2 bfloat16 values
403 template <>
404 EIGEN_STRONG_INLINE Packet16f pload2bf16<Packet16f>(const float* from) {
405   __m128i zero = _mm_setzero_si128();
406   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
407   return _mm512_castps128_ps512(
408       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
409 }
410 #endif
411 
412 // For each 128-bit lane convert 4 bfloat to 4 float values from the lower half
413 // of the 128-bit lane
414 template <typename Packet>
pexpand_bf16_l(const Packet8f & from)415 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_l(const Packet8f& from) {
416 #ifdef EIGEN_VECTORIZE_AVX2
417   __m256i zero = _mm256_setzero_si256();
418   __m256i tmp = _mm256_castps_si256(from);
419   return _mm256_castsi256_ps(_mm256_unpacklo_epi16(zero, tmp));
420 #else
421   __m128i zero = _mm_setzero_si128();
422   __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
423   __m128i res_l = _mm_unpacklo_epi16(zero, low);
424   __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
425   __m128i res_h = _mm_unpacklo_epi16(zero, high);
426   __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
427   res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
428   return res;
429 #endif
430 }
431 
432 // For each 128-bit lane convert 4 bfloat to 4 float values from the upper half
433 // of the 128-bit lane
434 template <typename Packet>
pexpand_bf16_u(const Packet8f & from)435 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_u(const Packet8f& from) {
436 #ifdef EIGEN_VECTORIZE_AVX2
437   __m256i zero = _mm256_setzero_si256();
438   __m256i tmp = _mm256_castps_si256(from);
439   return _mm256_castsi256_ps(_mm256_unpackhi_epi16(zero, tmp));
440 #else
441   __m128i zero = _mm_setzero_si128();
442   __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
443   __m128i res_l = _mm_unpackhi_epi16(zero, low);
444   __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
445   __m128i res_h = _mm_unpackhi_epi16(zero, high);
446   __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
447   res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
448   return res;
449 #endif
450 }
451 
452 // Return a packet with the first value of the input Packet replicated
453 template <>
454 EIGEN_STRONG_INLINE Packet8f pbroadcast_first<Packet8f>(const Packet8f& a) {
455   return _mm256_set1_ps(pfirst<Packet8f>(a));
456 }
457 
458 // Return a packet with the second value of the input Packet replicated
459 template <>
460 EIGEN_STRONG_INLINE Packet8f pbroadcast_second<Packet8f>(const Packet8f& a) {
461   return _mm256_set1_ps(
462       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 1))));
463 }
464 
465 // Return a packet with the third value of the input Packet replicated
466 template <>
467 EIGEN_STRONG_INLINE Packet8f pbroadcast_third<Packet8f>(const Packet8f& a) {
468   return _mm256_set1_ps(
469       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 2))));
470 }
471 
472 // Return a packet with the fourth value of the input Packet replicated
473 template <>
474 EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) {
475   return _mm256_set1_ps(
476       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 3))));
477 }
478 
479 #endif
480 
481 #ifdef EIGEN_VECTORIZE_AVX512
482 
483 template <typename Packet>
pexpand_bf16_l(const Packet16f & from)484 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) {
485   return _mm512_castsi512_ps(_mm512_slli_epi32(
486       _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))),
487       16));
488 }
489 
490 template <typename Packet>
pexpand_bf16_u(const Packet16f & from)491 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
492   Packet16i tmp = _mm512_castps_si512(from);
493   Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8);
494   return _mm512_castsi512_ps(_mm512_slli_epi32(
495       _mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16));
496 }
497 
498 #endif
499 }  // namespace internal
500 }  // namespace Eigen
501 #endif  // TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
502