xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_UNIFORMSUPPORT_H_
17 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_UNIFORMSUPPORT_H_
18 
19 #include <utility>
20 
21 #include "llvm/ADT/APFloat.h"
22 #include "llvm/ADT/APInt.h"
23 #include "llvm/ADT/APSInt.h"
24 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Types.h"  // from @llvm-project
27 
28 namespace mlir {
29 namespace quantfork {
30 
31 /// Performs type conversion from an arbitrary input type to a type
32 /// that is expressed by a QuantizedType.
33 ///
34 /// This handles cases where the inputType is a supported primitive type
35 /// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported
36 /// elemental type.
37 ///
38 /// Since conversion often involves introspecting some attributes of the
39 /// input type in order to determine how to represent it, this is a two step
40 /// process.
41 struct ExpressedToQuantizedConverter {
42   /// Creates a converter for the given input type.
43   static ExpressedToQuantizedConverter forInputType(Type inputType);
44 
45   /// Converts the inputType to be based on the given elemental type,
46   /// returning the new type (or nullptr and emit an error on failure).
47   Type convert(quant::QuantizedType elementalType) const;
48 
49   /// Whether the conversion is legal.
50   explicit operator bool() const { return (bool)expressedType; }
51 
52   /// The input type that is being converted from.
53   /// This may be an elemental or composite type.
54   const Type inputType;
55 
56   /// Supported, elemental expressed type (i.e. f32).
57   /// Will be nullptr if conversion is not supported.
58   const Type expressedType;
59 };
60 
61 /// Reference implementation of converting between real numbers and values
62 /// represented by a UniformQuantizedType.
63 /// Note that this is not expected to be speedy and may be superseded eventually
64 /// by a more optimal implementation.
65 /// Also, the interface assumes that quantization is done per-layer and will
66 /// need to be wider for various per-channel schemes. As such, this is a
67 /// placeholder.
68 class UniformQuantizedValueConverter {
69  public:
UniformQuantizedValueConverter(quant::UniformQuantizedType uniformType)70   explicit UniformQuantizedValueConverter(
71       quant::UniformQuantizedType uniformType)
72       : UniformQuantizedValueConverter(
73             uniformType.getScale(),
74             static_cast<double>(uniformType.getZeroPoint()),
75             static_cast<double>(uniformType.getStorageTypeMin()),
76             static_cast<double>(uniformType.getStorageTypeMax()),
77             uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) {
78     assert(uniformType.getExpressedType().isa<FloatType>());
79     assert(uniformType.getStorageType().isSignlessInteger());
80   }
81 
UniformQuantizedValueConverter(double scale,double zeroPoint,double clampMin,double clampMax,uint32_t storageBitWidth,bool isSigned)82   UniformQuantizedValueConverter(double scale, double zeroPoint,
83                                  double clampMin, double clampMax,
84                                  uint32_t storageBitWidth, bool isSigned)
85       : scale(scale),
86         zeroPoint(zeroPoint),
87         clampMin(clampMin),
88         clampMax(clampMax),
89         scaleDouble(scale),
90         zeroPointDouble(zeroPoint),
91         clampMinDouble(clampMin),
92         clampMaxDouble(clampMax),
93         storageBitWidth(storageBitWidth),
94         isSigned(isSigned),
95         roundMode(APFloat::rmNearestTiesToAway) {}
96 
UniformQuantizedValueConverter(double scale,double zeroPoint,const APFloat & clampMin,const APFloat & clampMax,uint32_t storageBitWidth,bool isSigned)97   UniformQuantizedValueConverter(double scale, double zeroPoint,
98                                  const APFloat &clampMin,
99                                  const APFloat &clampMax,
100                                  uint32_t storageBitWidth, bool isSigned)
101       : scale(scale),
102         zeroPoint(zeroPoint),
103         clampMin(clampMin),
104         clampMax(clampMax),
105         scaleDouble(scale),
106         zeroPointDouble(zeroPoint),
107         clampMinDouble(clampMin.convertToDouble()),
108         clampMaxDouble(clampMax.convertToDouble()),
109         storageBitWidth(storageBitWidth),
110         isSigned(isSigned),
111         roundMode(APFloat::rmNearestTiesToAway) {}
112 
quantizeFloatToInt(APFloat expressedValue)113   virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
114     // This function is a performance critical code path in quantization
115     // since it runs for each single float parameter value.
116 
117     // Specialize f32->u8/i8 case to optimize performance.
118     if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() &&
119         storageBitWidth == 8 &&
120         roundMode == llvm::APFloatBase::rmNearestTiesToAway) {
121       return quantizeF32ToInt8(expressedValue);
122     }
123 
124     bool lossy;
125     expressedValue.convert(scale.getSemantics(), roundMode, &lossy);
126     // fixedpoint = clamp(clampMin, clampMax, (
127     //   roundHalfToEven(expressed / scale) + zeroPoint))
128     APFloat scaled = (expressedValue / scale);
129     scaled.roundToIntegral(roundMode);
130     scaled.add(zeroPoint, roundMode);
131     APFloat fixedpoint = llvm::minimum(scaled, clampMax);
132     fixedpoint = llvm::maximum(fixedpoint, clampMin);
133 
134     llvm::APSInt result(storageBitWidth, !isSigned);
135     fixedpoint.convertToInteger(result, roundMode, &lossy);
136 
137     return std::move(result);
138   }
139 
quantizeFloatToInt64(APFloat expressedValue)140   int64_t quantizeFloatToInt64(APFloat expressedValue) const {
141     APInt qValue = quantizeFloatToInt(std::move(expressedValue));
142     return isSigned ? qValue.getSExtValue() : qValue.getZExtValue();
143   }
144 
145   virtual ~UniformQuantizedValueConverter() = default;
146 
147  private:
148   // An optimized implementation to quantize f32 to i8/u8 with C++ native
149   // arithmetic.
quantizeF32ToInt8(APFloat expressedValue)150   virtual APInt quantizeF32ToInt8(APFloat expressedValue) const {
151     assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle());
152     assert(storageBitWidth == 8);
153     assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway);
154 
155     const float realValue = expressedValue.convertToFloat();
156 
157     const double scaled = realValue / scaleDouble + zeroPointDouble;
158     // Round to nearest integer with halfway cases rounded away from zero.
159     const double scaledRounded = std::round(scaled);
160     const double clamped =
161         std::min(std::max(scaledRounded, clampMinDouble), clampMaxDouble);
162 
163     uint64_t signlessResult;
164     if (isSigned) {
165       int64_t clampedInt = static_cast<int8_t>(clamped);
166       memcpy(&signlessResult, &clampedInt, sizeof(clampedInt));
167     } else {
168       signlessResult = static_cast<uint8_t>(clamped);
169     }
170     return APInt(storageBitWidth, signlessResult);
171   }
172 
173   // Keep both APFloat and double versions of the quantization parameters
174   // around since they will be used in generic and specialized arithmetic,
175   // respectively.
176   const APFloat scale;
177   const APFloat zeroPoint;
178   const APFloat clampMin;
179   const APFloat clampMax;
180 
181   const double scaleDouble;
182   const double zeroPointDouble;
183   const double clampMinDouble;
184   const double clampMaxDouble;
185 
186   const uint32_t storageBitWidth;
187   const bool isSigned;
188   const llvm::APFloat::roundingMode roundMode;
189 };
190 
191 /// An utility class to quantize an attribute by the per-axis quantization
192 /// parameters. The size of the quantization dim in the converted elements
193 /// attribute should matche the size of of scales/zeroPoints vectors in the
194 /// quantization parameters.
195 class UniformQuantizedPerAxisValueConverter {
196  public:
UniformQuantizedPerAxisValueConverter(quant::UniformQuantizedPerAxisType uniformType)197   explicit UniformQuantizedPerAxisValueConverter(
198       quant::UniformQuantizedPerAxisType uniformType)
199       : scales(uniformType.getScales()),
200         zeroPoints(uniformType.getZeroPoints()),
201         clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
202         clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
203         storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
204         isSigned(uniformType.isSigned()),
205         quantizationDim(uniformType.getQuantizedDimension()) {
206     assert(uniformType.getExpressedType().isa<FloatType>());
207     assert(uniformType.getStorageType().isSignlessInteger());
208     assert(scales.size() == zeroPoints.size());
209   }
210 
211   /// Quantize an Attribute by the quantization parameters. Return nullptr if
212   /// the conversion fails or the input array isn't an ElementsAttr.
213   ElementsAttr convert(Attribute realValue);
214 
215  private:
216   /// Quantize an DenseFPElementsAttr by the quantization parameters.
217   DenseElementsAttr convert(DenseFPElementsAttr attr);
218 
219   /// Get a uniform converter for the index-th chunk along the quantizationDim.
220   /// All the elements in this chunk is quantized by the returned converter.
getPerChunkConverter(int index)221   UniformQuantizedValueConverter getPerChunkConverter(int index) const {
222     UniformQuantizedValueConverter converter(scales[index], zeroPoints[index],
223                                              clampMin, clampMax,
224                                              storageBitWidth, isSigned);
225     return converter;
226   }
227 
228   const ArrayRef<double> scales;
229   const ArrayRef<int64_t> zeroPoints;
230   const APFloat clampMin;
231   const APFloat clampMax;
232   const uint32_t storageBitWidth;
233   const bool isSigned;
234   int32_t quantizationDim;
235 };
236 
237 }  // namespace quantfork
238 }  // namespace mlir
239 
240 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_UNIFORMSUPPORT_H_
241