xref: /aosp_15_r20/external/gemmlowp/fixedpoint/fixedpoint.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // fixedpoint.h: fixed-point arithmetic, with basic operations and
16*5f39d1b3SJooyung Han // a few math functions such as tanh.
17*5f39d1b3SJooyung Han 
18*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
19*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
20*5f39d1b3SJooyung Han 
21*5f39d1b3SJooyung Han #include <algorithm>
22*5f39d1b3SJooyung Han #include <cassert>
23*5f39d1b3SJooyung Han #include <cmath>
24*5f39d1b3SJooyung Han #include <cstdint>
25*5f39d1b3SJooyung Han #include <limits>
26*5f39d1b3SJooyung Han 
27*5f39d1b3SJooyung Han #include "../internal/detect_platform.h"
28*5f39d1b3SJooyung Han 
29*5f39d1b3SJooyung Han namespace gemmlowp {
30*5f39d1b3SJooyung Han 
31*5f39d1b3SJooyung Han // Part 1: Low-level integer-arithmetic primitives.
32*5f39d1b3SJooyung Han // The implementations here are generic implementations valid for
33*5f39d1b3SJooyung Han // scalar types (e.g. std::int32_t). Architecture-specific SIMD types
34*5f39d1b3SJooyung Han // (e.g. NEON int32x4_t) may be supported by providing
35*5f39d1b3SJooyung Han // specializations for them in separate files.
36*5f39d1b3SJooyung Han //
37*5f39d1b3SJooyung Han // The purpose of these primitives is two-fold:
38*5f39d1b3SJooyung Han //  - They will be used to implement higher-level fixed-point
39*5f39d1b3SJooyung Han //    abstractions, namely the FixedPoint class and its arithmetic
40*5f39d1b3SJooyung Han //    operators.
41*5f39d1b3SJooyung Han //  - They will be directly used to implement some more involved
42*5f39d1b3SJooyung Han //    fixed-point computations, e.g. the fixed-point implementation
43*5f39d1b3SJooyung Han //    of math functions such as tanh.
44*5f39d1b3SJooyung Han 
45*5f39d1b3SJooyung Han // Some compile-time traits around raw types to handle SIMD aspects:
46*5f39d1b3SJooyung Han // number of lanes, underlying scalar type.
47*5f39d1b3SJooyung Han template <typename tIntegerType>
48*5f39d1b3SJooyung Han struct FixedPointRawTypeTraits {};
49*5f39d1b3SJooyung Han 
50*5f39d1b3SJooyung Han template <>
51*5f39d1b3SJooyung Han struct FixedPointRawTypeTraits<std::int32_t> {
52*5f39d1b3SJooyung Han   typedef std::int32_t ScalarRawType;
53*5f39d1b3SJooyung Han   static constexpr int kLanes = 1;
54*5f39d1b3SJooyung Han };
55*5f39d1b3SJooyung Han 
56*5f39d1b3SJooyung Han template <>
57*5f39d1b3SJooyung Han struct FixedPointRawTypeTraits<std::int16_t> {
58*5f39d1b3SJooyung Han   typedef std::int16_t ScalarRawType;
59*5f39d1b3SJooyung Han   static constexpr int kLanes = 1;
60*5f39d1b3SJooyung Han };
61*5f39d1b3SJooyung Han 
62*5f39d1b3SJooyung Han // Returns a SIMD value duplicating a scalar value across all lanes.
63*5f39d1b3SJooyung Han template <typename tRawType>
64*5f39d1b3SJooyung Han tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
65*5f39d1b3SJooyung Han   return x;
66*5f39d1b3SJooyung Han }
67*5f39d1b3SJooyung Han 
68*5f39d1b3SJooyung Han // Plain bit-wise AND
69*5f39d1b3SJooyung Han template <typename tIntegerType>
70*5f39d1b3SJooyung Han tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
71*5f39d1b3SJooyung Han   return a & b;
72*5f39d1b3SJooyung Han }
73*5f39d1b3SJooyung Han 
74*5f39d1b3SJooyung Han // Plain bit-wise OR
75*5f39d1b3SJooyung Han template <typename tIntegerType>
76*5f39d1b3SJooyung Han tIntegerType BitOr(tIntegerType a, tIntegerType b) {
77*5f39d1b3SJooyung Han   return a | b;
78*5f39d1b3SJooyung Han }
79*5f39d1b3SJooyung Han 
80*5f39d1b3SJooyung Han // Plain bit-wise XOR
81*5f39d1b3SJooyung Han template <typename tIntegerType>
82*5f39d1b3SJooyung Han tIntegerType BitXor(tIntegerType a, tIntegerType b) {
83*5f39d1b3SJooyung Han   return a ^ b;
84*5f39d1b3SJooyung Han }
85*5f39d1b3SJooyung Han 
86*5f39d1b3SJooyung Han // Plain bit-wise NOT
87*5f39d1b3SJooyung Han template <typename tIntegerType>
88*5f39d1b3SJooyung Han tIntegerType BitNot(tIntegerType a) {
89*5f39d1b3SJooyung Han   return ~a;
90*5f39d1b3SJooyung Han }
91*5f39d1b3SJooyung Han 
92*5f39d1b3SJooyung Han // Integer addition. Not saturating. Overflow is undefined behavior.
93*5f39d1b3SJooyung Han template <typename tIntegerType>
94*5f39d1b3SJooyung Han tIntegerType Add(tIntegerType a, tIntegerType b) {
95*5f39d1b3SJooyung Han   return a + b;
96*5f39d1b3SJooyung Han }
97*5f39d1b3SJooyung Han 
98*5f39d1b3SJooyung Han // Integer multiplication. Not saturating. Overflow is undefined behavior.
99*5f39d1b3SJooyung Han template <typename tIntegerType>
100*5f39d1b3SJooyung Han tIntegerType Mul(tIntegerType a, tIntegerType b) {
101*5f39d1b3SJooyung Han   return a * b;
102*5f39d1b3SJooyung Han }
103*5f39d1b3SJooyung Han 
104*5f39d1b3SJooyung Han // Integer subtraction. Not saturating. Overflow is undefined behavior.
105*5f39d1b3SJooyung Han template <typename tIntegerType>
106*5f39d1b3SJooyung Han tIntegerType Sub(tIntegerType a, tIntegerType b) {
107*5f39d1b3SJooyung Han   return a - b;
108*5f39d1b3SJooyung Han }
109*5f39d1b3SJooyung Han 
110*5f39d1b3SJooyung Han // Integer unary negative. Not saturating. Overflow is undefined behavior.
111*5f39d1b3SJooyung Han template <typename tIntegerType>
112*5f39d1b3SJooyung Han tIntegerType Neg(tIntegerType a) {
113*5f39d1b3SJooyung Han   return -a;
114*5f39d1b3SJooyung Han }
115*5f39d1b3SJooyung Han 
116*5f39d1b3SJooyung Han // Integer arithmetic left-shift, equivalent to multiplying with a power of two.
117*5f39d1b3SJooyung Han // Negative values are OK. In case of overflow, no Undefined
118*5f39d1b3SJooyung Han // Behavior, but the results are implementation-defined (in practice,
119*5f39d1b3SJooyung Han // they currently are saturated, but we make no commitment to that). The idea
120*5f39d1b3SJooyung Han // is that the caller will want to implement the overflowing cases with
121*5f39d1b3SJooyung Han // saturation with compare-and-mask, so we don't care about the results
122*5f39d1b3SJooyung Han // in the overflow case, we just want to avoid undefined behavior.
123*5f39d1b3SJooyung Han //
124*5f39d1b3SJooyung Han // tIntegerType may be int32 or any narrower signed type.
125*5f39d1b3SJooyung Han template <typename tIntegerType, typename OffsetType>
126*5f39d1b3SJooyung Han tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
127*5f39d1b3SJooyung Han   const std::int64_t wide_a = static_cast<std::int64_t>(a);
128*5f39d1b3SJooyung Han   const std::int64_t wide_shifted = wide_a * (1 << offset);
129*5f39d1b3SJooyung Han   const auto min = std::numeric_limits<tIntegerType>::min();
130*5f39d1b3SJooyung Han   const auto max = std::numeric_limits<tIntegerType>::max();
131*5f39d1b3SJooyung Han   return wide_shifted < min
132*5f39d1b3SJooyung Han              ? min
133*5f39d1b3SJooyung Han              : wide_shifted > max ? max
134*5f39d1b3SJooyung Han                                   : static_cast<tIntegerType>(wide_shifted);
135*5f39d1b3SJooyung Han }
136*5f39d1b3SJooyung Han 
137*5f39d1b3SJooyung Han // Integer arithmetic right-shift. Not rounding.
138*5f39d1b3SJooyung Han // Relying on implementation-defined, but in-practice-consistent,
139*5f39d1b3SJooyung Han // C++ compiler behavior.
140*5f39d1b3SJooyung Han template <typename tIntegerType>
141*5f39d1b3SJooyung Han tIntegerType ShiftRight(tIntegerType a, int offset) {
142*5f39d1b3SJooyung Han   return a >> offset;
143*5f39d1b3SJooyung Han }
144*5f39d1b3SJooyung Han 
145*5f39d1b3SJooyung Han // Each bit of the result is set to the corresponding bit of either then_val or
146*5f39d1b3SJooyung Han // else_val depending on whether the corresponding bit of if_mask is set.
147*5f39d1b3SJooyung Han // Equivalent to the VBSL instruction in ARM NEON.
148*5f39d1b3SJooyung Han template <typename tIntegerType>
149*5f39d1b3SJooyung Han tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
150*5f39d1b3SJooyung Han                              tIntegerType else_val) {
151*5f39d1b3SJooyung Han   return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
152*5f39d1b3SJooyung Han }
153*5f39d1b3SJooyung Han 
154*5f39d1b3SJooyung Han // For each input scalar, the corresponding bits of the result are set if the
155*5f39d1b3SJooyung Han // input scalar is non-zero.
156*5f39d1b3SJooyung Han template <typename tIntegerType>
157*5f39d1b3SJooyung Han tIntegerType MaskIfNonZero(tIntegerType a) {
158*5f39d1b3SJooyung Han   static constexpr tIntegerType zero = 0;
159*5f39d1b3SJooyung Han   return a ? BitNot(zero) : zero;
160*5f39d1b3SJooyung Han }
161*5f39d1b3SJooyung Han 
162*5f39d1b3SJooyung Han // For each input scalar, the corresponding bits of the result are set if the
163*5f39d1b3SJooyung Han // input scalar is zero.
164*5f39d1b3SJooyung Han template <typename tIntegerType>
165*5f39d1b3SJooyung Han tIntegerType MaskIfZero(tIntegerType a) {
166*5f39d1b3SJooyung Han   return MaskIfNonZero<tIntegerType>(!a);
167*5f39d1b3SJooyung Han }
168*5f39d1b3SJooyung Han 
169*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are
170*5f39d1b3SJooyung Han // set if the input scalars are equal.
171*5f39d1b3SJooyung Han template <typename tIntegerType>
172*5f39d1b3SJooyung Han tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
173*5f39d1b3SJooyung Han   return MaskIfNonZero<tIntegerType>(a == b);
174*5f39d1b3SJooyung Han }
175*5f39d1b3SJooyung Han 
176*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are
177*5f39d1b3SJooyung Han // set if the input scalars are not equal.
178*5f39d1b3SJooyung Han template <typename tIntegerType>
179*5f39d1b3SJooyung Han tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
180*5f39d1b3SJooyung Han   return MaskIfNonZero<tIntegerType>(a != b);
181*5f39d1b3SJooyung Han }
182*5f39d1b3SJooyung Han 
183*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are
184*5f39d1b3SJooyung Han // set if the input scalars a, b satisfy a > b.
185*5f39d1b3SJooyung Han template <typename tIntegerType>
186*5f39d1b3SJooyung Han tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
187*5f39d1b3SJooyung Han   return MaskIfNonZero<tIntegerType>(a > b);
188*5f39d1b3SJooyung Han }
189*5f39d1b3SJooyung Han 
190*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are
191*5f39d1b3SJooyung Han // set if the input scalars a, b satisfy a >= b.
192*5f39d1b3SJooyung Han template <typename tIntegerType>
193*5f39d1b3SJooyung Han tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
194*5f39d1b3SJooyung Han   return MaskIfNonZero<tIntegerType>(a >= b);
195*5f39d1b3SJooyung Han }
196*5f39d1b3SJooyung Han 
197*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are
198*5f39d1b3SJooyung Han // set if the input scalars a, b satisfy a < b.
199*5f39d1b3SJooyung Han template <typename tIntegerType>
200*5f39d1b3SJooyung Han tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
201*5f39d1b3SJooyung Han   return MaskIfNonZero<tIntegerType>(a < b);
202*5f39d1b3SJooyung Han }
203*5f39d1b3SJooyung Han 
204*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are
205*5f39d1b3SJooyung Han // set if the input scalars a, b satisfy a <= b.
206*5f39d1b3SJooyung Han template <typename tIntegerType>
207*5f39d1b3SJooyung Han tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
208*5f39d1b3SJooyung Han   return MaskIfNonZero<tIntegerType>(a <= b);
209*5f39d1b3SJooyung Han }
210*5f39d1b3SJooyung Han 
211*5f39d1b3SJooyung Han // Returns true if all of the input scalars are nonzero.
212*5f39d1b3SJooyung Han // This function may currently assume that each of the input scalars has either
213*5f39d1b3SJooyung Han // all or none of its bits set. Otherwise, its behavior is currently undefined.
214*5f39d1b3SJooyung Han template <typename tIntegerType>
215*5f39d1b3SJooyung Han bool All(tIntegerType a) {
216*5f39d1b3SJooyung Han   return a;
217*5f39d1b3SJooyung Han }
218*5f39d1b3SJooyung Han 
219*5f39d1b3SJooyung Han // Returns true if any of the input scalars are nonzero.
220*5f39d1b3SJooyung Han // This function may currently assume that each of the input scalars has either
221*5f39d1b3SJooyung Han // all or none of its bits set. Otherwise, its behavior is currently undefined.
222*5f39d1b3SJooyung Han template <typename tIntegerType>
223*5f39d1b3SJooyung Han bool Any(tIntegerType a) {
224*5f39d1b3SJooyung Han   return a;
225*5f39d1b3SJooyung Han }
226*5f39d1b3SJooyung Han 
227*5f39d1b3SJooyung Han // Returns (a+b)/2, rounded to the nearest integer.
228*5f39d1b3SJooyung Han // Equivalent to VRHADD in the ARM NEON instruction set.
229*5f39d1b3SJooyung Han template <typename IntegerType>
230*5f39d1b3SJooyung Han IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
231*5f39d1b3SJooyung Han   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
232*5f39d1b3SJooyung Han   (void)b;
233*5f39d1b3SJooyung Han   return a;
234*5f39d1b3SJooyung Han }
235*5f39d1b3SJooyung Han 
236*5f39d1b3SJooyung Han template <>
237*5f39d1b3SJooyung Han inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
238*5f39d1b3SJooyung Han   std::int64_t a64 = a;
239*5f39d1b3SJooyung Han   std::int64_t b64 = b;
240*5f39d1b3SJooyung Han   std::int64_t sum = a64 + b64;
241*5f39d1b3SJooyung Han   std::int64_t sign = sum >= 0 ? 1 : -1;
242*5f39d1b3SJooyung Han   return static_cast<std::int32_t>((sum + sign) / 2);
243*5f39d1b3SJooyung Han }
244*5f39d1b3SJooyung Han 
245*5f39d1b3SJooyung Han template <>
246*5f39d1b3SJooyung Han inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
247*5f39d1b3SJooyung Han   std::int32_t a32 = a;
248*5f39d1b3SJooyung Han   std::int32_t b32 = b;
249*5f39d1b3SJooyung Han   std::int32_t sum = a32 + b32;
250*5f39d1b3SJooyung Han   std::int32_t sign = sum >= 0 ? 1 : -1;
251*5f39d1b3SJooyung Han   return static_cast<std::int16_t>((sum + sign) / 2);
252*5f39d1b3SJooyung Han }
253*5f39d1b3SJooyung Han 
254*5f39d1b3SJooyung Han template <typename IntegerType>
255*5f39d1b3SJooyung Han IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
256*5f39d1b3SJooyung Han   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
257*5f39d1b3SJooyung Han   (void)b;
258*5f39d1b3SJooyung Han   return a;
259*5f39d1b3SJooyung Han }
260*5f39d1b3SJooyung Han 
261*5f39d1b3SJooyung Han // So far this is only needed for int16.
262*5f39d1b3SJooyung Han template <>
263*5f39d1b3SJooyung Han inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
264*5f39d1b3SJooyung Han   std::int32_t a32 = a;
265*5f39d1b3SJooyung Han   std::int32_t b32 = b;
266*5f39d1b3SJooyung Han   std::int32_t sum = a32 + b32;
267*5f39d1b3SJooyung Han   return static_cast<std::int16_t>(
268*5f39d1b3SJooyung Han       std::min(static_cast<std::int32_t>(32767),
269*5f39d1b3SJooyung Han                std::max(static_cast<std::int32_t>(-32768), sum)));
270*5f39d1b3SJooyung Han }
271*5f39d1b3SJooyung Han 
272*5f39d1b3SJooyung Han template <>
273*5f39d1b3SJooyung Han inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) {
274*5f39d1b3SJooyung Han   std::int16_t a16 = a;
275*5f39d1b3SJooyung Han   std::int16_t b16 = b;
276*5f39d1b3SJooyung Han   std::int16_t sum = a16 + b16;
277*5f39d1b3SJooyung Han   return static_cast<std::int8_t>(std::min(
278*5f39d1b3SJooyung Han       static_cast<int16_t>(std::numeric_limits<int8_t>::max()),
279*5f39d1b3SJooyung Han       std::max(static_cast<int16_t>(std::numeric_limits<int8_t>::min()), sum)));
280*5f39d1b3SJooyung Han }
281*5f39d1b3SJooyung Han 
282*5f39d1b3SJooyung Han // Returns a+b, saturating if the integers are 16bit or narrower,
283*5f39d1b3SJooyung Han // otherwise just a plain addition.
284*5f39d1b3SJooyung Han template <typename IntegerType, bool Is16Bit>
285*5f39d1b3SJooyung Han struct AddSaturatingIf16BitImpl {
286*5f39d1b3SJooyung Han   static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
287*5f39d1b3SJooyung Han };
288*5f39d1b3SJooyung Han template <typename IntegerType>
289*5f39d1b3SJooyung Han struct AddSaturatingIf16BitImpl<IntegerType, true> {
290*5f39d1b3SJooyung Han   static IntegerType Run(IntegerType a, IntegerType b) {
291*5f39d1b3SJooyung Han     return SaturatingAdd(a, b);
292*5f39d1b3SJooyung Han   }
293*5f39d1b3SJooyung Han };
294*5f39d1b3SJooyung Han template <typename IntegerType>
295*5f39d1b3SJooyung Han IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
296*5f39d1b3SJooyung Han   using ScalarType =
297*5f39d1b3SJooyung Han       typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
298*5f39d1b3SJooyung Han   return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
299*5f39d1b3SJooyung Han                                                                              b);
300*5f39d1b3SJooyung Han }
301*5f39d1b3SJooyung Han 
302*5f39d1b3SJooyung Han // Returns the integer that represents the product of two fixed-point
303*5f39d1b3SJooyung Han // numbers, interpreting all integers as fixed-point values in the
304*5f39d1b3SJooyung Han // interval [-1, 1), rounding to the nearest value, and saturating
305*5f39d1b3SJooyung Han // -1 * -1 to the maximum value (since 1 is not in the half-open
306*5f39d1b3SJooyung Han // interval [-1, 1)).
307*5f39d1b3SJooyung Han //
308*5f39d1b3SJooyung Han // [The explanation below specializes to std::int32_t for example purpose.]
309*5f39d1b3SJooyung Han //
310*5f39d1b3SJooyung Han // The mapping between IntegerType and the interval [-1, 1) is unique and
311*5f39d1b3SJooyung Han // implied by IntegerType, which is assumed to be signed. For example,
312*5f39d1b3SJooyung Han // for IntegerType==std::int32_t, the mapping is
313*5f39d1b3SJooyung Han //   real_value = integer_value / 2^31.
314*5f39d1b3SJooyung Han // So in this case, and leaving aside rounding and saturating, this
315*5f39d1b3SJooyung Han // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
316*5f39d1b3SJooyung Han //   (a * b) / 2^31.
317*5f39d1b3SJooyung Han //
318*5f39d1b3SJooyung Han // The 'doubling' part in the name of this function comes from the fact that
319*5f39d1b3SJooyung Han // this operation is very close to a "multiply-high" operation, keeping only
320*5f39d1b3SJooyung Han // the top half bits, except that that would be effectively computing
321*5f39d1b3SJooyung Han //   (a * b) / 2^32,
322*5f39d1b3SJooyung Han // so here we are computing 2x that, since
323*5f39d1b3SJooyung Han //   1/2^31 = 2 * 1/2^32.
324*5f39d1b3SJooyung Han // The idea is to use all of the available 32 bits in the destination int32
325*5f39d1b3SJooyung Han // value.
326*5f39d1b3SJooyung Han //
327*5f39d1b3SJooyung Han // [End of the explanation specializing to int32.]
328*5f39d1b3SJooyung Han //
329*5f39d1b3SJooyung Han // This is equivalent to the VQRDMULH instruction in ARM NEON.
330*5f39d1b3SJooyung Han template <typename IntegerType>
331*5f39d1b3SJooyung Han IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
332*5f39d1b3SJooyung Han   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
333*5f39d1b3SJooyung Han   (void)b;
334*5f39d1b3SJooyung Han   return a;
335*5f39d1b3SJooyung Han }
336*5f39d1b3SJooyung Han 
337*5f39d1b3SJooyung Han // This function implements the same computation as the ARMv7 NEON VQRDMULH
338*5f39d1b3SJooyung Han // instruction.
339*5f39d1b3SJooyung Han template <>
340*5f39d1b3SJooyung Han inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
341*5f39d1b3SJooyung Han                                                       std::int32_t b) {
342*5f39d1b3SJooyung Han   bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
343*5f39d1b3SJooyung Han   std::int64_t a_64(a);
344*5f39d1b3SJooyung Han   std::int64_t b_64(b);
345*5f39d1b3SJooyung Han   std::int64_t ab_64 = a_64 * b_64;
346*5f39d1b3SJooyung Han   std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
347*5f39d1b3SJooyung Han   std::int32_t ab_x2_high32 =
348*5f39d1b3SJooyung Han       static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
349*5f39d1b3SJooyung Han   return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
350*5f39d1b3SJooyung Han }
351*5f39d1b3SJooyung Han 
352*5f39d1b3SJooyung Han template <>
353*5f39d1b3SJooyung Han inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
354*5f39d1b3SJooyung Han                                                       std::int16_t b) {
355*5f39d1b3SJooyung Han   bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
356*5f39d1b3SJooyung Han   std::int32_t a_32(a);
357*5f39d1b3SJooyung Han   std::int32_t b_32(b);
358*5f39d1b3SJooyung Han   std::int32_t ab_32 = a_32 * b_32;
359*5f39d1b3SJooyung Han   std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
360*5f39d1b3SJooyung Han   std::int16_t ab_x2_high16 =
361*5f39d1b3SJooyung Han       static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
362*5f39d1b3SJooyung Han   return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
363*5f39d1b3SJooyung Han }
364*5f39d1b3SJooyung Han 
365*5f39d1b3SJooyung Han // Correctly-rounded-to-nearest division by a power-of-two.
366*5f39d1b3SJooyung Han // Also known as a rounding arithmetic right shift.
367*5f39d1b3SJooyung Han template <typename IntegerType, typename ExponentType>
368*5f39d1b3SJooyung Han inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
369*5f39d1b3SJooyung Han   assert(exponent >= 0);
370*5f39d1b3SJooyung Han   assert(exponent <= 31);
371*5f39d1b3SJooyung Han   const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
372*5f39d1b3SJooyung Han   const IntegerType zero = Dup<IntegerType>(0);
373*5f39d1b3SJooyung Han   const IntegerType one = Dup<IntegerType>(1);
374*5f39d1b3SJooyung Han   const IntegerType remainder = BitAnd(x, mask);
375*5f39d1b3SJooyung Han   const IntegerType threshold =
376*5f39d1b3SJooyung Han       Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
377*5f39d1b3SJooyung Han   return Add(ShiftRight(x, exponent),
378*5f39d1b3SJooyung Han              BitAnd(MaskIfGreaterThan(remainder, threshold), one));
379*5f39d1b3SJooyung Han }
380*5f39d1b3SJooyung Han 
381*5f39d1b3SJooyung Han // Returns the product of a run-time integer value by a compile-time power
382*5f39d1b3SJooyung Han // of two, with either a positive exponent (equivalent to an arithmetic
383*5f39d1b3SJooyung Han // left shift, saturating) or a negative exponent (equivalent to an arithmetic
384*5f39d1b3SJooyung Han // right shift, rounding to nearest).
385*5f39d1b3SJooyung Han template <int Exponent, typename IntegerType,
386*5f39d1b3SJooyung Han           int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
387*5f39d1b3SJooyung Han struct ImplSaturatingRoundingMultiplyByPOT {};
388*5f39d1b3SJooyung Han 
389*5f39d1b3SJooyung Han template <int Exponent, typename IntegerType>
390*5f39d1b3SJooyung Han struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
391*5f39d1b3SJooyung Han   static IntegerType eval(IntegerType x) { return x; }
392*5f39d1b3SJooyung Han };
393*5f39d1b3SJooyung Han 
394*5f39d1b3SJooyung Han template <int Exponent, typename IntegerType>
395*5f39d1b3SJooyung Han struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
396*5f39d1b3SJooyung Han   static IntegerType eval(IntegerType x) {
397*5f39d1b3SJooyung Han     using ScalarIntegerType =
398*5f39d1b3SJooyung Han         typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
399*5f39d1b3SJooyung Han     const IntegerType min =
400*5f39d1b3SJooyung Han         Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
401*5f39d1b3SJooyung Han     const IntegerType max =
402*5f39d1b3SJooyung Han         Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
403*5f39d1b3SJooyung Han     const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
404*5f39d1b3SJooyung Han 
405*5f39d1b3SJooyung Han     const std::int32_t threshold =
406*5f39d1b3SJooyung Han         ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
407*5f39d1b3SJooyung Han     const IntegerType positive_mask =
408*5f39d1b3SJooyung Han         MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
409*5f39d1b3SJooyung Han     const IntegerType negative_mask =
410*5f39d1b3SJooyung Han         MaskIfLessThan(x, Dup<IntegerType>(-threshold));
411*5f39d1b3SJooyung Han 
412*5f39d1b3SJooyung Han     IntegerType result = ShiftLeft(x, Exponent);
413*5f39d1b3SJooyung Han     result = SelectUsingMask(positive_mask, max, result);
414*5f39d1b3SJooyung Han     result = SelectUsingMask(negative_mask, min, result);
415*5f39d1b3SJooyung Han     return result;
416*5f39d1b3SJooyung Han   }
417*5f39d1b3SJooyung Han };
418*5f39d1b3SJooyung Han 
419*5f39d1b3SJooyung Han template <int Exponent, typename IntegerType>
420*5f39d1b3SJooyung Han struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
421*5f39d1b3SJooyung Han   static IntegerType eval(IntegerType x) {
422*5f39d1b3SJooyung Han     return RoundingDivideByPOT<IntegerType>(x, -Exponent);
423*5f39d1b3SJooyung Han   }
424*5f39d1b3SJooyung Han };
425*5f39d1b3SJooyung Han 
426*5f39d1b3SJooyung Han template <int Exponent, typename IntegerType>
427*5f39d1b3SJooyung Han IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
428*5f39d1b3SJooyung Han   return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
429*5f39d1b3SJooyung Han }
430*5f39d1b3SJooyung Han 
431*5f39d1b3SJooyung Han // Part 2: the FixedPoint class.
432*5f39d1b3SJooyung Han 
433*5f39d1b3SJooyung Han // A FixedPoint object represents a fixed-point value stored in the underlying
434*5f39d1b3SJooyung Han // integer type tRawType, if tRawType is a plain scalar integer type.
435*5f39d1b3SJooyung Han // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
436*5f39d1b3SJooyung Han // case a FixedPoint object represents a corresponding SIMD vector of fixed
437*5f39d1b3SJooyung Han // point values.
438*5f39d1b3SJooyung Han //
439*5f39d1b3SJooyung Han // tIntegerBits describes the range of the fixed-point format: if
440*5f39d1b3SJooyung Han // tIntegerBits == m then the range of representable values is the half-open
441*5f39d1b3SJooyung Han // interval [-2^m; 2^m) where the open boundary on the right side means that
442*5f39d1b3SJooyung Han // 2^m is not representable (how close the maximum representable value is to
443*5f39d1b3SJooyung Han // it, depends on bit-depth of tRawType).
444*5f39d1b3SJooyung Han //
445*5f39d1b3SJooyung Han // In "Q format notation",
446*5f39d1b3SJooyung Han //   https://en.wikipedia.org/wiki/Q_(number_format)
447*5f39d1b3SJooyung Han // we are describing the format
448*5f39d1b3SJooyung Han //   Qm.n
449*5f39d1b3SJooyung Han // where
450*5f39d1b3SJooyung Han //   m = tIntegerBits
451*5f39d1b3SJooyung Han // and
452*5f39d1b3SJooyung Han //   n = NumberOfBits(tRawType) - (m + 1)
453*5f39d1b3SJooyung Han // Note that the (m + 1) in the above line is because we adopt the convention
454*5f39d1b3SJooyung Han // that we count the integer bits exclusively of the sign bit; so (m + 1) is
455*5f39d1b3SJooyung Han // the total number of integer bits inclusive of the sign bit.
456*5f39d1b3SJooyung Han //
457*5f39d1b3SJooyung Han // Accordingly, the number of integral representable values in our range
458*5f39d1b3SJooyung Han //   [-2^m ; 2^m)
459*5f39d1b3SJooyung Han // is equal to 2^(m+1).
460*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
461*5f39d1b3SJooyung Han class FixedPoint {
462*5f39d1b3SJooyung Han  public:
463*5f39d1b3SJooyung Han   typedef tRawType RawType;
464*5f39d1b3SJooyung Han 
465*5f39d1b3SJooyung Han   typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
466*5f39d1b3SJooyung Han   typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
467*5f39d1b3SJooyung Han 
468*5f39d1b3SJooyung Han   static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
469*5f39d1b3SJooyung Han   static constexpr int kIntegerBits = tIntegerBits;
470*5f39d1b3SJooyung Han   static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
471*5f39d1b3SJooyung Han   static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
472*5f39d1b3SJooyung Han                 "bad IntegerBits");
473*5f39d1b3SJooyung Han 
474*5f39d1b3SJooyung Han   typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
475*5f39d1b3SJooyung Han 
476*5f39d1b3SJooyung Han   static const ScalarRawType ScalarRawMin() {
477*5f39d1b3SJooyung Han     return std::numeric_limits<ScalarRawType>::min();
478*5f39d1b3SJooyung Han   }
479*5f39d1b3SJooyung Han 
480*5f39d1b3SJooyung Han   static const ScalarRawType ScalarRawMax() {
481*5f39d1b3SJooyung Han     return std::numeric_limits<ScalarRawType>::max();
482*5f39d1b3SJooyung Han   }
483*5f39d1b3SJooyung Han 
484*5f39d1b3SJooyung Han   static const ScalarRawType RawMin() {
485*5f39d1b3SJooyung Han     return VectorFromScalar(ScalarRawMin());
486*5f39d1b3SJooyung Han   }
487*5f39d1b3SJooyung Han 
488*5f39d1b3SJooyung Han   static const ScalarRawType RawMax() {
489*5f39d1b3SJooyung Han     return VectorFromScalar(ScalarRawMax());
490*5f39d1b3SJooyung Han   }
491*5f39d1b3SJooyung Han 
492*5f39d1b3SJooyung Han   static FixedPoint FromRaw(RawType x) {
493*5f39d1b3SJooyung Han     FixedPoint retval;
494*5f39d1b3SJooyung Han     retval.raw() = x;
495*5f39d1b3SJooyung Han     return retval;
496*5f39d1b3SJooyung Han   }
497*5f39d1b3SJooyung Han 
498*5f39d1b3SJooyung Han   static FixedPoint FromScalarRaw(ScalarRawType x) {
499*5f39d1b3SJooyung Han     FixedPoint retval;
500*5f39d1b3SJooyung Han     retval.raw() = Dup<RawType>(x);
501*5f39d1b3SJooyung Han     return retval;
502*5f39d1b3SJooyung Han   }
503*5f39d1b3SJooyung Han 
504*5f39d1b3SJooyung Han   static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
505*5f39d1b3SJooyung Han     return FromScalarRaw(x.raw());
506*5f39d1b3SJooyung Han   }
507*5f39d1b3SJooyung Han 
508*5f39d1b3SJooyung Han   template <int Exponent>
509*5f39d1b3SJooyung Han   static FixedPoint ConstantPOT() {
510*5f39d1b3SJooyung Han     static constexpr int kOffset = kFractionalBits + Exponent;
511*5f39d1b3SJooyung Han     static_assert(
512*5f39d1b3SJooyung Han         kOffset < 31,
513*5f39d1b3SJooyung Han         "Constant not exactly representable in this fixed-point format");
514*5f39d1b3SJooyung Han     return FromScalarRaw(ScalarRawType(1) << kOffset);
515*5f39d1b3SJooyung Han   }
516*5f39d1b3SJooyung Han 
517*5f39d1b3SJooyung Han   static FixedPoint Zero() { return FromScalarRaw(0); }
518*5f39d1b3SJooyung Han 
519*5f39d1b3SJooyung Han   static FixedPoint One() {
520*5f39d1b3SJooyung Han     return FromScalarRaw(
521*5f39d1b3SJooyung Han         kIntegerBits == 0
522*5f39d1b3SJooyung Han             ? ScalarRawMax()
523*5f39d1b3SJooyung Han             : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
524*5f39d1b3SJooyung Han   }
525*5f39d1b3SJooyung Han 
526*5f39d1b3SJooyung Han   static FixedPoint FromDouble(double x) {
527*5f39d1b3SJooyung Han     const double min_bound = static_cast<double>(ScalarRawMin());
528*5f39d1b3SJooyung Han     const double max_bound = static_cast<double>(ScalarRawMax());
529*5f39d1b3SJooyung Han     return FromScalarRaw(static_cast<ScalarRawType>(std::min(
530*5f39d1b3SJooyung Han         std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
531*5f39d1b3SJooyung Han                  min_bound),
532*5f39d1b3SJooyung Han         max_bound)));
533*5f39d1b3SJooyung Han   }
534*5f39d1b3SJooyung Han 
535*5f39d1b3SJooyung Han   RawType raw() const { return i_; }
536*5f39d1b3SJooyung Han   RawType& raw() { return i_; }
537*5f39d1b3SJooyung Han 
538*5f39d1b3SJooyung Han  private:
539*5f39d1b3SJooyung Han   RawType i_;
540*5f39d1b3SJooyung Han };
541*5f39d1b3SJooyung Han 
542*5f39d1b3SJooyung Han // Part 3: implementation of arithmetic operators for the
543*5f39d1b3SJooyung Han // FixedPoint class, and a few related functions.
544*5f39d1b3SJooyung Han 
545*5f39d1b3SJooyung Han // A FixedPoint multiplication is just a
546*5f39d1b3SJooyung Han // SaturatingRoundingDoublingHighMul operation on the underlying
547*5f39d1b3SJooyung Han // raw integer values. The IntegerBits simply add up, as is obvious
548*5f39d1b3SJooyung Han // from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
549*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
550*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
551*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits_a> a,
552*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits_b> b) {
553*5f39d1b3SJooyung Han   FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
554*5f39d1b3SJooyung Han   c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
555*5f39d1b3SJooyung Han   return c;
556*5f39d1b3SJooyung Han }
557*5f39d1b3SJooyung Han 
558*5f39d1b3SJooyung Han // Tweaking IntegerBits gives exact multiplication by a power of two.
559*5f39d1b3SJooyung Han template <int tExponent, typename tRawType, int tIntegerBits>
560*5f39d1b3SJooyung Han FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
561*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits> a) {
562*5f39d1b3SJooyung Han   FixedPoint<tRawType, tExponent + tIntegerBits> c;
563*5f39d1b3SJooyung Han   c.raw() = a.raw();
564*5f39d1b3SJooyung Han   return c;
565*5f39d1b3SJooyung Han }
566*5f39d1b3SJooyung Han 
567*5f39d1b3SJooyung Han // If we want to leave IntegerBits fixed, then multiplication
568*5f39d1b3SJooyung Han // by a power of two has to be saturating/rounding, not exact anymore.
569*5f39d1b3SJooyung Han template <int tExponent, typename tRawType, int tIntegerBits>
570*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
571*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits> a) {
572*5f39d1b3SJooyung Han   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
573*5f39d1b3SJooyung Han       SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
574*5f39d1b3SJooyung Han }
575*5f39d1b3SJooyung Han 
576*5f39d1b3SJooyung Han // Generic arithmetic operators.
577*5f39d1b3SJooyung Han 
578*5f39d1b3SJooyung Han #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
579*5f39d1b3SJooyung Han   template <typename tRawType, int tIntegerBits>                               \
580*5f39d1b3SJooyung Han   FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
581*5f39d1b3SJooyung Han       FixedPoint<tRawType, tIntegerBits> a) {                                  \
582*5f39d1b3SJooyung Han     return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
583*5f39d1b3SJooyung Han   }
584*5f39d1b3SJooyung Han 
585*5f39d1b3SJooyung Han #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
586*5f39d1b3SJooyung Han   template <typename tRawType, int tIntegerBits>            \
587*5f39d1b3SJooyung Han   FixedPoint<tRawType, tIntegerBits> FuncName(              \
588*5f39d1b3SJooyung Han       FixedPoint<tRawType, tIntegerBits> a,                 \
589*5f39d1b3SJooyung Han       FixedPoint<tRawType, tIntegerBits> b) {               \
590*5f39d1b3SJooyung Han     return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
591*5f39d1b3SJooyung Han         ImplFuncName(a.raw(), b.raw()));                    \
592*5f39d1b3SJooyung Han   }
593*5f39d1b3SJooyung Han 
594*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
595*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
596*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
597*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
598*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
599*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
600*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
601*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
602*5f39d1b3SJooyung Han 
603*5f39d1b3SJooyung Han #undef MAKE_FIXEDPOINT_UNARY_FUNC
604*5f39d1b3SJooyung Han #undef MAKE_FIXEDPOINT_BINARY_FUNC
605*5f39d1b3SJooyung Han 
606*5f39d1b3SJooyung Han #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
607*5f39d1b3SJooyung Han   template <typename tRawType, int tIntegerBits>            \
608*5f39d1b3SJooyung Han   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
609*5f39d1b3SJooyung Han     return FuncName(a.raw());                               \
610*5f39d1b3SJooyung Han   }
611*5f39d1b3SJooyung Han 
612*5f39d1b3SJooyung Han #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
613*5f39d1b3SJooyung Han   template <typename tRawType, int tIntegerBits>            \
614*5f39d1b3SJooyung Han   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
615*5f39d1b3SJooyung Han                     FixedPoint<tRawType, tIntegerBits> b) { \
616*5f39d1b3SJooyung Han     return FuncName(a.raw(), b.raw());                      \
617*5f39d1b3SJooyung Han   }
618*5f39d1b3SJooyung Han 
619*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
620*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
621*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
622*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
623*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
624*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
625*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
626*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
627*5f39d1b3SJooyung Han 
628*5f39d1b3SJooyung Han #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
629*5f39d1b3SJooyung Han #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
630*5f39d1b3SJooyung Han 
631*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
632*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
633*5f39d1b3SJooyung Han     tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
634*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits> else_val) {
635*5f39d1b3SJooyung Han   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
636*5f39d1b3SJooyung Han       SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
637*5f39d1b3SJooyung Han }
638*5f39d1b3SJooyung Han 
639*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
640*5f39d1b3SJooyung Han bool operator==(FixedPoint<tRawType, tIntegerBits> a,
641*5f39d1b3SJooyung Han                 FixedPoint<tRawType, tIntegerBits> b) {
642*5f39d1b3SJooyung Han   return All(MaskIfEqual(a.raw(), b.raw()));
643*5f39d1b3SJooyung Han }
644*5f39d1b3SJooyung Han 
645*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
646*5f39d1b3SJooyung Han bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
647*5f39d1b3SJooyung Han                 FixedPoint<tRawType, tIntegerBits> b) {
648*5f39d1b3SJooyung Han   return !(a == b);
649*5f39d1b3SJooyung Han }
650*5f39d1b3SJooyung Han 
651*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
652*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
653*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits> a,
654*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits> b) {
655*5f39d1b3SJooyung Han   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
656*5f39d1b3SJooyung Han       SaturatingAdd(a.raw(), b.raw()));
657*5f39d1b3SJooyung Han }
658*5f39d1b3SJooyung Han 
659*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
660*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
661*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits> a,
662*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits> b) {
663*5f39d1b3SJooyung Han   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
664*5f39d1b3SJooyung Han       AddSaturatingIf16Bit(a.raw(), b.raw()));
665*5f39d1b3SJooyung Han }
666*5f39d1b3SJooyung Han 
667*5f39d1b3SJooyung Han // Conversion to floating-point.
668*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
669*5f39d1b3SJooyung Han double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
670*5f39d1b3SJooyung Han   static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
671*5f39d1b3SJooyung Han                 "not applicable to SIMD types");
672*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, tIntegerBits> F;
673*5f39d1b3SJooyung Han   return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
674*5f39d1b3SJooyung Han }
675*5f39d1b3SJooyung Han 
676*5f39d1b3SJooyung Han // Rescale changes the number of IntegerBits and updates the underlying
677*5f39d1b3SJooyung Han // raw integer value accordingly.
678*5f39d1b3SJooyung Han template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
679*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBitsDst> Rescale(
680*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBitsSrc> x) {
681*5f39d1b3SJooyung Han   static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
682*5f39d1b3SJooyung Han   FixedPoint<tRawType, tIntegerBitsDst> result;
683*5f39d1b3SJooyung Han   result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
684*5f39d1b3SJooyung Han   return result;
685*5f39d1b3SJooyung Han }
686*5f39d1b3SJooyung Han 
687*5f39d1b3SJooyung Han // CheckedFixedPointConstant allows to specify fixed-point constants
688*5f39d1b3SJooyung Han // initialized as real numbers, in a way that does not compile floating-point
689*5f39d1b3SJooyung Han // arithmetic in production code, yet still checks agreement with the
690*5f39d1b3SJooyung Han // floating-point expressions when asserts are enabled.
691*5f39d1b3SJooyung Han //
692*5f39d1b3SJooyung Han // The raw integer value provided is always a int32, encoding a 32-bit
693*5f39d1b3SJooyung Han // fixed-point value, regardless of the actual Scalar type. This allows
694*5f39d1b3SJooyung Han // writing generic code that applies just as well to the 32-bit and 16-bit
695*5f39d1b3SJooyung Han // cases. In the 16-bit case, the raw integer value is internally
696*5f39d1b3SJooyung Han // rounding-shifted by 16 bits to the right.
697*5f39d1b3SJooyung Han template <typename FixedPointType>
698*5f39d1b3SJooyung Han inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
699*5f39d1b3SJooyung Han     std::int32_t int32_value) {
700*5f39d1b3SJooyung Han   typedef typename FixedPointType::ScalarRawType ScalarRawType;
701*5f39d1b3SJooyung Han   static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
702*5f39d1b3SJooyung Han   return static_cast<ScalarRawType>(
703*5f39d1b3SJooyung Han       RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
704*5f39d1b3SJooyung Han }
705*5f39d1b3SJooyung Han #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
706*5f39d1b3SJooyung Han template <typename FixedPointType>
707*5f39d1b3SJooyung Han FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
708*5f39d1b3SJooyung Han                                          double double_value) {
709*5f39d1b3SJooyung Han   const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
710*5f39d1b3SJooyung Han   assert(result == FixedPointType::FromDouble(double_value));
711*5f39d1b3SJooyung Han   return result;
712*5f39d1b3SJooyung Han }
713*5f39d1b3SJooyung Han #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
714*5f39d1b3SJooyung Han                                              ScalarRawInt32Value, DoubleValue) \
715*5f39d1b3SJooyung Han   (gemmlowp::CheckedFixedPointConstant<FixedPointType>(                        \
716*5f39d1b3SJooyung Han       gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
717*5f39d1b3SJooyung Han           ScalarRawInt32Value),                                                \
718*5f39d1b3SJooyung Han       DoubleValue))
719*5f39d1b3SJooyung Han 
720*5f39d1b3SJooyung Han #else
721*5f39d1b3SJooyung Han #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
722*5f39d1b3SJooyung Han                                              ScalarRawInt32Value, DoubleValue) \
723*5f39d1b3SJooyung Han   (FixedPointType::FromScalarRaw(                                              \
724*5f39d1b3SJooyung Han       gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
725*5f39d1b3SJooyung Han           ScalarRawInt32Value)))
726*5f39d1b3SJooyung Han #endif
727*5f39d1b3SJooyung Han 
728*5f39d1b3SJooyung Han // Implementation of exponential function.
729*5f39d1b3SJooyung Han 
730*5f39d1b3SJooyung Han // Returns exp(x) for x in [-1/4, 0).
731*5f39d1b3SJooyung Han template <typename tRawType>
732*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
733*5f39d1b3SJooyung Han     FixedPoint<tRawType, 0> a) {
734*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, 0> F;
735*5f39d1b3SJooyung Han   const F constant_term =
736*5f39d1b3SJooyung Han       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
737*5f39d1b3SJooyung Han   const F constant_1_over_3 =
738*5f39d1b3SJooyung Han       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
739*5f39d1b3SJooyung Han   // We're evaluating a Taylor expansion around -1/8, so we do the change of
740*5f39d1b3SJooyung Han   // variable: x = a + 1/8.
741*5f39d1b3SJooyung Han   // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
742*5f39d1b3SJooyung Han   F x = a + F::template ConstantPOT<-3>();
743*5f39d1b3SJooyung Han   F x2 = x * x;
744*5f39d1b3SJooyung Han   F x3 = x2 * x;
745*5f39d1b3SJooyung Han   F x4 = x2 * x2;
746*5f39d1b3SJooyung Han   F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
747*5f39d1b3SJooyung Han   F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
748*5f39d1b3SJooyung Han       SaturatingRoundingMultiplyByPOT<-1>(
749*5f39d1b3SJooyung Han           ((x4_over_4 + x3) * constant_1_over_3) + x2);
750*5f39d1b3SJooyung Han   return AddSaturatingIf16Bit(
751*5f39d1b3SJooyung Han       constant_term,
752*5f39d1b3SJooyung Han       constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
753*5f39d1b3SJooyung Han }
754*5f39d1b3SJooyung Han 
755*5f39d1b3SJooyung Han // Returns exp(x) for x < 0.
756*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
757*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> exp_on_negative_values(
758*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits> a) {
759*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, tIntegerBits> InputF;
760*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, 0> ResultF;
761*5f39d1b3SJooyung Han   static constexpr int kFractionalBits = InputF::kFractionalBits;
762*5f39d1b3SJooyung Han   static constexpr int kIntegerBits = InputF::kIntegerBits;
763*5f39d1b3SJooyung Han   const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
764*5f39d1b3SJooyung Han   InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
765*5f39d1b3SJooyung Han   InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
766*5f39d1b3SJooyung Han   ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
767*5f39d1b3SJooyung Han       Rescale<0>(a_mod_quarter_minus_one_quarter));
768*5f39d1b3SJooyung Han   tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
769*5f39d1b3SJooyung Han 
770*5f39d1b3SJooyung Han #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
771*5f39d1b3SJooyung Han   if (kIntegerBits > Exponent) {                                            \
772*5f39d1b3SJooyung Han     const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(       \
773*5f39d1b3SJooyung Han         ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
774*5f39d1b3SJooyung Han     static constexpr int kShiftAmount =                                     \
775*5f39d1b3SJooyung Han         kIntegerBits > Exponent ? kFractionalBits + Exponent : 0;           \
776*5f39d1b3SJooyung Han     result = SelectUsingMask(                                               \
777*5f39d1b3SJooyung Han         MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
778*5f39d1b3SJooyung Han         result * kMultiplier, result);                                      \
779*5f39d1b3SJooyung Han   }
780*5f39d1b3SJooyung Han 
781*5f39d1b3SJooyung Han   // Constants below are Q0 representations of negative exp fractionals:
782*5f39d1b3SJooyung Han   GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);  // exp(-1/4)
783*5f39d1b3SJooyung Han   GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);  // exp(-1/2)
784*5f39d1b3SJooyung Han   GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);   // exp(-1)
785*5f39d1b3SJooyung Han   GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);   // exp(-2)
786*5f39d1b3SJooyung Han   GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);    // exp(-4)
787*5f39d1b3SJooyung Han   GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);      // exp(-8)
788*5f39d1b3SJooyung Han   GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);         // exp(-16)
789*5f39d1b3SJooyung Han 
790*5f39d1b3SJooyung Han #undef GEMMLOWP_EXP_BARREL_SHIFTER
791*5f39d1b3SJooyung Han 
792*5f39d1b3SJooyung Han   static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
793*5f39d1b3SJooyung Han   if (kIntegerBits > 5) {
794*5f39d1b3SJooyung Han     const InputF clamp =
795*5f39d1b3SJooyung Han         GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
796*5f39d1b3SJooyung Han     result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
797*5f39d1b3SJooyung Han   }
798*5f39d1b3SJooyung Han 
799*5f39d1b3SJooyung Han   result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
800*5f39d1b3SJooyung Han   return result;
801*5f39d1b3SJooyung Han }
802*5f39d1b3SJooyung Han 
803*5f39d1b3SJooyung Han // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
804*5f39d1b3SJooyung Han 
805*5f39d1b3SJooyung Han // Returns (1 - x) / (1 + x) for x in (0, 1).
806*5f39d1b3SJooyung Han template <typename tRawType>
807*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
808*5f39d1b3SJooyung Han     FixedPoint<tRawType, 0> a) {
809*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, 0> F0;
810*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, 2> F2;
811*5f39d1b3SJooyung Han   F0 half_denominator = RoundingHalfSum(a, F0::One());
812*5f39d1b3SJooyung Han   // Newton-Raphson division
813*5f39d1b3SJooyung Han   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
814*5f39d1b3SJooyung Han   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
815*5f39d1b3SJooyung Han   const F2 constant_48_over_17 =
816*5f39d1b3SJooyung Han       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
817*5f39d1b3SJooyung Han   const F2 constant_neg_32_over_17 =
818*5f39d1b3SJooyung Han       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
819*5f39d1b3SJooyung Han   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
820*5f39d1b3SJooyung Han   for (int i = 0; i < 3; i++) {
821*5f39d1b3SJooyung Han     F2 half_denominator_times_x = half_denominator * x;
822*5f39d1b3SJooyung Han     F2 one_minus_half_denominator_times_x =
823*5f39d1b3SJooyung Han         F2::One() - half_denominator_times_x;
824*5f39d1b3SJooyung Han     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
825*5f39d1b3SJooyung Han   }
826*5f39d1b3SJooyung Han   return Rescale<0>(x - F2::One());
827*5f39d1b3SJooyung Han }
828*5f39d1b3SJooyung Han 
829*5f39d1b3SJooyung Han // Returns -tanh(x) for x < 0.
830*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
831*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
832*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits> a) {
833*5f39d1b3SJooyung Han   return one_minus_x_over_one_plus_x_for_x_in_0_1(
834*5f39d1b3SJooyung Han       exp_on_negative_values(ExactMulByPot<1>(a)));
835*5f39d1b3SJooyung Han }
836*5f39d1b3SJooyung Han 
837*5f39d1b3SJooyung Han // Returns tanh(x) for any x.
838*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
839*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
840*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, tIntegerBits> InputF;
841*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, 0> ResultF;
842*5f39d1b3SJooyung Han   tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
843*5f39d1b3SJooyung Han   tRawType mask_if_zero = MaskIfZero(a);
844*5f39d1b3SJooyung Han   InputF n = SelectUsingMask(mask_if_negative, a, -a);
845*5f39d1b3SJooyung Han   ResultF t = neg_tanh_on_negative_values(n);
846*5f39d1b3SJooyung Han   return SelectUsingMask(mask_if_zero, ResultF::Zero(),
847*5f39d1b3SJooyung Han                          SelectUsingMask(mask_if_negative, -t, t));
848*5f39d1b3SJooyung Han }
849*5f39d1b3SJooyung Han 
850*5f39d1b3SJooyung Han // Implementation of logistic function.
851*5f39d1b3SJooyung Han 
852*5f39d1b3SJooyung Han // Returns 1 / (1 + x) for x in (0, 1).
853*5f39d1b3SJooyung Han template <typename tRawType>
854*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
855*5f39d1b3SJooyung Han     FixedPoint<tRawType, 0> a) {
856*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, 0> F0;
857*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, 2> F2;
858*5f39d1b3SJooyung Han   F0 half_denominator = RoundingHalfSum(a, F0::One());
859*5f39d1b3SJooyung Han   // Newton-Raphson division
860*5f39d1b3SJooyung Han   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
861*5f39d1b3SJooyung Han   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
862*5f39d1b3SJooyung Han   const F2 constant_48_over_17 =
863*5f39d1b3SJooyung Han       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
864*5f39d1b3SJooyung Han   const F2 constant_neg_32_over_17 =
865*5f39d1b3SJooyung Han       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
866*5f39d1b3SJooyung Han   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
867*5f39d1b3SJooyung Han   for (int i = 0; i < 3; i++) {
868*5f39d1b3SJooyung Han     F2 half_denominator_times_x = half_denominator * x;
869*5f39d1b3SJooyung Han     F2 one_minus_half_denominator_times_x =
870*5f39d1b3SJooyung Han         F2::One() - half_denominator_times_x;
871*5f39d1b3SJooyung Han     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
872*5f39d1b3SJooyung Han   }
873*5f39d1b3SJooyung Han   return Rescale<0>(ExactMulByPot<-1>(x));
874*5f39d1b3SJooyung Han }
875*5f39d1b3SJooyung Han 
876*5f39d1b3SJooyung Han // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
877*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
878*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> logistic_on_positive_values(
879*5f39d1b3SJooyung Han     FixedPoint<tRawType, tIntegerBits> a) {
880*5f39d1b3SJooyung Han   return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
881*5f39d1b3SJooyung Han }
882*5f39d1b3SJooyung Han 
883*5f39d1b3SJooyung Han // Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
884*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits>
885*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
886*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, tIntegerBits> InputF;
887*5f39d1b3SJooyung Han   typedef FixedPoint<tRawType, 0> ResultF;
888*5f39d1b3SJooyung Han   tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
889*5f39d1b3SJooyung Han   tRawType mask_if_zero = MaskIfZero(a);
890*5f39d1b3SJooyung Han   InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
891*5f39d1b3SJooyung Han   ResultF result_if_positive = logistic_on_positive_values(abs_input);
892*5f39d1b3SJooyung Han   ResultF result_if_negative = ResultF::One() - result_if_positive;
893*5f39d1b3SJooyung Han   const ResultF one_half =
894*5f39d1b3SJooyung Han       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
895*5f39d1b3SJooyung Han   return SelectUsingMask(mask_if_zero, one_half,
896*5f39d1b3SJooyung Han                          SelectUsingMask(mask_if_positive, result_if_positive,
897*5f39d1b3SJooyung Han                                          result_if_negative));
898*5f39d1b3SJooyung Han }
899*5f39d1b3SJooyung Han 
900*5f39d1b3SJooyung Han }  // end namespace gemmlowp
901*5f39d1b3SJooyung Han 
902*5f39d1b3SJooyung Han #ifdef GEMMLOWP_NEON
903*5f39d1b3SJooyung Han #include "./fixedpoint_neon.h"
904*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_AVX2)
905*5f39d1b3SJooyung Han #include "./fixedpoint_avx.h"
906*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_SSE4)
907*5f39d1b3SJooyung Han #include "./fixedpoint_sse.h"
908*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_MSA)
909*5f39d1b3SJooyung Han #include "./fixedpoint_msa.h"
910*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_WASMSIMD)
911*5f39d1b3SJooyung Han #include "./fixedpoint_wasmsimd.h"
912*5f39d1b3SJooyung Han #endif
913*5f39d1b3SJooyung Han 
914*5f39d1b3SJooyung Han #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
915