xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/scan_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/tensor_types.h"
21 
22 namespace tensorflow {
23 namespace functor {
24 
25 typedef Eigen::Index Index;
26 
27 // TODO(b/154339590): Needs to be vectorized.
28 template <typename Device, typename Reducer, typename T>
29 struct Scan {
operatorScan30   void operator()(const Device& d, typename TTypes<T, 3>::ConstTensor in,
31                   typename TTypes<T, 3>::Tensor out, const Reducer& reducer,
32                   const bool reverse, const bool exclusive) {
33     // Perform the reverse ops directly with Eigen, which avoids copying the
34     // tensor twice compared to using individual ops.
35     Eigen::array<bool, 3> dims;
36     dims[0] = false;
37     dims[1] = reverse;
38     dims[2] = false;
39     MaybeWith32BitIndexing<Device>(
40         [&](auto in32, auto out32) {
41           out32.device(d) =
42               in32.reverse(dims).scan(1, reducer, exclusive).reverse(dims);
43         },
44         in, out);
45   }
46 };
47 
48 template <typename T>
49 struct LogSumExp {
operatorLogSumExp50   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
51                                                      const T& b) const {
52     auto mi = Eigen::internal::scalar_min_op<T>()(a, b);
53     auto ma = Eigen::internal::scalar_max_op<T>()(a, b);
54 
55     auto sub = Eigen::internal::scalar_difference_op<T>();
56     auto add = Eigen::internal::scalar_sum_op<T>();
57     auto exp = Eigen::internal::scalar_exp_op<T>();
58     auto log1p = Eigen::internal::scalar_log1p_op<T>();
59     auto cmp_lt =
60         Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
61 
62     auto logsumexp = add(log1p(exp(sub(mi, ma))), ma);
63     return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? ma : logsumexp;
64   }
packetOpLogSumExp65   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a,
66                                                    const T& b) const {
67     auto mi = Eigen::internal::pmin(a, b);
68     auto ma = Eigen::internal::pmax(a, b);
69     using Eigen::internal::padd;
70     using Eigen::internal::pcmp_lt;
71     using Eigen::internal::pexp;
72     using Eigen::internal::plog1p;
73     using Eigen::internal::pset1;
74     using Eigen::internal::psub;
75 
76     auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma);
77     return pselect(pcmp_lt(ma, pset1(Eigen::NumTraits<T>::lowest())), ma,
78                    logsumexp);
79   }
80 };
81 
82 template <typename T>
83 struct LogSumExpReducer {
reduceLogSumExpReducer84   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
85     LogSumExp<T> logsumexp;
86     *accum = logsumexp(*accum, t);
87   }
88 
89   template <typename Packet>
reducePacketLogSumExpReducer90   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p,
91                                                           Packet* accum) const {
92     LogSumExp<T> logsumexp;
93     *accum = logsumexp.packetOp(*accum, p);
94   }
95 
initializeLogSumExpReducer96   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
97     return -Eigen::NumTraits<T>::infinity();
98   }
99 
100   template <typename Packet>
initializePacketLogSumExpReducer101   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
102     return Eigen::internal::pset1(initialize());
103   }
104 
finalizeLogSumExpReducer105   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
106     return accum;
107   }
108 
109   template <typename Packet>
110   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
finalizePacketLogSumExpReducer111   finalizePacket(const Packet& vaccum) const {
112     return vaccum;
113   }
114 
115   template <typename Packet>
116   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T
finalizeBothLogSumExpReducer117   finalizeBoth(const T saccum, const Packet& vaccum) const {
118     auto max_reducer = Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>();
119     auto sum_reducer = Eigen::internal::SumReducer<T>();
120     auto exp = Eigen::internal::scalar_exp_op<T>();
121     auto cmp_lt =
122         Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
123     auto log = Eigen::internal::scalar_log_op<T>();
124     auto add = Eigen::internal::scalar_sum_op<T>();
125 
126     using Eigen::internal::pexp;
127     using Eigen::internal::psub;
128 
129     // `ma = max(x1, ..., xn)`
130     // If the max of all of the `xi` is `-infinity` then the result is
131     // -infinity. If the max is larger than `-infinity` then it's safe to use
132     // for normalization even if the other elements are `-infinity`.
133     //
134     // `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))`
135     auto ma = max_reducer.finalizeBoth(saccum, vaccum);
136     auto logsumexp = add(log(sum_reducer.finalizeBoth(
137                              exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))),
138                          ma);
139     return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? initialize() : logsumexp;
140   }
141 };
142 
143 }  // namespace functor
144 }  // namespace tensorflow
145 
146 #endif  // TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_
147