xref: /aosp_15_r20/external/eigen/Eigen/src/Core/products/TriangularMatrixVector.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009 Gael Guennebaud <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11 #define EIGEN_TRIANGULARMATRIXVECTOR_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
18 struct triangular_matrix_vector_product;
19 
20 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
21 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
22 {
23   typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
24   enum {
25     IsLower = ((Mode&Lower)==Lower),
26     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
27     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
28   };
29   static EIGEN_DONT_INLINE  void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
30                                      const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha);
31 };
32 
33 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
34 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
35   ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
36         const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha)
37   {
38     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
39     Index size = (std::min)(_rows,_cols);
40     Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
41     Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
42 
43     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
44     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
45     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
46 
47     typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
48     const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
49     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
50 
51     typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
52     ResMap res(_res,rows);
53 
54     typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
55     typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
56 
57     for (Index pi=0; pi<size; pi+=PanelWidth)
58     {
59       Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
60       for (Index k=0; k<actualPanelWidth; ++k)
61       {
62         Index i = pi + k;
63         Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
64         Index r = IsLower ? actualPanelWidth-k : k+1;
65         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
66           res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
67         if (HasUnitDiag)
68           res.coeffRef(i) += alpha * cjRhs.coeff(i);
69       }
70       Index r = IsLower ? rows - pi - actualPanelWidth : pi;
71       if (r>0)
72       {
73         Index s = IsLower ? pi+actualPanelWidth : 0;
74         general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
75             r, actualPanelWidth,
76             LhsMapper(&lhs.coeffRef(s,pi), lhsStride),
77             RhsMapper(&rhs.coeffRef(pi), rhsIncr),
78             &res.coeffRef(s), resIncr, alpha);
79       }
80     }
81     if((!IsLower) && cols>size)
82     {
83       general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
84           rows, cols-size,
85           LhsMapper(&lhs.coeffRef(0,size), lhsStride),
86           RhsMapper(&rhs.coeffRef(size), rhsIncr),
87           _res, resIncr, alpha);
88     }
89   }
90 
91 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
92 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
93 {
94   typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
95   enum {
96     IsLower = ((Mode&Lower)==Lower),
97     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
98     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
99   };
100   static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
101                                     const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
102 };
103 
104 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
105 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
106   ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
107         const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
108   {
109     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
110     Index diagSize = (std::min)(_rows,_cols);
111     Index rows = IsLower ? _rows : diagSize;
112     Index cols = IsLower ? diagSize : _cols;
113 
114     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
115     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
116     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
117 
118     typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
119     const RhsMap rhs(_rhs,cols);
120     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
121 
122     typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
123     ResMap res(_res,rows,InnerStride<>(resIncr));
124 
125     typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
126     typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
127 
128     for (Index pi=0; pi<diagSize; pi+=PanelWidth)
129     {
130       Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
131       for (Index k=0; k<actualPanelWidth; ++k)
132       {
133         Index i = pi + k;
134         Index s = IsLower ? pi  : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
135         Index r = IsLower ? k+1 : actualPanelWidth-k;
136         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
137           res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
138         if (HasUnitDiag)
139           res.coeffRef(i) += alpha * cjRhs.coeff(i);
140       }
141       Index r = IsLower ? pi : cols - pi - actualPanelWidth;
142       if (r>0)
143       {
144         Index s = IsLower ? 0 : pi + actualPanelWidth;
145         general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
146             actualPanelWidth, r,
147             LhsMapper(&lhs.coeffRef(pi,s), lhsStride),
148             RhsMapper(&rhs.coeffRef(s), rhsIncr),
149             &res.coeffRef(pi), resIncr, alpha);
150       }
151     }
152     if(IsLower && rows>diagSize)
153     {
154       general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
155             rows-diagSize, cols,
156             LhsMapper(&lhs.coeffRef(diagSize,0), lhsStride),
157             RhsMapper(&rhs.coeffRef(0), rhsIncr),
158             &res.coeffRef(diagSize), resIncr, alpha);
159     }
160   }
161 
162 /***************************************************************************
163 * Wrapper to product_triangular_vector
164 ***************************************************************************/
165 
166 template<int Mode,int StorageOrder>
167 struct trmv_selector;
168 
169 } // end namespace internal
170 
171 namespace internal {
172 
173 template<int Mode, typename Lhs, typename Rhs>
174 struct triangular_product_impl<Mode,true,Lhs,false,Rhs,true>
175 {
176   template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
177   {
178     eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
179 
180     internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(lhs, rhs, dst, alpha);
181   }
182 };
183 
184 template<int Mode, typename Lhs, typename Rhs>
185 struct triangular_product_impl<Mode,false,Lhs,true,Rhs,false>
186 {
187   template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
188   {
189     eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
190 
191     Transpose<Dest> dstT(dst);
192     internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),
193                             (int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
194             ::run(rhs.transpose(),lhs.transpose(), dstT, alpha);
195   }
196 };
197 
198 } // end namespace internal
199 
200 namespace internal {
201 
202 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
203 
204 template<int Mode> struct trmv_selector<Mode,ColMajor>
205 {
206   template<typename Lhs, typename Rhs, typename Dest>
207   static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
208   {
209     typedef typename Lhs::Scalar      LhsScalar;
210     typedef typename Rhs::Scalar      RhsScalar;
211     typedef typename Dest::Scalar     ResScalar;
212     typedef typename Dest::RealScalar RealScalar;
213 
214     typedef internal::blas_traits<Lhs> LhsBlasTraits;
215     typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
216     typedef internal::blas_traits<Rhs> RhsBlasTraits;
217     typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
218 
219     typedef Map<Matrix<ResScalar,Dynamic,1>, EIGEN_PLAIN_ENUM_MIN(AlignedMax,internal::packet_traits<ResScalar>::size)> MappedDest;
220 
221     typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
222     typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
223 
224     LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
225     RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
226     ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
227 
228     enum {
229       // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
230       // on, the other hand it is good for the cache to pack the vector anyways...
231       EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
232       ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
233       MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
234     };
235 
236     gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
237 
238     bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
239     bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
240 
241     RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
242 
243     ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
244                                                   evalToDest ? dest.data() : static_dest.data());
245 
246     if(!evalToDest)
247     {
248       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
249       Index size = dest.size();
250       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
251       #endif
252       if(!alphaIsCompatible)
253       {
254         MappedDest(actualDestPtr, dest.size()).setZero();
255         compatibleAlpha = RhsScalar(1);
256       }
257       else
258         MappedDest(actualDestPtr, dest.size()) = dest;
259     }
260 
261     internal::triangular_matrix_vector_product
262       <Index,Mode,
263        LhsScalar, LhsBlasTraits::NeedToConjugate,
264        RhsScalar, RhsBlasTraits::NeedToConjugate,
265        ColMajor>
266       ::run(actualLhs.rows(),actualLhs.cols(),
267             actualLhs.data(),actualLhs.outerStride(),
268             actualRhs.data(),actualRhs.innerStride(),
269             actualDestPtr,1,compatibleAlpha);
270 
271     if (!evalToDest)
272     {
273       if(!alphaIsCompatible)
274         dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
275       else
276         dest = MappedDest(actualDestPtr, dest.size());
277     }
278 
279     if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
280     {
281       Index diagSize = (std::min)(lhs.rows(),lhs.cols());
282       dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
283     }
284   }
285 };
286 
287 template<int Mode> struct trmv_selector<Mode,RowMajor>
288 {
289   template<typename Lhs, typename Rhs, typename Dest>
290   static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
291   {
292     typedef typename Lhs::Scalar      LhsScalar;
293     typedef typename Rhs::Scalar      RhsScalar;
294     typedef typename Dest::Scalar     ResScalar;
295 
296     typedef internal::blas_traits<Lhs> LhsBlasTraits;
297     typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
298     typedef internal::blas_traits<Rhs> RhsBlasTraits;
299     typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
300     typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
301 
302     typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
303     typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
304 
305     LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
306     RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
307     ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
308 
309     enum {
310       DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
311     };
312 
313     gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
314 
315     ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
316         DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
317 
318     if(!DirectlyUseRhs)
319     {
320       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
321       Index size = actualRhs.size();
322       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
323       #endif
324       Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
325     }
326 
327     internal::triangular_matrix_vector_product
328       <Index,Mode,
329        LhsScalar, LhsBlasTraits::NeedToConjugate,
330        RhsScalar, RhsBlasTraits::NeedToConjugate,
331        RowMajor>
332       ::run(actualLhs.rows(),actualLhs.cols(),
333             actualLhs.data(),actualLhs.outerStride(),
334             actualRhsPtr,1,
335             dest.data(),dest.innerStride(),
336             actualAlpha);
337 
338     if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
339     {
340       Index diagSize = (std::min)(lhs.rows(),lhs.cols());
341       dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
342     }
343   }
344 };
345 
346 } // end namespace internal
347 
348 } // end namespace Eigen
349 
350 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H
351