xref: /aosp_15_r20/external/eigen/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Eugene Brevdo <[email protected]>
5 // Copyright (C) 2016 Gael Guennebaud <[email protected]>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_SPECIALFUNCTIONS_FUNCTORS_H
12 #define EIGEN_SPECIALFUNCTIONS_FUNCTORS_H
13 
14 namespace Eigen {
15 
16 namespace internal {
17 
18 
19 /** \internal
20   * \brief Template functor to compute the incomplete gamma function igamma(a, x)
21   *
22   * \sa class CwiseBinaryOp, Cwise::igamma
23   */
24 template<typename Scalar> struct scalar_igamma_op : binary_op_base<Scalar,Scalar>
25 {
EIGEN_EMPTY_STRUCT_CTORscalar_igamma_op26   EIGEN_EMPTY_STRUCT_CTOR(scalar_igamma_op)
27   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a, const Scalar& x) const {
28     using numext::igamma; return igamma(a, x);
29   }
30   template<typename Packet>
packetOpscalar_igamma_op31   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const {
32     return internal::pigamma(a, x);
33   }
34 };
35 template<typename Scalar>
36 struct functor_traits<scalar_igamma_op<Scalar> > {
37   enum {
38     // Guesstimate
39     Cost = 20 * NumTraits<Scalar>::MulCost + 10 * NumTraits<Scalar>::AddCost,
40     PacketAccess = packet_traits<Scalar>::HasIGamma
41   };
42 };
43 
44 /** \internal
45   * \brief Template functor to compute the derivative of the incomplete gamma
46   * function igamma_der_a(a, x)
47   *
48   * \sa class CwiseBinaryOp, Cwise::igamma_der_a
49   */
50 template <typename Scalar>
51 struct scalar_igamma_der_a_op {
52   EIGEN_EMPTY_STRUCT_CTOR(scalar_igamma_der_a_op)
53   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a, const Scalar& x) const {
54     using numext::igamma_der_a;
55     return igamma_der_a(a, x);
56   }
57   template <typename Packet>
58   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const {
59     return internal::pigamma_der_a(a, x);
60   }
61 };
62 template <typename Scalar>
63 struct functor_traits<scalar_igamma_der_a_op<Scalar> > {
64   enum {
65     // 2x the cost of igamma
66     Cost = 40 * NumTraits<Scalar>::MulCost + 20 * NumTraits<Scalar>::AddCost,
67     PacketAccess = packet_traits<Scalar>::HasIGammaDerA
68   };
69 };
70 
71 /** \internal
72   * \brief Template functor to compute the derivative of the sample
73   * of a Gamma(alpha, 1) random variable with respect to the parameter alpha
74   * gamma_sample_der_alpha(alpha, sample)
75   *
76   * \sa class CwiseBinaryOp, Cwise::gamma_sample_der_alpha
77   */
78 template <typename Scalar>
79 struct scalar_gamma_sample_der_alpha_op {
80   EIGEN_EMPTY_STRUCT_CTOR(scalar_gamma_sample_der_alpha_op)
81   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& alpha, const Scalar& sample) const {
82     using numext::gamma_sample_der_alpha;
83     return gamma_sample_der_alpha(alpha, sample);
84   }
85   template <typename Packet>
86   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& alpha, const Packet& sample) const {
87     return internal::pgamma_sample_der_alpha(alpha, sample);
88   }
89 };
90 template <typename Scalar>
91 struct functor_traits<scalar_gamma_sample_der_alpha_op<Scalar> > {
92   enum {
93     // 2x the cost of igamma, minus the lgamma cost (the lgamma cancels out)
94     Cost = 30 * NumTraits<Scalar>::MulCost + 15 * NumTraits<Scalar>::AddCost,
95     PacketAccess = packet_traits<Scalar>::HasGammaSampleDerAlpha
96   };
97 };
98 
99 /** \internal
100   * \brief Template functor to compute the complementary incomplete gamma function igammac(a, x)
101   *
102   * \sa class CwiseBinaryOp, Cwise::igammac
103   */
104 template<typename Scalar> struct scalar_igammac_op : binary_op_base<Scalar,Scalar>
105 {
106   EIGEN_EMPTY_STRUCT_CTOR(scalar_igammac_op)
107   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a, const Scalar& x) const {
108     using numext::igammac; return igammac(a, x);
109   }
110   template<typename Packet>
111   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const
112   {
113     return internal::pigammac(a, x);
114   }
115 };
116 template<typename Scalar>
117 struct functor_traits<scalar_igammac_op<Scalar> > {
118   enum {
119     // Guesstimate
120     Cost = 20 * NumTraits<Scalar>::MulCost + 10 * NumTraits<Scalar>::AddCost,
121     PacketAccess = packet_traits<Scalar>::HasIGammac
122   };
123 };
124 
125 
126 /** \internal
127   * \brief Template functor to compute the incomplete beta integral betainc(a, b, x)
128   *
129   */
130 template<typename Scalar> struct scalar_betainc_op {
131   EIGEN_EMPTY_STRUCT_CTOR(scalar_betainc_op)
132   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& x, const Scalar& a, const Scalar& b) const {
133     using numext::betainc; return betainc(x, a, b);
134   }
135   template<typename Packet>
136   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& x, const Packet& a, const Packet& b) const
137   {
138     return internal::pbetainc(x, a, b);
139   }
140 };
141 template<typename Scalar>
142 struct functor_traits<scalar_betainc_op<Scalar> > {
143   enum {
144     // Guesstimate
145     Cost = 400 * NumTraits<Scalar>::MulCost + 400 * NumTraits<Scalar>::AddCost,
146     PacketAccess = packet_traits<Scalar>::HasBetaInc
147   };
148 };
149 
150 
151 /** \internal
152  * \brief Template functor to compute the natural log of the absolute
153  * value of Gamma of a scalar
154  * \sa class CwiseUnaryOp, Cwise::lgamma()
155  */
156 template<typename Scalar> struct scalar_lgamma_op {
157   EIGEN_EMPTY_STRUCT_CTOR(scalar_lgamma_op)
158   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const {
159     using numext::lgamma; return lgamma(a);
160   }
161   typedef typename packet_traits<Scalar>::type Packet;
162   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::plgamma(a); }
163 };
164 template<typename Scalar>
165 struct functor_traits<scalar_lgamma_op<Scalar> >
166 {
167   enum {
168     // Guesstimate
169     Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
170     PacketAccess = packet_traits<Scalar>::HasLGamma
171   };
172 };
173 
174 /** \internal
175  * \brief Template functor to compute psi, the derivative of lgamma of a scalar.
176  * \sa class CwiseUnaryOp, Cwise::digamma()
177  */
178 template<typename Scalar> struct scalar_digamma_op {
179   EIGEN_EMPTY_STRUCT_CTOR(scalar_digamma_op)
180   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const {
181     using numext::digamma; return digamma(a);
182   }
183   typedef typename packet_traits<Scalar>::type Packet;
184   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::pdigamma(a); }
185 };
186 template<typename Scalar>
187 struct functor_traits<scalar_digamma_op<Scalar> >
188 {
189   enum {
190     // Guesstimate
191     Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
192     PacketAccess = packet_traits<Scalar>::HasDiGamma
193   };
194 };
195 
196 /** \internal
197  * \brief Template functor to compute the Riemann Zeta function of two arguments.
198  * \sa class CwiseUnaryOp, Cwise::zeta()
199  */
200 template<typename Scalar> struct scalar_zeta_op {
201     EIGEN_EMPTY_STRUCT_CTOR(scalar_zeta_op)
202     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& x, const Scalar& q) const {
203         using numext::zeta; return zeta(x, q);
204     }
205     typedef typename packet_traits<Scalar>::type Packet;
206     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, const Packet& q) const { return internal::pzeta(x, q); }
207 };
208 template<typename Scalar>
209 struct functor_traits<scalar_zeta_op<Scalar> >
210 {
211     enum {
212         // Guesstimate
213         Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
214         PacketAccess = packet_traits<Scalar>::HasZeta
215     };
216 };
217 
218 /** \internal
219  * \brief Template functor to compute the polygamma function.
220  * \sa class CwiseUnaryOp, Cwise::polygamma()
221  */
222 template<typename Scalar> struct scalar_polygamma_op {
223     EIGEN_EMPTY_STRUCT_CTOR(scalar_polygamma_op)
224     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& n, const Scalar& x) const {
225         using numext::polygamma; return polygamma(n, x);
226     }
227     typedef typename packet_traits<Scalar>::type Packet;
228     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& n, const Packet& x) const { return internal::ppolygamma(n, x); }
229 };
230 template<typename Scalar>
231 struct functor_traits<scalar_polygamma_op<Scalar> >
232 {
233     enum {
234         // Guesstimate
235         Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
236         PacketAccess = packet_traits<Scalar>::HasPolygamma
237     };
238 };
239 
240 /** \internal
241  * \brief Template functor to compute the error function of a scalar
242  * \sa class CwiseUnaryOp, ArrayBase::erf()
243  */
244 template<typename Scalar> struct scalar_erf_op {
245   EIGEN_EMPTY_STRUCT_CTOR(scalar_erf_op)
246   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
247   operator()(const Scalar& a) const {
248     return numext::erf(a);
249   }
250   template <typename Packet>
251   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
252     return perf(x);
253   }
254 };
255 template <typename Scalar>
256 struct functor_traits<scalar_erf_op<Scalar> > {
257   enum {
258     PacketAccess = packet_traits<Scalar>::HasErf,
259     Cost =
260         (PacketAccess
261 #ifdef EIGEN_VECTORIZE_FMA
262              // TODO(rmlarsen): Move the FMA cost model to a central location.
263              // Haswell can issue 2 add/mul/madd per cycle.
264              // 10 pmadd, 2 pmul, 1 div, 2 other
265              ? (2 * NumTraits<Scalar>::AddCost +
266                 7 * NumTraits<Scalar>::MulCost +
267                 scalar_div_cost<Scalar, packet_traits<Scalar>::HasDiv>::value)
268 #else
269              ? (12 * NumTraits<Scalar>::AddCost +
270                 12 * NumTraits<Scalar>::MulCost +
271                 scalar_div_cost<Scalar, packet_traits<Scalar>::HasDiv>::value)
272 #endif
273              // Assume for simplicity that this is as expensive as an exp().
274              : (functor_traits<scalar_exp_op<Scalar> >::Cost))
275   };
276 };
277 
278 /** \internal
279  * \brief Template functor to compute the Complementary Error Function
280  * of a scalar
281  * \sa class CwiseUnaryOp, Cwise::erfc()
282  */
283 template<typename Scalar> struct scalar_erfc_op {
284   EIGEN_EMPTY_STRUCT_CTOR(scalar_erfc_op)
285   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const {
286     using numext::erfc; return erfc(a);
287   }
288   typedef typename packet_traits<Scalar>::type Packet;
289   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::perfc(a); }
290 };
291 template<typename Scalar>
292 struct functor_traits<scalar_erfc_op<Scalar> >
293 {
294   enum {
295     // Guesstimate
296     Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
297     PacketAccess = packet_traits<Scalar>::HasErfc
298   };
299 };
300 
301 /** \internal
302  * \brief Template functor to compute the Inverse of the normal distribution
303  * function of a scalar
304  * \sa class CwiseUnaryOp, Cwise::ndtri()
305  */
306 template<typename Scalar> struct scalar_ndtri_op {
307   EIGEN_EMPTY_STRUCT_CTOR(scalar_ndtri_op)
308   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const {
309     using numext::ndtri; return ndtri(a);
310   }
311   typedef typename packet_traits<Scalar>::type Packet;
312   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::pndtri(a); }
313 };
314 template<typename Scalar>
315 struct functor_traits<scalar_ndtri_op<Scalar> >
316 {
317   enum {
318     // On average, We are evaluating rational functions with degree N=9 in the
319     // numerator and denominator. This results in 2*N additions and 2*N
320     // multiplications.
321     Cost = 18 * NumTraits<Scalar>::MulCost + 18 * NumTraits<Scalar>::AddCost,
322     PacketAccess = packet_traits<Scalar>::HasNdtri
323   };
324 };
325 
326 } // end namespace internal
327 
328 } // end namespace Eigen
329 
330 #endif // EIGEN_SPECIALFUNCTIONS_FUNCTORS_H
331