1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_UNIFORMSUPPORT_H_ 17 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_UNIFORMSUPPORT_H_ 18 19 #include <utility> 20 21 #include "llvm/ADT/APFloat.h" 22 #include "llvm/ADT/APInt.h" 23 #include "llvm/ADT/APSInt.h" 24 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project 25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 26 #include "mlir/IR/Types.h" // from @llvm-project 27 28 namespace mlir { 29 namespace quantfork { 30 31 /// Performs type conversion from an arbitrary input type to a type 32 /// that is expressed by a QuantizedType. 33 /// 34 /// This handles cases where the inputType is a supported primitive type 35 /// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported 36 /// elemental type. 37 /// 38 /// Since conversion often involves introspecting some attributes of the 39 /// input type in order to determine how to represent it, this is a two step 40 /// process. 41 struct ExpressedToQuantizedConverter { 42 /// Creates a converter for the given input type. 43 static ExpressedToQuantizedConverter forInputType(Type inputType); 44 45 /// Converts the inputType to be based on the given elemental type, 46 /// returning the new type (or nullptr and emit an error on failure). 47 Type convert(quant::QuantizedType elementalType) const; 48 49 /// Whether the conversion is legal. 50 explicit operator bool() const { return (bool)expressedType; } 51 52 /// The input type that is being converted from. 53 /// This may be an elemental or composite type. 54 const Type inputType; 55 56 /// Supported, elemental expressed type (i.e. f32). 57 /// Will be nullptr if conversion is not supported. 58 const Type expressedType; 59 }; 60 61 /// Reference implementation of converting between real numbers and values 62 /// represented by a UniformQuantizedType. 63 /// Note that this is not expected to be speedy and may be superseded eventually 64 /// by a more optimal implementation. 65 /// Also, the interface assumes that quantization is done per-layer and will 66 /// need to be wider for various per-channel schemes. As such, this is a 67 /// placeholder. 68 class UniformQuantizedValueConverter { 69 public: UniformQuantizedValueConverter(quant::UniformQuantizedType uniformType)70 explicit UniformQuantizedValueConverter( 71 quant::UniformQuantizedType uniformType) 72 : UniformQuantizedValueConverter( 73 uniformType.getScale(), 74 static_cast<double>(uniformType.getZeroPoint()), 75 static_cast<double>(uniformType.getStorageTypeMin()), 76 static_cast<double>(uniformType.getStorageTypeMax()), 77 uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) { 78 assert(uniformType.getExpressedType().isa<FloatType>()); 79 assert(uniformType.getStorageType().isSignlessInteger()); 80 } 81 UniformQuantizedValueConverter(double scale,double zeroPoint,double clampMin,double clampMax,uint32_t storageBitWidth,bool isSigned)82 UniformQuantizedValueConverter(double scale, double zeroPoint, 83 double clampMin, double clampMax, 84 uint32_t storageBitWidth, bool isSigned) 85 : scale(scale), 86 zeroPoint(zeroPoint), 87 clampMin(clampMin), 88 clampMax(clampMax), 89 scaleDouble(scale), 90 zeroPointDouble(zeroPoint), 91 clampMinDouble(clampMin), 92 clampMaxDouble(clampMax), 93 storageBitWidth(storageBitWidth), 94 isSigned(isSigned), 95 roundMode(APFloat::rmNearestTiesToAway) {} 96 UniformQuantizedValueConverter(double scale,double zeroPoint,const APFloat & clampMin,const APFloat & clampMax,uint32_t storageBitWidth,bool isSigned)97 UniformQuantizedValueConverter(double scale, double zeroPoint, 98 const APFloat &clampMin, 99 const APFloat &clampMax, 100 uint32_t storageBitWidth, bool isSigned) 101 : scale(scale), 102 zeroPoint(zeroPoint), 103 clampMin(clampMin), 104 clampMax(clampMax), 105 scaleDouble(scale), 106 zeroPointDouble(zeroPoint), 107 clampMinDouble(clampMin.convertToDouble()), 108 clampMaxDouble(clampMax.convertToDouble()), 109 storageBitWidth(storageBitWidth), 110 isSigned(isSigned), 111 roundMode(APFloat::rmNearestTiesToAway) {} 112 quantizeFloatToInt(APFloat expressedValue)113 virtual APInt quantizeFloatToInt(APFloat expressedValue) const { 114 // This function is a performance critical code path in quantization 115 // since it runs for each single float parameter value. 116 117 // Specialize f32->u8/i8 case to optimize performance. 118 if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() && 119 storageBitWidth == 8 && 120 roundMode == llvm::APFloatBase::rmNearestTiesToAway) { 121 return quantizeF32ToInt8(expressedValue); 122 } 123 124 bool lossy; 125 expressedValue.convert(scale.getSemantics(), roundMode, &lossy); 126 // fixedpoint = clamp(clampMin, clampMax, ( 127 // roundHalfToEven(expressed / scale) + zeroPoint)) 128 APFloat scaled = (expressedValue / scale); 129 scaled.roundToIntegral(roundMode); 130 scaled.add(zeroPoint, roundMode); 131 APFloat fixedpoint = llvm::minimum(scaled, clampMax); 132 fixedpoint = llvm::maximum(fixedpoint, clampMin); 133 134 llvm::APSInt result(storageBitWidth, !isSigned); 135 fixedpoint.convertToInteger(result, roundMode, &lossy); 136 137 return std::move(result); 138 } 139 quantizeFloatToInt64(APFloat expressedValue)140 int64_t quantizeFloatToInt64(APFloat expressedValue) const { 141 APInt qValue = quantizeFloatToInt(std::move(expressedValue)); 142 return isSigned ? qValue.getSExtValue() : qValue.getZExtValue(); 143 } 144 145 virtual ~UniformQuantizedValueConverter() = default; 146 147 private: 148 // An optimized implementation to quantize f32 to i8/u8 with C++ native 149 // arithmetic. quantizeF32ToInt8(APFloat expressedValue)150 virtual APInt quantizeF32ToInt8(APFloat expressedValue) const { 151 assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle()); 152 assert(storageBitWidth == 8); 153 assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway); 154 155 const float realValue = expressedValue.convertToFloat(); 156 157 const double scaled = realValue / scaleDouble + zeroPointDouble; 158 // Round to nearest integer with halfway cases rounded away from zero. 159 const double scaledRounded = std::round(scaled); 160 const double clamped = 161 std::min(std::max(scaledRounded, clampMinDouble), clampMaxDouble); 162 163 uint64_t signlessResult; 164 if (isSigned) { 165 int64_t clampedInt = static_cast<int8_t>(clamped); 166 memcpy(&signlessResult, &clampedInt, sizeof(clampedInt)); 167 } else { 168 signlessResult = static_cast<uint8_t>(clamped); 169 } 170 return APInt(storageBitWidth, signlessResult); 171 } 172 173 // Keep both APFloat and double versions of the quantization parameters 174 // around since they will be used in generic and specialized arithmetic, 175 // respectively. 176 const APFloat scale; 177 const APFloat zeroPoint; 178 const APFloat clampMin; 179 const APFloat clampMax; 180 181 const double scaleDouble; 182 const double zeroPointDouble; 183 const double clampMinDouble; 184 const double clampMaxDouble; 185 186 const uint32_t storageBitWidth; 187 const bool isSigned; 188 const llvm::APFloat::roundingMode roundMode; 189 }; 190 191 /// An utility class to quantize an attribute by the per-axis quantization 192 /// parameters. The size of the quantization dim in the converted elements 193 /// attribute should matche the size of of scales/zeroPoints vectors in the 194 /// quantization parameters. 195 class UniformQuantizedPerAxisValueConverter { 196 public: UniformQuantizedPerAxisValueConverter(quant::UniformQuantizedPerAxisType uniformType)197 explicit UniformQuantizedPerAxisValueConverter( 198 quant::UniformQuantizedPerAxisType uniformType) 199 : scales(uniformType.getScales()), 200 zeroPoints(uniformType.getZeroPoints()), 201 clampMin(static_cast<double>(uniformType.getStorageTypeMin())), 202 clampMax(static_cast<double>(uniformType.getStorageTypeMax())), 203 storageBitWidth(uniformType.getStorageTypeIntegralWidth()), 204 isSigned(uniformType.isSigned()), 205 quantizationDim(uniformType.getQuantizedDimension()) { 206 assert(uniformType.getExpressedType().isa<FloatType>()); 207 assert(uniformType.getStorageType().isSignlessInteger()); 208 assert(scales.size() == zeroPoints.size()); 209 } 210 211 /// Quantize an Attribute by the quantization parameters. Return nullptr if 212 /// the conversion fails or the input array isn't an ElementsAttr. 213 ElementsAttr convert(Attribute realValue); 214 215 private: 216 /// Quantize an DenseFPElementsAttr by the quantization parameters. 217 DenseElementsAttr convert(DenseFPElementsAttr attr); 218 219 /// Get a uniform converter for the index-th chunk along the quantizationDim. 220 /// All the elements in this chunk is quantized by the returned converter. getPerChunkConverter(int index)221 UniformQuantizedValueConverter getPerChunkConverter(int index) const { 222 UniformQuantizedValueConverter converter(scales[index], zeroPoints[index], 223 clampMin, clampMax, 224 storageBitWidth, isSigned); 225 return converter; 226 } 227 228 const ArrayRef<double> scales; 229 const ArrayRef<int64_t> zeroPoints; 230 const APFloat clampMin; 231 const APFloat clampMax; 232 const uint32_t storageBitWidth; 233 const bool isSigned; 234 int32_t quantizationDim; 235 }; 236 237 } // namespace quantfork 238 } // namespace mlir 239 240 #endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_UNIFORMSUPPORT_H_ 241