1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
18
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/core/platform/byte_order.h"
21 #include "tensorflow/core/platform/types.h"
22
23 #if defined(PLATFORM_WINDOWS)
24 #include "tensorflow/tsl/platform/windows/intrinsics_port.h"
25 #endif
26
27 namespace Eigen {
28 namespace internal {
29
30 // Return the float representation of the bfloat16 value
31 // in the lower 16-bits of input
32 template <typename Packet>
pexpand_bf16_l(const Packet & from)33 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) {
34 tensorflow::uint32 tmp;
35 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
36 tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
37 #else
38 tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
39 #endif
40 return reinterpret_cast<const float&>(tmp);
41 }
42
43 // Return the float representation of the bfloat16 value
44 // in the upper 16-bits of input
45 template <typename Packet>
pexpand_bf16_u(const Packet & from)46 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) {
47 tensorflow::uint32 tmp;
48 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
49 tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
50 #else
51 tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
52 #endif
53 return reinterpret_cast<const float&>(tmp);
54 }
55
56 // Specialization non-scalar version on non-sse.
57 // Enable vectorization on z13 and higher
58 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
59 defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
60 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)61 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
62 float r[4];
63 tensorflow::uint32 p[4];
64 pstoreu(r, from);
65 tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
66 p[0] = (ir[0] << 16) & 0xffff0000;
67 p[1] = ir[0] & 0xffff0000;
68 p[2] = (ir[1] << 16) & 0xffff0000;
69 p[3] = ir[1] & 0xffff0000;
70 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
71 }
72
73 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)74 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
75 float r[4];
76 tensorflow::uint32 p[4];
77 pstoreu(r, from);
78 tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
79 p[0] = (ir[2] << 16) & 0xffff0000;
80 p[1] = ir[2] & 0xffff0000;
81 p[2] = (ir[3] << 16) & 0xffff0000;
82 p[3] = ir[3] & 0xffff0000;
83 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
84 }
85 #endif
86
87 template <typename Packet>
pinterleave4x64(const Packet & from)88 EIGEN_DEVICE_FUNC inline Packet pinterleave4x64(const Packet& from) {
89 return from;
90 }
91
92 template <typename Packet>
pbroadcast_first(const Packet & a)93 EIGEN_DEVICE_FUNC inline Packet pbroadcast_first(const Packet& a) {
94 return a;
95 }
96
97 template <typename Packet>
pbroadcast_second(const Packet & a)98 EIGEN_DEVICE_FUNC inline Packet pbroadcast_second(const Packet& a) {
99 assert(false && "Not applicable to Scalar Values");
100 return a;
101 }
102
103 template <typename Packet>
pbroadcast_third(const Packet & a)104 EIGEN_DEVICE_FUNC inline Packet pbroadcast_third(const Packet& a) {
105 assert(false && "Not applicable to Scalar Values");
106 return a;
107 }
108
109 template <typename Packet>
pbroadcast_fourth(const Packet & a)110 EIGEN_DEVICE_FUNC inline Packet pbroadcast_fourth(const Packet& a) {
111 assert(false && "Not applicable to Scalar Values");
112 return a;
113 }
114
115 template <typename Packet>
pload4bf16(const typename unpacket_traits<Packet>::type * from)116 EIGEN_DEVICE_FUNC inline Packet pload4bf16(
117 const typename unpacket_traits<Packet>::type* from) {
118 assert(false && "Not applicable to Scalar Values");
119 return Packet();
120 }
121
122 template <typename Packet>
pload2bf16(const typename unpacket_traits<Packet>::type * from)123 EIGEN_DEVICE_FUNC inline Packet pload2bf16(
124 const typename unpacket_traits<Packet>::type* from) {
125 assert(false && "Not applicable to Scalar Values");
126 return Packet();
127 }
128
129 // Specialization for pload4bf16 and pload2bf16 for non-sse.
130 // Enable vectorization on z13 and higher.
131 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
132 defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
133 template <>
134 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
135 tensorflow::uint32 p[4];
136 const tensorflow::uint32* ir =
137 reinterpret_cast<const tensorflow::uint32*>(from);
138 p[0] = (ir[0] << 16) & 0xffff0000;
139 p[1] = ir[0] & 0xffff0000;
140 p[2] = (ir[1] << 16) & 0xffff0000;
141 p[3] = ir[1] & 0xffff0000;
142 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
143 }
144
145 template <>
146 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
147 tensorflow::uint32 p[4];
148 const tensorflow::uint32* ir =
149 reinterpret_cast<const tensorflow::uint32*>(from);
150 p[0] = (ir[0] << 16) & 0xffff0000;
151 p[1] = ir[0] & 0xffff0000;
152 p[2] = (ir[0] << 16) & 0xffff0000;
153 p[3] = ir[0] & 0xffff0000;
154 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
155 }
156 #endif
157
158 #if defined(EIGEN_VECTORIZE_NEON)
159 // Return a packet with the first value of the input Packet replicated
160 template <>
161 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
162 return pset1<Packet4f>(pfirst(a));
163 }
164 template <>
165 EIGEN_STRONG_INLINE Packet2f pbroadcast_first<Packet2f>(const Packet2f& a) {
166 return pset1<Packet2f>(pfirst(a));
167 }
168
169 // Return a packet with the second value of the input Packet replicated
170 template <>
171 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
172 return pset1<Packet4f>(vgetq_lane_f32(a, 1));
173 }
174 template <>
175 EIGEN_STRONG_INLINE Packet2f pbroadcast_second<Packet2f>(const Packet2f& a) {
176 return pset1<Packet2f>(vget_lane_f32(a, 1));
177 }
178
179 // Return a packet with the third value of the input Packet replicated
180 template <>
181 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
182 return pset1<Packet4f>(vgetq_lane_f32(a, 2));
183 }
184
185 // Return a packet with the fourth value of the input Packet replicated
186 template <>
187 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
188 return pset1<Packet4f>(vgetq_lane_f32(a, 3));
189 }
190 #endif
191
192 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
193 // Return a packet with the first value of the input Packet replicated
194 template <>
195 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
196 return vec_splat(a, 0);
197 }
198
199 // Return a packet with the second value of the input Packet replicated
200 template <>
201 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
202 return vec_splat(a, 1);
203 }
204
205 // Return a packet with the third value of the input Packet replicated
206 template <>
207 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
208 return vec_splat(a, 2);
209 }
210
211 // Return a packet with the fourth value of the input Packet replicated
212 template <>
213 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
214 return vec_splat(a, 3);
215 }
216 #endif
217
218 #ifdef EIGEN_VECTORIZE_SSE2
219 // For PacketSize of 4 floats the Packet is not modified
220 template <>
221 EIGEN_STRONG_INLINE Packet4f pinterleave4x64<Packet4f>(const Packet4f& from) {
222 return from;
223 }
224
225 // Return a Packet with 4 floats loaded from 4 bfloat16 values
226 template <>
227 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
228 __m128i zero = _mm_setzero_si128();
229 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
230 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
231 }
232
233 // Return a Packet with 2 floats loaded from 2 bfloat16 values
234 template <>
235 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
236 __m128i zero = _mm_setzero_si128();
237 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
238 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
239 }
240
241 // Return a Packet with 4 floats expanded from 4 bfloat16 values
242 // in the lower half of the 128-bit lane
243 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)244 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
245 __m128i zero = _mm_setzero_si128();
246 __m128i tmp = _mm_castps_si128(from);
247 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
248 }
249
250 // Return a Packet with 4 floats expanded from 4 bfloat16 values
251 // in the upper half of the 128-bit lane
252 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)253 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
254 __m128i zero = _mm_setzero_si128();
255 __m128i tmp = _mm_castps_si128(from);
256 return _mm_castsi128_ps(_mm_unpackhi_epi16(zero, tmp));
257 }
258
259 // Return a packet with the first value of the input Packet replicated
260 template <>
261 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
262 return _mm_set1_ps(pfirst<Packet4f>(a));
263 }
264
265 // Return a packet with the second value of the input Packet replicated
266 template <>
267 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
268 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 1)));
269 }
270
271 // Return a packet with the third value of the input Packet replicated
272 template <>
273 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
274 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 2)));
275 }
276
277 // Return a packet with the fourth value of the input Packet replicated
278 template <>
279 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
280 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 3)));
281 }
282
283 #endif
284
285 #ifdef EIGEN_VECTORIZE_AVX512
286 template <>
287 EIGEN_STRONG_INLINE Packet16f
288 pbroadcast_first<Packet16f>(const Packet16f& a_in) {
289 Packet4f a = _mm512_castps512_ps128(a_in);
290 return _mm512_broadcastss_ps(a);
291 }
292 template <>
293 EIGEN_STRONG_INLINE Packet16f
294 pbroadcast_second<Packet16f>(const Packet16f& a_in) {
295 Packet4f a = _mm512_castps512_ps128(a_in);
296 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1)));
297 }
298 template <>
299 EIGEN_STRONG_INLINE Packet16f
300 pbroadcast_third<Packet16f>(const Packet16f& a_in) {
301 Packet4f a = _mm512_castps512_ps128(a_in);
302 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2)));
303 }
304 template <>
305 EIGEN_STRONG_INLINE Packet16f
306 pbroadcast_fourth<Packet16f>(const Packet16f& a_in) {
307 Packet4f a = _mm512_castps512_ps128(a_in);
308 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3)));
309 }
310 template <>
311 EIGEN_STRONG_INLINE Packet8d pbroadcast_first<Packet8d>(const Packet8d& a_in) {
312 Packet2d a = _mm512_castpd512_pd128(a_in);
313 return _mm512_broadcastsd_pd(a);
314 }
315 template <>
316 EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) {
317 Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3);
318 return _mm512_broadcastsd_pd(a);
319 }
320 template <>
321 EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) {
322 Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1);
323 return _mm512_broadcastsd_pd(a);
324 }
325 template <>
326 EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) {
327 Packet2d a =
328 _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3);
329 return _mm512_broadcastsd_pd(a);
330 }
331 template <>
332 EIGEN_STRONG_INLINE Packet16i
333 pbroadcast_first<Packet16i>(const Packet16i& a_in) {
334 Packet4i a = _mm512_castsi512_si128(a_in);
335 return _mm512_broadcastd_epi32(a);
336 }
337 template <>
338 EIGEN_STRONG_INLINE Packet16i
339 pbroadcast_second<Packet16i>(const Packet16i& a_in) {
340 Packet4i a = _mm512_castsi512_si128(a_in);
341 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1)));
342 }
343 template <>
344 EIGEN_STRONG_INLINE Packet16i
345 pbroadcast_third<Packet16i>(const Packet16i& a_in) {
346 Packet4i a = _mm512_castsi512_si128(a_in);
347 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2)));
348 }
349 template <>
350 EIGEN_STRONG_INLINE Packet16i
351 pbroadcast_fourth<Packet16i>(const Packet16i& a_in) {
352 Packet4i a = _mm512_castsi512_si128(a_in);
353 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3)));
354 }
355 #endif
356
357 #ifdef EIGEN_VECTORIZE_AVX
358 // For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords
359 template <>
360 EIGEN_STRONG_INLINE Packet8f pinterleave4x64<Packet8f>(const Packet8f& from) {
361 #ifdef EIGEN_VECTORIZE_AVX2
362 return _mm256_castsi256_ps(_mm256_permute4x64_epi64(_mm256_castps_si256(from),
363 _MM_SHUFFLE(3, 1, 2, 0)));
364 #else
365 auto tmp1 = _mm256_extract_epi32(_mm256_castps_si256(from), 2);
366 auto tmp2 = _mm256_extract_epi32(_mm256_castps_si256(from), 3);
367 auto tmp3 = _mm256_extract_epi32(_mm256_castps_si256(from), 4);
368 auto tmp4 = _mm256_extract_epi32(_mm256_castps_si256(from), 5);
369 auto tmp5 = _mm256_insert_epi32(_mm256_castps_si256(from), tmp1, 4);
370 tmp5 = _mm256_insert_epi32(tmp5, tmp2, 5);
371 tmp5 = _mm256_insert_epi32(tmp5, tmp3, 2);
372 tmp5 = _mm256_insert_epi32(tmp5, tmp4, 3);
373 return _mm256_castsi256_ps(tmp5);
374 #endif
375 }
376 // Return a Packet with 4 floats loaded from 4 bfloat16 values
377 template <>
378 EIGEN_STRONG_INLINE Packet8f pload4bf16<Packet8f>(const float* from) {
379 __m128i zero = _mm_setzero_si128();
380 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
381 return _mm256_castps128_ps256(
382 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
383 }
384 // Return a Packet with 2 floats loaded from 2 bfloat16 values
385 template <>
386 EIGEN_STRONG_INLINE Packet8f pload2bf16<Packet8f>(const float* from) {
387 __m128i zero = _mm_setzero_si128();
388 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
389 return _mm256_castps128_ps256(
390 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
391 }
392
393 #ifdef EIGEN_VECTORIZE_AVX512
394 // Return a Packet with 4 floats loaded from 4 bfloat16 values
395 template <>
396 EIGEN_STRONG_INLINE Packet16f pload4bf16<Packet16f>(const float* from) {
397 __m128i zero = _mm_setzero_si128();
398 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
399 return _mm512_castps128_ps512(
400 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
401 }
402 // Return a Packet with 2 floats loaded from 2 bfloat16 values
403 template <>
404 EIGEN_STRONG_INLINE Packet16f pload2bf16<Packet16f>(const float* from) {
405 __m128i zero = _mm_setzero_si128();
406 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
407 return _mm512_castps128_ps512(
408 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
409 }
410 #endif
411
412 // For each 128-bit lane convert 4 bfloat to 4 float values from the lower half
413 // of the 128-bit lane
414 template <typename Packet>
pexpand_bf16_l(const Packet8f & from)415 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_l(const Packet8f& from) {
416 #ifdef EIGEN_VECTORIZE_AVX2
417 __m256i zero = _mm256_setzero_si256();
418 __m256i tmp = _mm256_castps_si256(from);
419 return _mm256_castsi256_ps(_mm256_unpacklo_epi16(zero, tmp));
420 #else
421 __m128i zero = _mm_setzero_si128();
422 __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
423 __m128i res_l = _mm_unpacklo_epi16(zero, low);
424 __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
425 __m128i res_h = _mm_unpacklo_epi16(zero, high);
426 __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
427 res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
428 return res;
429 #endif
430 }
431
432 // For each 128-bit lane convert 4 bfloat to 4 float values from the upper half
433 // of the 128-bit lane
434 template <typename Packet>
pexpand_bf16_u(const Packet8f & from)435 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_u(const Packet8f& from) {
436 #ifdef EIGEN_VECTORIZE_AVX2
437 __m256i zero = _mm256_setzero_si256();
438 __m256i tmp = _mm256_castps_si256(from);
439 return _mm256_castsi256_ps(_mm256_unpackhi_epi16(zero, tmp));
440 #else
441 __m128i zero = _mm_setzero_si128();
442 __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
443 __m128i res_l = _mm_unpackhi_epi16(zero, low);
444 __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
445 __m128i res_h = _mm_unpackhi_epi16(zero, high);
446 __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
447 res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
448 return res;
449 #endif
450 }
451
452 // Return a packet with the first value of the input Packet replicated
453 template <>
454 EIGEN_STRONG_INLINE Packet8f pbroadcast_first<Packet8f>(const Packet8f& a) {
455 return _mm256_set1_ps(pfirst<Packet8f>(a));
456 }
457
458 // Return a packet with the second value of the input Packet replicated
459 template <>
460 EIGEN_STRONG_INLINE Packet8f pbroadcast_second<Packet8f>(const Packet8f& a) {
461 return _mm256_set1_ps(
462 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 1))));
463 }
464
465 // Return a packet with the third value of the input Packet replicated
466 template <>
467 EIGEN_STRONG_INLINE Packet8f pbroadcast_third<Packet8f>(const Packet8f& a) {
468 return _mm256_set1_ps(
469 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 2))));
470 }
471
472 // Return a packet with the fourth value of the input Packet replicated
473 template <>
474 EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) {
475 return _mm256_set1_ps(
476 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 3))));
477 }
478
479 #endif
480
481 #ifdef EIGEN_VECTORIZE_AVX512
482
483 template <typename Packet>
pexpand_bf16_l(const Packet16f & from)484 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) {
485 return _mm512_castsi512_ps(_mm512_slli_epi32(
486 _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))),
487 16));
488 }
489
490 template <typename Packet>
pexpand_bf16_u(const Packet16f & from)491 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
492 Packet16i tmp = _mm512_castps_si512(from);
493 Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8);
494 return _mm512_castsi512_ps(_mm512_slli_epi32(
495 _mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16));
496 }
497
498 #endif
499 } // namespace internal
500 } // namespace Eigen
501 #endif // TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
502