1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2009 Gael Guennebaud <[email protected]> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H 11 #define EIGEN_TRIANGULARMATRIXVECTOR_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized> 18 struct triangular_matrix_vector_product; 19 20 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version> 21 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version> 22 { 23 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; 24 enum { 25 IsLower = ((Mode&Lower)==Lower), 26 HasUnitDiag = (Mode & UnitDiag)==UnitDiag, 27 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag 28 }; 29 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 30 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha); 31 }; 32 33 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version> 34 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version> 35 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 36 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha) 37 { 38 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; 39 Index size = (std::min)(_rows,_cols); 40 Index rows = IsLower ? _rows : (std::min)(_rows,_cols); 41 Index cols = IsLower ? (std::min)(_rows,_cols) : _cols; 42 43 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap; 44 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); 45 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs); 46 47 typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap; 48 const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr)); 49 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs); 50 51 typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap; 52 ResMap res(_res,rows); 53 54 typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper; 55 typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper; 56 57 for (Index pi=0; pi<size; pi+=PanelWidth) 58 { 59 Index actualPanelWidth = (std::min)(PanelWidth, size-pi); 60 for (Index k=0; k<actualPanelWidth; ++k) 61 { 62 Index i = pi + k; 63 Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi; 64 Index r = IsLower ? actualPanelWidth-k : k+1; 65 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0) 66 res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r); 67 if (HasUnitDiag) 68 res.coeffRef(i) += alpha * cjRhs.coeff(i); 69 } 70 Index r = IsLower ? rows - pi - actualPanelWidth : pi; 71 if (r>0) 72 { 73 Index s = IsLower ? pi+actualPanelWidth : 0; 74 general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run( 75 r, actualPanelWidth, 76 LhsMapper(&lhs.coeffRef(s,pi), lhsStride), 77 RhsMapper(&rhs.coeffRef(pi), rhsIncr), 78 &res.coeffRef(s), resIncr, alpha); 79 } 80 } 81 if((!IsLower) && cols>size) 82 { 83 general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run( 84 rows, cols-size, 85 LhsMapper(&lhs.coeffRef(0,size), lhsStride), 86 RhsMapper(&rhs.coeffRef(size), rhsIncr), 87 _res, resIncr, alpha); 88 } 89 } 90 91 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version> 92 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version> 93 { 94 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; 95 enum { 96 IsLower = ((Mode&Lower)==Lower), 97 HasUnitDiag = (Mode & UnitDiag)==UnitDiag, 98 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag 99 }; 100 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 101 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha); 102 }; 103 104 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version> 105 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version> 106 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 107 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha) 108 { 109 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; 110 Index diagSize = (std::min)(_rows,_cols); 111 Index rows = IsLower ? _rows : diagSize; 112 Index cols = IsLower ? diagSize : _cols; 113 114 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap; 115 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); 116 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs); 117 118 typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap; 119 const RhsMap rhs(_rhs,cols); 120 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs); 121 122 typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap; 123 ResMap res(_res,rows,InnerStride<>(resIncr)); 124 125 typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper; 126 typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper; 127 128 for (Index pi=0; pi<diagSize; pi+=PanelWidth) 129 { 130 Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi); 131 for (Index k=0; k<actualPanelWidth; ++k) 132 { 133 Index i = pi + k; 134 Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i); 135 Index r = IsLower ? k+1 : actualPanelWidth-k; 136 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0) 137 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum(); 138 if (HasUnitDiag) 139 res.coeffRef(i) += alpha * cjRhs.coeff(i); 140 } 141 Index r = IsLower ? pi : cols - pi - actualPanelWidth; 142 if (r>0) 143 { 144 Index s = IsLower ? 0 : pi + actualPanelWidth; 145 general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run( 146 actualPanelWidth, r, 147 LhsMapper(&lhs.coeffRef(pi,s), lhsStride), 148 RhsMapper(&rhs.coeffRef(s), rhsIncr), 149 &res.coeffRef(pi), resIncr, alpha); 150 } 151 } 152 if(IsLower && rows>diagSize) 153 { 154 general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run( 155 rows-diagSize, cols, 156 LhsMapper(&lhs.coeffRef(diagSize,0), lhsStride), 157 RhsMapper(&rhs.coeffRef(0), rhsIncr), 158 &res.coeffRef(diagSize), resIncr, alpha); 159 } 160 } 161 162 /*************************************************************************** 163 * Wrapper to product_triangular_vector 164 ***************************************************************************/ 165 166 template<int Mode,int StorageOrder> 167 struct trmv_selector; 168 169 } // end namespace internal 170 171 namespace internal { 172 173 template<int Mode, typename Lhs, typename Rhs> 174 struct triangular_product_impl<Mode,true,Lhs,false,Rhs,true> 175 { 176 template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha) 177 { 178 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols()); 179 180 internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(lhs, rhs, dst, alpha); 181 } 182 }; 183 184 template<int Mode, typename Lhs, typename Rhs> 185 struct triangular_product_impl<Mode,false,Lhs,true,Rhs,false> 186 { 187 template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha) 188 { 189 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols()); 190 191 Transpose<Dest> dstT(dst); 192 internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower), 193 (int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor> 194 ::run(rhs.transpose(),lhs.transpose(), dstT, alpha); 195 } 196 }; 197 198 } // end namespace internal 199 200 namespace internal { 201 202 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same. 203 204 template<int Mode> struct trmv_selector<Mode,ColMajor> 205 { 206 template<typename Lhs, typename Rhs, typename Dest> 207 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha) 208 { 209 typedef typename Lhs::Scalar LhsScalar; 210 typedef typename Rhs::Scalar RhsScalar; 211 typedef typename Dest::Scalar ResScalar; 212 typedef typename Dest::RealScalar RealScalar; 213 214 typedef internal::blas_traits<Lhs> LhsBlasTraits; 215 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; 216 typedef internal::blas_traits<Rhs> RhsBlasTraits; 217 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; 218 219 typedef Map<Matrix<ResScalar,Dynamic,1>, EIGEN_PLAIN_ENUM_MIN(AlignedMax,internal::packet_traits<ResScalar>::size)> MappedDest; 220 221 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs); 222 typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs); 223 224 LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs); 225 RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs); 226 ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha; 227 228 enum { 229 // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1 230 // on, the other hand it is good for the cache to pack the vector anyways... 231 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1, 232 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex), 233 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal 234 }; 235 236 gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest; 237 238 bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0)); 239 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; 240 241 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha); 242 243 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), 244 evalToDest ? dest.data() : static_dest.data()); 245 246 if(!evalToDest) 247 { 248 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 249 Index size = dest.size(); 250 EIGEN_DENSE_STORAGE_CTOR_PLUGIN 251 #endif 252 if(!alphaIsCompatible) 253 { 254 MappedDest(actualDestPtr, dest.size()).setZero(); 255 compatibleAlpha = RhsScalar(1); 256 } 257 else 258 MappedDest(actualDestPtr, dest.size()) = dest; 259 } 260 261 internal::triangular_matrix_vector_product 262 <Index,Mode, 263 LhsScalar, LhsBlasTraits::NeedToConjugate, 264 RhsScalar, RhsBlasTraits::NeedToConjugate, 265 ColMajor> 266 ::run(actualLhs.rows(),actualLhs.cols(), 267 actualLhs.data(),actualLhs.outerStride(), 268 actualRhs.data(),actualRhs.innerStride(), 269 actualDestPtr,1,compatibleAlpha); 270 271 if (!evalToDest) 272 { 273 if(!alphaIsCompatible) 274 dest += actualAlpha * MappedDest(actualDestPtr, dest.size()); 275 else 276 dest = MappedDest(actualDestPtr, dest.size()); 277 } 278 279 if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) ) 280 { 281 Index diagSize = (std::min)(lhs.rows(),lhs.cols()); 282 dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize); 283 } 284 } 285 }; 286 287 template<int Mode> struct trmv_selector<Mode,RowMajor> 288 { 289 template<typename Lhs, typename Rhs, typename Dest> 290 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha) 291 { 292 typedef typename Lhs::Scalar LhsScalar; 293 typedef typename Rhs::Scalar RhsScalar; 294 typedef typename Dest::Scalar ResScalar; 295 296 typedef internal::blas_traits<Lhs> LhsBlasTraits; 297 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; 298 typedef internal::blas_traits<Rhs> RhsBlasTraits; 299 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; 300 typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned; 301 302 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs); 303 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs); 304 305 LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs); 306 RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs); 307 ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha; 308 309 enum { 310 DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1 311 }; 312 313 gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs; 314 315 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(), 316 DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data()); 317 318 if(!DirectlyUseRhs) 319 { 320 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 321 Index size = actualRhs.size(); 322 EIGEN_DENSE_STORAGE_CTOR_PLUGIN 323 #endif 324 Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; 325 } 326 327 internal::triangular_matrix_vector_product 328 <Index,Mode, 329 LhsScalar, LhsBlasTraits::NeedToConjugate, 330 RhsScalar, RhsBlasTraits::NeedToConjugate, 331 RowMajor> 332 ::run(actualLhs.rows(),actualLhs.cols(), 333 actualLhs.data(),actualLhs.outerStride(), 334 actualRhsPtr,1, 335 dest.data(),dest.innerStride(), 336 actualAlpha); 337 338 if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) ) 339 { 340 Index diagSize = (std::min)(lhs.rows(),lhs.cols()); 341 dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize); 342 } 343 } 344 }; 345 346 } // end namespace internal 347 348 } // end namespace Eigen 349 350 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H 351