1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2017 Gael Guennebaud <[email protected]> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SYMBOLIC_INDEX_H 11 #define EIGEN_SYMBOLIC_INDEX_H 12 13 namespace Eigen { 14 15 /** \namespace Eigen::symbolic 16 * \ingroup Core_Module 17 * 18 * This namespace defines a set of classes and functions to build and evaluate symbolic expressions of scalar type Index. 19 * Here is a simple example: 20 * 21 * \code 22 * // First step, defines symbols: 23 * struct x_tag {}; static const symbolic::SymbolExpr<x_tag> x; 24 * struct y_tag {}; static const symbolic::SymbolExpr<y_tag> y; 25 * struct z_tag {}; static const symbolic::SymbolExpr<z_tag> z; 26 * 27 * // Defines an expression: 28 * auto expr = (x+3)/y+z; 29 * 30 * // And evaluate it: (c++14) 31 * std::cout << expr.eval(x=6,y=3,z=-13) << "\n"; 32 * 33 * // In c++98/11, only one symbol per expression is supported for now: 34 * auto expr98 = (3-x)/2; 35 * std::cout << expr98.eval(x=6) << "\n"; 36 * \endcode 37 * 38 * It is currently only used internally to define and manipulate the Eigen::last and Eigen::lastp1 symbols in Eigen::seq and Eigen::seqN. 39 * 40 */ 41 namespace symbolic { 42 43 template<typename Tag> class Symbol; 44 template<typename Arg0> class NegateExpr; 45 template<typename Arg1,typename Arg2> class AddExpr; 46 template<typename Arg1,typename Arg2> class ProductExpr; 47 template<typename Arg1,typename Arg2> class QuotientExpr; 48 49 // A simple wrapper around an integral value to provide the eval method. 50 // We could also use a free-function symbolic_eval... 51 template<typename IndexType=Index> 52 class ValueExpr { 53 public: ValueExpr(IndexType val)54 ValueExpr(IndexType val) : m_value(val) {} 55 template<typename T> eval_impl(const T &)56 IndexType eval_impl(const T&) const { return m_value; } 57 protected: 58 IndexType m_value; 59 }; 60 61 // Specialization for compile-time value, 62 // It is similar to ValueExpr(N) but this version helps the compiler to generate better code. 63 template<int N> 64 class ValueExpr<internal::FixedInt<N> > { 65 public: ValueExpr()66 ValueExpr() {} 67 template<typename T> eval_impl(const T &)68 EIGEN_CONSTEXPR Index eval_impl(const T&) const { return N; } 69 }; 70 71 72 /** \class BaseExpr 73 * \ingroup Core_Module 74 * Common base class of any symbolic expressions 75 */ 76 template<typename Derived> 77 class BaseExpr 78 { 79 public: derived()80 const Derived& derived() const { return *static_cast<const Derived*>(this); } 81 82 /** Evaluate the expression given the \a values of the symbols. 83 * 84 * \param values defines the values of the symbols, it can either be a SymbolValue or a std::tuple of SymbolValue 85 * as constructed by SymbolExpr::operator= operator. 86 * 87 */ 88 template<typename T> eval(const T & values)89 Index eval(const T& values) const { return derived().eval_impl(values); } 90 91 #if EIGEN_HAS_CXX14 92 template<typename... Types> eval(Types &&...values)93 Index eval(Types&&... values) const { return derived().eval_impl(std::make_tuple(values...)); } 94 #endif 95 96 NegateExpr<Derived> operator-() const { return NegateExpr<Derived>(derived()); } 97 98 AddExpr<Derived,ValueExpr<> > operator+(Index b) const 99 { return AddExpr<Derived,ValueExpr<> >(derived(), b); } 100 AddExpr<Derived,ValueExpr<> > operator-(Index a) const 101 { return AddExpr<Derived,ValueExpr<> >(derived(), -a); } 102 ProductExpr<Derived,ValueExpr<> > operator*(Index a) const 103 { return ProductExpr<Derived,ValueExpr<> >(derived(),a); } 104 QuotientExpr<Derived,ValueExpr<> > operator/(Index a) const 105 { return QuotientExpr<Derived,ValueExpr<> >(derived(),a); } 106 107 friend AddExpr<Derived,ValueExpr<> > operator+(Index a, const BaseExpr& b) 108 { return AddExpr<Derived,ValueExpr<> >(b.derived(), a); } 109 friend AddExpr<NegateExpr<Derived>,ValueExpr<> > operator-(Index a, const BaseExpr& b) 110 { return AddExpr<NegateExpr<Derived>,ValueExpr<> >(-b.derived(), a); } 111 friend ProductExpr<ValueExpr<>,Derived> operator*(Index a, const BaseExpr& b) 112 { return ProductExpr<ValueExpr<>,Derived>(a,b.derived()); } 113 friend QuotientExpr<ValueExpr<>,Derived> operator/(Index a, const BaseExpr& b) 114 { return QuotientExpr<ValueExpr<>,Derived>(a,b.derived()); } 115 116 template<int N> 117 AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N>) const 118 { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >()); } 119 template<int N> 120 AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > > operator-(internal::FixedInt<N>) const 121 { return AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > >(derived(), ValueExpr<internal::FixedInt<-N> >()); } 122 template<int N> 123 ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator*(internal::FixedInt<N>) const 124 { return ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); } 125 template<int N> 126 QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator/(internal::FixedInt<N>) const 127 { return QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); } 128 129 template<int N> 130 friend AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N>, const BaseExpr& b) 131 { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(b.derived(), ValueExpr<internal::FixedInt<N> >()); } 132 template<int N> 133 friend AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > > operator-(internal::FixedInt<N>, const BaseExpr& b) 134 { return AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > >(-b.derived(), ValueExpr<internal::FixedInt<N> >()); } 135 template<int N> 136 friend ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator*(internal::FixedInt<N>, const BaseExpr& b) 137 { return ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); } 138 template<int N> 139 friend QuotientExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator/(internal::FixedInt<N>, const BaseExpr& b) 140 { return QuotientExpr<ValueExpr<internal::FixedInt<N> > ,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); } 141 142 #if (!EIGEN_HAS_CXX14) 143 template<int N> 144 AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N> (*)()) const 145 { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >()); } 146 template<int N> 147 AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > > operator-(internal::FixedInt<N> (*)()) const 148 { return AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > >(derived(), ValueExpr<internal::FixedInt<-N> >()); } 149 template<int N> 150 ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator*(internal::FixedInt<N> (*)()) const 151 { return ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); } 152 template<int N> 153 QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator/(internal::FixedInt<N> (*)()) const 154 { return QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); } 155 156 template<int N> 157 friend AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N> (*)(), const BaseExpr& b) 158 { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(b.derived(), ValueExpr<internal::FixedInt<N> >()); } 159 template<int N> 160 friend AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > > operator-(internal::FixedInt<N> (*)(), const BaseExpr& b) 161 { return AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > >(-b.derived(), ValueExpr<internal::FixedInt<N> >()); } 162 template<int N> 163 friend ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator*(internal::FixedInt<N> (*)(), const BaseExpr& b) 164 { return ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); } 165 template<int N> 166 friend QuotientExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator/(internal::FixedInt<N> (*)(), const BaseExpr& b) 167 { return QuotientExpr<ValueExpr<internal::FixedInt<N> > ,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); } 168 #endif 169 170 171 template<typename OtherDerived> 172 AddExpr<Derived,OtherDerived> operator+(const BaseExpr<OtherDerived> &b) const 173 { return AddExpr<Derived,OtherDerived>(derived(), b.derived()); } 174 175 template<typename OtherDerived> 176 AddExpr<Derived,NegateExpr<OtherDerived> > operator-(const BaseExpr<OtherDerived> &b) const 177 { return AddExpr<Derived,NegateExpr<OtherDerived> >(derived(), -b.derived()); } 178 179 template<typename OtherDerived> 180 ProductExpr<Derived,OtherDerived> operator*(const BaseExpr<OtherDerived> &b) const 181 { return ProductExpr<Derived,OtherDerived>(derived(), b.derived()); } 182 183 template<typename OtherDerived> 184 QuotientExpr<Derived,OtherDerived> operator/(const BaseExpr<OtherDerived> &b) const 185 { return QuotientExpr<Derived,OtherDerived>(derived(), b.derived()); } 186 }; 187 188 template<typename T> 189 struct is_symbolic { 190 // BaseExpr has no conversion ctor, so we only have to check whether T can be statically cast to its base class BaseExpr<T>. 191 enum { value = internal::is_convertible<T,BaseExpr<T> >::value }; 192 }; 193 194 /** Represents the actual value of a symbol identified by its tag 195 * 196 * It is the return type of SymbolValue::operator=, and most of the time this is only way it is used. 197 */ 198 template<typename Tag> 199 class SymbolValue 200 { 201 public: 202 /** Default constructor from the value \a val */ SymbolValue(Index val)203 SymbolValue(Index val) : m_value(val) {} 204 205 /** \returns the stored value of the symbol */ value()206 Index value() const { return m_value; } 207 protected: 208 Index m_value; 209 }; 210 211 /** Expression of a symbol uniquely identified by the template parameter type \c tag */ 212 template<typename tag> 213 class SymbolExpr : public BaseExpr<SymbolExpr<tag> > 214 { 215 public: 216 /** Alias to the template parameter \c tag */ 217 typedef tag Tag; 218 SymbolExpr()219 SymbolExpr() {} 220 221 /** Associate the value \a val to the given symbol \c *this, uniquely identified by its \c Tag. 222 * 223 * The returned object should be passed to ExprBase::eval() to evaluate a given expression with this specified runtime-time value. 224 */ 225 SymbolValue<Tag> operator=(Index val) const { 226 return SymbolValue<Tag>(val); 227 } 228 eval_impl(const SymbolValue<Tag> & values)229 Index eval_impl(const SymbolValue<Tag> &values) const { return values.value(); } 230 231 #if EIGEN_HAS_CXX14 232 // C++14 versions suitable for multiple symbols 233 template<typename... Types> eval_impl(const std::tuple<Types...> & values)234 Index eval_impl(const std::tuple<Types...>& values) const { return std::get<SymbolValue<Tag> >(values).value(); } 235 #endif 236 }; 237 238 template<typename Arg0> 239 class NegateExpr : public BaseExpr<NegateExpr<Arg0> > 240 { 241 public: NegateExpr(const Arg0 & arg0)242 NegateExpr(const Arg0& arg0) : m_arg0(arg0) {} 243 244 template<typename T> eval_impl(const T & values)245 Index eval_impl(const T& values) const { return -m_arg0.eval_impl(values); } 246 protected: 247 Arg0 m_arg0; 248 }; 249 250 template<typename Arg0, typename Arg1> 251 class AddExpr : public BaseExpr<AddExpr<Arg0,Arg1> > 252 { 253 public: AddExpr(const Arg0 & arg0,const Arg1 & arg1)254 AddExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} 255 256 template<typename T> eval_impl(const T & values)257 Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) + m_arg1.eval_impl(values); } 258 protected: 259 Arg0 m_arg0; 260 Arg1 m_arg1; 261 }; 262 263 template<typename Arg0, typename Arg1> 264 class ProductExpr : public BaseExpr<ProductExpr<Arg0,Arg1> > 265 { 266 public: ProductExpr(const Arg0 & arg0,const Arg1 & arg1)267 ProductExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} 268 269 template<typename T> eval_impl(const T & values)270 Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) * m_arg1.eval_impl(values); } 271 protected: 272 Arg0 m_arg0; 273 Arg1 m_arg1; 274 }; 275 276 template<typename Arg0, typename Arg1> 277 class QuotientExpr : public BaseExpr<QuotientExpr<Arg0,Arg1> > 278 { 279 public: QuotientExpr(const Arg0 & arg0,const Arg1 & arg1)280 QuotientExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} 281 282 template<typename T> eval_impl(const T & values)283 Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) / m_arg1.eval_impl(values); } 284 protected: 285 Arg0 m_arg0; 286 Arg1 m_arg1; 287 }; 288 289 } // end namespace symbolic 290 291 } // end namespace Eigen 292 293 #endif // EIGEN_SYMBOLIC_INDEX_H 294