xref: /aosp_15_r20/external/eigen/Eigen/src/Core/IndexedView.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2017 Gael Guennebaud <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_INDEXED_VIEW_H
11 #define EIGEN_INDEXED_VIEW_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template<typename XprType, typename RowIndices, typename ColIndices>
18 struct traits<IndexedView<XprType, RowIndices, ColIndices> >
19  : traits<XprType>
20 {
21   enum {
22     RowsAtCompileTime = int(array_size<RowIndices>::value),
23     ColsAtCompileTime = int(array_size<ColIndices>::value),
24     MaxRowsAtCompileTime = RowsAtCompileTime != Dynamic ? int(RowsAtCompileTime) : Dynamic,
25     MaxColsAtCompileTime = ColsAtCompileTime != Dynamic ? int(ColsAtCompileTime) : Dynamic,
26 
27     XprTypeIsRowMajor = (int(traits<XprType>::Flags)&RowMajorBit) != 0,
28     IsRowMajor = (MaxRowsAtCompileTime==1&&MaxColsAtCompileTime!=1) ? 1
29                : (MaxColsAtCompileTime==1&&MaxRowsAtCompileTime!=1) ? 0
30                : XprTypeIsRowMajor,
31 
32     RowIncr = int(get_compile_time_incr<RowIndices>::value),
33     ColIncr = int(get_compile_time_incr<ColIndices>::value),
34     InnerIncr = IsRowMajor ? ColIncr : RowIncr,
35     OuterIncr = IsRowMajor ? RowIncr : ColIncr,
36 
37     HasSameStorageOrderAsXprType = (IsRowMajor == XprTypeIsRowMajor),
38     XprInnerStride = HasSameStorageOrderAsXprType ? int(inner_stride_at_compile_time<XprType>::ret) : int(outer_stride_at_compile_time<XprType>::ret),
39     XprOuterstride = HasSameStorageOrderAsXprType ? int(outer_stride_at_compile_time<XprType>::ret) : int(inner_stride_at_compile_time<XprType>::ret),
40 
41     InnerSize = XprTypeIsRowMajor ? ColsAtCompileTime : RowsAtCompileTime,
42     IsBlockAlike = InnerIncr==1 && OuterIncr==1,
43     IsInnerPannel = HasSameStorageOrderAsXprType && is_same<AllRange<InnerSize>,typename conditional<XprTypeIsRowMajor,ColIndices,RowIndices>::type>::value,
44 
45     InnerStrideAtCompileTime = InnerIncr<0 || InnerIncr==DynamicIndex || XprInnerStride==Dynamic ? Dynamic : XprInnerStride * InnerIncr,
46     OuterStrideAtCompileTime = OuterIncr<0 || OuterIncr==DynamicIndex || XprOuterstride==Dynamic ? Dynamic : XprOuterstride * OuterIncr,
47 
48     ReturnAsScalar = is_same<RowIndices,SingleRange>::value && is_same<ColIndices,SingleRange>::value,
49     ReturnAsBlock = (!ReturnAsScalar) && IsBlockAlike,
50     ReturnAsIndexedView = (!ReturnAsScalar) && (!ReturnAsBlock),
51 
52     // FIXME we deal with compile-time strides if and only if we have DirectAccessBit flag,
53     // but this is too strict regarding negative strides...
54     DirectAccessMask = (int(InnerIncr)!=UndefinedIncr && int(OuterIncr)!=UndefinedIncr && InnerIncr>=0 && OuterIncr>=0) ? DirectAccessBit : 0,
55     FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
56     FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
57     FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
58     Flags = (traits<XprType>::Flags & (HereditaryBits | DirectAccessMask )) | FlagsLvalueBit | FlagsRowMajorBit | FlagsLinearAccessBit
59   };
60 
61   typedef Block<XprType,RowsAtCompileTime,ColsAtCompileTime,IsInnerPannel> BlockType;
62 };
63 
64 }
65 
66 template<typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
67 class IndexedViewImpl;
68 
69 
70 /** \class IndexedView
71   * \ingroup Core_Module
72   *
73   * \brief Expression of a non-sequential sub-matrix defined by arbitrary sequences of row and column indices
74   *
75   * \tparam XprType the type of the expression in which we are taking the intersections of sub-rows and sub-columns
76   * \tparam RowIndices the type of the object defining the sequence of row indices
77   * \tparam ColIndices the type of the object defining the sequence of column indices
78   *
79   * This class represents an expression of a sub-matrix (or sub-vector) defined as the intersection
80   * of sub-sets of rows and columns, that are themself defined by generic sequences of row indices \f$ \{r_0,r_1,..r_{m-1}\} \f$
81   * and column indices \f$ \{c_0,c_1,..c_{n-1} \}\f$. Let \f$ A \f$  be the nested matrix, then the resulting matrix \f$ B \f$ has \c m
82   * rows and \c n columns, and its entries are given by: \f$ B(i,j) = A(r_i,c_j) \f$.
83   *
84   * The \c RowIndices and \c ColIndices types must be compatible with the following API:
85   * \code
86   * <integral type> operator[](Index) const;
87   * Index size() const;
88   * \endcode
89   *
90   * Typical supported types thus include:
91   *  - std::vector<int>
92   *  - std::valarray<int>
93   *  - std::array<int>
94   *  - Plain C arrays: int[N]
95   *  - Eigen::ArrayXi
96   *  - decltype(ArrayXi::LinSpaced(...))
97   *  - Any view/expressions of the previous types
98   *  - Eigen::ArithmeticSequence
99   *  - Eigen::internal::AllRange      (helper for Eigen::all)
100   *  - Eigen::internal::SingleRange  (helper for single index)
101   *  - etc.
102   *
103   * In typical usages of %Eigen, this class should never be used directly. It is the return type of
104   * DenseBase::operator()(const RowIndices&, const ColIndices&).
105   *
106   * \sa class Block
107   */
108 template<typename XprType, typename RowIndices, typename ColIndices>
109 class IndexedView : public IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind>
110 {
111 public:
112   typedef typename IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind>::Base Base;
113   EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView)
114   EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView)
115 
116   typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
117   typedef typename internal::remove_all<XprType>::type NestedExpression;
118 
119   template<typename T0, typename T1>
120   IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices)
121     : m_xpr(xpr), m_rowIndices(rowIndices), m_colIndices(colIndices)
122   {}
123 
124   /** \returns number of rows */
125   Index rows() const { return internal::size(m_rowIndices); }
126 
127   /** \returns number of columns */
128   Index cols() const { return internal::size(m_colIndices); }
129 
130   /** \returns the nested expression */
131   const typename internal::remove_all<XprType>::type&
132   nestedExpression() const { return m_xpr; }
133 
134   /** \returns the nested expression */
135   typename internal::remove_reference<XprType>::type&
136   nestedExpression() { return m_xpr; }
137 
138   /** \returns a const reference to the object storing/generating the row indices */
139   const RowIndices& rowIndices() const { return m_rowIndices; }
140 
141   /** \returns a const reference to the object storing/generating the column indices */
142   const ColIndices& colIndices() const { return m_colIndices; }
143 
144 protected:
145   MatrixTypeNested m_xpr;
146   RowIndices m_rowIndices;
147   ColIndices m_colIndices;
148 };
149 
150 
151 // Generic API dispatcher
152 template<typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
153 class IndexedViewImpl
154   : public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices> >::type
155 {
156 public:
157   typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices> >::type Base;
158 };
159 
160 namespace internal {
161 
162 
163 template<typename ArgType, typename RowIndices, typename ColIndices>
164 struct unary_evaluator<IndexedView<ArgType, RowIndices, ColIndices>, IndexBased>
165   : evaluator_base<IndexedView<ArgType, RowIndices, ColIndices> >
166 {
167   typedef IndexedView<ArgType, RowIndices, ColIndices> XprType;
168 
169   enum {
170     CoeffReadCost = evaluator<ArgType>::CoeffReadCost /* TODO + cost of row/col index */,
171 
172     FlagsLinearAccessBit = (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1) ? LinearAccessBit : 0,
173 
174     FlagsRowMajorBit = traits<XprType>::FlagsRowMajorBit,
175 
176     Flags = (evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit /*| LinearAccessBit | DirectAccessBit*/)) | FlagsLinearAccessBit | FlagsRowMajorBit,
177 
178     Alignment = 0
179   };
180 
181   EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
182   {
183     EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
184   }
185 
186   typedef typename XprType::Scalar Scalar;
187   typedef typename XprType::CoeffReturnType CoeffReturnType;
188 
189   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
190   CoeffReturnType coeff(Index row, Index col) const
191   {
192     return m_argImpl.coeff(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
193   }
194 
195   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
196   Scalar& coeffRef(Index row, Index col)
197   {
198     return m_argImpl.coeffRef(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
199   }
200 
201   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
202   Scalar& coeffRef(Index index)
203   {
204     EIGEN_STATIC_ASSERT_LVALUE(XprType)
205     Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
206     Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
207     return m_argImpl.coeffRef( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
208   }
209 
210   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
211   const Scalar& coeffRef(Index index) const
212   {
213     Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
214     Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
215     return m_argImpl.coeffRef( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
216   }
217 
218   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
219   const CoeffReturnType coeff(Index index) const
220   {
221     Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
222     Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
223     return m_argImpl.coeff( m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
224   }
225 
226 protected:
227 
228   evaluator<ArgType> m_argImpl;
229   const XprType& m_xpr;
230 
231 };
232 
233 } // end namespace internal
234 
235 } // end namespace Eigen
236 
237 #endif // EIGEN_INDEXED_VIEW_H
238