1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Gael Guennebaud <[email protected]> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SOLVEWITHGUESS_H 11 #define EIGEN_SOLVEWITHGUESS_H 12 13 namespace Eigen { 14 15 template<typename Decomposition, typename RhsType, typename GuessType> class SolveWithGuess; 16 17 /** \class SolveWithGuess 18 * \ingroup IterativeLinearSolvers_Module 19 * 20 * \brief Pseudo expression representing a solving operation 21 * 22 * \tparam Decomposition the type of the matrix or decomposion object 23 * \tparam Rhstype the type of the right-hand side 24 * 25 * This class represents an expression of A.solve(B) 26 * and most of the time this is the only way it is used. 27 * 28 */ 29 namespace internal { 30 31 32 template<typename Decomposition, typename RhsType, typename GuessType> 33 struct traits<SolveWithGuess<Decomposition, RhsType, GuessType> > 34 : traits<Solve<Decomposition,RhsType> > 35 {}; 36 37 } 38 39 40 template<typename Decomposition, typename RhsType, typename GuessType> 41 class SolveWithGuess : public internal::generic_xpr_base<SolveWithGuess<Decomposition,RhsType,GuessType>, MatrixXpr, typename internal::traits<RhsType>::StorageKind>::type 42 { 43 public: 44 typedef typename internal::traits<SolveWithGuess>::Scalar Scalar; 45 typedef typename internal::traits<SolveWithGuess>::PlainObject PlainObject; 46 typedef typename internal::generic_xpr_base<SolveWithGuess<Decomposition,RhsType,GuessType>, MatrixXpr, typename internal::traits<RhsType>::StorageKind>::type Base; 47 typedef typename internal::ref_selector<SolveWithGuess>::type Nested; 48 49 SolveWithGuess(const Decomposition &dec, const RhsType &rhs, const GuessType &guess) 50 : m_dec(dec), m_rhs(rhs), m_guess(guess) 51 {} 52 53 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR 54 Index rows() const EIGEN_NOEXCEPT { return m_dec.cols(); } 55 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR 56 Index cols() const EIGEN_NOEXCEPT { return m_rhs.cols(); } 57 58 EIGEN_DEVICE_FUNC const Decomposition& dec() const { return m_dec; } 59 EIGEN_DEVICE_FUNC const RhsType& rhs() const { return m_rhs; } 60 EIGEN_DEVICE_FUNC const GuessType& guess() const { return m_guess; } 61 62 protected: 63 const Decomposition &m_dec; 64 const RhsType &m_rhs; 65 const GuessType &m_guess; 66 67 private: 68 Scalar coeff(Index row, Index col) const; 69 Scalar coeff(Index i) const; 70 }; 71 72 namespace internal { 73 74 // Evaluator of SolveWithGuess -> eval into a temporary 75 template<typename Decomposition, typename RhsType, typename GuessType> 76 struct evaluator<SolveWithGuess<Decomposition,RhsType, GuessType> > 77 : public evaluator<typename SolveWithGuess<Decomposition,RhsType,GuessType>::PlainObject> 78 { 79 typedef SolveWithGuess<Decomposition,RhsType,GuessType> SolveType; 80 typedef typename SolveType::PlainObject PlainObject; 81 typedef evaluator<PlainObject> Base; 82 83 evaluator(const SolveType& solve) 84 : m_result(solve.rows(), solve.cols()) 85 { 86 ::new (static_cast<Base*>(this)) Base(m_result); 87 m_result = solve.guess(); 88 solve.dec()._solve_with_guess_impl(solve.rhs(), m_result); 89 } 90 91 protected: 92 PlainObject m_result; 93 }; 94 95 // Specialization for "dst = dec.solveWithGuess(rhs)" 96 // NOTE we need to specialize it for Dense2Dense to avoid ambiguous specialization error and a Sparse2Sparse specialization must exist somewhere 97 template<typename DstXprType, typename DecType, typename RhsType, typename GuessType, typename Scalar> 98 struct Assignment<DstXprType, SolveWithGuess<DecType,RhsType,GuessType>, internal::assign_op<Scalar,Scalar>, Dense2Dense> 99 { 100 typedef SolveWithGuess<DecType,RhsType,GuessType> SrcXprType; 101 static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &) 102 { 103 Index dstRows = src.rows(); 104 Index dstCols = src.cols(); 105 if((dst.rows()!=dstRows) || (dst.cols()!=dstCols)) 106 dst.resize(dstRows, dstCols); 107 108 dst = src.guess(); 109 src.dec()._solve_with_guess_impl(src.rhs(), dst/*, src.guess()*/); 110 } 111 }; 112 113 } // end namespace internal 114 115 } // end namespace Eigen 116 117 #endif // EIGEN_SOLVEWITHGUESS_H 118