1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2*5f39d1b3SJooyung Han // 3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License"); 4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License. 5*5f39d1b3SJooyung Han // You may obtain a copy of the License at 6*5f39d1b3SJooyung Han // 7*5f39d1b3SJooyung Han // http://www.apache.org/licenses/LICENSE-2.0 8*5f39d1b3SJooyung Han // 9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software 10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS, 11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and 13*5f39d1b3SJooyung Han // limitations under the License. 14*5f39d1b3SJooyung Han 15*5f39d1b3SJooyung Han // fixedpoint.h: fixed-point arithmetic, with basic operations and 16*5f39d1b3SJooyung Han // a few math functions such as tanh. 17*5f39d1b3SJooyung Han 18*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 19*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 20*5f39d1b3SJooyung Han 21*5f39d1b3SJooyung Han #include <algorithm> 22*5f39d1b3SJooyung Han #include <cassert> 23*5f39d1b3SJooyung Han #include <cmath> 24*5f39d1b3SJooyung Han #include <cstdint> 25*5f39d1b3SJooyung Han #include <limits> 26*5f39d1b3SJooyung Han 27*5f39d1b3SJooyung Han #include "../internal/detect_platform.h" 28*5f39d1b3SJooyung Han 29*5f39d1b3SJooyung Han namespace gemmlowp { 30*5f39d1b3SJooyung Han 31*5f39d1b3SJooyung Han // Part 1: Low-level integer-arithmetic primitives. 32*5f39d1b3SJooyung Han // The implementations here are generic implementations valid for 33*5f39d1b3SJooyung Han // scalar types (e.g. std::int32_t). Architecture-specific SIMD types 34*5f39d1b3SJooyung Han // (e.g. NEON int32x4_t) may be supported by providing 35*5f39d1b3SJooyung Han // specializations for them in separate files. 36*5f39d1b3SJooyung Han // 37*5f39d1b3SJooyung Han // The purpose of these primitives is two-fold: 38*5f39d1b3SJooyung Han // - They will be used to implement higher-level fixed-point 39*5f39d1b3SJooyung Han // abstractions, namely the FixedPoint class and its arithmetic 40*5f39d1b3SJooyung Han // operators. 41*5f39d1b3SJooyung Han // - They will be directly used to implement some more involved 42*5f39d1b3SJooyung Han // fixed-point computations, e.g. the fixed-point implementation 43*5f39d1b3SJooyung Han // of math functions such as tanh. 44*5f39d1b3SJooyung Han 45*5f39d1b3SJooyung Han // Some compile-time traits around raw types to handle SIMD aspects: 46*5f39d1b3SJooyung Han // number of lanes, underlying scalar type. 47*5f39d1b3SJooyung Han template <typename tIntegerType> 48*5f39d1b3SJooyung Han struct FixedPointRawTypeTraits {}; 49*5f39d1b3SJooyung Han 50*5f39d1b3SJooyung Han template <> 51*5f39d1b3SJooyung Han struct FixedPointRawTypeTraits<std::int32_t> { 52*5f39d1b3SJooyung Han typedef std::int32_t ScalarRawType; 53*5f39d1b3SJooyung Han static constexpr int kLanes = 1; 54*5f39d1b3SJooyung Han }; 55*5f39d1b3SJooyung Han 56*5f39d1b3SJooyung Han template <> 57*5f39d1b3SJooyung Han struct FixedPointRawTypeTraits<std::int16_t> { 58*5f39d1b3SJooyung Han typedef std::int16_t ScalarRawType; 59*5f39d1b3SJooyung Han static constexpr int kLanes = 1; 60*5f39d1b3SJooyung Han }; 61*5f39d1b3SJooyung Han 62*5f39d1b3SJooyung Han // Returns a SIMD value duplicating a scalar value across all lanes. 63*5f39d1b3SJooyung Han template <typename tRawType> 64*5f39d1b3SJooyung Han tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { 65*5f39d1b3SJooyung Han return x; 66*5f39d1b3SJooyung Han } 67*5f39d1b3SJooyung Han 68*5f39d1b3SJooyung Han // Plain bit-wise AND 69*5f39d1b3SJooyung Han template <typename tIntegerType> 70*5f39d1b3SJooyung Han tIntegerType BitAnd(tIntegerType a, tIntegerType b) { 71*5f39d1b3SJooyung Han return a & b; 72*5f39d1b3SJooyung Han } 73*5f39d1b3SJooyung Han 74*5f39d1b3SJooyung Han // Plain bit-wise OR 75*5f39d1b3SJooyung Han template <typename tIntegerType> 76*5f39d1b3SJooyung Han tIntegerType BitOr(tIntegerType a, tIntegerType b) { 77*5f39d1b3SJooyung Han return a | b; 78*5f39d1b3SJooyung Han } 79*5f39d1b3SJooyung Han 80*5f39d1b3SJooyung Han // Plain bit-wise XOR 81*5f39d1b3SJooyung Han template <typename tIntegerType> 82*5f39d1b3SJooyung Han tIntegerType BitXor(tIntegerType a, tIntegerType b) { 83*5f39d1b3SJooyung Han return a ^ b; 84*5f39d1b3SJooyung Han } 85*5f39d1b3SJooyung Han 86*5f39d1b3SJooyung Han // Plain bit-wise NOT 87*5f39d1b3SJooyung Han template <typename tIntegerType> 88*5f39d1b3SJooyung Han tIntegerType BitNot(tIntegerType a) { 89*5f39d1b3SJooyung Han return ~a; 90*5f39d1b3SJooyung Han } 91*5f39d1b3SJooyung Han 92*5f39d1b3SJooyung Han // Integer addition. Not saturating. Overflow is undefined behavior. 93*5f39d1b3SJooyung Han template <typename tIntegerType> 94*5f39d1b3SJooyung Han tIntegerType Add(tIntegerType a, tIntegerType b) { 95*5f39d1b3SJooyung Han return a + b; 96*5f39d1b3SJooyung Han } 97*5f39d1b3SJooyung Han 98*5f39d1b3SJooyung Han // Integer multiplication. Not saturating. Overflow is undefined behavior. 99*5f39d1b3SJooyung Han template <typename tIntegerType> 100*5f39d1b3SJooyung Han tIntegerType Mul(tIntegerType a, tIntegerType b) { 101*5f39d1b3SJooyung Han return a * b; 102*5f39d1b3SJooyung Han } 103*5f39d1b3SJooyung Han 104*5f39d1b3SJooyung Han // Integer subtraction. Not saturating. Overflow is undefined behavior. 105*5f39d1b3SJooyung Han template <typename tIntegerType> 106*5f39d1b3SJooyung Han tIntegerType Sub(tIntegerType a, tIntegerType b) { 107*5f39d1b3SJooyung Han return a - b; 108*5f39d1b3SJooyung Han } 109*5f39d1b3SJooyung Han 110*5f39d1b3SJooyung Han // Integer unary negative. Not saturating. Overflow is undefined behavior. 111*5f39d1b3SJooyung Han template <typename tIntegerType> 112*5f39d1b3SJooyung Han tIntegerType Neg(tIntegerType a) { 113*5f39d1b3SJooyung Han return -a; 114*5f39d1b3SJooyung Han } 115*5f39d1b3SJooyung Han 116*5f39d1b3SJooyung Han // Integer arithmetic left-shift, equivalent to multiplying with a power of two. 117*5f39d1b3SJooyung Han // Negative values are OK. In case of overflow, no Undefined 118*5f39d1b3SJooyung Han // Behavior, but the results are implementation-defined (in practice, 119*5f39d1b3SJooyung Han // they currently are saturated, but we make no commitment to that). The idea 120*5f39d1b3SJooyung Han // is that the caller will want to implement the overflowing cases with 121*5f39d1b3SJooyung Han // saturation with compare-and-mask, so we don't care about the results 122*5f39d1b3SJooyung Han // in the overflow case, we just want to avoid undefined behavior. 123*5f39d1b3SJooyung Han // 124*5f39d1b3SJooyung Han // tIntegerType may be int32 or any narrower signed type. 125*5f39d1b3SJooyung Han template <typename tIntegerType, typename OffsetType> 126*5f39d1b3SJooyung Han tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) { 127*5f39d1b3SJooyung Han const std::int64_t wide_a = static_cast<std::int64_t>(a); 128*5f39d1b3SJooyung Han const std::int64_t wide_shifted = wide_a * (1 << offset); 129*5f39d1b3SJooyung Han const auto min = std::numeric_limits<tIntegerType>::min(); 130*5f39d1b3SJooyung Han const auto max = std::numeric_limits<tIntegerType>::max(); 131*5f39d1b3SJooyung Han return wide_shifted < min 132*5f39d1b3SJooyung Han ? min 133*5f39d1b3SJooyung Han : wide_shifted > max ? max 134*5f39d1b3SJooyung Han : static_cast<tIntegerType>(wide_shifted); 135*5f39d1b3SJooyung Han } 136*5f39d1b3SJooyung Han 137*5f39d1b3SJooyung Han // Integer arithmetic right-shift. Not rounding. 138*5f39d1b3SJooyung Han // Relying on implementation-defined, but in-practice-consistent, 139*5f39d1b3SJooyung Han // C++ compiler behavior. 140*5f39d1b3SJooyung Han template <typename tIntegerType> 141*5f39d1b3SJooyung Han tIntegerType ShiftRight(tIntegerType a, int offset) { 142*5f39d1b3SJooyung Han return a >> offset; 143*5f39d1b3SJooyung Han } 144*5f39d1b3SJooyung Han 145*5f39d1b3SJooyung Han // Each bit of the result is set to the corresponding bit of either then_val or 146*5f39d1b3SJooyung Han // else_val depending on whether the corresponding bit of if_mask is set. 147*5f39d1b3SJooyung Han // Equivalent to the VBSL instruction in ARM NEON. 148*5f39d1b3SJooyung Han template <typename tIntegerType> 149*5f39d1b3SJooyung Han tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, 150*5f39d1b3SJooyung Han tIntegerType else_val) { 151*5f39d1b3SJooyung Han return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); 152*5f39d1b3SJooyung Han } 153*5f39d1b3SJooyung Han 154*5f39d1b3SJooyung Han // For each input scalar, the corresponding bits of the result are set if the 155*5f39d1b3SJooyung Han // input scalar is non-zero. 156*5f39d1b3SJooyung Han template <typename tIntegerType> 157*5f39d1b3SJooyung Han tIntegerType MaskIfNonZero(tIntegerType a) { 158*5f39d1b3SJooyung Han static constexpr tIntegerType zero = 0; 159*5f39d1b3SJooyung Han return a ? BitNot(zero) : zero; 160*5f39d1b3SJooyung Han } 161*5f39d1b3SJooyung Han 162*5f39d1b3SJooyung Han // For each input scalar, the corresponding bits of the result are set if the 163*5f39d1b3SJooyung Han // input scalar is zero. 164*5f39d1b3SJooyung Han template <typename tIntegerType> 165*5f39d1b3SJooyung Han tIntegerType MaskIfZero(tIntegerType a) { 166*5f39d1b3SJooyung Han return MaskIfNonZero<tIntegerType>(!a); 167*5f39d1b3SJooyung Han } 168*5f39d1b3SJooyung Han 169*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are 170*5f39d1b3SJooyung Han // set if the input scalars are equal. 171*5f39d1b3SJooyung Han template <typename tIntegerType> 172*5f39d1b3SJooyung Han tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { 173*5f39d1b3SJooyung Han return MaskIfNonZero<tIntegerType>(a == b); 174*5f39d1b3SJooyung Han } 175*5f39d1b3SJooyung Han 176*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are 177*5f39d1b3SJooyung Han // set if the input scalars are not equal. 178*5f39d1b3SJooyung Han template <typename tIntegerType> 179*5f39d1b3SJooyung Han tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { 180*5f39d1b3SJooyung Han return MaskIfNonZero<tIntegerType>(a != b); 181*5f39d1b3SJooyung Han } 182*5f39d1b3SJooyung Han 183*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are 184*5f39d1b3SJooyung Han // set if the input scalars a, b satisfy a > b. 185*5f39d1b3SJooyung Han template <typename tIntegerType> 186*5f39d1b3SJooyung Han tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { 187*5f39d1b3SJooyung Han return MaskIfNonZero<tIntegerType>(a > b); 188*5f39d1b3SJooyung Han } 189*5f39d1b3SJooyung Han 190*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are 191*5f39d1b3SJooyung Han // set if the input scalars a, b satisfy a >= b. 192*5f39d1b3SJooyung Han template <typename tIntegerType> 193*5f39d1b3SJooyung Han tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { 194*5f39d1b3SJooyung Han return MaskIfNonZero<tIntegerType>(a >= b); 195*5f39d1b3SJooyung Han } 196*5f39d1b3SJooyung Han 197*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are 198*5f39d1b3SJooyung Han // set if the input scalars a, b satisfy a < b. 199*5f39d1b3SJooyung Han template <typename tIntegerType> 200*5f39d1b3SJooyung Han tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { 201*5f39d1b3SJooyung Han return MaskIfNonZero<tIntegerType>(a < b); 202*5f39d1b3SJooyung Han } 203*5f39d1b3SJooyung Han 204*5f39d1b3SJooyung Han // For each pair of input scalars, the corresponding bits of the result are 205*5f39d1b3SJooyung Han // set if the input scalars a, b satisfy a <= b. 206*5f39d1b3SJooyung Han template <typename tIntegerType> 207*5f39d1b3SJooyung Han tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { 208*5f39d1b3SJooyung Han return MaskIfNonZero<tIntegerType>(a <= b); 209*5f39d1b3SJooyung Han } 210*5f39d1b3SJooyung Han 211*5f39d1b3SJooyung Han // Returns true if all of the input scalars are nonzero. 212*5f39d1b3SJooyung Han // This function may currently assume that each of the input scalars has either 213*5f39d1b3SJooyung Han // all or none of its bits set. Otherwise, its behavior is currently undefined. 214*5f39d1b3SJooyung Han template <typename tIntegerType> 215*5f39d1b3SJooyung Han bool All(tIntegerType a) { 216*5f39d1b3SJooyung Han return a; 217*5f39d1b3SJooyung Han } 218*5f39d1b3SJooyung Han 219*5f39d1b3SJooyung Han // Returns true if any of the input scalars are nonzero. 220*5f39d1b3SJooyung Han // This function may currently assume that each of the input scalars has either 221*5f39d1b3SJooyung Han // all or none of its bits set. Otherwise, its behavior is currently undefined. 222*5f39d1b3SJooyung Han template <typename tIntegerType> 223*5f39d1b3SJooyung Han bool Any(tIntegerType a) { 224*5f39d1b3SJooyung Han return a; 225*5f39d1b3SJooyung Han } 226*5f39d1b3SJooyung Han 227*5f39d1b3SJooyung Han // Returns (a+b)/2, rounded to the nearest integer. 228*5f39d1b3SJooyung Han // Equivalent to VRHADD in the ARM NEON instruction set. 229*5f39d1b3SJooyung Han template <typename IntegerType> 230*5f39d1b3SJooyung Han IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { 231*5f39d1b3SJooyung Han static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 232*5f39d1b3SJooyung Han (void)b; 233*5f39d1b3SJooyung Han return a; 234*5f39d1b3SJooyung Han } 235*5f39d1b3SJooyung Han 236*5f39d1b3SJooyung Han template <> 237*5f39d1b3SJooyung Han inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { 238*5f39d1b3SJooyung Han std::int64_t a64 = a; 239*5f39d1b3SJooyung Han std::int64_t b64 = b; 240*5f39d1b3SJooyung Han std::int64_t sum = a64 + b64; 241*5f39d1b3SJooyung Han std::int64_t sign = sum >= 0 ? 1 : -1; 242*5f39d1b3SJooyung Han return static_cast<std::int32_t>((sum + sign) / 2); 243*5f39d1b3SJooyung Han } 244*5f39d1b3SJooyung Han 245*5f39d1b3SJooyung Han template <> 246*5f39d1b3SJooyung Han inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) { 247*5f39d1b3SJooyung Han std::int32_t a32 = a; 248*5f39d1b3SJooyung Han std::int32_t b32 = b; 249*5f39d1b3SJooyung Han std::int32_t sum = a32 + b32; 250*5f39d1b3SJooyung Han std::int32_t sign = sum >= 0 ? 1 : -1; 251*5f39d1b3SJooyung Han return static_cast<std::int16_t>((sum + sign) / 2); 252*5f39d1b3SJooyung Han } 253*5f39d1b3SJooyung Han 254*5f39d1b3SJooyung Han template <typename IntegerType> 255*5f39d1b3SJooyung Han IntegerType SaturatingAdd(IntegerType a, IntegerType b) { 256*5f39d1b3SJooyung Han static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 257*5f39d1b3SJooyung Han (void)b; 258*5f39d1b3SJooyung Han return a; 259*5f39d1b3SJooyung Han } 260*5f39d1b3SJooyung Han 261*5f39d1b3SJooyung Han // So far this is only needed for int16. 262*5f39d1b3SJooyung Han template <> 263*5f39d1b3SJooyung Han inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { 264*5f39d1b3SJooyung Han std::int32_t a32 = a; 265*5f39d1b3SJooyung Han std::int32_t b32 = b; 266*5f39d1b3SJooyung Han std::int32_t sum = a32 + b32; 267*5f39d1b3SJooyung Han return static_cast<std::int16_t>( 268*5f39d1b3SJooyung Han std::min(static_cast<std::int32_t>(32767), 269*5f39d1b3SJooyung Han std::max(static_cast<std::int32_t>(-32768), sum))); 270*5f39d1b3SJooyung Han } 271*5f39d1b3SJooyung Han 272*5f39d1b3SJooyung Han template <> 273*5f39d1b3SJooyung Han inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) { 274*5f39d1b3SJooyung Han std::int16_t a16 = a; 275*5f39d1b3SJooyung Han std::int16_t b16 = b; 276*5f39d1b3SJooyung Han std::int16_t sum = a16 + b16; 277*5f39d1b3SJooyung Han return static_cast<std::int8_t>(std::min( 278*5f39d1b3SJooyung Han static_cast<int16_t>(std::numeric_limits<int8_t>::max()), 279*5f39d1b3SJooyung Han std::max(static_cast<int16_t>(std::numeric_limits<int8_t>::min()), sum))); 280*5f39d1b3SJooyung Han } 281*5f39d1b3SJooyung Han 282*5f39d1b3SJooyung Han // Returns a+b, saturating if the integers are 16bit or narrower, 283*5f39d1b3SJooyung Han // otherwise just a plain addition. 284*5f39d1b3SJooyung Han template <typename IntegerType, bool Is16Bit> 285*5f39d1b3SJooyung Han struct AddSaturatingIf16BitImpl { 286*5f39d1b3SJooyung Han static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); } 287*5f39d1b3SJooyung Han }; 288*5f39d1b3SJooyung Han template <typename IntegerType> 289*5f39d1b3SJooyung Han struct AddSaturatingIf16BitImpl<IntegerType, true> { 290*5f39d1b3SJooyung Han static IntegerType Run(IntegerType a, IntegerType b) { 291*5f39d1b3SJooyung Han return SaturatingAdd(a, b); 292*5f39d1b3SJooyung Han } 293*5f39d1b3SJooyung Han }; 294*5f39d1b3SJooyung Han template <typename IntegerType> 295*5f39d1b3SJooyung Han IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { 296*5f39d1b3SJooyung Han using ScalarType = 297*5f39d1b3SJooyung Han typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 298*5f39d1b3SJooyung Han return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a, 299*5f39d1b3SJooyung Han b); 300*5f39d1b3SJooyung Han } 301*5f39d1b3SJooyung Han 302*5f39d1b3SJooyung Han // Returns the integer that represents the product of two fixed-point 303*5f39d1b3SJooyung Han // numbers, interpreting all integers as fixed-point values in the 304*5f39d1b3SJooyung Han // interval [-1, 1), rounding to the nearest value, and saturating 305*5f39d1b3SJooyung Han // -1 * -1 to the maximum value (since 1 is not in the half-open 306*5f39d1b3SJooyung Han // interval [-1, 1)). 307*5f39d1b3SJooyung Han // 308*5f39d1b3SJooyung Han // [The explanation below specializes to std::int32_t for example purpose.] 309*5f39d1b3SJooyung Han // 310*5f39d1b3SJooyung Han // The mapping between IntegerType and the interval [-1, 1) is unique and 311*5f39d1b3SJooyung Han // implied by IntegerType, which is assumed to be signed. For example, 312*5f39d1b3SJooyung Han // for IntegerType==std::int32_t, the mapping is 313*5f39d1b3SJooyung Han // real_value = integer_value / 2^31. 314*5f39d1b3SJooyung Han // So in this case, and leaving aside rounding and saturating, this 315*5f39d1b3SJooyung Han // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to 316*5f39d1b3SJooyung Han // (a * b) / 2^31. 317*5f39d1b3SJooyung Han // 318*5f39d1b3SJooyung Han // The 'doubling' part in the name of this function comes from the fact that 319*5f39d1b3SJooyung Han // this operation is very close to a "multiply-high" operation, keeping only 320*5f39d1b3SJooyung Han // the top half bits, except that that would be effectively computing 321*5f39d1b3SJooyung Han // (a * b) / 2^32, 322*5f39d1b3SJooyung Han // so here we are computing 2x that, since 323*5f39d1b3SJooyung Han // 1/2^31 = 2 * 1/2^32. 324*5f39d1b3SJooyung Han // The idea is to use all of the available 32 bits in the destination int32 325*5f39d1b3SJooyung Han // value. 326*5f39d1b3SJooyung Han // 327*5f39d1b3SJooyung Han // [End of the explanation specializing to int32.] 328*5f39d1b3SJooyung Han // 329*5f39d1b3SJooyung Han // This is equivalent to the VQRDMULH instruction in ARM NEON. 330*5f39d1b3SJooyung Han template <typename IntegerType> 331*5f39d1b3SJooyung Han IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { 332*5f39d1b3SJooyung Han static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 333*5f39d1b3SJooyung Han (void)b; 334*5f39d1b3SJooyung Han return a; 335*5f39d1b3SJooyung Han } 336*5f39d1b3SJooyung Han 337*5f39d1b3SJooyung Han // This function implements the same computation as the ARMv7 NEON VQRDMULH 338*5f39d1b3SJooyung Han // instruction. 339*5f39d1b3SJooyung Han template <> 340*5f39d1b3SJooyung Han inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, 341*5f39d1b3SJooyung Han std::int32_t b) { 342*5f39d1b3SJooyung Han bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min(); 343*5f39d1b3SJooyung Han std::int64_t a_64(a); 344*5f39d1b3SJooyung Han std::int64_t b_64(b); 345*5f39d1b3SJooyung Han std::int64_t ab_64 = a_64 * b_64; 346*5f39d1b3SJooyung Han std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); 347*5f39d1b3SJooyung Han std::int32_t ab_x2_high32 = 348*5f39d1b3SJooyung Han static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31)); 349*5f39d1b3SJooyung Han return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32; 350*5f39d1b3SJooyung Han } 351*5f39d1b3SJooyung Han 352*5f39d1b3SJooyung Han template <> 353*5f39d1b3SJooyung Han inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, 354*5f39d1b3SJooyung Han std::int16_t b) { 355*5f39d1b3SJooyung Han bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min(); 356*5f39d1b3SJooyung Han std::int32_t a_32(a); 357*5f39d1b3SJooyung Han std::int32_t b_32(b); 358*5f39d1b3SJooyung Han std::int32_t ab_32 = a_32 * b_32; 359*5f39d1b3SJooyung Han std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14)); 360*5f39d1b3SJooyung Han std::int16_t ab_x2_high16 = 361*5f39d1b3SJooyung Han static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15)); 362*5f39d1b3SJooyung Han return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16; 363*5f39d1b3SJooyung Han } 364*5f39d1b3SJooyung Han 365*5f39d1b3SJooyung Han // Correctly-rounded-to-nearest division by a power-of-two. 366*5f39d1b3SJooyung Han // Also known as a rounding arithmetic right shift. 367*5f39d1b3SJooyung Han template <typename IntegerType, typename ExponentType> 368*5f39d1b3SJooyung Han inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) { 369*5f39d1b3SJooyung Han assert(exponent >= 0); 370*5f39d1b3SJooyung Han assert(exponent <= 31); 371*5f39d1b3SJooyung Han const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1); 372*5f39d1b3SJooyung Han const IntegerType zero = Dup<IntegerType>(0); 373*5f39d1b3SJooyung Han const IntegerType one = Dup<IntegerType>(1); 374*5f39d1b3SJooyung Han const IntegerType remainder = BitAnd(x, mask); 375*5f39d1b3SJooyung Han const IntegerType threshold = 376*5f39d1b3SJooyung Han Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); 377*5f39d1b3SJooyung Han return Add(ShiftRight(x, exponent), 378*5f39d1b3SJooyung Han BitAnd(MaskIfGreaterThan(remainder, threshold), one)); 379*5f39d1b3SJooyung Han } 380*5f39d1b3SJooyung Han 381*5f39d1b3SJooyung Han // Returns the product of a run-time integer value by a compile-time power 382*5f39d1b3SJooyung Han // of two, with either a positive exponent (equivalent to an arithmetic 383*5f39d1b3SJooyung Han // left shift, saturating) or a negative exponent (equivalent to an arithmetic 384*5f39d1b3SJooyung Han // right shift, rounding to nearest). 385*5f39d1b3SJooyung Han template <int Exponent, typename IntegerType, 386*5f39d1b3SJooyung Han int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> 387*5f39d1b3SJooyung Han struct ImplSaturatingRoundingMultiplyByPOT {}; 388*5f39d1b3SJooyung Han 389*5f39d1b3SJooyung Han template <int Exponent, typename IntegerType> 390*5f39d1b3SJooyung Han struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { 391*5f39d1b3SJooyung Han static IntegerType eval(IntegerType x) { return x; } 392*5f39d1b3SJooyung Han }; 393*5f39d1b3SJooyung Han 394*5f39d1b3SJooyung Han template <int Exponent, typename IntegerType> 395*5f39d1b3SJooyung Han struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> { 396*5f39d1b3SJooyung Han static IntegerType eval(IntegerType x) { 397*5f39d1b3SJooyung Han using ScalarIntegerType = 398*5f39d1b3SJooyung Han typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 399*5f39d1b3SJooyung Han const IntegerType min = 400*5f39d1b3SJooyung Han Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min()); 401*5f39d1b3SJooyung Han const IntegerType max = 402*5f39d1b3SJooyung Han Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max()); 403*5f39d1b3SJooyung Han const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); 404*5f39d1b3SJooyung Han 405*5f39d1b3SJooyung Han const std::int32_t threshold = 406*5f39d1b3SJooyung Han ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1); 407*5f39d1b3SJooyung Han const IntegerType positive_mask = 408*5f39d1b3SJooyung Han MaskIfGreaterThan(x, Dup<IntegerType>(threshold)); 409*5f39d1b3SJooyung Han const IntegerType negative_mask = 410*5f39d1b3SJooyung Han MaskIfLessThan(x, Dup<IntegerType>(-threshold)); 411*5f39d1b3SJooyung Han 412*5f39d1b3SJooyung Han IntegerType result = ShiftLeft(x, Exponent); 413*5f39d1b3SJooyung Han result = SelectUsingMask(positive_mask, max, result); 414*5f39d1b3SJooyung Han result = SelectUsingMask(negative_mask, min, result); 415*5f39d1b3SJooyung Han return result; 416*5f39d1b3SJooyung Han } 417*5f39d1b3SJooyung Han }; 418*5f39d1b3SJooyung Han 419*5f39d1b3SJooyung Han template <int Exponent, typename IntegerType> 420*5f39d1b3SJooyung Han struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> { 421*5f39d1b3SJooyung Han static IntegerType eval(IntegerType x) { 422*5f39d1b3SJooyung Han return RoundingDivideByPOT<IntegerType>(x, -Exponent); 423*5f39d1b3SJooyung Han } 424*5f39d1b3SJooyung Han }; 425*5f39d1b3SJooyung Han 426*5f39d1b3SJooyung Han template <int Exponent, typename IntegerType> 427*5f39d1b3SJooyung Han IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { 428*5f39d1b3SJooyung Han return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); 429*5f39d1b3SJooyung Han } 430*5f39d1b3SJooyung Han 431*5f39d1b3SJooyung Han // Part 2: the FixedPoint class. 432*5f39d1b3SJooyung Han 433*5f39d1b3SJooyung Han // A FixedPoint object represents a fixed-point value stored in the underlying 434*5f39d1b3SJooyung Han // integer type tRawType, if tRawType is a plain scalar integer type. 435*5f39d1b3SJooyung Han // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which 436*5f39d1b3SJooyung Han // case a FixedPoint object represents a corresponding SIMD vector of fixed 437*5f39d1b3SJooyung Han // point values. 438*5f39d1b3SJooyung Han // 439*5f39d1b3SJooyung Han // tIntegerBits describes the range of the fixed-point format: if 440*5f39d1b3SJooyung Han // tIntegerBits == m then the range of representable values is the half-open 441*5f39d1b3SJooyung Han // interval [-2^m; 2^m) where the open boundary on the right side means that 442*5f39d1b3SJooyung Han // 2^m is not representable (how close the maximum representable value is to 443*5f39d1b3SJooyung Han // it, depends on bit-depth of tRawType). 444*5f39d1b3SJooyung Han // 445*5f39d1b3SJooyung Han // In "Q format notation", 446*5f39d1b3SJooyung Han // https://en.wikipedia.org/wiki/Q_(number_format) 447*5f39d1b3SJooyung Han // we are describing the format 448*5f39d1b3SJooyung Han // Qm.n 449*5f39d1b3SJooyung Han // where 450*5f39d1b3SJooyung Han // m = tIntegerBits 451*5f39d1b3SJooyung Han // and 452*5f39d1b3SJooyung Han // n = NumberOfBits(tRawType) - (m + 1) 453*5f39d1b3SJooyung Han // Note that the (m + 1) in the above line is because we adopt the convention 454*5f39d1b3SJooyung Han // that we count the integer bits exclusively of the sign bit; so (m + 1) is 455*5f39d1b3SJooyung Han // the total number of integer bits inclusive of the sign bit. 456*5f39d1b3SJooyung Han // 457*5f39d1b3SJooyung Han // Accordingly, the number of integral representable values in our range 458*5f39d1b3SJooyung Han // [-2^m ; 2^m) 459*5f39d1b3SJooyung Han // is equal to 2^(m+1). 460*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 461*5f39d1b3SJooyung Han class FixedPoint { 462*5f39d1b3SJooyung Han public: 463*5f39d1b3SJooyung Han typedef tRawType RawType; 464*5f39d1b3SJooyung Han 465*5f39d1b3SJooyung Han typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; 466*5f39d1b3SJooyung Han typedef typename RawTypeTraits::ScalarRawType ScalarRawType; 467*5f39d1b3SJooyung Han 468*5f39d1b3SJooyung Han static constexpr int kTotalBits = 8 * sizeof(ScalarRawType); 469*5f39d1b3SJooyung Han static constexpr int kIntegerBits = tIntegerBits; 470*5f39d1b3SJooyung Han static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits; 471*5f39d1b3SJooyung Han static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, 472*5f39d1b3SJooyung Han "bad IntegerBits"); 473*5f39d1b3SJooyung Han 474*5f39d1b3SJooyung Han typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; 475*5f39d1b3SJooyung Han 476*5f39d1b3SJooyung Han static const ScalarRawType ScalarRawMin() { 477*5f39d1b3SJooyung Han return std::numeric_limits<ScalarRawType>::min(); 478*5f39d1b3SJooyung Han } 479*5f39d1b3SJooyung Han 480*5f39d1b3SJooyung Han static const ScalarRawType ScalarRawMax() { 481*5f39d1b3SJooyung Han return std::numeric_limits<ScalarRawType>::max(); 482*5f39d1b3SJooyung Han } 483*5f39d1b3SJooyung Han 484*5f39d1b3SJooyung Han static const ScalarRawType RawMin() { 485*5f39d1b3SJooyung Han return VectorFromScalar(ScalarRawMin()); 486*5f39d1b3SJooyung Han } 487*5f39d1b3SJooyung Han 488*5f39d1b3SJooyung Han static const ScalarRawType RawMax() { 489*5f39d1b3SJooyung Han return VectorFromScalar(ScalarRawMax()); 490*5f39d1b3SJooyung Han } 491*5f39d1b3SJooyung Han 492*5f39d1b3SJooyung Han static FixedPoint FromRaw(RawType x) { 493*5f39d1b3SJooyung Han FixedPoint retval; 494*5f39d1b3SJooyung Han retval.raw() = x; 495*5f39d1b3SJooyung Han return retval; 496*5f39d1b3SJooyung Han } 497*5f39d1b3SJooyung Han 498*5f39d1b3SJooyung Han static FixedPoint FromScalarRaw(ScalarRawType x) { 499*5f39d1b3SJooyung Han FixedPoint retval; 500*5f39d1b3SJooyung Han retval.raw() = Dup<RawType>(x); 501*5f39d1b3SJooyung Han return retval; 502*5f39d1b3SJooyung Han } 503*5f39d1b3SJooyung Han 504*5f39d1b3SJooyung Han static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { 505*5f39d1b3SJooyung Han return FromScalarRaw(x.raw()); 506*5f39d1b3SJooyung Han } 507*5f39d1b3SJooyung Han 508*5f39d1b3SJooyung Han template <int Exponent> 509*5f39d1b3SJooyung Han static FixedPoint ConstantPOT() { 510*5f39d1b3SJooyung Han static constexpr int kOffset = kFractionalBits + Exponent; 511*5f39d1b3SJooyung Han static_assert( 512*5f39d1b3SJooyung Han kOffset < 31, 513*5f39d1b3SJooyung Han "Constant not exactly representable in this fixed-point format"); 514*5f39d1b3SJooyung Han return FromScalarRaw(ScalarRawType(1) << kOffset); 515*5f39d1b3SJooyung Han } 516*5f39d1b3SJooyung Han 517*5f39d1b3SJooyung Han static FixedPoint Zero() { return FromScalarRaw(0); } 518*5f39d1b3SJooyung Han 519*5f39d1b3SJooyung Han static FixedPoint One() { 520*5f39d1b3SJooyung Han return FromScalarRaw( 521*5f39d1b3SJooyung Han kIntegerBits == 0 522*5f39d1b3SJooyung Han ? ScalarRawMax() 523*5f39d1b3SJooyung Han : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits))); 524*5f39d1b3SJooyung Han } 525*5f39d1b3SJooyung Han 526*5f39d1b3SJooyung Han static FixedPoint FromDouble(double x) { 527*5f39d1b3SJooyung Han const double min_bound = static_cast<double>(ScalarRawMin()); 528*5f39d1b3SJooyung Han const double max_bound = static_cast<double>(ScalarRawMax()); 529*5f39d1b3SJooyung Han return FromScalarRaw(static_cast<ScalarRawType>(std::min( 530*5f39d1b3SJooyung Han std::max(round(x * static_cast<double>(1ll << kFractionalBits)), 531*5f39d1b3SJooyung Han min_bound), 532*5f39d1b3SJooyung Han max_bound))); 533*5f39d1b3SJooyung Han } 534*5f39d1b3SJooyung Han 535*5f39d1b3SJooyung Han RawType raw() const { return i_; } 536*5f39d1b3SJooyung Han RawType& raw() { return i_; } 537*5f39d1b3SJooyung Han 538*5f39d1b3SJooyung Han private: 539*5f39d1b3SJooyung Han RawType i_; 540*5f39d1b3SJooyung Han }; 541*5f39d1b3SJooyung Han 542*5f39d1b3SJooyung Han // Part 3: implementation of arithmetic operators for the 543*5f39d1b3SJooyung Han // FixedPoint class, and a few related functions. 544*5f39d1b3SJooyung Han 545*5f39d1b3SJooyung Han // A FixedPoint multiplication is just a 546*5f39d1b3SJooyung Han // SaturatingRoundingDoublingHighMul operation on the underlying 547*5f39d1b3SJooyung Han // raw integer values. The IntegerBits simply add up, as is obvious 548*5f39d1b3SJooyung Han // from the fact that the range is [-2^IntegerBits, 2^IntegerBits). 549*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> 550*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*( 551*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits_a> a, 552*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits_b> b) { 553*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; 554*5f39d1b3SJooyung Han c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); 555*5f39d1b3SJooyung Han return c; 556*5f39d1b3SJooyung Han } 557*5f39d1b3SJooyung Han 558*5f39d1b3SJooyung Han // Tweaking IntegerBits gives exact multiplication by a power of two. 559*5f39d1b3SJooyung Han template <int tExponent, typename tRawType, int tIntegerBits> 560*5f39d1b3SJooyung Han FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot( 561*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> a) { 562*5f39d1b3SJooyung Han FixedPoint<tRawType, tExponent + tIntegerBits> c; 563*5f39d1b3SJooyung Han c.raw() = a.raw(); 564*5f39d1b3SJooyung Han return c; 565*5f39d1b3SJooyung Han } 566*5f39d1b3SJooyung Han 567*5f39d1b3SJooyung Han // If we want to leave IntegerBits fixed, then multiplication 568*5f39d1b3SJooyung Han // by a power of two has to be saturating/rounding, not exact anymore. 569*5f39d1b3SJooyung Han template <int tExponent, typename tRawType, int tIntegerBits> 570*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT( 571*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> a) { 572*5f39d1b3SJooyung Han return FixedPoint<tRawType, tIntegerBits>::FromRaw( 573*5f39d1b3SJooyung Han SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); 574*5f39d1b3SJooyung Han } 575*5f39d1b3SJooyung Han 576*5f39d1b3SJooyung Han // Generic arithmetic operators. 577*5f39d1b3SJooyung Han 578*5f39d1b3SJooyung Han #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ 579*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> \ 580*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> FuncName( \ 581*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> a) { \ 582*5f39d1b3SJooyung Han return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ 583*5f39d1b3SJooyung Han } 584*5f39d1b3SJooyung Han 585*5f39d1b3SJooyung Han #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ 586*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> \ 587*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> FuncName( \ 588*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> a, \ 589*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> b) { \ 590*5f39d1b3SJooyung Han return FixedPoint<tRawType, tIntegerBits>::FromRaw( \ 591*5f39d1b3SJooyung Han ImplFuncName(a.raw(), b.raw())); \ 592*5f39d1b3SJooyung Han } 593*5f39d1b3SJooyung Han 594*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) 595*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) 596*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) 597*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) 598*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) 599*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) 600*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) 601*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) 602*5f39d1b3SJooyung Han 603*5f39d1b3SJooyung Han #undef MAKE_FIXEDPOINT_UNARY_FUNC 604*5f39d1b3SJooyung Han #undef MAKE_FIXEDPOINT_BINARY_FUNC 605*5f39d1b3SJooyung Han 606*5f39d1b3SJooyung Han #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ 607*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> \ 608*5f39d1b3SJooyung Han tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ 609*5f39d1b3SJooyung Han return FuncName(a.raw()); \ 610*5f39d1b3SJooyung Han } 611*5f39d1b3SJooyung Han 612*5f39d1b3SJooyung Han #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ 613*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> \ 614*5f39d1b3SJooyung Han tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \ 615*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> b) { \ 616*5f39d1b3SJooyung Han return FuncName(a.raw(), b.raw()); \ 617*5f39d1b3SJooyung Han } 618*5f39d1b3SJooyung Han 619*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) 620*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) 621*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) 622*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) 623*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) 624*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) 625*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) 626*5f39d1b3SJooyung Han MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) 627*5f39d1b3SJooyung Han 628*5f39d1b3SJooyung Han #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW 629*5f39d1b3SJooyung Han #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW 630*5f39d1b3SJooyung Han 631*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 632*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> SelectUsingMask( 633*5f39d1b3SJooyung Han tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, 634*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> else_val) { 635*5f39d1b3SJooyung Han return FixedPoint<tRawType, tIntegerBits>::FromRaw( 636*5f39d1b3SJooyung Han SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); 637*5f39d1b3SJooyung Han } 638*5f39d1b3SJooyung Han 639*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 640*5f39d1b3SJooyung Han bool operator==(FixedPoint<tRawType, tIntegerBits> a, 641*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> b) { 642*5f39d1b3SJooyung Han return All(MaskIfEqual(a.raw(), b.raw())); 643*5f39d1b3SJooyung Han } 644*5f39d1b3SJooyung Han 645*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 646*5f39d1b3SJooyung Han bool operator!=(FixedPoint<tRawType, tIntegerBits> a, 647*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> b) { 648*5f39d1b3SJooyung Han return !(a == b); 649*5f39d1b3SJooyung Han } 650*5f39d1b3SJooyung Han 651*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 652*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> SaturatingAdd( 653*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> a, 654*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> b) { 655*5f39d1b3SJooyung Han return FixedPoint<tRawType, tIntegerBits>::FromRaw( 656*5f39d1b3SJooyung Han SaturatingAdd(a.raw(), b.raw())); 657*5f39d1b3SJooyung Han } 658*5f39d1b3SJooyung Han 659*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 660*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit( 661*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> a, 662*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> b) { 663*5f39d1b3SJooyung Han return FixedPoint<tRawType, tIntegerBits>::FromRaw( 664*5f39d1b3SJooyung Han AddSaturatingIf16Bit(a.raw(), b.raw())); 665*5f39d1b3SJooyung Han } 666*5f39d1b3SJooyung Han 667*5f39d1b3SJooyung Han // Conversion to floating-point. 668*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 669*5f39d1b3SJooyung Han double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { 670*5f39d1b3SJooyung Han static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, 671*5f39d1b3SJooyung Han "not applicable to SIMD types"); 672*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, tIntegerBits> F; 673*5f39d1b3SJooyung Han return x.raw() / static_cast<double>(1ll << F::kFractionalBits); 674*5f39d1b3SJooyung Han } 675*5f39d1b3SJooyung Han 676*5f39d1b3SJooyung Han // Rescale changes the number of IntegerBits and updates the underlying 677*5f39d1b3SJooyung Han // raw integer value accordingly. 678*5f39d1b3SJooyung Han template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> 679*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBitsDst> Rescale( 680*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBitsSrc> x) { 681*5f39d1b3SJooyung Han static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; 682*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBitsDst> result; 683*5f39d1b3SJooyung Han result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); 684*5f39d1b3SJooyung Han return result; 685*5f39d1b3SJooyung Han } 686*5f39d1b3SJooyung Han 687*5f39d1b3SJooyung Han // CheckedFixedPointConstant allows to specify fixed-point constants 688*5f39d1b3SJooyung Han // initialized as real numbers, in a way that does not compile floating-point 689*5f39d1b3SJooyung Han // arithmetic in production code, yet still checks agreement with the 690*5f39d1b3SJooyung Han // floating-point expressions when asserts are enabled. 691*5f39d1b3SJooyung Han // 692*5f39d1b3SJooyung Han // The raw integer value provided is always a int32, encoding a 32-bit 693*5f39d1b3SJooyung Han // fixed-point value, regardless of the actual Scalar type. This allows 694*5f39d1b3SJooyung Han // writing generic code that applies just as well to the 32-bit and 16-bit 695*5f39d1b3SJooyung Han // cases. In the 16-bit case, the raw integer value is internally 696*5f39d1b3SJooyung Han // rounding-shifted by 16 bits to the right. 697*5f39d1b3SJooyung Han template <typename FixedPointType> 698*5f39d1b3SJooyung Han inline typename FixedPointType::ScalarRawType RescaleConstantInitializer( 699*5f39d1b3SJooyung Han std::int32_t int32_value) { 700*5f39d1b3SJooyung Han typedef typename FixedPointType::ScalarRawType ScalarRawType; 701*5f39d1b3SJooyung Han static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType); 702*5f39d1b3SJooyung Han return static_cast<ScalarRawType>( 703*5f39d1b3SJooyung Han RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits)); 704*5f39d1b3SJooyung Han } 705*5f39d1b3SJooyung Han #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS 706*5f39d1b3SJooyung Han template <typename FixedPointType> 707*5f39d1b3SJooyung Han FixedPointType CheckedFixedPointConstant(std::int32_t raw_value, 708*5f39d1b3SJooyung Han double double_value) { 709*5f39d1b3SJooyung Han const FixedPointType result = FixedPointType::FromScalarRaw(raw_value); 710*5f39d1b3SJooyung Han assert(result == FixedPointType::FromDouble(double_value)); 711*5f39d1b3SJooyung Han return result; 712*5f39d1b3SJooyung Han } 713*5f39d1b3SJooyung Han #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ 714*5f39d1b3SJooyung Han ScalarRawInt32Value, DoubleValue) \ 715*5f39d1b3SJooyung Han (gemmlowp::CheckedFixedPointConstant<FixedPointType>( \ 716*5f39d1b3SJooyung Han gemmlowp::RescaleConstantInitializer<FixedPointType>( \ 717*5f39d1b3SJooyung Han ScalarRawInt32Value), \ 718*5f39d1b3SJooyung Han DoubleValue)) 719*5f39d1b3SJooyung Han 720*5f39d1b3SJooyung Han #else 721*5f39d1b3SJooyung Han #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ 722*5f39d1b3SJooyung Han ScalarRawInt32Value, DoubleValue) \ 723*5f39d1b3SJooyung Han (FixedPointType::FromScalarRaw( \ 724*5f39d1b3SJooyung Han gemmlowp::RescaleConstantInitializer<FixedPointType>( \ 725*5f39d1b3SJooyung Han ScalarRawInt32Value))) 726*5f39d1b3SJooyung Han #endif 727*5f39d1b3SJooyung Han 728*5f39d1b3SJooyung Han // Implementation of exponential function. 729*5f39d1b3SJooyung Han 730*5f39d1b3SJooyung Han // Returns exp(x) for x in [-1/4, 0). 731*5f39d1b3SJooyung Han template <typename tRawType> 732*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl( 733*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> a) { 734*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, 0> F; 735*5f39d1b3SJooyung Han const F constant_term = 736*5f39d1b3SJooyung Han GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0)); 737*5f39d1b3SJooyung Han const F constant_1_over_3 = 738*5f39d1b3SJooyung Han GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); 739*5f39d1b3SJooyung Han // We're evaluating a Taylor expansion around -1/8, so we do the change of 740*5f39d1b3SJooyung Han // variable: x = a + 1/8. 741*5f39d1b3SJooyung Han // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. 742*5f39d1b3SJooyung Han F x = a + F::template ConstantPOT<-3>(); 743*5f39d1b3SJooyung Han F x2 = x * x; 744*5f39d1b3SJooyung Han F x3 = x2 * x; 745*5f39d1b3SJooyung Han F x4 = x2 * x2; 746*5f39d1b3SJooyung Han F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); 747*5f39d1b3SJooyung Han F x4_over_24_plus_x3_over_6_plus_x2_over_2 = 748*5f39d1b3SJooyung Han SaturatingRoundingMultiplyByPOT<-1>( 749*5f39d1b3SJooyung Han ((x4_over_4 + x3) * constant_1_over_3) + x2); 750*5f39d1b3SJooyung Han return AddSaturatingIf16Bit( 751*5f39d1b3SJooyung Han constant_term, 752*5f39d1b3SJooyung Han constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); 753*5f39d1b3SJooyung Han } 754*5f39d1b3SJooyung Han 755*5f39d1b3SJooyung Han // Returns exp(x) for x < 0. 756*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 757*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> exp_on_negative_values( 758*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> a) { 759*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, tIntegerBits> InputF; 760*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, 0> ResultF; 761*5f39d1b3SJooyung Han static constexpr int kFractionalBits = InputF::kFractionalBits; 762*5f39d1b3SJooyung Han static constexpr int kIntegerBits = InputF::kIntegerBits; 763*5f39d1b3SJooyung Han const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); 764*5f39d1b3SJooyung Han InputF mask = kOneQuarter - InputF::FromScalarRaw(1); 765*5f39d1b3SJooyung Han InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; 766*5f39d1b3SJooyung Han ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( 767*5f39d1b3SJooyung Han Rescale<0>(a_mod_quarter_minus_one_quarter)); 768*5f39d1b3SJooyung Han tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); 769*5f39d1b3SJooyung Han 770*5f39d1b3SJooyung Han #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ 771*5f39d1b3SJooyung Han if (kIntegerBits > Exponent) { \ 772*5f39d1b3SJooyung Han const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \ 773*5f39d1b3SJooyung Han ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \ 774*5f39d1b3SJooyung Han static constexpr int kShiftAmount = \ 775*5f39d1b3SJooyung Han kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ 776*5f39d1b3SJooyung Han result = SelectUsingMask( \ 777*5f39d1b3SJooyung Han MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \ 778*5f39d1b3SJooyung Han result * kMultiplier, result); \ 779*5f39d1b3SJooyung Han } 780*5f39d1b3SJooyung Han 781*5f39d1b3SJooyung Han // Constants below are Q0 representations of negative exp fractionals: 782*5f39d1b3SJooyung Han GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); // exp(-1/4) 783*5f39d1b3SJooyung Han GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); // exp(-1/2) 784*5f39d1b3SJooyung Han GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); // exp(-1) 785*5f39d1b3SJooyung Han GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); // exp(-2) 786*5f39d1b3SJooyung Han GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); // exp(-4) 787*5f39d1b3SJooyung Han GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); // exp(-8) 788*5f39d1b3SJooyung Han GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); // exp(-16) 789*5f39d1b3SJooyung Han 790*5f39d1b3SJooyung Han #undef GEMMLOWP_EXP_BARREL_SHIFTER 791*5f39d1b3SJooyung Han 792*5f39d1b3SJooyung Han static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0; 793*5f39d1b3SJooyung Han if (kIntegerBits > 5) { 794*5f39d1b3SJooyung Han const InputF clamp = 795*5f39d1b3SJooyung Han GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0); 796*5f39d1b3SJooyung Han result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); 797*5f39d1b3SJooyung Han } 798*5f39d1b3SJooyung Han 799*5f39d1b3SJooyung Han result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); 800*5f39d1b3SJooyung Han return result; 801*5f39d1b3SJooyung Han } 802*5f39d1b3SJooyung Han 803*5f39d1b3SJooyung Han // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)). 804*5f39d1b3SJooyung Han 805*5f39d1b3SJooyung Han // Returns (1 - x) / (1 + x) for x in (0, 1). 806*5f39d1b3SJooyung Han template <typename tRawType> 807*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1( 808*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> a) { 809*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, 0> F0; 810*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, 2> F2; 811*5f39d1b3SJooyung Han F0 half_denominator = RoundingHalfSum(a, F0::One()); 812*5f39d1b3SJooyung Han // Newton-Raphson division 813*5f39d1b3SJooyung Han // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 814*5f39d1b3SJooyung Han // Refer to that page for the logic behind the 48/17 and 32/17 constants. 815*5f39d1b3SJooyung Han const F2 constant_48_over_17 = 816*5f39d1b3SJooyung Han GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 817*5f39d1b3SJooyung Han const F2 constant_neg_32_over_17 = 818*5f39d1b3SJooyung Han GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 819*5f39d1b3SJooyung Han F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 820*5f39d1b3SJooyung Han for (int i = 0; i < 3; i++) { 821*5f39d1b3SJooyung Han F2 half_denominator_times_x = half_denominator * x; 822*5f39d1b3SJooyung Han F2 one_minus_half_denominator_times_x = 823*5f39d1b3SJooyung Han F2::One() - half_denominator_times_x; 824*5f39d1b3SJooyung Han x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 825*5f39d1b3SJooyung Han } 826*5f39d1b3SJooyung Han return Rescale<0>(x - F2::One()); 827*5f39d1b3SJooyung Han } 828*5f39d1b3SJooyung Han 829*5f39d1b3SJooyung Han // Returns -tanh(x) for x < 0. 830*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 831*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> neg_tanh_on_negative_values( 832*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> a) { 833*5f39d1b3SJooyung Han return one_minus_x_over_one_plus_x_for_x_in_0_1( 834*5f39d1b3SJooyung Han exp_on_negative_values(ExactMulByPot<1>(a))); 835*5f39d1b3SJooyung Han } 836*5f39d1b3SJooyung Han 837*5f39d1b3SJooyung Han // Returns tanh(x) for any x. 838*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 839*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { 840*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, tIntegerBits> InputF; 841*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, 0> ResultF; 842*5f39d1b3SJooyung Han tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); 843*5f39d1b3SJooyung Han tRawType mask_if_zero = MaskIfZero(a); 844*5f39d1b3SJooyung Han InputF n = SelectUsingMask(mask_if_negative, a, -a); 845*5f39d1b3SJooyung Han ResultF t = neg_tanh_on_negative_values(n); 846*5f39d1b3SJooyung Han return SelectUsingMask(mask_if_zero, ResultF::Zero(), 847*5f39d1b3SJooyung Han SelectUsingMask(mask_if_negative, -t, t)); 848*5f39d1b3SJooyung Han } 849*5f39d1b3SJooyung Han 850*5f39d1b3SJooyung Han // Implementation of logistic function. 851*5f39d1b3SJooyung Han 852*5f39d1b3SJooyung Han // Returns 1 / (1 + x) for x in (0, 1). 853*5f39d1b3SJooyung Han template <typename tRawType> 854*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1( 855*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> a) { 856*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, 0> F0; 857*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, 2> F2; 858*5f39d1b3SJooyung Han F0 half_denominator = RoundingHalfSum(a, F0::One()); 859*5f39d1b3SJooyung Han // Newton-Raphson division 860*5f39d1b3SJooyung Han // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 861*5f39d1b3SJooyung Han // Refer to that page for the logic behind the 48/17 and 32/17 constants. 862*5f39d1b3SJooyung Han const F2 constant_48_over_17 = 863*5f39d1b3SJooyung Han GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 864*5f39d1b3SJooyung Han const F2 constant_neg_32_over_17 = 865*5f39d1b3SJooyung Han GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 866*5f39d1b3SJooyung Han F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 867*5f39d1b3SJooyung Han for (int i = 0; i < 3; i++) { 868*5f39d1b3SJooyung Han F2 half_denominator_times_x = half_denominator * x; 869*5f39d1b3SJooyung Han F2 one_minus_half_denominator_times_x = 870*5f39d1b3SJooyung Han F2::One() - half_denominator_times_x; 871*5f39d1b3SJooyung Han x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 872*5f39d1b3SJooyung Han } 873*5f39d1b3SJooyung Han return Rescale<0>(ExactMulByPot<-1>(x)); 874*5f39d1b3SJooyung Han } 875*5f39d1b3SJooyung Han 876*5f39d1b3SJooyung Han // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. 877*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 878*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> logistic_on_positive_values( 879*5f39d1b3SJooyung Han FixedPoint<tRawType, tIntegerBits> a) { 880*5f39d1b3SJooyung Han return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); 881*5f39d1b3SJooyung Han } 882*5f39d1b3SJooyung Han 883*5f39d1b3SJooyung Han // Returns logistic(x) = 1 / (1 + exp(-x)) for any x. 884*5f39d1b3SJooyung Han template <typename tRawType, int tIntegerBits> 885*5f39d1b3SJooyung Han FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { 886*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, tIntegerBits> InputF; 887*5f39d1b3SJooyung Han typedef FixedPoint<tRawType, 0> ResultF; 888*5f39d1b3SJooyung Han tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero()); 889*5f39d1b3SJooyung Han tRawType mask_if_zero = MaskIfZero(a); 890*5f39d1b3SJooyung Han InputF abs_input = SelectUsingMask(mask_if_positive, a, -a); 891*5f39d1b3SJooyung Han ResultF result_if_positive = logistic_on_positive_values(abs_input); 892*5f39d1b3SJooyung Han ResultF result_if_negative = ResultF::One() - result_if_positive; 893*5f39d1b3SJooyung Han const ResultF one_half = 894*5f39d1b3SJooyung Han GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5); 895*5f39d1b3SJooyung Han return SelectUsingMask(mask_if_zero, one_half, 896*5f39d1b3SJooyung Han SelectUsingMask(mask_if_positive, result_if_positive, 897*5f39d1b3SJooyung Han result_if_negative)); 898*5f39d1b3SJooyung Han } 899*5f39d1b3SJooyung Han 900*5f39d1b3SJooyung Han } // end namespace gemmlowp 901*5f39d1b3SJooyung Han 902*5f39d1b3SJooyung Han #ifdef GEMMLOWP_NEON 903*5f39d1b3SJooyung Han #include "./fixedpoint_neon.h" 904*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_AVX2) 905*5f39d1b3SJooyung Han #include "./fixedpoint_avx.h" 906*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_SSE4) 907*5f39d1b3SJooyung Han #include "./fixedpoint_sse.h" 908*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_MSA) 909*5f39d1b3SJooyung Han #include "./fixedpoint_msa.h" 910*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_WASMSIMD) 911*5f39d1b3SJooyung Han #include "./fixedpoint_wasmsimd.h" 912*5f39d1b3SJooyung Han #endif 913*5f39d1b3SJooyung Han 914*5f39d1b3SJooyung Han #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 915