xref: /aosp_15_r20/external/eigen/Eigen/src/Core/util/SymbolicIndex.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2017 Gael Guennebaud <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SYMBOLIC_INDEX_H
11 #define EIGEN_SYMBOLIC_INDEX_H
12 
13 namespace Eigen {
14 
15 /** \namespace Eigen::symbolic
16   * \ingroup Core_Module
17   *
18   * This namespace defines a set of classes and functions to build and evaluate symbolic expressions of scalar type Index.
19   * Here is a simple example:
20   *
21   * \code
22   * // First step, defines symbols:
23   * struct x_tag {};  static const symbolic::SymbolExpr<x_tag> x;
24   * struct y_tag {};  static const symbolic::SymbolExpr<y_tag> y;
25   * struct z_tag {};  static const symbolic::SymbolExpr<z_tag> z;
26   *
27   * // Defines an expression:
28   * auto expr = (x+3)/y+z;
29   *
30   * // And evaluate it: (c++14)
31   * std::cout << expr.eval(x=6,y=3,z=-13) << "\n";
32   *
33   * // In c++98/11, only one symbol per expression is supported for now:
34   * auto expr98 = (3-x)/2;
35   * std::cout << expr98.eval(x=6) << "\n";
36   * \endcode
37   *
38   * It is currently only used internally to define and manipulate the Eigen::last and Eigen::lastp1 symbols in Eigen::seq and Eigen::seqN.
39   *
40   */
41 namespace symbolic {
42 
43 template<typename Tag> class Symbol;
44 template<typename Arg0> class NegateExpr;
45 template<typename Arg1,typename Arg2> class AddExpr;
46 template<typename Arg1,typename Arg2> class ProductExpr;
47 template<typename Arg1,typename Arg2> class QuotientExpr;
48 
49 // A simple wrapper around an integral value to provide the eval method.
50 // We could also use a free-function symbolic_eval...
51 template<typename IndexType=Index>
52 class ValueExpr {
53 public:
ValueExpr(IndexType val)54   ValueExpr(IndexType val) : m_value(val) {}
55   template<typename T>
eval_impl(const T &)56   IndexType eval_impl(const T&) const { return m_value; }
57 protected:
58   IndexType m_value;
59 };
60 
61 // Specialization for compile-time value,
62 // It is similar to ValueExpr(N) but this version helps the compiler to generate better code.
63 template<int N>
64 class ValueExpr<internal::FixedInt<N> > {
65 public:
ValueExpr()66   ValueExpr() {}
67   template<typename T>
eval_impl(const T &)68   EIGEN_CONSTEXPR Index eval_impl(const T&) const { return N; }
69 };
70 
71 
72 /** \class BaseExpr
73   * \ingroup Core_Module
74   * Common base class of any symbolic expressions
75   */
76 template<typename Derived>
77 class BaseExpr
78 {
79 public:
derived()80   const Derived& derived() const { return *static_cast<const Derived*>(this); }
81 
82   /** Evaluate the expression given the \a values of the symbols.
83     *
84     * \param values defines the values of the symbols, it can either be a SymbolValue or a std::tuple of SymbolValue
85     *               as constructed by SymbolExpr::operator= operator.
86     *
87     */
88   template<typename T>
eval(const T & values)89   Index eval(const T& values) const { return derived().eval_impl(values); }
90 
91 #if EIGEN_HAS_CXX14
92   template<typename... Types>
eval(Types &&...values)93   Index eval(Types&&... values) const { return derived().eval_impl(std::make_tuple(values...)); }
94 #endif
95 
96   NegateExpr<Derived> operator-() const { return NegateExpr<Derived>(derived()); }
97 
98   AddExpr<Derived,ValueExpr<> > operator+(Index b) const
99   { return AddExpr<Derived,ValueExpr<> >(derived(),  b); }
100   AddExpr<Derived,ValueExpr<> > operator-(Index a) const
101   { return AddExpr<Derived,ValueExpr<> >(derived(), -a); }
102   ProductExpr<Derived,ValueExpr<> > operator*(Index a) const
103   { return ProductExpr<Derived,ValueExpr<> >(derived(),a); }
104   QuotientExpr<Derived,ValueExpr<> > operator/(Index a) const
105   { return QuotientExpr<Derived,ValueExpr<> >(derived(),a); }
106 
107   friend AddExpr<Derived,ValueExpr<> > operator+(Index a, const BaseExpr& b)
108   { return AddExpr<Derived,ValueExpr<> >(b.derived(), a); }
109   friend AddExpr<NegateExpr<Derived>,ValueExpr<> > operator-(Index a, const BaseExpr& b)
110   { return AddExpr<NegateExpr<Derived>,ValueExpr<> >(-b.derived(), a); }
111   friend ProductExpr<ValueExpr<>,Derived> operator*(Index a, const BaseExpr& b)
112   { return ProductExpr<ValueExpr<>,Derived>(a,b.derived()); }
113   friend QuotientExpr<ValueExpr<>,Derived> operator/(Index a, const BaseExpr& b)
114   { return QuotientExpr<ValueExpr<>,Derived>(a,b.derived()); }
115 
116   template<int N>
117   AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N>) const
118   { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >()); }
119   template<int N>
120   AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > > operator-(internal::FixedInt<N>) const
121   { return AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > >(derived(), ValueExpr<internal::FixedInt<-N> >()); }
122   template<int N>
123   ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator*(internal::FixedInt<N>) const
124   { return ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
125   template<int N>
126   QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator/(internal::FixedInt<N>) const
127   { return QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
128 
129   template<int N>
130   friend AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N>, const BaseExpr& b)
131   { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(b.derived(), ValueExpr<internal::FixedInt<N> >()); }
132   template<int N>
133   friend AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > > operator-(internal::FixedInt<N>, const BaseExpr& b)
134   { return AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > >(-b.derived(), ValueExpr<internal::FixedInt<N> >()); }
135   template<int N>
136   friend ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator*(internal::FixedInt<N>, const BaseExpr& b)
137   { return ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
138   template<int N>
139   friend QuotientExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator/(internal::FixedInt<N>, const BaseExpr& b)
140   { return QuotientExpr<ValueExpr<internal::FixedInt<N> > ,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
141 
142 #if (!EIGEN_HAS_CXX14)
143   template<int N>
144   AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N> (*)()) const
145   { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >()); }
146   template<int N>
147   AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > > operator-(internal::FixedInt<N> (*)()) const
148   { return AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > >(derived(), ValueExpr<internal::FixedInt<-N> >()); }
149   template<int N>
150   ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator*(internal::FixedInt<N> (*)()) const
151   { return ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
152   template<int N>
153   QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator/(internal::FixedInt<N> (*)()) const
154   { return QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
155 
156   template<int N>
157   friend AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N> (*)(), const BaseExpr& b)
158   { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(b.derived(), ValueExpr<internal::FixedInt<N> >()); }
159   template<int N>
160   friend AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > > operator-(internal::FixedInt<N> (*)(), const BaseExpr& b)
161   { return AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > >(-b.derived(), ValueExpr<internal::FixedInt<N> >()); }
162   template<int N>
163   friend ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator*(internal::FixedInt<N> (*)(), const BaseExpr& b)
164   { return ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
165   template<int N>
166   friend QuotientExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator/(internal::FixedInt<N> (*)(), const BaseExpr& b)
167   { return QuotientExpr<ValueExpr<internal::FixedInt<N> > ,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
168 #endif
169 
170 
171   template<typename OtherDerived>
172   AddExpr<Derived,OtherDerived> operator+(const BaseExpr<OtherDerived> &b) const
173   { return AddExpr<Derived,OtherDerived>(derived(),  b.derived()); }
174 
175   template<typename OtherDerived>
176   AddExpr<Derived,NegateExpr<OtherDerived> > operator-(const BaseExpr<OtherDerived> &b) const
177   { return AddExpr<Derived,NegateExpr<OtherDerived> >(derived(), -b.derived()); }
178 
179   template<typename OtherDerived>
180   ProductExpr<Derived,OtherDerived> operator*(const BaseExpr<OtherDerived> &b) const
181   { return ProductExpr<Derived,OtherDerived>(derived(), b.derived()); }
182 
183   template<typename OtherDerived>
184   QuotientExpr<Derived,OtherDerived> operator/(const BaseExpr<OtherDerived> &b) const
185   { return QuotientExpr<Derived,OtherDerived>(derived(), b.derived()); }
186 };
187 
188 template<typename T>
189 struct is_symbolic {
190   // BaseExpr has no conversion ctor, so we only have to check whether T can be statically cast to its base class BaseExpr<T>.
191   enum { value = internal::is_convertible<T,BaseExpr<T> >::value };
192 };
193 
194 /** Represents the actual value of a symbol identified by its tag
195   *
196   * It is the return type of SymbolValue::operator=, and most of the time this is only way it is used.
197   */
198 template<typename Tag>
199 class SymbolValue
200 {
201 public:
202   /** Default constructor from the value \a val */
SymbolValue(Index val)203   SymbolValue(Index val) : m_value(val) {}
204 
205   /** \returns the stored value of the symbol */
value()206   Index value() const { return m_value; }
207 protected:
208   Index m_value;
209 };
210 
211 /** Expression of a symbol uniquely identified by the template parameter type \c tag */
212 template<typename tag>
213 class SymbolExpr : public BaseExpr<SymbolExpr<tag> >
214 {
215 public:
216   /** Alias to the template parameter \c tag */
217   typedef tag Tag;
218 
SymbolExpr()219   SymbolExpr() {}
220 
221   /** Associate the value \a val to the given symbol \c *this, uniquely identified by its \c Tag.
222     *
223     * The returned object should be passed to ExprBase::eval() to evaluate a given expression with this specified runtime-time value.
224     */
225   SymbolValue<Tag> operator=(Index val) const {
226     return SymbolValue<Tag>(val);
227   }
228 
eval_impl(const SymbolValue<Tag> & values)229   Index eval_impl(const SymbolValue<Tag> &values) const { return values.value(); }
230 
231 #if EIGEN_HAS_CXX14
232   // C++14 versions suitable for multiple symbols
233   template<typename... Types>
eval_impl(const std::tuple<Types...> & values)234   Index eval_impl(const std::tuple<Types...>& values) const { return std::get<SymbolValue<Tag> >(values).value(); }
235 #endif
236 };
237 
238 template<typename Arg0>
239 class NegateExpr : public BaseExpr<NegateExpr<Arg0> >
240 {
241 public:
NegateExpr(const Arg0 & arg0)242   NegateExpr(const Arg0& arg0) : m_arg0(arg0) {}
243 
244   template<typename T>
eval_impl(const T & values)245   Index eval_impl(const T& values) const { return -m_arg0.eval_impl(values); }
246 protected:
247   Arg0 m_arg0;
248 };
249 
250 template<typename Arg0, typename Arg1>
251 class AddExpr : public BaseExpr<AddExpr<Arg0,Arg1> >
252 {
253 public:
AddExpr(const Arg0 & arg0,const Arg1 & arg1)254   AddExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
255 
256   template<typename T>
eval_impl(const T & values)257   Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) + m_arg1.eval_impl(values); }
258 protected:
259   Arg0 m_arg0;
260   Arg1 m_arg1;
261 };
262 
263 template<typename Arg0, typename Arg1>
264 class ProductExpr : public BaseExpr<ProductExpr<Arg0,Arg1> >
265 {
266 public:
ProductExpr(const Arg0 & arg0,const Arg1 & arg1)267   ProductExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
268 
269   template<typename T>
eval_impl(const T & values)270   Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) * m_arg1.eval_impl(values); }
271 protected:
272   Arg0 m_arg0;
273   Arg1 m_arg1;
274 };
275 
276 template<typename Arg0, typename Arg1>
277 class QuotientExpr : public BaseExpr<QuotientExpr<Arg0,Arg1> >
278 {
279 public:
QuotientExpr(const Arg0 & arg0,const Arg1 & arg1)280   QuotientExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
281 
282   template<typename T>
eval_impl(const T & values)283   Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) / m_arg1.eval_impl(values); }
284 protected:
285   Arg0 m_arg0;
286   Arg1 m_arg1;
287 };
288 
289 } // end namespace symbolic
290 
291 } // end namespace Eigen
292 
293 #endif // EIGEN_SYMBOLIC_INDEX_H
294