1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2008-2015 Gael Guennebaud <[email protected]> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SPARSEDENSEPRODUCT_H 11 #define EIGEN_SPARSEDENSEPRODUCT_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 template <> struct product_promote_storage_type<Sparse,Dense, OuterProduct> { typedef Sparse ret; }; 18 template <> struct product_promote_storage_type<Dense,Sparse, OuterProduct> { typedef Sparse ret; }; 19 20 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, 21 typename AlphaType, 22 int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor, 23 bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1> 24 struct sparse_time_dense_product_impl; 25 26 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 27 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, true> 28 { 29 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 30 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 31 typedef typename internal::remove_all<DenseResType>::type Res; 32 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 33 typedef evaluator<Lhs> LhsEval; 34 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 35 { 36 LhsEval lhsEval(lhs); 37 38 Index n = lhs.outerSize(); 39 #ifdef EIGEN_HAS_OPENMP 40 Eigen::initParallel(); 41 Index threads = Eigen::nbThreads(); 42 #endif 43 44 for(Index c=0; c<rhs.cols(); ++c) 45 { 46 #ifdef EIGEN_HAS_OPENMP 47 // This 20000 threshold has been found experimentally on 2D and 3D Poisson problems. 48 // It basically represents the minimal amount of work to be done to be worth it. 49 if(threads>1 && lhsEval.nonZerosEstimate() > 20000) 50 { 51 #pragma omp parallel for schedule(dynamic,(n+threads*4-1)/(threads*4)) num_threads(threads) 52 for(Index i=0; i<n; ++i) 53 processRow(lhsEval,rhs,res,alpha,i,c); 54 } 55 else 56 #endif 57 { 58 for(Index i=0; i<n; ++i) 59 processRow(lhsEval,rhs,res,alpha,i,c); 60 } 61 } 62 } 63 64 static void processRow(const LhsEval& lhsEval, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha, Index i, Index col) 65 { 66 typename Res::Scalar tmp(0); 67 for(LhsInnerIterator it(lhsEval,i); it ;++it) 68 tmp += it.value() * rhs.coeff(it.index(),col); 69 res.coeffRef(i,col) += alpha * tmp; 70 } 71 72 }; 73 74 // FIXME: what is the purpose of the following specialization? Is it for the BlockedSparse format? 75 // -> let's disable it for now as it is conflicting with generic scalar*matrix and matrix*scalar operators 76 // template<typename T1, typename T2/*, int _Options, typename _StrideType*/> 77 // struct ScalarBinaryOpTraits<T1, Ref<T2/*, _Options, _StrideType*/> > 78 // { 79 // enum { 80 // Defined = 1 81 // }; 82 // typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType; 83 // }; 84 85 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType> 86 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType, ColMajor, true> 87 { 88 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 89 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 90 typedef typename internal::remove_all<DenseResType>::type Res; 91 typedef evaluator<Lhs> LhsEval; 92 typedef typename LhsEval::InnerIterator LhsInnerIterator; 93 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) 94 { 95 LhsEval lhsEval(lhs); 96 for(Index c=0; c<rhs.cols(); ++c) 97 { 98 for(Index j=0; j<lhs.outerSize(); ++j) 99 { 100 // typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c); 101 typename ScalarBinaryOpTraits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c)); 102 for(LhsInnerIterator it(lhsEval,j); it ;++it) 103 res.coeffRef(it.index(),c) += it.value() * rhs_j; 104 } 105 } 106 } 107 }; 108 109 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 110 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, false> 111 { 112 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 113 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 114 typedef typename internal::remove_all<DenseResType>::type Res; 115 typedef evaluator<Lhs> LhsEval; 116 typedef typename LhsEval::InnerIterator LhsInnerIterator; 117 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 118 { 119 Index n = lhs.rows(); 120 LhsEval lhsEval(lhs); 121 122 #ifdef EIGEN_HAS_OPENMP 123 Eigen::initParallel(); 124 Index threads = Eigen::nbThreads(); 125 // This 20000 threshold has been found experimentally on 2D and 3D Poisson problems. 126 // It basically represents the minimal amount of work to be done to be worth it. 127 if(threads>1 && lhsEval.nonZerosEstimate()*rhs.cols() > 20000) 128 { 129 #pragma omp parallel for schedule(dynamic,(n+threads*4-1)/(threads*4)) num_threads(threads) 130 for(Index i=0; i<n; ++i) 131 processRow(lhsEval,rhs,res,alpha,i); 132 } 133 else 134 #endif 135 { 136 for(Index i=0; i<n; ++i) 137 processRow(lhsEval, rhs, res, alpha, i); 138 } 139 } 140 141 static void processRow(const LhsEval& lhsEval, const DenseRhsType& rhs, Res& res, const typename Res::Scalar& alpha, Index i) 142 { 143 typename Res::RowXpr res_i(res.row(i)); 144 for(LhsInnerIterator it(lhsEval,i); it ;++it) 145 res_i += (alpha*it.value()) * rhs.row(it.index()); 146 } 147 }; 148 149 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 150 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, ColMajor, false> 151 { 152 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 153 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 154 typedef typename internal::remove_all<DenseResType>::type Res; 155 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 156 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 157 { 158 evaluator<Lhs> lhsEval(lhs); 159 for(Index j=0; j<lhs.outerSize(); ++j) 160 { 161 typename Rhs::ConstRowXpr rhs_j(rhs.row(j)); 162 for(LhsInnerIterator it(lhsEval,j); it ;++it) 163 res.row(it.index()) += (alpha*it.value()) * rhs_j; 164 } 165 } 166 }; 167 168 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType> 169 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) 170 { 171 sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha); 172 } 173 174 } // end namespace internal 175 176 namespace internal { 177 178 template<typename Lhs, typename Rhs, int ProductType> 179 struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> 180 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SparseShape,DenseShape,ProductType> > 181 { 182 typedef typename Product<Lhs,Rhs>::Scalar Scalar; 183 184 template<typename Dest> 185 static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) 186 { 187 typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? 1 : Rhs::ColsAtCompileTime>::type LhsNested; 188 typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==0) ? 1 : Dynamic>::type RhsNested; 189 LhsNested lhsNested(lhs); 190 RhsNested rhsNested(rhs); 191 internal::sparse_time_dense_product(lhsNested, rhsNested, dst, alpha); 192 } 193 }; 194 195 template<typename Lhs, typename Rhs, int ProductType> 196 struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType> 197 : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> 198 {}; 199 200 template<typename Lhs, typename Rhs, int ProductType> 201 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> 202 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SparseShape,ProductType> > 203 { 204 typedef typename Product<Lhs,Rhs>::Scalar Scalar; 205 206 template<typename Dst> 207 static void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) 208 { 209 typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? Dynamic : 1>::type LhsNested; 210 typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==RowMajorBit) ? 1 : Lhs::RowsAtCompileTime>::type RhsNested; 211 LhsNested lhsNested(lhs); 212 RhsNested rhsNested(rhs); 213 214 // transpose everything 215 Transpose<Dst> dstT(dst); 216 internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, alpha); 217 } 218 }; 219 220 template<typename Lhs, typename Rhs, int ProductType> 221 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType> 222 : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> 223 {}; 224 225 template<typename LhsT, typename RhsT, bool NeedToTranspose> 226 struct sparse_dense_outer_product_evaluator 227 { 228 protected: 229 typedef typename conditional<NeedToTranspose,RhsT,LhsT>::type Lhs1; 230 typedef typename conditional<NeedToTranspose,LhsT,RhsT>::type ActualRhs; 231 typedef Product<LhsT,RhsT,DefaultProduct> ProdXprType; 232 233 // if the actual left-hand side is a dense vector, 234 // then build a sparse-view so that we can seamlessly iterate over it. 235 typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, 236 Lhs1, SparseView<Lhs1> >::type ActualLhs; 237 typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, 238 Lhs1 const&, SparseView<Lhs1> >::type LhsArg; 239 240 typedef evaluator<ActualLhs> LhsEval; 241 typedef evaluator<ActualRhs> RhsEval; 242 typedef typename evaluator<ActualLhs>::InnerIterator LhsIterator; 243 typedef typename ProdXprType::Scalar Scalar; 244 245 public: 246 enum { 247 Flags = NeedToTranspose ? RowMajorBit : 0, 248 CoeffReadCost = HugeCost 249 }; 250 251 class InnerIterator : public LhsIterator 252 { 253 public: 254 InnerIterator(const sparse_dense_outer_product_evaluator &xprEval, Index outer) 255 : LhsIterator(xprEval.m_lhsXprImpl, 0), 256 m_outer(outer), 257 m_empty(false), 258 m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<ActualRhs>::StorageKind() )) 259 {} 260 261 EIGEN_STRONG_INLINE Index outer() const { return m_outer; } 262 EIGEN_STRONG_INLINE Index row() const { return NeedToTranspose ? m_outer : LhsIterator::index(); } 263 EIGEN_STRONG_INLINE Index col() const { return NeedToTranspose ? LhsIterator::index() : m_outer; } 264 265 EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; } 266 EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); } 267 268 protected: 269 Scalar get(const RhsEval &rhs, Index outer, Dense = Dense()) const 270 { 271 return rhs.coeff(outer); 272 } 273 274 Scalar get(const RhsEval &rhs, Index outer, Sparse = Sparse()) 275 { 276 typename RhsEval::InnerIterator it(rhs, outer); 277 if (it && it.index()==0 && it.value()!=Scalar(0)) 278 return it.value(); 279 m_empty = true; 280 return Scalar(0); 281 } 282 283 Index m_outer; 284 bool m_empty; 285 Scalar m_factor; 286 }; 287 288 sparse_dense_outer_product_evaluator(const Lhs1 &lhs, const ActualRhs &rhs) 289 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) 290 { 291 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 292 } 293 294 // transpose case 295 sparse_dense_outer_product_evaluator(const ActualRhs &rhs, const Lhs1 &lhs) 296 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) 297 { 298 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 299 } 300 301 protected: 302 const LhsArg m_lhs; 303 evaluator<ActualLhs> m_lhsXprImpl; 304 evaluator<ActualRhs> m_rhsXprImpl; 305 }; 306 307 // sparse * dense outer product 308 template<typename Lhs, typename Rhs> 309 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape> 310 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> 311 { 312 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base; 313 314 typedef Product<Lhs, Rhs> XprType; 315 typedef typename XprType::PlainObject PlainObject; 316 317 explicit product_evaluator(const XprType& xpr) 318 : Base(xpr.lhs(), xpr.rhs()) 319 {} 320 321 }; 322 323 template<typename Lhs, typename Rhs> 324 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape> 325 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> 326 { 327 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base; 328 329 typedef Product<Lhs, Rhs> XprType; 330 typedef typename XprType::PlainObject PlainObject; 331 332 explicit product_evaluator(const XprType& xpr) 333 : Base(xpr.lhs(), xpr.rhs()) 334 {} 335 336 }; 337 338 } // end namespace internal 339 340 } // end namespace Eigen 341 342 #endif // EIGEN_SPARSEDENSEPRODUCT_H 343