xref: /aosp_15_r20/external/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2015 Gael Guennebaud <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
11 #define EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template<typename Lhs, typename Rhs, typename ResultType>
18 static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, bool sortedInsertion = false)
19 {
20   typedef typename remove_all<Lhs>::type::Scalar LhsScalar;
21   typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
22   typedef typename remove_all<ResultType>::type::Scalar ResScalar;
23 
24   // make sure to call innerSize/outerSize since we fake the storage order.
25   Index rows = lhs.innerSize();
26   Index cols = rhs.outerSize();
27   eigen_assert(lhs.outerSize() == rhs.innerSize());
28 
29   ei_declare_aligned_stack_constructed_variable(bool,   mask,     rows, 0);
30   ei_declare_aligned_stack_constructed_variable(ResScalar, values,   rows, 0);
31   ei_declare_aligned_stack_constructed_variable(Index,  indices,  rows, 0);
32 
33   std::memset(mask,0,sizeof(bool)*rows);
34 
35   evaluator<Lhs> lhsEval(lhs);
36   evaluator<Rhs> rhsEval(rhs);
37 
38   // estimate the number of non zero entries
39   // given a rhs column containing Y non zeros, we assume that the respective Y columns
40   // of the lhs differs in average of one non zeros, thus the number of non zeros for
41   // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
42   // per column of the lhs.
43   // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
44   Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();
45 
46   res.setZero();
47   res.reserve(Index(estimated_nnz_prod));
48   // we compute each column of the result, one after the other
49   for (Index j=0; j<cols; ++j)
50   {
51 
52     res.startVec(j);
53     Index nnz = 0;
54     for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
55     {
56       RhsScalar y = rhsIt.value();
57       Index k = rhsIt.index();
58       for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt)
59       {
60         Index i = lhsIt.index();
61         LhsScalar x = lhsIt.value();
62         if(!mask[i])
63         {
64           mask[i] = true;
65           values[i] = x * y;
66           indices[nnz] = i;
67           ++nnz;
68         }
69         else
70           values[i] += x * y;
71       }
72     }
73     if(!sortedInsertion)
74     {
75       // unordered insertion
76       for(Index k=0; k<nnz; ++k)
77       {
78         Index i = indices[k];
79         res.insertBackByOuterInnerUnordered(j,i) = values[i];
80         mask[i] = false;
81       }
82     }
83     else
84     {
85       // alternative ordered insertion code:
86       const Index t200 = rows/11; // 11 == (log2(200)*1.39)
87       const Index t = (rows*100)/139;
88 
89       // FIXME reserve nnz non zeros
90       // FIXME implement faster sorting algorithms for very small nnz
91       // if the result is sparse enough => use a quick sort
92       // otherwise => loop through the entire vector
93       // In order to avoid to perform an expensive log2 when the
94       // result is clearly very sparse we use a linear bound up to 200.
95       if((nnz<200 && nnz<t200) || nnz * numext::log2(int(nnz)) < t)
96       {
97         if(nnz>1) std::sort(indices,indices+nnz);
98         for(Index k=0; k<nnz; ++k)
99         {
100           Index i = indices[k];
101           res.insertBackByOuterInner(j,i) = values[i];
102           mask[i] = false;
103         }
104       }
105       else
106       {
107         // dense path
108         for(Index i=0; i<rows; ++i)
109         {
110           if(mask[i])
111           {
112             mask[i] = false;
113             res.insertBackByOuterInner(j,i) = values[i];
114           }
115         }
116       }
117     }
118   }
119   res.finalize();
120 }
121 
122 
123 } // end namespace internal
124 
125 namespace internal {
126 
127 template<typename Lhs, typename Rhs, typename ResultType,
128   int LhsStorageOrder = (traits<Lhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
129   int RhsStorageOrder = (traits<Rhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
130   int ResStorageOrder = (traits<ResultType>::Flags&RowMajorBit) ? RowMajor : ColMajor>
131 struct conservative_sparse_sparse_product_selector;
132 
133 template<typename Lhs, typename Rhs, typename ResultType>
134 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
135 {
136   typedef typename remove_all<Lhs>::type LhsCleaned;
137   typedef typename LhsCleaned::Scalar Scalar;
138 
139   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
140   {
141     typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
142     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrixAux;
143     typedef typename sparse_eval<ColMajorMatrixAux,ResultType::RowsAtCompileTime,ResultType::ColsAtCompileTime,ColMajorMatrixAux::Flags>::type ColMajorMatrix;
144 
145     // If the result is tall and thin (in the extreme case a column vector)
146     // then it is faster to sort the coefficients inplace instead of transposing twice.
147     // FIXME, the following heuristic is probably not very good.
148     if(lhs.rows()>rhs.cols())
149     {
150       ColMajorMatrix resCol(lhs.rows(),rhs.cols());
151       // perform sorted insertion
152       internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol, true);
153       res = resCol.markAsRValue();
154     }
155     else
156     {
157       ColMajorMatrixAux resCol(lhs.rows(),rhs.cols());
158       // resort to transpose to sort the entries
159       internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrixAux>(lhs, rhs, resCol, false);
160       RowMajorMatrix resRow(resCol);
161       res = resRow.markAsRValue();
162     }
163   }
164 };
165 
166 template<typename Lhs, typename Rhs, typename ResultType>
167 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
168 {
169   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
170   {
171     typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRhs;
172     typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRes;
173     RowMajorRhs rhsRow = rhs;
174     RowMajorRes resRow(lhs.rows(), rhs.cols());
175     internal::conservative_sparse_sparse_product_impl<RowMajorRhs,Lhs,RowMajorRes>(rhsRow, lhs, resRow);
176     res = resRow;
177   }
178 };
179 
180 template<typename Lhs, typename Rhs, typename ResultType>
181 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
182 {
183   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
184   {
185     typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorLhs;
186     typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRes;
187     RowMajorLhs lhsRow = lhs;
188     RowMajorRes resRow(lhs.rows(), rhs.cols());
189     internal::conservative_sparse_sparse_product_impl<Rhs,RowMajorLhs,RowMajorRes>(rhs, lhsRow, resRow);
190     res = resRow;
191   }
192 };
193 
194 template<typename Lhs, typename Rhs, typename ResultType>
195 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
196 {
197   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
198   {
199     typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
200     RowMajorMatrix resRow(lhs.rows(), rhs.cols());
201     internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
202     res = resRow;
203   }
204 };
205 
206 
207 template<typename Lhs, typename Rhs, typename ResultType>
208 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
209 {
210   typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
211 
212   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
213   {
214     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
215     ColMajorMatrix resCol(lhs.rows(), rhs.cols());
216     internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol);
217     res = resCol;
218   }
219 };
220 
221 template<typename Lhs, typename Rhs, typename ResultType>
222 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
223 {
224   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
225   {
226     typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorLhs;
227     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRes;
228     ColMajorLhs lhsCol = lhs;
229     ColMajorRes resCol(lhs.rows(), rhs.cols());
230     internal::conservative_sparse_sparse_product_impl<ColMajorLhs,Rhs,ColMajorRes>(lhsCol, rhs, resCol);
231     res = resCol;
232   }
233 };
234 
235 template<typename Lhs, typename Rhs, typename ResultType>
236 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
237 {
238   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
239   {
240     typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRhs;
241     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRes;
242     ColMajorRhs rhsCol = rhs;
243     ColMajorRes resCol(lhs.rows(), rhs.cols());
244     internal::conservative_sparse_sparse_product_impl<Lhs,ColMajorRhs,ColMajorRes>(lhs, rhsCol, resCol);
245     res = resCol;
246   }
247 };
248 
249 template<typename Lhs, typename Rhs, typename ResultType>
250 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
251 {
252   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
253   {
254     typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
255     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
256     RowMajorMatrix resRow(lhs.rows(),rhs.cols());
257     internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
258     // sort the non zeros:
259     ColMajorMatrix resCol(resRow);
260     res = resCol;
261   }
262 };
263 
264 } // end namespace internal
265 
266 
267 namespace internal {
268 
269 template<typename Lhs, typename Rhs, typename ResultType>
270 static void sparse_sparse_to_dense_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
271 {
272   typedef typename remove_all<Lhs>::type::Scalar LhsScalar;
273   typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
274   Index cols = rhs.outerSize();
275   eigen_assert(lhs.outerSize() == rhs.innerSize());
276 
277   evaluator<Lhs> lhsEval(lhs);
278   evaluator<Rhs> rhsEval(rhs);
279 
280   for (Index j=0; j<cols; ++j)
281   {
282     for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
283     {
284       RhsScalar y = rhsIt.value();
285       Index k = rhsIt.index();
286       for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt)
287       {
288         Index i = lhsIt.index();
289         LhsScalar x = lhsIt.value();
290         res.coeffRef(i,j) += x * y;
291       }
292     }
293   }
294 }
295 
296 
297 } // end namespace internal
298 
299 namespace internal {
300 
301 template<typename Lhs, typename Rhs, typename ResultType,
302   int LhsStorageOrder = (traits<Lhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
303   int RhsStorageOrder = (traits<Rhs>::Flags&RowMajorBit) ? RowMajor : ColMajor>
304 struct sparse_sparse_to_dense_product_selector;
305 
306 template<typename Lhs, typename Rhs, typename ResultType>
307 struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor>
308 {
309   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
310   {
311     internal::sparse_sparse_to_dense_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, res);
312   }
313 };
314 
315 template<typename Lhs, typename Rhs, typename ResultType>
316 struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor>
317 {
318   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
319   {
320     typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorLhs;
321     ColMajorLhs lhsCol(lhs);
322     internal::sparse_sparse_to_dense_product_impl<ColMajorLhs,Rhs,ResultType>(lhsCol, rhs, res);
323   }
324 };
325 
326 template<typename Lhs, typename Rhs, typename ResultType>
327 struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor>
328 {
329   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
330   {
331     typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRhs;
332     ColMajorRhs rhsCol(rhs);
333     internal::sparse_sparse_to_dense_product_impl<Lhs,ColMajorRhs,ResultType>(lhs, rhsCol, res);
334   }
335 };
336 
337 template<typename Lhs, typename Rhs, typename ResultType>
338 struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor>
339 {
340   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
341   {
342     Transpose<ResultType> trRes(res);
343     internal::sparse_sparse_to_dense_product_impl<Rhs,Lhs,Transpose<ResultType> >(rhs, lhs, trRes);
344   }
345 };
346 
347 
348 } // end namespace internal
349 
350 } // end namespace Eigen
351 
352 #endif // EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
353