xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This header file defines common utils used by TFLite transformation
17 // passes to work with op attributes.
18 
19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
20 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
21 
22 #include <algorithm>
23 #include <functional>
24 #include <string>
25 #include <unordered_map>
26 
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/string_view.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/Twine.h"
31 #include "llvm/Support/Casting.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
35 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
36 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
37 #include "mlir/IR/Attributes.h"  // from @llvm-project
38 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
41 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
42 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
43 #include "mlir/IR/Matchers.h"  // from @llvm-project
44 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
45 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
46 #include "mlir/Support/LLVM.h"  // from @llvm-project
47 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
48 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
49 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
50 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
51 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 
54 namespace mlir {
55 namespace quant {
56 
57 // A unit attribute can be attached to the quantize/dequantize ops which are
58 // added by the quantization passes. These ops can be removed erased without
59 // losing accuracy.
60 constexpr char kVolatileOpAttrName[] = "volatile";
61 
62 // Following attributes are used to mark ops that are not quantizable during
63 // debug model generation process for whole-model verify mode. If these
64 // attributes are attached, the upstream float/quantized ops know which ops to
65 // connect to, and it also prevents these ops from being copied again.
66 constexpr char kDebugModeOpFloatAttrName[] = "debug_float";
67 constexpr char kDebugModeOpQuantAttrName[] = "debug_quant";
68 
69 // Used to annotate custom ops if they are quantizable.
70 constexpr char kQuantTraitAttrName[] = "_tfl_quant_trait";
71 enum QuantizationTrait { FullyQuantizable = 0, NotQuantizable = 1 };
72 constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable",
73                                                   "not_quantizable"};
74 
75 constexpr double kNearZeroTolerance = 1.0e-6;
76 
77 using QuantParams = mlir::quant::QuantizedType;
78 using QuantSpec = QuantizationSpecs;
79 using SignedInteger = std::pair<unsigned, unsigned>;  // bitwidth and sign
80 using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
81 using AccumulatorScaleFunc =
82     std::function<QuantParams(const std::vector<QuantParams>&, bool)>;
83 using BiasParamsMap =
84     std::unordered_map<int, std::pair<std::vector<int>, AccumulatorScaleFunc>>;
85 // UniformQuantizedType GetFixedOutputRange(bool sign, int bit_width)
86 using GetFixedOutputRangeFunc = std::function<UniformQuantizedType(bool, int)>;
87 // bool RequiredSameOperandsAndResultsScale(bool sign, int $bit_width)
88 using RequiredSameOperandsAndResultsScaleFunc = std::function<bool(bool, int)>;
89 // bool RequiredSameQuantizedAxes()
90 using RequiredSameQuantizedAxesFunc = std::function<bool()>;
91 
92 using StringSet = absl::flat_hash_set<std::string>;
93 using CustomMap = quant::CustomOpMap;
94 
95 // Quantization spec of an op, driving the quantization algorithm.
96 struct OpQuantSpec {
97   // Maps the operand index of a bias input to its quantization specifications,
98   // including the non-bias operand indexes and the method retrieving
99   // quantization parameters from list of parameters of the non-bias operands.
100   // This map is empty if the op doesn't have a bias operand.
101   BiasParamsMap biases_params;
102 
103   // Quantization parameters for value restricted outputs. This is the
104   // "hard-coded" parameters and should be used unconditionally for the
105   // quantized op. This vector is empty if the op doesn't have value restricted
106   // outputs.
107   llvm::DenseMap<SignedInteger, QuantParamsForResults> restricted_output_params;
108 
109   // Coefficient operand index and whether supporting per-channel quantization.
110   // For QAT, this information is carried by the FakeQuant*/QDQ ops, but
111   // post-training quantization, the quantization parameters need to be inferred
112   // from the tensor content and op property. A "-1" value indicates the
113   // operand doesn't support per-channel quantization.
114   llvm::DenseMap<int, int> coeff_op_quant_dim;
115 
116   // Indices of quantizable operands. Biases are not included in this field,
117   // the indices of biases can be found in the `biases_params`.
118   absl::flat_hash_set<int> quantizable_operands;
119 };
120 
121 // Quantization scale spec of an op. The information defined in the MLIR
122 // interfaces FixedOutputRangeInterface and SameOperandsAndResultsScale should
123 // be checked first if present.
124 struct OpQuantScaleSpec {
125   // Whether this op has a fixed range requirement (e.g. sigmoid)
126   bool has_fixed_output_range = false;
127   // Whether this op should have same result and operand scales (e.g. concat)
128   bool has_same_scale_requirement = false;
129   // Returns the fixed output range, when has_fixed_output_range is set.
130   GetFixedOutputRangeFunc fixed_output_range_func;
131   // Returns whether same operands and results scales are required.
132   RequiredSameOperandsAndResultsScaleFunc required_same_scale_func =
133       [](bool sign, int bit_width) { return true; };
134   // Returns whether operands and results must have the same quantized axis.
135   RequiredSameQuantizedAxesFunc required_same_quantized_axes_func = []() {
136     return true;
137   };
138 };
139 
140 // Used in TFL Numeric Verify
141 struct NumericVerifySpec {
142   // Whether to enable numeric verification
143   bool verify_numeric = false;
144 
145   // Tolerance level from the quantized value for verification. If the tolerance
146   // is very small(<0.1), only the stats of the diff is displayed.
147   float error_tolerance = 5.0f;
148 
149   // Whether to verify numerical correctness layer by layer or by whole model
150   bool whole_model_verify = false;
151 
152   // Whether to enable log for failures
153   bool log_if_failed_flag = false;
154 };
155 
156 // Used in TFL Quantize Pass
157 struct QuantPassSpec {
158   // Variables to control TFL Numeric Verify
159   NumericVerifySpec numeric_verify_spec;
160 
161   // Variables related to quantization
162   QuantSpec quant_spec;
163 };
164 
165 // A function signature for getting the particular OpQuantSpec for the provided
166 // op.
167 typedef std::unique_ptr<OpQuantSpec> (*OpQuantSpecGetter)(Operation* op);
168 // A function signature for getting the particular OpQuantScaleSpec for the
169 // provided op.
170 typedef std::unique_ptr<OpQuantScaleSpec> (*OpQuantScaleSpecGetter)(
171     Operation* op);
172 
173 // Re-calculates scales again in float instead of simply downcasting existing
174 // scales.
175 quant::QuantizedType DownCastScale(quant::QuantizedType type,
176                                    const SmallVectorImpl<double>& mins,
177                                    const SmallVectorImpl<double>& maxs,
178                                    Location loc);
179 
180 quant::QuantizedType DownCastScale(quant::QuantizedType type, double min,
181                                    double max, Location loc);
182 
183 bool IsOpNotQuantizable(Operation* op);
184 
185 // Specialized version of location to string for flatbuffer exported locations.
GetTensorNameFromLoc(Location loc)186 inline std::string GetTensorNameFromLoc(Location loc) {
187   if (auto name_loc = loc.dyn_cast<NameLoc>()) {
188     return name_loc.getName().str();
189   }
190   return "";
191 }
192 
193 template <typename Q, typename DQ>
194 struct ConvertStatsToQDQs : public OpRewritePattern<quantfork::StatisticsOp> {
ConvertStatsToQDQsConvertStatsToQDQs195   ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed,
196                      bool legacy_float_scale, MLIRContext* context)
197       : OpRewritePattern<quantfork::StatisticsOp>(context),
198         num_bits(num_bits),
199         narrow_range(narrow_range),
200         is_signed(is_signed),
201         legacy_float_scale(legacy_float_scale) {}
202 
matchAndRewriteConvertStatsToQDQs203   LogicalResult matchAndRewrite(quantfork::StatisticsOp op,
204                                 PatternRewriter& rewriter) const override {
205     Type expressed = op.getType().cast<ShapedType>().getElementType();
206     quant::QuantizedType quant_type;
207     SmallVector<double, 4> mins, maxs;
208 
209     if (op.getAxisStats().has_value()) {
210       int stats_num = op.getAxisStats()->getNumElements();
211       if (stats_num == 0 || stats_num % 2 != 0) return failure();
212       auto stats = op.getAxisStats()->dyn_cast<DenseFPElementsAttr>();
213       if (!stats) return failure();
214 
215       for (auto it = stats.begin(), e = stats.end(); it != e; ++it) {
216         double rmin = FloatAttr::getValueAsDouble(*it++);
217         double rmax = FloatAttr::getValueAsDouble(*it);
218         // The default nudging implementation of mlir quant library might cause
219         // clamping during inference if the calibration range isn't wide enough.
220         // So here we adjust the range to include 0.0.
221         rmin = std::min(rmin, 0.0);
222         rmax = std::max(rmax, 0.0);
223         TensorRangeSanityCheck(op, rmin, rmax);
224         mins.push_back(rmin);
225         maxs.push_back(rmax);
226       }
227       quant_type = quantfork::fakeQuantAttrsToType(
228           op.getLoc(), num_bits, *op.getAxis(), mins, maxs, narrow_range,
229           expressed, is_signed);
230       if (legacy_float_scale) {
231         quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc());
232       }
233     } else if (auto stats =
234                    op.getLayerStats().dyn_cast<DenseFPElementsAttr>()) {
235       auto statValues = stats.getValues<APFloat>();
236       double rmin = FloatAttr::getValueAsDouble(statValues[0]);
237       double rmax = FloatAttr::getValueAsDouble(statValues[1]);
238       // The default nudging implementation of mlir quant library might cause
239       // clamping during inference if the calibration range isn't wide enough.
240       // So here we adjust the range to include 0.0.
241       rmin = std::min(rmin, 0.0);
242       rmax = std::max(rmax, 0.0);
243       TensorRangeSanityCheck(op, rmin, rmax);
244       quant_type =
245           quantfork::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax,
246                                           narrow_range, expressed, is_signed);
247       if (legacy_float_scale) {
248         quant_type = DownCastScale(quant_type, rmin, rmax, op->getLoc());
249       }
250     } else {
251       return failure();
252     }
253 
254     rewriter.setInsertionPointAfter(op.getOperation());
255     Type result_type = quant_type.castFromExpressedType(op.getType());
256     auto q = rewriter.create<Q>(op.getLoc(), result_type, op.getArg());
257     q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr());
258 
259     auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
260     op.getResult().replaceAllUsesWith(dq);
261     q.getOperation()->replaceUsesOfWith(dq, op.getArg());
262     op.erase();
263 
264     return success();
265   }
266 
267  private:
268   int num_bits;
269   bool narrow_range;
270   bool is_signed;
271   bool legacy_float_scale;
272 
273   // Emits an op warning message if the calibrated range is larger than 10.0 and
274   // the storage type is less than or equal to 8 bits.
TensorRangeSanityCheckConvertStatsToQDQs275   void TensorRangeSanityCheck(quantfork::StatisticsOp op, double& min,
276                               double& max) const {
277     double range = std::fabs(max - min);
278     if (num_bits <= 8 && range >= 10.0) {
279       op.emitWarning()
280           << "Tensor range is too wide to be quantized. Use tf.clip_by_value "
281              "or tf.relu6 to narrow the tensor range. Range: "
282           << range << ", bit width: " << num_bits;
283     }
284     if (std::abs(max - min) < kNearZeroTolerance) {
285       op.emitWarning() << "Tensor range (" << min << ", " << max
286                        << ") is too narrow and it might cause overflow. "
287                           "Expanding range symmetrically by "
288                        << kNearZeroTolerance;
289       min -= kNearZeroTolerance;
290       max += kNearZeroTolerance;
291     }
292   }
293 };
294 
295 template <typename VerifierT>
UsedBy(Operation * op)296 bool UsedBy(Operation* op) {
297   for (Operation* user : op->getUsers()) {
298     if (llvm::isa_and_nonnull<VerifierT>(user)) return true;
299   }
300   return false;
301 }
302 
303 template <typename VerifierT>
CreateVerifier(Operation * quantizing_op,Operation * quantized_op,PatternRewriter & rewriter,int result_idx,const QuantPassSpec & quant_params)304 void CreateVerifier(Operation* quantizing_op, Operation* quantized_op,
305                     PatternRewriter& rewriter, int result_idx,
306                     const QuantPassSpec& quant_params) {
307   rewriter.setInsertionPointAfter(quantized_op);
308   FloatAttr tolerance = rewriter.getF32FloatAttr(
309       quant_params.numeric_verify_spec.error_tolerance);
310   BoolAttr log =
311       rewriter.getBoolAttr(quant_params.numeric_verify_spec.log_if_failed_flag);
312   // Verify the quantized value by sending the result to the verifier.
313   rewriter.create<VerifierT>(
314       quantizing_op->getLoc(), quantized_op->getResult(result_idx).getType(),
315       quantized_op->getResult(result_idx), quantizing_op->getResult(result_idx),
316       tolerance, log);
317 }
318 
319 template <>
320 inline bool UsedBy<void>(Operation* op) {
321   return false;
322 }
323 
324 // This specialization is not going to be called, but needed for compilation.
325 template <>
326 inline void CreateVerifier<void>(Operation* quantizing_op,
327                                  Operation* quantized_op,
328                                  PatternRewriter& rewriter, int result_idx,
329                                  const QuantPassSpec& quant_params) {}
330 
331 // A base rewrite pattern which matches any N-in-M-out operations with
332 // quantization parameters propagated to at least one of its operands. The
333 // quantization parameters are annotated by the Q/DQ op pairs. Each
334 // matched pattern are rewritten by its quantized alternatives.
335 //
336 // The concrete pattern, extends from this base pattern, can specify whether it
337 // allows dynamic range quantized operands and results for the operations in the
338 // current context. These "DynamicRangeQuantized" operands and results don't
339 // have quantization parameters propagated to, so will be in float in the
340 // quantized results. The concrete pattern should define the following two
341 // functions:
342 //
343 //   bool AllowDynamicRangeQuantizedOperand(Operation *) const
344 //   bool AllowDynamicRangeQuantizedResult(Operation *) const
345 //
346 // Full integer quantization disallows "DynamicRangeQuantized" operands or
347 // results. Dynamic range quantization allows "DynamicRangeQuantized" operands
348 // and results.
349 template <typename ConcretTy, typename Q, typename DQ, typename VERIFIER,
350           typename RootOp = DQ>
351 class QuantizationPattern : public RewritePattern {
352  public:
353   using BaseType = QuantizationPattern<ConcretTy, Q, DQ, VERIFIER, RootOp>;
354 
QuantizationPattern(MLIRContext * context,const QuantPassSpec & quant_params)355   explicit QuantizationPattern(MLIRContext* context,
356                                const QuantPassSpec& quant_params)
357       // Set the score to a large number so it is always preferred.
358       : RewritePattern(RootOp::getOperationName(), 300, context),
359         quant_params_(quant_params) {}
360 
matchAndRewrite(Operation * op,PatternRewriter & rewriter)361   LogicalResult matchAndRewrite(Operation* op,
362                                 PatternRewriter& rewriter) const override {
363     llvm::SmallVector<Operation*, 4> quantizing_ops;
364 
365     // Collect all the ops to quantize, as the user / producer of the root op.
366     if (std::is_same<RootOp, DQ>::value) {
367       if (op->getNumResults() != 1) {
368         return failure();
369       }
370       auto users = op->getResult(0).getUsers();
371       quantizing_ops.append(users.begin(), users.end());
372     } else if (std::is_same<RootOp, Q>::value) {
373       if (op->getNumOperands() != 1) {
374         return failure();
375       }
376       Value quantize_operand = op->getOperand(0);
377       if (QuantizedType::getQuantizedElementType(quantize_operand.getType())) {
378         // The input of this Q op has already been quantized, i.e. rescale.
379         return failure();
380       }
381       DenseFPElementsAttr attr;
382       if (matchPattern(quantize_operand, m_Constant(&attr))) {
383         // Const->Q pattern will be handled separately.
384         return failure();
385       }
386       if (Operation* quantizing_op = quantize_operand.getDefiningOp()) {
387         quantizing_ops.push_back(quantizing_op);
388       }
389     }
390 
391     tensorflow::DataType inference_type =
392         quant_params_.quant_spec.inference_type;
393     bool weight_only_quantization =
394         quant_params_.quant_spec.weight_only_quantization;
395     bool enable_verify = quant_params_.numeric_verify_spec.verify_numeric;
396     bool enable_whole_model_verify =
397         quant_params_.numeric_verify_spec.whole_model_verify;
398     StringSet ops_blocklist = quant_params_.quant_spec.ops_blocklist;
399     StringSet nodes_blocklist = quant_params_.quant_spec.nodes_blocklist;
400     CustomMap custom_map = quant_params_.quant_spec.custom_map;
401 
402     // Rewrite the floating-point ops to the quantized version, by fusing
403     // preceding dequantize ops and succeding quantize ops.
404     for (Operation* quantizing_op : quantizing_ops) {
405       // If it is requantize op, we shouldn't rewrite this op.
406       if (llvm::isa<Q, DQ>(quantizing_op)) {
407         return failure();
408       }
409 
410       // If the op is terminator, not quantizable or any ops from the mlir quant
411       // ops dialect, we shouldn't rewrite. In case of whole-model verify debug
412       // mode, not-quantizable ops should be duplicated to keep parallel
413       // float/quant model execution.
414       if (quantizing_op->hasTrait<OpTrait::IsTerminator>()) {
415         return failure();
416       }
417 
418       if (IsOpNotQuantizable(quantizing_op) &&
419           !static_cast<const ConcretTy*>(this)->IsQuantizableCustomOp(
420               quantizing_op, custom_map)) {
421         if (!(enable_verify && enable_whole_model_verify)) {
422           return failure();
423         }
424         if (quantizing_op->hasAttr(kDebugModeOpQuantAttrName) ||
425             quantizing_op->hasAttr(kDebugModeOpFloatAttrName)) {
426           return failure();
427         }
428 
429         rewriter.setInsertionPoint(quantizing_op);
430         Operation* float_op = rewriter.clone(*quantizing_op);
431         quantizing_op->setAttr(kDebugModeOpQuantAttrName,
432                                rewriter.getUnitAttr());
433         float_op->setAttr(kDebugModeOpFloatAttrName, rewriter.getUnitAttr());
434         RewireFloatModelBackbone(quantizing_op, float_op);
435         return success();
436       }
437 
438       // Blocklist op is checked in advance for non-dynamic range quantization
439       // case.
440       if (!quant_params_.quant_spec.weight_quantization &&
441           (ops_blocklist.find(quantizing_op->getName().getStringRef().str()) !=
442            ops_blocklist.end())) {
443         return failure();
444       }
445 
446       if (!nodes_blocklist.empty()) {
447         if (auto name_loc = quantizing_op->getLoc().dyn_cast<NameLoc>()) {
448           std::string sloc = name_loc.getName().str();
449           if (!sloc.empty() &&
450               (nodes_blocklist.find(sloc) != nodes_blocklist.end())) {
451             return failure();
452           }
453         }
454       }
455 
456       // An op with float inputs and outputs are expected when it's used by a
457       // NumericVerify op. Skip this op.
458       if (enable_verify && UsedBy<VERIFIER>(quantizing_op)) {
459         continue;
460       }
461 
462       // Collect all the quantized inputs and "clone" the matched op by these
463       // inputs.
464       SmallVector<Value, 4> inputs;
465       inputs.reserve(quantizing_op->getNumOperands());
466       for (auto operand : quantizing_op->getOperands()) {
467         Type operand_type = operand.getType();
468         if (operand_type.isa<NoneType>()) {
469           inputs.push_back(operand);
470           continue;
471         }
472 
473         auto ele_type = operand.getType().cast<TensorType>().getElementType();
474         if (static_cast<const ConcretTy*>(this)
475                 ->AllowDynamicRangeQuantizedOperand(quantizing_op,
476                                                     custom_map)) {
477           auto dq_op = dyn_cast_or_null<DQ>(operand.getDefiningOp());
478 
479           if (dq_op && inference_type == tensorflow::DT_QINT8 &&
480               !static_cast<const ConcretTy*>(this)->IsWeightOnlyOp(
481                   quantizing_op, ops_blocklist, weight_only_quantization,
482                   custom_map)) {
483             // Dynamic range quantization is applied by having Q as an input.
484             // Only int8 weight is supported for now.
485             inputs.push_back(dq_op.getOperand());
486           } else {
487             // Otherwise, it's the case where the operand is activations or the
488             // quantizing_op is non-supported/weight-only.
489             inputs.push_back(operand);
490           }
491         } else {
492           if (auto dq_op = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
493             inputs.push_back(dq_op.getOperand());
494           } else if (!ele_type.isF32()) {
495             // If the operand is an integer tensor, then it doesn't require the
496             // DQ op in the pattern.
497             inputs.push_back(operand);
498           } else {
499             return failure();
500           }
501         }
502       }
503 
504       // Collect all the quantized outputs and replace them by the results of
505       // the new quantized op.
506       llvm::SmallDenseMap<Value, int> outputs_replaced;
507       SmallVector<Type, 4> output_types;
508       output_types.reserve(quantizing_op->getNumResults());
509       for (const auto& enumerated_result :
510            llvm::enumerate(quantizing_op->getResults())) {
511         Value result = enumerated_result.value();
512         Type result_type = result.getType();
513         // Add this to the test coverage once we create test ops with none type
514         // results.
515         if (result_type.isa<NoneType>()) {
516           outputs_replaced.insert({result, enumerated_result.index()});
517           output_types.push_back(result_type);
518           continue;
519         }
520         Type result_ele_type =
521             result.getType().cast<TensorType>().getElementType();
522         // If the user is the Quantize op, it must be the only user.
523         if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) {
524           auto user = llvm::cast<Q>(*result.user_begin());
525           outputs_replaced.insert(
526               {user.getResult(), enumerated_result.index()});
527           output_types.push_back(user.getType());
528         } else if (!result_ele_type.isF32()) {
529           // If the result is an integer tensor, then it doesn't require the
530           // D op in the pattern.
531           outputs_replaced.insert({result, enumerated_result.index()});
532           output_types.push_back(result.getType());
533         } else if (static_cast<const ConcretTy*>(this)
534                        ->AllowDynamicRangeQuantizedResult(quantizing_op,
535                                                           custom_map)) {
536           outputs_replaced.insert({result, enumerated_result.index()});
537           output_types.push_back(result.getType());
538         } else {
539           return failure();
540         }
541       }
542 
543       rewriter.setInsertionPointAfter(quantizing_op);
544       OperationState new_state(quantizing_op->getLoc(),
545                                quantizing_op->getName().getStringRef(), inputs,
546                                output_types, quantizing_op->getAttrs());
547       for (int i = 0; i < quantizing_op->getNumRegions(); ++i) {
548         new_state.addRegion();
549       }
550       Operation* quantized_op = rewriter.create(new_state);
551       if (quantizing_op->getNumRegions() != 0) {
552         for (const auto& indexed_regions :
553              llvm::enumerate(quantizing_op->getRegions())) {
554           Region& target_region =
555               quantized_op->getRegion(indexed_regions.index());
556           BlockAndValueMapping mapping;
557           indexed_regions.value().cloneInto(&target_region, mapping);
558         }
559       }
560       for (auto output : outputs_replaced) {
561         output.getFirst().replaceAllUsesWith(
562             quantized_op->getResult(output.getSecond()));
563       }
564 
565       // To verify the numericals, the original floating-point ops are
566       // preserved in the graph. The result of these floating-point ops are sent
567       // to a numeric verifier op as the reference.
568       if (enable_verify && !std::is_same<VERIFIER, void>()) {
569         // For constant operands, the floating-point constant is duplicated in
570         // case it is quantized.
571         for (int i = 0, e = quantized_op->getNumOperands(); i < e; ++i) {
572           auto def = quantized_op->getOperand(i).getDefiningOp();
573           if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
574             DenseFPElementsAttr attr;
575             if (!matchPattern(q.getOperand(), m_Constant(&attr))) {
576               continue;
577             }
578             auto cst = rewriter.create<arith::ConstantOp>(
579                 quantized_op->getLoc(), attr);
580             quantizing_op->setOperand(i, cst.getResult());
581           }
582         }
583 
584         for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) {
585           if (!quantizing_op->getResult(i)
586                    .getType()
587                    .cast<ShapedType>()
588                    .getElementType()
589                    .isa<FloatType>()) {
590             continue;
591           }
592           CreateVerifier<VERIFIER>(quantizing_op, quantized_op, rewriter, i,
593                                    quant_params_);
594 
595           if (enable_whole_model_verify) {
596             RewireFloatModelBackbone(quantized_op, quantizing_op);
597           }
598         }
599       }
600     }
601     return success();
602   }
603 
604  private:
605   // Reconnects float ops in the whole-model verify mode. Works for both
606   // Quantizable ops and Unquantizable ops
RewireFloatModelBackbone(Operation * quantized_op,Operation * float_op)607   void RewireFloatModelBackbone(Operation* quantized_op,
608                                 Operation* float_op) const {
609     for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) {
610       if (!float_op->getResult(i)
611                .getType()
612                .cast<ShapedType>()
613                .getElementType()
614                .isF32()) {
615         continue;
616       }
617       // Find the Quantize/Dequantize users of the new op results, and replace
618       // the usage. Then all the floating-point ops are connected, forming a
619       // separate float "backbone" model that the quantized model can be
620       // compared against in parallel.
621       // N.B. the return op will use this floating-point result.
622       Value result;
623       if (IsOpNotQuantizable(float_op)) {
624         // For not quantizable ops, search for dequantize attached to the
625         // quantized op of the output.
626         if (Operation* quantize_op = dyn_cast_or_null<Q>(
627                 *quantized_op->getResult(i).getUsers().begin())) {
628           result = quantize_op->getResult(0);
629         } else {
630           quantize_op->emitError()
631               << "Output[" << i
632               << "] is expected to have only one user [QUANTIZE]";
633           return;
634         }
635       } else {
636         result = quantized_op->getResult(i);
637       }
638       for (auto user : result.getUsers()) {
639         // Skip the Requantize op and set the user to the following dequantize
640         // op. This happens when the quantizer tries to match the scale conflict
641         // with Q - Q(requant) - DQ op triples. The correct float op should be
642         // the user of the last DQ op.
643         if (llvm::isa<Q>(user)) {
644           user = *user->getResult(0).getUsers().begin();
645         }
646         if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
647           // Replace all uses, except not quantizable ops that are being used in
648           // the float backbone.
649           dequantize.getResult().replaceUsesWithIf(
650               float_op->getResult(i), [&](OpOperand& use) {
651                 return !use.getOwner()->hasAttr(kDebugModeOpQuantAttrName);
652               });
653         }
654       }
655     }
656   }
657 
658   QuantPassSpec quant_params_;
659 };
660 
661 // A pattern that removes debug attributes that are annotated to ops during
662 // the debug model creation.
663 class RemoveDebugAttrPattern : public RewritePattern {
664  public:
RemoveDebugAttrPattern(MLIRContext * context)665   explicit RemoveDebugAttrPattern(MLIRContext* context)
666       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
667   LogicalResult matchAndRewrite(Operation* op,
668                                 PatternRewriter& rewriter) const override;
669 };
670 
671 // Converts quantized tensor type with signed integer type to quantized tensor
672 // type with unsigned integer type.
673 Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc);
674 
675 // Converts quantize ops with unsigned quantized types to these with signed
676 // quantized types and preserves the scales.
677 template <typename Q>
678 struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
679   using BaseType = ConvertUnsignedToSigned<Q>;
680   using QType = quant::QuantizedType;
681 
ConvertUnsignedToSignedConvertUnsignedToSigned682   explicit ConvertUnsignedToSigned(MLIRContext* context)
683       : OpRewritePattern<Q>(context, 1) {}
684 
matchAndRewriteConvertUnsignedToSigned685   LogicalResult matchAndRewrite(Q op,
686                                 PatternRewriter& rewriter) const override {
687     Type output_type = op.getResult().getType();
688     auto qtype = QType::getQuantizedElementType(output_type);
689     if (!qtype || qtype.isSigned()) return failure();
690 
691     int num_bits = qtype.getStorageTypeIntegralWidth();
692     if (num_bits == 8) {
693       // If storage is 8-bit, trained num bits may be less than 8 so check here.
694       num_bits =
695           static_cast<int>(std::ceil(std::log2(qtype.getStorageTypeMax())));
696     }
697     // This is a positive value, and will be applied on zero points and fixed
698     // point ranges.
699     int64_t offset =
700         QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits) -
701         QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits);
702 
703     auto flags = quant::QuantizationFlags::Signed;
704     QType new_qtype;
705     if (auto uqtype = qtype.template dyn_cast<quant::UniformQuantizedType>()) {
706       new_qtype = quant::UniformQuantizedType::getChecked(
707           op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(),
708           uqtype.getScale(), uqtype.getZeroPoint() - offset,
709           uqtype.getStorageTypeMin() - offset,
710           uqtype.getStorageTypeMax() - offset);
711     } else if (auto aqtype = qtype.template dyn_cast<
712                              quant::UniformQuantizedPerAxisType>()) {
713       auto zero_points = aqtype.getZeroPoints();
714       llvm::SmallVector<int64_t, 4> new_zero_points(zero_points.begin(),
715                                                     zero_points.end());
716       for (int i = 0, e = new_zero_points.size(); i < e; ++i) {
717         new_zero_points[i] -= offset;
718       }
719       new_qtype = quant::UniformQuantizedPerAxisType::getChecked(
720           op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(),
721           aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(),
722           aqtype.getStorageTypeMin() - offset,
723           aqtype.getStorageTypeMax() - offset);
724     } else {
725       return failure();
726     }
727 
728     if (!new_qtype) return failure();
729     Type new_output_type = new_qtype.castFromExpressedType(
730         QType::castToExpressedType(output_type));
731     rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.getArg());
732     return success();
733   }
734 };
735 
736 // Fold Extra Requantize ops if the preceding ops has free scale requirement.
737 template <typename RQ>
738 struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
FoldTrivalRequantizeOpFoldTrivalRequantizeOp739   explicit FoldTrivalRequantizeOp(MLIRContext* context)
740       : OpRewritePattern<RQ>(context, 1) {}
741 
matchAndRewriteFoldTrivalRequantizeOp742   LogicalResult matchAndRewrite(RQ op,
743                                 PatternRewriter& rewriter) const override {
744     Value pre_quantized = op->getOperand(0);
745     auto pre_quantized_type =
746         quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
747     if (!pre_quantized_type) return failure();
748 
749     Operation* def = pre_quantized.getDefiningOp();
750     if (!def) return failure();
751     if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) ||
752         !def->hasTrait<OpTrait::quant::QuantizableResult>()) {
753       return failure();
754     }
755 
756     // This op should not clobber def, if more than one requant of this value.
757     if (!pre_quantized.hasOneUse()) {
758       return failure();
759     }
760 
761     op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
762 
763     llvm::SmallVector<Type, 4> new_output_types;
764     for (auto result : def->getResults()) {
765       if (result.hasOneUse() && *result.getUsers().begin() == op) {
766         new_output_types.push_back(op.getResult().getType());
767       } else {
768         new_output_types.push_back(result.getType());
769       }
770     }
771 
772     // Remove this rescale op.
773     rewriter.replaceOp(op, {pre_quantized});
774 
775     // Replace the output scale of the preceding op.
776     rewriter.setInsertionPointAfter(def);
777     OperationState new_state(def->getLoc(), def->getName().getStringRef(),
778                              def->getOperands(), new_output_types,
779                              def->getAttrs());
780     Operation* new_op = rewriter.create(new_state);
781 
782     rewriter.replaceOp(def, new_op->getResults());
783     return success();
784   }
785 };
786 
787 // Given a quantized type `input`, magnifying its scales by the factor stored in
788 // `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
789 // dimension size of `input` or isn't floating-point, nullptr will be returned.
790 TypeAttr RescaleQuantizedType(Type input, Attribute factor);
791 
792 // Converts the min/max/num_bits/narrow_range information to a
793 // QuantizedType, and then returns the attribute containing the QuantizedType.
794 // The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and
795 // returns UniformQuantizedType or UniformQuantizedPerAxisType respectively.
796 // `narrow_range` is set to true for weights and `is_signed` is set to true
797 // if it is using signed int symmetric quantization.
798 //
799 // Note that this method may broadcast min and max to match the dimension length
800 // of `input_type`, if the `quant_dim` is valid. On the other hand, the
801 // symmetry of min and max is not adjusted by this method. The QAT workflow
802 // should set min/max correctly (and use `narrow_range`=true, `is_signed`=true)
803 // if symmetric quantization is required.
804 TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
805                               Attribute max, int quant_dim,
806                               IntegerAttr num_bits, BoolAttr narrow_range,
807                               bool is_signed, bool legacy_float_scale = false,
808                               bool use_fake_quant_num_bits = false);
809 
810 // Casts the `target` type to a quantized type by using the quantization
811 // parameters from the type in the `source` type attribute.
812 // Examples:
813 //   f32 -> !quant.uniform<i8:f32, 1.0>
814 //   tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
815 // The result is wrapped by a type attribute. Returns nullptr if the cast
816 // isn't valid.
817 //
818 // `axis` is to specify the quantization dimension in the `target` and only
819 // used if the element type of `source` is a per-channel quantized type. During
820 // the casting, the quantization dimension of the result type needs to be set
821 // this new `axis` value.
822 TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
823                                                 TypeAttr source, Type target,
824                                                 int axis);
825 
826 // Quantizes the elements in the attribute `real_value` by the quantization
827 // parameters in `tensor_type`. Returns empty Attribute if the
828 // `tensor_type` is not a QuantizedType or the quantization fails.
829 ElementsAttr Quantize(Attribute real_value, Type tensor_type);
830 
831 // Quantizes the elements in "legacy mode", where it calls TOCO's methods to
832 // to quantize values with float scale.
833 ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type);
834 
835 // Returns the quantized type for an element attribute. The quantization
836 // parameters in this type is based on the min and max element of the
837 // attribute. When the elements in the `attr` are not in floating-point, or
838 // the value range isn't straddling zero, an empty type is returned. The min/max
839 // are adjusted to be symmetric if `symmetric` flag is set to True. And
840 // `symmetric` can only be set to true when it is signed and narrow_range.
841 Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,
842                                       unsigned num_bits, bool is_signed,
843                                       bool narrow_range,
844                                       bool legacy_float_scale = false,
845                                       bool use_fake_quant_num_bits = false);
846 
847 // Returns the per channel quantized type for an element attribute.
848 // `quant_dim` defines the quantization axis. The channel min/max are adjusted
849 // to be symmetric if `symmetric` flag is set to True. And `symmetric` can only
850 // be set to true when it is signed and narrow_range.
851 Type GetUniformQuantizedPerAxisTypeForWeight(
852     ElementsAttr attr, int quant_dim, bool symmetric, unsigned num_bits,
853     bool is_signed, bool narrow_range, bool legacy_float_scale = false,
854     bool use_fake_quant_num_bits = false);
855 
856 // Returns the quantized type of a bias input, given the quantized types of
857 // other operands which are multiply-accumulated (the bias is added to the
858 // accumulated value).
859 quant::QuantizedType GetUniformQuantizedTypeForBias(
860     const std::vector<quant::QuantizedType>& op_types,
861     bool legacy_float_scale = false);
862 
863 // Propagates quantization parameters across ops in this function and satisfy
864 // the quantization specification of the ops. This methods assumes the initial
865 // quantization parameters are stored as adjacent quantize and dequantize ops
866 // and the propagation results are materialized by inserting pairs of quantize
867 // and dequantize ops to this function. Set `disable_per_channel` to true to not
868 // use per channel quantization even the op supports it.
869 // Setting `infer_tensor_range` to true, to infer quantization parameters from
870 // the activation ops and weight constants. This is only used for post-training
871 // quantization.
872 void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed,
873                                         bool disable_per_channel,
874                                         OpQuantSpecGetter op_quant_spec_getter,
875                                         bool infer_tensor_ranges,
876                                         bool legacy_float_scale = false);
877 
878 void ApplyQuantizationParamsPropagation(
879     mlir::func::FuncOp func, bool is_signed, bool disable_per_channel,
880     OpQuantSpecGetter op_quant_spec_getter,
881     OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges,
882     bool legacy_float_scale = false);
883 
884 // Gets quantization scale specs (e.g. fixed output range, same result and
885 // operand scales) from the default quantization interfaces. The op should
886 // outlive returned spec for its interface methods to be properly referenced.
887 std::unique_ptr<OpQuantScaleSpec> GetDefaultQuantScaleSpec(Operation* op);
888 
889 // The function might contain more stats ops than required, and it will
890 // introduce requantize if the calibration stats have conflicts. This method
891 // tries to remove all the redundant stats ops.
892 bool RemoveRedundantStatsOps(mlir::func::FuncOp func,
893                              OpQuantSpecGetter op_quant_spec_getter,
894                              OpQuantScaleSpecGetter op_quant_scale_spec_getter =
895                                  GetDefaultQuantScaleSpec);
896 
897 // Given quantization parameters for int8, compute the quantization parameters
898 // for uint if it is required, and wrap the result in an UniformQuantizedType.
899 quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
900                                                 Type tensor_type, double scale,
901                                                 int64_t zero_point,
902                                                 int64_t storage_min = -128,
903                                                 int64_t storage_max = 127);
904 
905 // Extrace min and max values from the DenseFPElementsAttr, and stores them into
906 // `mins` and `maxs`. When mins and maxs are extracted per-channel, `dim_size`
907 // is number of channels and `slice_size` is the size of slice per each channel.
908 // When `symmetric` is true, the range is expanded to [-M, M].
909 void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
910                            int slice_size, bool symmetric,
911                            SmallVectorImpl<double>& mins,
912                            SmallVectorImpl<double>& maxs);
913 
914 // Returns the quantized type for the
915 // input_type/min/max/storag_type_width/narrow_range.
916 Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min,
917                       ArrayRef<double> max, int quant_dim,
918                       int storage_type_width, bool narrow_range, bool is_signed,
919                       bool legacy_float_scale = false,
920                       bool use_fake_quant_num_bits = false);
921 }  // namespace quant
922 }  // namespace mlir
923 
924 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
925