xref: /aosp_15_r20/external/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Mehdi Goli    Codeplay Software Ltd.
5 // Ralph Potter  Codeplay Software Ltd.
6 // Luke Iwanski  Codeplay Software Ltd.
7 // Contact: <[email protected]>
8 //
9 // This Source Code Form is subject to the terms of the Mozilla
10 // Public License v. 2.0. If a copy of the MPL was not distributed
11 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
12 
13 /*****************************************************************
14  * PacketMath.h
15  *
16  * \brief:
17  *  PacketMath
18  *
19  *****************************************************************/
20 
21 #ifndef EIGEN_PACKET_MATH_SYCL_H
22 #define EIGEN_PACKET_MATH_SYCL_H
23 #include <type_traits>
24 namespace Eigen {
25 
26 namespace internal {
27 #ifdef SYCL_DEVICE_ONLY
28 
29 #define SYCL_PLOADT_RO(address_space_target)                                 \
30   template <typename packet_type, int Alignment>                             \
31   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type ploadt_ro(               \
32       typename cl::sycl::multi_ptr<                                          \
33           const typename unpacket_traits<packet_type>::type,                 \
34           cl::sycl::access::address_space::address_space_target>::pointer_t  \
35           from) {                                                            \
36     typedef typename unpacket_traits<packet_type>::type scalar;              \
37     typedef cl::sycl::multi_ptr<                                             \
38         scalar, cl::sycl::access::address_space::address_space_target>       \
39         multi_ptr;                                                           \
40     auto res = packet_type(                                                  \
41         static_cast<typename unpacket_traits<packet_type>::type>(0));        \
42     res.load(0, multi_ptr(const_cast<typename multi_ptr::pointer_t>(from))); \
43     return res;                                                              \
44   }
45 
46 SYCL_PLOADT_RO(global_space)
SYCL_PLOADT_RO(local_space)47 SYCL_PLOADT_RO(local_space)
48 #undef SYCL_PLOADT_RO
49 #endif
50 
51 template <typename packet_type, int Alignment, typename T>
52 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type
53 ploadt_ro(const Eigen::TensorSycl::internal::RangeAccess<
54           cl::sycl::access::mode::read_write, T>& from) {
55   return ploadt_ro<packet_type, Alignment>(from.get_pointer());
56 }
57 
58 #ifdef SYCL_DEVICE_ONLY
59 #define SYCL_PLOAD(address_space_target, Alignment, AlignedType)            \
60   template <typename packet_type>                                           \
61   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##AlignedType(     \
62       typename cl::sycl::multi_ptr<                                         \
63           const typename unpacket_traits<packet_type>::type,                \
64           cl::sycl::access::address_space::address_space_target>::pointer_t \
65           from) {                                                           \
66     return ploadt_ro<packet_type, Alignment>(from);                         \
67   }
68 
69 // global space
SYCL_PLOAD(global_space,Unaligned,u)70 SYCL_PLOAD(global_space, Unaligned, u)
71 SYCL_PLOAD(global_space, Aligned, )
72 // local space
73 SYCL_PLOAD(local_space, Unaligned, u)
74 SYCL_PLOAD(local_space, Aligned, )
75 
76 #undef SYCL_PLOAD
77 #endif
78 
79 #define SYCL_PLOAD(Alignment, AlignedType)                              \
80   template <typename packet_type>                                       \
81   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##AlignedType( \
82       const Eigen::TensorSycl::internal::RangeAccess<                   \
83           cl::sycl::access::mode::read_write,                           \
84           typename unpacket_traits<packet_type>::type>                  \
85           from) {                                                       \
86     return ploadt_ro<packet_type, Alignment>(from);                     \
87   }
88 SYCL_PLOAD(Unaligned, u)
89 SYCL_PLOAD(Aligned, )
90 #undef SYCL_PLOAD
91 
92 #ifdef SYCL_DEVICE_ONLY
93 /** \internal \returns a packet version of \a *from.
94  * The pointer \a from must be aligned on a \a Alignment bytes boundary. */
95 #define SYCL_PLOADT(address_space_target)                                   \
96   template <typename packet_type, int Alignment>                            \
97   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type ploadt(                 \
98       typename cl::sycl::multi_ptr<                                         \
99           const typename unpacket_traits<packet_type>::type,                \
100           cl::sycl::access::address_space::address_space_target>::pointer_t \
101           from) {                                                           \
102     if (Alignment >= unpacket_traits<packet_type>::alignment)               \
103       return pload<packet_type>(from);                                      \
104     else                                                                    \
105       return ploadu<packet_type>(from);                                     \
106   }
107 
108 // global space
109 SYCL_PLOADT(global_space)
110 // local space
111 SYCL_PLOADT(local_space)
112 #undef SYCL_PLOADT
113 #endif
114 
115 template <typename packet_type, int Alignment>
116 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type
117 ploadt(const Eigen::TensorSycl::internal::RangeAccess<
118        cl::sycl::access::mode::read_write,
119        typename unpacket_traits<packet_type>::type>& from) {
120   return ploadt<packet_type, Alignment>(from.get_pointer());
121 }
122 #ifdef SYCL_DEVICE_ONLY
123 
124 // private_space
125 #define SYCL_PLOADT_RO_SPECIAL(packet_type, Alignment)                 \
126   template <>                                                          \
127   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type                    \
128   ploadt_ro<packet_type, Alignment>(                                   \
129       const typename unpacket_traits<packet_type>::type* from) {       \
130     typedef typename unpacket_traits<packet_type>::type scalar;        \
131     auto res = packet_type(static_cast<scalar>(0));                    \
132     res.template load<cl::sycl::access::address_space::private_space>( \
133         0, const_cast<scalar*>(from));                                 \
134     return res;                                                        \
135   }
136 
137 SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_float4, Aligned)
138 SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_double2, Aligned)
139 SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_float4, Unaligned)
140 SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_double2, Unaligned)
141 
142 #define SYCL_PLOAD_SPECIAL(packet_type, alignment_type)                    \
143   template <>                                                              \
144   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##alignment_type( \
145       const typename unpacket_traits<packet_type>::type* from) {           \
146     typedef typename unpacket_traits<packet_type>::type scalar;            \
147     auto res = packet_type(static_cast<scalar>(0));                        \
148     res.template load<cl::sycl::access::address_space::private_space>(     \
149         0, const_cast<scalar*>(from));                                     \
150     return res;                                                            \
151   }
152 SYCL_PLOAD_SPECIAL(cl::sycl::cl_float4, )
153 SYCL_PLOAD_SPECIAL(cl::sycl::cl_double2, )
154 SYCL_PLOAD_SPECIAL(cl::sycl::cl_float4, u)
155 SYCL_PLOAD_SPECIAL(cl::sycl::cl_double2, u)
156 
157 #undef SYCL_PLOAD_SPECIAL
158 
159 #define SYCL_PSTORE(scalar, packet_type, address_space_target, alignment)   \
160   template <>                                                               \
161   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstore##alignment(             \
162       typename cl::sycl::multi_ptr<                                         \
163           scalar,                                                           \
164           cl::sycl::access::address_space::address_space_target>::pointer_t \
165           to,                                                               \
166       const packet_type& from) {                                            \
167     typedef cl::sycl::multi_ptr<                                            \
168         scalar, cl::sycl::access::address_space::address_space_target>      \
169         multi_ptr;                                                          \
170     from.store(0, multi_ptr(to));                                           \
171   }
172 
173 // global space
174 SYCL_PSTORE(float, cl::sycl::cl_float4, global_space, )
175 SYCL_PSTORE(float, cl::sycl::cl_float4, global_space, u)
176 SYCL_PSTORE(double, cl::sycl::cl_double2, global_space, )
177 SYCL_PSTORE(double, cl::sycl::cl_double2, global_space, u)
178 SYCL_PSTORE(float, cl::sycl::cl_float4, local_space, )
179 SYCL_PSTORE(float, cl::sycl::cl_float4, local_space, u)
180 SYCL_PSTORE(double, cl::sycl::cl_double2, local_space, )
181 SYCL_PSTORE(double, cl::sycl::cl_double2, local_space, u)
182 
183 SYCL_PSTORE(float, cl::sycl::cl_float4, private_space, )
184 SYCL_PSTORE(float, cl::sycl::cl_float4, private_space, u)
185 SYCL_PSTORE(double, cl::sycl::cl_double2, private_space, )
186 SYCL_PSTORE(double, cl::sycl::cl_double2, private_space, u)
187 #undef SYCL_PSTORE
188 
189 #define SYCL_PSTORE_T(address_space_target)                                 \
190   template <typename scalar, typename packet_type, int Alignment>           \
191   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstoret(                       \
192       typename cl::sycl::multi_ptr<                                         \
193           scalar,                                                           \
194           cl::sycl::access::address_space::address_space_target>::pointer_t \
195           to,                                                               \
196       const packet_type& from) {                                            \
197     if (Alignment)                                                          \
198       pstore(to, from);                                                     \
199     else                                                                    \
200       pstoreu(to, from);                                                    \
201   }
202 
203 SYCL_PSTORE_T(global_space)
204 
205 SYCL_PSTORE_T(local_space)
206 
207 #undef SYCL_PSTORE_T
208 
209 #define SYCL_PSET1(packet_type)                                         \
210   template <>                                                           \
211   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pset1<packet_type>( \
212       const typename unpacket_traits<packet_type>::type& from) {        \
213     return packet_type(from);                                           \
214   }
215 
216 // global space
217 SYCL_PSET1(cl::sycl::cl_float4)
218 SYCL_PSET1(cl::sycl::cl_double2)
219 
220 #undef SYCL_PSET1
221 
222 template <typename packet_type>
223 struct get_base_packet {
224   template <typename sycl_multi_pointer>
225   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type
get_ploaddupget_base_packet226   get_ploaddup(sycl_multi_pointer) {}
227 
228   template <typename sycl_multi_pointer>
229   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type
get_pgatherget_base_packet230   get_pgather(sycl_multi_pointer, Index) {}
231 };
232 
233 template <>
234 struct get_base_packet<cl::sycl::cl_float4> {
235   template <typename sycl_multi_pointer>
236   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 get_ploaddup(
237       sycl_multi_pointer from) {
238     return cl::sycl::cl_float4(from[0], from[0], from[1], from[1]);
239   }
240   template <typename sycl_multi_pointer>
241   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 get_pgather(
242       sycl_multi_pointer from, Index stride) {
243     return cl::sycl::cl_float4(from[0 * stride], from[1 * stride],
244                                from[2 * stride], from[3 * stride]);
245   }
246 
247   template <typename sycl_multi_pointer>
248   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_pscatter(
249       sycl_multi_pointer to, const cl::sycl::cl_float4& from, Index stride) {
250     auto tmp = stride;
251     to[0] = from.x();
252     to[tmp] = from.y();
253     to[tmp += stride] = from.z();
254     to[tmp += stride] = from.w();
255   }
256   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 set_plset(
257       const float& a) {
258     return cl::sycl::cl_float4(static_cast<float>(a), static_cast<float>(a + 1),
259                                static_cast<float>(a + 2),
260                                static_cast<float>(a + 3));
261   }
262 };
263 
264 template <>
265 struct get_base_packet<cl::sycl::cl_double2> {
266   template <typename sycl_multi_pointer>
267   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2
268   get_ploaddup(const sycl_multi_pointer from) {
269     return cl::sycl::cl_double2(from[0], from[0]);
270   }
271 
272   template <typename sycl_multi_pointer, typename Index>
273   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2 get_pgather(
274       const sycl_multi_pointer from, Index stride) {
275     return cl::sycl::cl_double2(from[0 * stride], from[1 * stride]);
276   }
277 
278   template <typename sycl_multi_pointer>
279   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_pscatter(
280       sycl_multi_pointer to, const cl::sycl::cl_double2& from, Index stride) {
281     to[0] = from.x();
282     to[stride] = from.y();
283   }
284 
285   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2 set_plset(
286       const double& a) {
287     return cl::sycl::cl_double2(static_cast<double>(a),
288                                 static_cast<double>(a + 1));
289   }
290 };
291 
292 #define SYCL_PLOAD_DUP(address_space_target)                                \
293   template <typename packet_type>                                           \
294   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ploaddup(               \
295       typename cl::sycl::multi_ptr<                                         \
296           const typename unpacket_traits<packet_type>::type,                \
297           cl::sycl::access::address_space::address_space_target>::pointer_t \
298           from) {                                                           \
299     return get_base_packet<packet_type>::get_ploaddup(from);                \
300   }
301 
302 // global space
303 SYCL_PLOAD_DUP(global_space)
304 // local_space
305 SYCL_PLOAD_DUP(local_space)
306 #undef SYCL_PLOAD_DUP
307 
308 #define SYCL_PLOAD_DUP_SPECILIZE(packet_type)                              \
309   template <>                                                              \
310   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ploaddup<packet_type>( \
311       const typename unpacket_traits<packet_type>::type* from) {           \
312     return get_base_packet<packet_type>::get_ploaddup(from);               \
313   }
314 
315 SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_float4)
316 SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_double2)
317 
318 #undef SYCL_PLOAD_DUP_SPECILIZE
319 
320 #define SYCL_PLSET(packet_type)                                         \
321   template <>                                                           \
322   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type plset<packet_type>( \
323       const typename unpacket_traits<packet_type>::type& a) {           \
324     return get_base_packet<packet_type>::set_plset(a);                  \
325   }
326 
327 SYCL_PLSET(cl::sycl::cl_float4)
328 SYCL_PLSET(cl::sycl::cl_double2)
329 
330 #undef SYCL_PLSET
331 
332 #define SYCL_PGATHER(address_space_target)                                  \
333   template <typename Scalar, typename packet_type>                          \
334   EIGEN_DEVICE_FUNC inline packet_type pgather(                             \
335       typename cl::sycl::multi_ptr<                                         \
336           const typename unpacket_traits<packet_type>::type,                \
337           cl::sycl::access::address_space::address_space_target>::pointer_t \
338           from,                                                             \
339       Index stride) {                                                       \
340     return get_base_packet<packet_type>::get_pgather(from, stride);         \
341   }
342 
343 // global space
344 SYCL_PGATHER(global_space)
345 // local space
346 SYCL_PGATHER(local_space)
347 
348 #undef SYCL_PGATHER
349 
350 #define SYCL_PGATHER_SPECILIZE(scalar, packet_type)                            \
351   template <>                                                                  \
352   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type                            \
353   pgather<scalar, packet_type>(                                                \
354       const typename unpacket_traits<packet_type>::type* from, Index stride) { \
355     return get_base_packet<packet_type>::get_pgather(from, stride);            \
356   }
357 
358 SYCL_PGATHER_SPECILIZE(float, cl::sycl::cl_float4)
359 SYCL_PGATHER_SPECILIZE(double, cl::sycl::cl_double2)
360 
361 #undef SYCL_PGATHER_SPECILIZE
362 
363 #define SYCL_PSCATTER(address_space_target)                                 \
364   template <typename Scalar, typename packet_type>                          \
365   EIGEN_DEVICE_FUNC inline void pscatter(                                   \
366       typename cl::sycl::multi_ptr<                                         \
367           typename unpacket_traits<packet_type>::type,                      \
368           cl::sycl::access::address_space::address_space_target>::pointer_t \
369           to,                                                               \
370       const packet_type& from, Index stride) {                              \
371     get_base_packet<packet_type>::set_pscatter(to, from, stride);           \
372   }
373 
374 // global space
375 SYCL_PSCATTER(global_space)
376 // local space
377 SYCL_PSCATTER(local_space)
378 
379 #undef SYCL_PSCATTER
380 
381 #define SYCL_PSCATTER_SPECILIZE(scalar, packet_type)                        \
382   template <>                                                               \
383   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<scalar, packet_type>( \
384       typename unpacket_traits<packet_type>::type * to,                     \
385       const packet_type& from, Index stride) {                              \
386     get_base_packet<packet_type>::set_pscatter(to, from, stride);           \
387   }
388 
389 SYCL_PSCATTER_SPECILIZE(float, cl::sycl::cl_float4)
390 SYCL_PSCATTER_SPECILIZE(double, cl::sycl::cl_double2)
391 
392 #undef SYCL_PSCATTER_SPECILIZE
393 
394 #define SYCL_PMAD(packet_type)                                            \
395   template <>                                                             \
396   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pmadd(                \
397       const packet_type& a, const packet_type& b, const packet_type& c) { \
398     return cl::sycl::mad(a, b, c);                                        \
399   }
400 
401 SYCL_PMAD(cl::sycl::cl_float4)
402 SYCL_PMAD(cl::sycl::cl_double2)
403 #undef SYCL_PMAD
404 
405 template <>
406 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float pfirst<cl::sycl::cl_float4>(
407     const cl::sycl::cl_float4& a) {
408   return a.x();
409 }
410 template <>
411 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double pfirst<cl::sycl::cl_double2>(
412     const cl::sycl::cl_double2& a) {
413   return a.x();
414 }
415 
416 template <>
417 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux<cl::sycl::cl_float4>(
418     const cl::sycl::cl_float4& a) {
419   return a.x() + a.y() + a.z() + a.w();
420 }
421 
422 template <>
423 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux<cl::sycl::cl_double2>(
424     const cl::sycl::cl_double2& a) {
425   return a.x() + a.y();
426 }
427 
428 template <>
429 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_max<cl::sycl::cl_float4>(
430     const cl::sycl::cl_float4& a) {
431   return cl::sycl::fmax(cl::sycl::fmax(a.x(), a.y()),
432                         cl::sycl::fmax(a.z(), a.w()));
433 }
434 template <>
435 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_max<cl::sycl::cl_double2>(
436     const cl::sycl::cl_double2& a) {
437   return cl::sycl::fmax(a.x(), a.y());
438 }
439 
440 template <>
441 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_min<cl::sycl::cl_float4>(
442     const cl::sycl::cl_float4& a) {
443   return cl::sycl::fmin(cl::sycl::fmin(a.x(), a.y()),
444                         cl::sycl::fmin(a.z(), a.w()));
445 }
446 template <>
447 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_min<cl::sycl::cl_double2>(
448     const cl::sycl::cl_double2& a) {
449   return cl::sycl::fmin(a.x(), a.y());
450 }
451 
452 template <>
453 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_mul<cl::sycl::cl_float4>(
454     const cl::sycl::cl_float4& a) {
455   return a.x() * a.y() * a.z() * a.w();
456 }
457 template <>
458 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_mul<cl::sycl::cl_double2>(
459     const cl::sycl::cl_double2& a) {
460   return a.x() * a.y();
461 }
462 
463 template <>
464 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4
465 pabs<cl::sycl::cl_float4>(const cl::sycl::cl_float4& a) {
466   return cl::sycl::cl_float4(cl::sycl::fabs(a.x()), cl::sycl::fabs(a.y()),
467                              cl::sycl::fabs(a.z()), cl::sycl::fabs(a.w()));
468 }
469 template <>
470 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_double2
471 pabs<cl::sycl::cl_double2>(const cl::sycl::cl_double2& a) {
472   return cl::sycl::cl_double2(cl::sycl::fabs(a.x()), cl::sycl::fabs(a.y()));
473 }
474 
475 template <typename Packet>
476 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_le(const Packet &a,
477                                                           const Packet &b) {
478   return ((a <= b)
479               .template convert<typename unpacket_traits<Packet>::type,
480                                 cl::sycl::rounding_mode::automatic>());
481 }
482 
483 template <typename Packet>
484 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_lt(const Packet &a,
485                                                           const Packet &b) {
486   return ((a < b)
487               .template convert<typename unpacket_traits<Packet>::type,
488                                 cl::sycl::rounding_mode::automatic>());
489 }
490 
491 template <typename Packet>
492 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_eq(const Packet &a,
493                                                           const Packet &b) {
494   return ((a == b)
495               .template convert<typename unpacket_traits<Packet>::type,
496                                 cl::sycl::rounding_mode::automatic>());
497 }
498 
499 #define SYCL_PCMP(OP, TYPE)                                                    \
500   template <>                                                                  \
501   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE TYPE pcmp_##OP<TYPE>(const TYPE &a,    \
502                                                              const TYPE &b) {  \
503     return sycl_pcmp_##OP<TYPE>(a, b);                                         \
504   }
505 
506 SYCL_PCMP(le, cl::sycl::cl_float4)
507 SYCL_PCMP(lt, cl::sycl::cl_float4)
508 SYCL_PCMP(eq, cl::sycl::cl_float4)
509 SYCL_PCMP(le, cl::sycl::cl_double2)
510 SYCL_PCMP(lt, cl::sycl::cl_double2)
511 SYCL_PCMP(eq, cl::sycl::cl_double2)
512 #undef SYCL_PCMP
513 
514 template <typename T> struct convert_to_integer;
515 
516 template <> struct convert_to_integer<float> {
517   using type = std::int32_t;
518   using packet_type = cl::sycl::cl_int4;
519 };
520 template <> struct convert_to_integer<double> {
521   using type = std::int64_t;
522   using packet_type = cl::sycl::cl_long2;
523 };
524 
525 template <typename PacketIn>
526 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename convert_to_integer<
527     typename unpacket_traits<PacketIn>::type>::packet_type
528 vector_as_int(const PacketIn &p) {
529   return (
530       p.template convert<typename convert_to_integer<
531                              typename unpacket_traits<PacketIn>::type>::type,
532                          cl::sycl::rounding_mode::automatic>());
533 }
534 
535 template <typename packetOut, typename PacketIn>
536 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packetOut
537 convert_vector(const PacketIn &p) {
538   return (p.template convert<typename unpacket_traits<packetOut>::type,
539                              cl::sycl::rounding_mode::automatic>());
540 }
541 
542 #define SYCL_PAND(TYPE)                                                        \
543   template <>                                                                  \
544   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pand<TYPE>(const TYPE &a,         \
545                                                         const TYPE &b) {       \
546     return convert_vector<TYPE>(vector_as_int(a) & vector_as_int(b));          \
547   }
548 SYCL_PAND(cl::sycl::cl_float4)
549 SYCL_PAND(cl::sycl::cl_double2)
550 #undef SYCL_PAND
551 
552 #define SYCL_POR(TYPE)                                                         \
553   template <>                                                                  \
554   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE por<TYPE>(const TYPE &a,          \
555                                                        const TYPE &b) {        \
556     return convert_vector<TYPE>(vector_as_int(a) | vector_as_int(b));          \
557   }
558 
559 SYCL_POR(cl::sycl::cl_float4)
560 SYCL_POR(cl::sycl::cl_double2)
561 #undef SYCL_POR
562 
563 #define SYCL_PXOR(TYPE)                                                        \
564   template <>                                                                  \
565   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pxor<TYPE>(const TYPE &a,         \
566                                                         const TYPE &b) {       \
567     return convert_vector<TYPE>(vector_as_int(a) ^ vector_as_int(b));          \
568   }
569 
570 SYCL_PXOR(cl::sycl::cl_float4)
571 SYCL_PXOR(cl::sycl::cl_double2)
572 #undef SYCL_PXOR
573 
574 #define SYCL_PANDNOT(TYPE)                                                     \
575   template <>                                                                  \
576   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pandnot<TYPE>(const TYPE &a,      \
577                                                            const TYPE &b) {    \
578     return convert_vector<TYPE>(vector_as_int(a) & (~vector_as_int(b)));       \
579   }
580 SYCL_PANDNOT(cl::sycl::cl_float4)
581 SYCL_PANDNOT(cl::sycl::cl_double2)
582 #undef SYCL_PANDNOT
583 
584 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void ptranspose(
585     PacketBlock<cl::sycl::cl_float4, 4>& kernel) {
586   float tmp = kernel.packet[0].y();
587   kernel.packet[0].y() = kernel.packet[1].x();
588   kernel.packet[1].x() = tmp;
589 
590   tmp = kernel.packet[0].z();
591   kernel.packet[0].z() = kernel.packet[2].x();
592   kernel.packet[2].x() = tmp;
593 
594   tmp = kernel.packet[0].w();
595   kernel.packet[0].w() = kernel.packet[3].x();
596   kernel.packet[3].x() = tmp;
597 
598   tmp = kernel.packet[1].z();
599   kernel.packet[1].z() = kernel.packet[2].y();
600   kernel.packet[2].y() = tmp;
601 
602   tmp = kernel.packet[1].w();
603   kernel.packet[1].w() = kernel.packet[3].y();
604   kernel.packet[3].y() = tmp;
605 
606   tmp = kernel.packet[2].w();
607   kernel.packet[2].w() = kernel.packet[3].z();
608   kernel.packet[3].z() = tmp;
609 }
610 
611 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void ptranspose(
612     PacketBlock<cl::sycl::cl_double2, 2>& kernel) {
613   double tmp = kernel.packet[0].y();
614   kernel.packet[0].y() = kernel.packet[1].x();
615   kernel.packet[1].x() = tmp;
616 }
617 
618 template <>
619 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4 pblend(
620     const Selector<unpacket_traits<cl::sycl::cl_float4>::size>& ifPacket,
621     const cl::sycl::cl_float4& thenPacket,
622     const cl::sycl::cl_float4& elsePacket) {
623   cl::sycl::cl_int4 condition(
624       ifPacket.select[0] ? 0 : -1, ifPacket.select[1] ? 0 : -1,
625       ifPacket.select[2] ? 0 : -1, ifPacket.select[3] ? 0 : -1);
626   return cl::sycl::select(thenPacket, elsePacket, condition);
627 }
628 
629 template <>
630 inline cl::sycl::cl_double2 pblend(
631     const Selector<unpacket_traits<cl::sycl::cl_double2>::size>& ifPacket,
632     const cl::sycl::cl_double2& thenPacket,
633     const cl::sycl::cl_double2& elsePacket) {
634   cl::sycl::cl_long2 condition(ifPacket.select[0] ? 0 : -1,
635                                ifPacket.select[1] ? 0 : -1);
636   return cl::sycl::select(thenPacket, elsePacket, condition);
637 }
638 #endif  // SYCL_DEVICE_ONLY
639 
640 #define SYCL_PSTORE(alignment)                                  \
641   template <typename packet_type>                               \
642   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstore##alignment( \
643       const Eigen::TensorSycl::internal::RangeAccess<           \
644           cl::sycl::access::mode::read_write,                   \
645           typename unpacket_traits<packet_type>::type>& to,     \
646       const packet_type& from) {                                \
647     pstore##alignment(to.get_pointer(), from);                  \
648   }
649 
650 // global space
651 SYCL_PSTORE()
652 SYCL_PSTORE(u)
653 
654 #undef SYCL_PSTORE
655 
656 template <typename scalar, typename packet_type, int Alignment>
657 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstoret(
658     Eigen::TensorSycl::internal::RangeAccess<
659         cl::sycl::access::mode::read_write,
660         typename unpacket_traits<packet_type>::type>
661         to,
662     const packet_type& from) {
663   pstoret<scalar, packet_type, Alignment>(to.get_pointer(), from);
664 }
665 
666 }  // end namespace internal
667 
668 }  // end namespace Eigen
669 
670 #endif  // EIGEN_PACKET_MATH_SYCL_H
671