1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2009-2010 Gael Guennebaud <[email protected]> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_BLASUTIL_H 11 #define EIGEN_BLASUTIL_H 12 13 // This file contains many lightweight helper classes used to 14 // implement and control fast level 2 and level 3 BLAS-like routines. 15 16 namespace Eigen { 17 18 namespace internal { 19 20 // forward declarations 21 template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false> 22 struct gebp_kernel; 23 24 template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false> 25 struct gemm_pack_rhs; 26 27 template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, int StorageOrder, bool Conjugate = false, bool PanelMode = false> 28 struct gemm_pack_lhs; 29 30 template< 31 typename Index, 32 typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, 33 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, 34 int ResStorageOrder, int ResInnerStride> 35 struct general_matrix_matrix_product; 36 37 template<typename Index, 38 typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs, 39 typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version=Specialized> 40 struct general_matrix_vector_product; 41 42 template<typename From,typename To> struct get_factor { runget_factor43 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); } 44 }; 45 46 template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> { 47 EIGEN_DEVICE_FUNC 48 static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); } 49 }; 50 51 52 template<typename Scalar, typename Index> 53 class BlasVectorMapper { 54 public: 55 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {} 56 57 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { 58 return m_data[i]; 59 } 60 template <typename Packet, int AlignmentType> 61 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i) const { 62 return ploadt<Packet, AlignmentType>(m_data + i); 63 } 64 65 template <typename Packet> 66 EIGEN_DEVICE_FUNC bool aligned(Index i) const { 67 return (UIntPtr(m_data+i)%sizeof(Packet))==0; 68 } 69 70 protected: 71 Scalar* m_data; 72 }; 73 74 template<typename Scalar, typename Index, int AlignmentType, int Incr=1> 75 class BlasLinearMapper; 76 77 template<typename Scalar, typename Index, int AlignmentType> 78 class BlasLinearMapper<Scalar,Index,AlignmentType> 79 { 80 public: 81 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data, Index incr=1) 82 : m_data(data) 83 { 84 EIGEN_ONLY_USED_FOR_DEBUG(incr); 85 eigen_assert(incr==1); 86 } 87 88 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const { 89 internal::prefetch(&operator()(i)); 90 } 91 92 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const { 93 return m_data[i]; 94 } 95 96 template<typename PacketType> 97 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i) const { 98 return ploadt<PacketType, AlignmentType>(m_data + i); 99 } 100 101 template<typename PacketType> 102 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType &p) const { 103 pstoret<Scalar, PacketType, AlignmentType>(m_data + i, p); 104 } 105 106 protected: 107 Scalar *m_data; 108 }; 109 110 // Lightweight helper class to access matrix coefficients. 111 template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned, int Incr = 1> 112 class blas_data_mapper; 113 114 // TMP to help PacketBlock store implementation. 115 // There's currently no known use case for PacketBlock load. 116 // The default implementation assumes ColMajor order. 117 // It always store each packet sequentially one `stride` apart. 118 template<typename Index, typename Scalar, typename Packet, int n, int idx, int StorageOrder> 119 struct PacketBlockManagement 120 { 121 PacketBlockManagement<Index, Scalar, Packet, n, idx - 1, StorageOrder> pbm; 122 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar *to, const Index stride, Index i, Index j, const PacketBlock<Packet, n> &block) const { 123 pbm.store(to, stride, i, j, block); 124 pstoreu<Scalar>(to + i + (j + idx)*stride, block.packet[idx]); 125 } 126 }; 127 128 // PacketBlockManagement specialization to take care of RowMajor order without ifs. 129 template<typename Index, typename Scalar, typename Packet, int n, int idx> 130 struct PacketBlockManagement<Index, Scalar, Packet, n, idx, RowMajor> 131 { 132 PacketBlockManagement<Index, Scalar, Packet, n, idx - 1, RowMajor> pbm; 133 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar *to, const Index stride, Index i, Index j, const PacketBlock<Packet, n> &block) const { 134 pbm.store(to, stride, i, j, block); 135 pstoreu<Scalar>(to + j + (i + idx)*stride, block.packet[idx]); 136 } 137 }; 138 139 template<typename Index, typename Scalar, typename Packet, int n, int StorageOrder> 140 struct PacketBlockManagement<Index, Scalar, Packet, n, -1, StorageOrder> 141 { 142 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar *to, const Index stride, Index i, Index j, const PacketBlock<Packet, n> &block) const { 143 EIGEN_UNUSED_VARIABLE(to); 144 EIGEN_UNUSED_VARIABLE(stride); 145 EIGEN_UNUSED_VARIABLE(i); 146 EIGEN_UNUSED_VARIABLE(j); 147 EIGEN_UNUSED_VARIABLE(block); 148 } 149 }; 150 151 template<typename Index, typename Scalar, typename Packet, int n> 152 struct PacketBlockManagement<Index, Scalar, Packet, n, -1, RowMajor> 153 { 154 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar *to, const Index stride, Index i, Index j, const PacketBlock<Packet, n> &block) const { 155 EIGEN_UNUSED_VARIABLE(to); 156 EIGEN_UNUSED_VARIABLE(stride); 157 EIGEN_UNUSED_VARIABLE(i); 158 EIGEN_UNUSED_VARIABLE(j); 159 EIGEN_UNUSED_VARIABLE(block); 160 } 161 }; 162 163 template<typename Scalar, typename Index, int StorageOrder, int AlignmentType> 164 class blas_data_mapper<Scalar,Index,StorageOrder,AlignmentType,1> 165 { 166 public: 167 typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper; 168 typedef BlasVectorMapper<Scalar, Index> VectorMapper; 169 170 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1) 171 : m_data(data), m_stride(stride) 172 { 173 EIGEN_ONLY_USED_FOR_DEBUG(incr); 174 eigen_assert(incr==1); 175 } 176 177 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType> 178 getSubMapper(Index i, Index j) const { 179 return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride); 180 } 181 182 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 183 return LinearMapper(&operator()(i, j)); 184 } 185 186 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { 187 return VectorMapper(&operator()(i, j)); 188 } 189 190 191 EIGEN_DEVICE_FUNC 192 EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const { 193 return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; 194 } 195 196 template<typename PacketType> 197 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i, Index j) const { 198 return ploadt<PacketType, AlignmentType>(&operator()(i, j)); 199 } 200 201 template <typename PacketT, int AlignmentT> 202 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const { 203 return ploadt<PacketT, AlignmentT>(&operator()(i, j)); 204 } 205 206 template<typename SubPacket> 207 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const { 208 pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride); 209 } 210 211 template<typename SubPacket> 212 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const { 213 return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride); 214 } 215 216 EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; } 217 EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; } 218 219 EIGEN_DEVICE_FUNC Index firstAligned(Index size) const { 220 if (UIntPtr(m_data)%sizeof(Scalar)) { 221 return -1; 222 } 223 return internal::first_default_aligned(m_data, size); 224 } 225 226 template<typename SubPacket, int n> 227 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketBlock(Index i, Index j, const PacketBlock<SubPacket, n> &block) const { 228 PacketBlockManagement<Index, Scalar, SubPacket, n, n-1, StorageOrder> pbm; 229 pbm.store(m_data, m_stride, i, j, block); 230 } 231 protected: 232 Scalar* EIGEN_RESTRICT m_data; 233 const Index m_stride; 234 }; 235 236 // Implementation of non-natural increment (i.e. inner-stride != 1) 237 // The exposed API is not complete yet compared to the Incr==1 case 238 // because some features makes less sense in this case. 239 template<typename Scalar, typename Index, int AlignmentType, int Incr> 240 class BlasLinearMapper 241 { 242 public: 243 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data,Index incr) : m_data(data), m_incr(incr) {} 244 245 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const { 246 internal::prefetch(&operator()(i)); 247 } 248 249 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const { 250 return m_data[i*m_incr.value()]; 251 } 252 253 template<typename PacketType> 254 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i) const { 255 return pgather<Scalar,PacketType>(m_data + i*m_incr.value(), m_incr.value()); 256 } 257 258 template<typename PacketType> 259 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType &p) const { 260 pscatter<Scalar, PacketType>(m_data + i*m_incr.value(), p, m_incr.value()); 261 } 262 263 protected: 264 Scalar *m_data; 265 const internal::variable_if_dynamic<Index,Incr> m_incr; 266 }; 267 268 template<typename Scalar, typename Index, int StorageOrder, int AlignmentType,int Incr> 269 class blas_data_mapper 270 { 271 public: 272 typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper; 273 274 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {} 275 276 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper 277 getSubMapper(Index i, Index j) const { 278 return blas_data_mapper(&operator()(i, j), m_stride, m_incr.value()); 279 } 280 281 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 282 return LinearMapper(&operator()(i, j), m_incr.value()); 283 } 284 285 EIGEN_DEVICE_FUNC 286 EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const { 287 return m_data[StorageOrder==RowMajor ? j*m_incr.value() + i*m_stride : i*m_incr.value() + j*m_stride]; 288 } 289 290 template<typename PacketType> 291 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i, Index j) const { 292 return pgather<Scalar,PacketType>(&operator()(i, j),m_incr.value()); 293 } 294 295 template <typename PacketT, int AlignmentT> 296 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const { 297 return pgather<Scalar,PacketT>(&operator()(i, j),m_incr.value()); 298 } 299 300 template<typename SubPacket> 301 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const { 302 pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride); 303 } 304 305 template<typename SubPacket> 306 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const { 307 return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride); 308 } 309 310 // storePacketBlock_helper defines a way to access values inside the PacketBlock, this is essentially required by the Complex types. 311 template<typename SubPacket, typename ScalarT, int n, int idx> 312 struct storePacketBlock_helper 313 { 314 storePacketBlock_helper<SubPacket, ScalarT, n, idx-1> spbh; 315 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j, const PacketBlock<SubPacket, n>& block) const { 316 spbh.store(sup, i,j,block); 317 for(int l = 0; l < unpacket_traits<SubPacket>::size; l++) 318 { 319 ScalarT *v = &sup->operator()(i+l, j+idx); 320 *v = block.packet[idx][l]; 321 } 322 } 323 }; 324 325 template<typename SubPacket, int n, int idx> 326 struct storePacketBlock_helper<SubPacket, std::complex<float>, n, idx> 327 { 328 storePacketBlock_helper<SubPacket, std::complex<float>, n, idx-1> spbh; 329 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j, const PacketBlock<SubPacket, n>& block) const { 330 spbh.store(sup,i,j,block); 331 for(int l = 0; l < unpacket_traits<SubPacket>::size; l++) 332 { 333 std::complex<float> *v = &sup->operator()(i+l, j+idx); 334 v->real(block.packet[idx].v[2*l+0]); 335 v->imag(block.packet[idx].v[2*l+1]); 336 } 337 } 338 }; 339 340 template<typename SubPacket, int n, int idx> 341 struct storePacketBlock_helper<SubPacket, std::complex<double>, n, idx> 342 { 343 storePacketBlock_helper<SubPacket, std::complex<double>, n, idx-1> spbh; 344 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j, const PacketBlock<SubPacket, n>& block) const { 345 spbh.store(sup,i,j,block); 346 for(int l = 0; l < unpacket_traits<SubPacket>::size; l++) 347 { 348 std::complex<double> *v = &sup->operator()(i+l, j+idx); 349 v->real(block.packet[idx].v[2*l+0]); 350 v->imag(block.packet[idx].v[2*l+1]); 351 } 352 } 353 }; 354 355 template<typename SubPacket, typename ScalarT, int n> 356 struct storePacketBlock_helper<SubPacket, ScalarT, n, -1> 357 { 358 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index, const PacketBlock<SubPacket, n>& ) const { 359 } 360 }; 361 362 template<typename SubPacket, int n> 363 struct storePacketBlock_helper<SubPacket, std::complex<float>, n, -1> 364 { 365 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index, const PacketBlock<SubPacket, n>& ) const { 366 } 367 }; 368 369 template<typename SubPacket, int n> 370 struct storePacketBlock_helper<SubPacket, std::complex<double>, n, -1> 371 { 372 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index, const PacketBlock<SubPacket, n>& ) const { 373 } 374 }; 375 // This function stores a PacketBlock on m_data, this approach is really quite slow compare to Incr=1 and should be avoided when possible. 376 template<typename SubPacket, int n> 377 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketBlock(Index i, Index j, const PacketBlock<SubPacket, n>&block) const { 378 storePacketBlock_helper<SubPacket, Scalar, n, n-1> spb; 379 spb.store(this, i,j,block); 380 } 381 protected: 382 Scalar* EIGEN_RESTRICT m_data; 383 const Index m_stride; 384 const internal::variable_if_dynamic<Index,Incr> m_incr; 385 }; 386 387 // lightweight helper class to access matrix coefficients (const version) 388 template<typename Scalar, typename Index, int StorageOrder> 389 class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> { 390 public: 391 EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {} 392 393 EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const { 394 return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride); 395 } 396 }; 397 398 399 /* Helper class to analyze the factors of a Product expression. 400 * In particular it allows to pop out operator-, scalar multiples, 401 * and conjugate */ 402 template<typename XprType> struct blas_traits 403 { 404 typedef typename traits<XprType>::Scalar Scalar; 405 typedef const XprType& ExtractType; 406 typedef XprType _ExtractType; 407 enum { 408 IsComplex = NumTraits<Scalar>::IsComplex, 409 IsTransposed = false, 410 NeedToConjugate = false, 411 HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit) 412 && ( bool(XprType::IsVectorAtCompileTime) 413 || int(inner_stride_at_compile_time<XprType>::ret) == 1) 414 ) ? 1 : 0, 415 HasScalarFactor = false 416 }; 417 typedef typename conditional<bool(HasUsableDirectAccess), 418 ExtractType, 419 typename _ExtractType::PlainObject 420 >::type DirectLinearAccessType; 421 static inline EIGEN_DEVICE_FUNC ExtractType extract(const XprType& x) { return x; } 422 static inline EIGEN_DEVICE_FUNC const Scalar extractScalarFactor(const XprType&) { return Scalar(1); } 423 }; 424 425 // pop conjugate 426 template<typename Scalar, typename NestedXpr> 427 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> > 428 : blas_traits<NestedXpr> 429 { 430 typedef blas_traits<NestedXpr> Base; 431 typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType; 432 typedef typename Base::ExtractType ExtractType; 433 434 enum { 435 IsComplex = NumTraits<Scalar>::IsComplex, 436 NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex 437 }; 438 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); } 439 static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); } 440 }; 441 442 // pop scalar multiple 443 template<typename Scalar, typename NestedXpr, typename Plain> 444 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> > 445 : blas_traits<NestedXpr> 446 { 447 enum { 448 HasScalarFactor = true 449 }; 450 typedef blas_traits<NestedXpr> Base; 451 typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType; 452 typedef typename Base::ExtractType ExtractType; 453 static inline EIGEN_DEVICE_FUNC ExtractType extract(const XprType& x) { return Base::extract(x.rhs()); } 454 static inline EIGEN_DEVICE_FUNC Scalar extractScalarFactor(const XprType& x) 455 { return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); } 456 }; 457 template<typename Scalar, typename NestedXpr, typename Plain> 458 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > > 459 : blas_traits<NestedXpr> 460 { 461 enum { 462 HasScalarFactor = true 463 }; 464 typedef blas_traits<NestedXpr> Base; 465 typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType; 466 typedef typename Base::ExtractType ExtractType; 467 static inline ExtractType extract(const XprType& x) { return Base::extract(x.lhs()); } 468 static inline Scalar extractScalarFactor(const XprType& x) 469 { return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; } 470 }; 471 template<typename Scalar, typename Plain1, typename Plain2> 472 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1>, 473 const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain2> > > 474 : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1> > 475 {}; 476 477 // pop opposite 478 template<typename Scalar, typename NestedXpr> 479 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> > 480 : blas_traits<NestedXpr> 481 { 482 enum { 483 HasScalarFactor = true 484 }; 485 typedef blas_traits<NestedXpr> Base; 486 typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType; 487 typedef typename Base::ExtractType ExtractType; 488 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); } 489 static inline Scalar extractScalarFactor(const XprType& x) 490 { return - Base::extractScalarFactor(x.nestedExpression()); } 491 }; 492 493 // pop/push transpose 494 template<typename NestedXpr> 495 struct blas_traits<Transpose<NestedXpr> > 496 : blas_traits<NestedXpr> 497 { 498 typedef typename NestedXpr::Scalar Scalar; 499 typedef blas_traits<NestedXpr> Base; 500 typedef Transpose<NestedXpr> XprType; 501 typedef Transpose<const typename Base::_ExtractType> ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS 502 typedef Transpose<const typename Base::_ExtractType> _ExtractType; 503 typedef typename conditional<bool(Base::HasUsableDirectAccess), 504 ExtractType, 505 typename ExtractType::PlainObject 506 >::type DirectLinearAccessType; 507 enum { 508 IsTransposed = Base::IsTransposed ? 0 : 1 509 }; 510 static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); } 511 static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); } 512 }; 513 514 template<typename T> 515 struct blas_traits<const T> 516 : blas_traits<T> 517 {}; 518 519 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess> 520 struct extract_data_selector { 521 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static const typename T::Scalar* run(const T& m) 522 { 523 return blas_traits<T>::extract(m).data(); 524 } 525 }; 526 527 template<typename T> 528 struct extract_data_selector<T,false> { 529 static typename T::Scalar* run(const T&) { return 0; } 530 }; 531 532 template<typename T> 533 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename T::Scalar* extract_data(const T& m) 534 { 535 return extract_data_selector<T>::run(m); 536 } 537 538 /** 539 * \c combine_scalar_factors extracts and multiplies factors from GEMM and GEMV products. 540 * There is a specialization for booleans 541 */ 542 template<typename ResScalar, typename Lhs, typename Rhs> 543 struct combine_scalar_factors_impl 544 { 545 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const Lhs& lhs, const Rhs& rhs) 546 { 547 return blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs); 548 } 549 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs) 550 { 551 return alpha * blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs); 552 } 553 }; 554 template<typename Lhs, typename Rhs> 555 struct combine_scalar_factors_impl<bool, Lhs, Rhs> 556 { 557 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const Lhs& lhs, const Rhs& rhs) 558 { 559 return blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs); 560 } 561 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const bool& alpha, const Lhs& lhs, const Rhs& rhs) 562 { 563 return alpha && blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs); 564 } 565 }; 566 567 template<typename ResScalar, typename Lhs, typename Rhs> 568 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs) 569 { 570 return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(alpha, lhs, rhs); 571 } 572 template<typename ResScalar, typename Lhs, typename Rhs> 573 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const Lhs& lhs, const Rhs& rhs) 574 { 575 return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(lhs, rhs); 576 } 577 578 579 } // end namespace internal 580 581 } // end namespace Eigen 582 583 #endif // EIGEN_BLASUTIL_H 584