1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This header file defines common utils used by TFLite transformation
17 // passes to work with op attributes.
18
19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
20 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
21
22 #include <algorithm>
23 #include <functional>
24 #include <string>
25 #include <unordered_map>
26
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/string_view.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/Twine.h"
31 #include "llvm/Support/Casting.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
35 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
36 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
37 #include "mlir/IR/Attributes.h" // from @llvm-project
38 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
39 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
41 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
42 #include "mlir/IR/MLIRContext.h" // from @llvm-project
43 #include "mlir/IR/Matchers.h" // from @llvm-project
44 #include "mlir/IR/OpDefinition.h" // from @llvm-project
45 #include "mlir/IR/PatternMatch.h" // from @llvm-project
46 #include "mlir/Support/LLVM.h" // from @llvm-project
47 #include "mlir/Support/LogicalResult.h" // from @llvm-project
48 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
49 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
50 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
51 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
52 #include "tensorflow/core/framework/types.pb.h"
53
54 namespace mlir {
55 namespace quant {
56
57 // A unit attribute can be attached to the quantize/dequantize ops which are
58 // added by the quantization passes. These ops can be removed erased without
59 // losing accuracy.
60 constexpr char kVolatileOpAttrName[] = "volatile";
61
62 // Following attributes are used to mark ops that are not quantizable during
63 // debug model generation process for whole-model verify mode. If these
64 // attributes are attached, the upstream float/quantized ops know which ops to
65 // connect to, and it also prevents these ops from being copied again.
66 constexpr char kDebugModeOpFloatAttrName[] = "debug_float";
67 constexpr char kDebugModeOpQuantAttrName[] = "debug_quant";
68
69 // Used to annotate custom ops if they are quantizable.
70 constexpr char kQuantTraitAttrName[] = "_tfl_quant_trait";
71 enum QuantizationTrait { FullyQuantizable = 0, NotQuantizable = 1 };
72 constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable",
73 "not_quantizable"};
74
75 constexpr double kNearZeroTolerance = 1.0e-6;
76
77 using QuantParams = mlir::quant::QuantizedType;
78 using QuantSpec = QuantizationSpecs;
79 using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
80 using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
81 using AccumulatorScaleFunc =
82 std::function<QuantParams(const std::vector<QuantParams>&, bool)>;
83 using BiasParamsMap =
84 std::unordered_map<int, std::pair<std::vector<int>, AccumulatorScaleFunc>>;
85 // UniformQuantizedType GetFixedOutputRange(bool sign, int bit_width)
86 using GetFixedOutputRangeFunc = std::function<UniformQuantizedType(bool, int)>;
87 // bool RequiredSameOperandsAndResultsScale(bool sign, int $bit_width)
88 using RequiredSameOperandsAndResultsScaleFunc = std::function<bool(bool, int)>;
89 // bool RequiredSameQuantizedAxes()
90 using RequiredSameQuantizedAxesFunc = std::function<bool()>;
91
92 using StringSet = absl::flat_hash_set<std::string>;
93 using CustomMap = quant::CustomOpMap;
94
95 // Quantization spec of an op, driving the quantization algorithm.
96 struct OpQuantSpec {
97 // Maps the operand index of a bias input to its quantization specifications,
98 // including the non-bias operand indexes and the method retrieving
99 // quantization parameters from list of parameters of the non-bias operands.
100 // This map is empty if the op doesn't have a bias operand.
101 BiasParamsMap biases_params;
102
103 // Quantization parameters for value restricted outputs. This is the
104 // "hard-coded" parameters and should be used unconditionally for the
105 // quantized op. This vector is empty if the op doesn't have value restricted
106 // outputs.
107 llvm::DenseMap<SignedInteger, QuantParamsForResults> restricted_output_params;
108
109 // Coefficient operand index and whether supporting per-channel quantization.
110 // For QAT, this information is carried by the FakeQuant*/QDQ ops, but
111 // post-training quantization, the quantization parameters need to be inferred
112 // from the tensor content and op property. A "-1" value indicates the
113 // operand doesn't support per-channel quantization.
114 llvm::DenseMap<int, int> coeff_op_quant_dim;
115
116 // Indices of quantizable operands. Biases are not included in this field,
117 // the indices of biases can be found in the `biases_params`.
118 absl::flat_hash_set<int> quantizable_operands;
119 };
120
121 // Quantization scale spec of an op. The information defined in the MLIR
122 // interfaces FixedOutputRangeInterface and SameOperandsAndResultsScale should
123 // be checked first if present.
124 struct OpQuantScaleSpec {
125 // Whether this op has a fixed range requirement (e.g. sigmoid)
126 bool has_fixed_output_range = false;
127 // Whether this op should have same result and operand scales (e.g. concat)
128 bool has_same_scale_requirement = false;
129 // Returns the fixed output range, when has_fixed_output_range is set.
130 GetFixedOutputRangeFunc fixed_output_range_func;
131 // Returns whether same operands and results scales are required.
132 RequiredSameOperandsAndResultsScaleFunc required_same_scale_func =
133 [](bool sign, int bit_width) { return true; };
134 // Returns whether operands and results must have the same quantized axis.
135 RequiredSameQuantizedAxesFunc required_same_quantized_axes_func = []() {
136 return true;
137 };
138 };
139
140 // Used in TFL Numeric Verify
141 struct NumericVerifySpec {
142 // Whether to enable numeric verification
143 bool verify_numeric = false;
144
145 // Tolerance level from the quantized value for verification. If the tolerance
146 // is very small(<0.1), only the stats of the diff is displayed.
147 float error_tolerance = 5.0f;
148
149 // Whether to verify numerical correctness layer by layer or by whole model
150 bool whole_model_verify = false;
151
152 // Whether to enable log for failures
153 bool log_if_failed_flag = false;
154 };
155
156 // Used in TFL Quantize Pass
157 struct QuantPassSpec {
158 // Variables to control TFL Numeric Verify
159 NumericVerifySpec numeric_verify_spec;
160
161 // Variables related to quantization
162 QuantSpec quant_spec;
163 };
164
165 // A function signature for getting the particular OpQuantSpec for the provided
166 // op.
167 typedef std::unique_ptr<OpQuantSpec> (*OpQuantSpecGetter)(Operation* op);
168 // A function signature for getting the particular OpQuantScaleSpec for the
169 // provided op.
170 typedef std::unique_ptr<OpQuantScaleSpec> (*OpQuantScaleSpecGetter)(
171 Operation* op);
172
173 // Re-calculates scales again in float instead of simply downcasting existing
174 // scales.
175 quant::QuantizedType DownCastScale(quant::QuantizedType type,
176 const SmallVectorImpl<double>& mins,
177 const SmallVectorImpl<double>& maxs,
178 Location loc);
179
180 quant::QuantizedType DownCastScale(quant::QuantizedType type, double min,
181 double max, Location loc);
182
183 bool IsOpNotQuantizable(Operation* op);
184
185 // Specialized version of location to string for flatbuffer exported locations.
GetTensorNameFromLoc(Location loc)186 inline std::string GetTensorNameFromLoc(Location loc) {
187 if (auto name_loc = loc.dyn_cast<NameLoc>()) {
188 return name_loc.getName().str();
189 }
190 return "";
191 }
192
193 template <typename Q, typename DQ>
194 struct ConvertStatsToQDQs : public OpRewritePattern<quantfork::StatisticsOp> {
ConvertStatsToQDQsConvertStatsToQDQs195 ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed,
196 bool legacy_float_scale, MLIRContext* context)
197 : OpRewritePattern<quantfork::StatisticsOp>(context),
198 num_bits(num_bits),
199 narrow_range(narrow_range),
200 is_signed(is_signed),
201 legacy_float_scale(legacy_float_scale) {}
202
matchAndRewriteConvertStatsToQDQs203 LogicalResult matchAndRewrite(quantfork::StatisticsOp op,
204 PatternRewriter& rewriter) const override {
205 Type expressed = op.getType().cast<ShapedType>().getElementType();
206 quant::QuantizedType quant_type;
207 SmallVector<double, 4> mins, maxs;
208
209 if (op.getAxisStats().has_value()) {
210 int stats_num = op.getAxisStats()->getNumElements();
211 if (stats_num == 0 || stats_num % 2 != 0) return failure();
212 auto stats = op.getAxisStats()->dyn_cast<DenseFPElementsAttr>();
213 if (!stats) return failure();
214
215 for (auto it = stats.begin(), e = stats.end(); it != e; ++it) {
216 double rmin = FloatAttr::getValueAsDouble(*it++);
217 double rmax = FloatAttr::getValueAsDouble(*it);
218 // The default nudging implementation of mlir quant library might cause
219 // clamping during inference if the calibration range isn't wide enough.
220 // So here we adjust the range to include 0.0.
221 rmin = std::min(rmin, 0.0);
222 rmax = std::max(rmax, 0.0);
223 TensorRangeSanityCheck(op, rmin, rmax);
224 mins.push_back(rmin);
225 maxs.push_back(rmax);
226 }
227 quant_type = quantfork::fakeQuantAttrsToType(
228 op.getLoc(), num_bits, *op.getAxis(), mins, maxs, narrow_range,
229 expressed, is_signed);
230 if (legacy_float_scale) {
231 quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc());
232 }
233 } else if (auto stats =
234 op.getLayerStats().dyn_cast<DenseFPElementsAttr>()) {
235 auto statValues = stats.getValues<APFloat>();
236 double rmin = FloatAttr::getValueAsDouble(statValues[0]);
237 double rmax = FloatAttr::getValueAsDouble(statValues[1]);
238 // The default nudging implementation of mlir quant library might cause
239 // clamping during inference if the calibration range isn't wide enough.
240 // So here we adjust the range to include 0.0.
241 rmin = std::min(rmin, 0.0);
242 rmax = std::max(rmax, 0.0);
243 TensorRangeSanityCheck(op, rmin, rmax);
244 quant_type =
245 quantfork::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax,
246 narrow_range, expressed, is_signed);
247 if (legacy_float_scale) {
248 quant_type = DownCastScale(quant_type, rmin, rmax, op->getLoc());
249 }
250 } else {
251 return failure();
252 }
253
254 rewriter.setInsertionPointAfter(op.getOperation());
255 Type result_type = quant_type.castFromExpressedType(op.getType());
256 auto q = rewriter.create<Q>(op.getLoc(), result_type, op.getArg());
257 q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr());
258
259 auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
260 op.getResult().replaceAllUsesWith(dq);
261 q.getOperation()->replaceUsesOfWith(dq, op.getArg());
262 op.erase();
263
264 return success();
265 }
266
267 private:
268 int num_bits;
269 bool narrow_range;
270 bool is_signed;
271 bool legacy_float_scale;
272
273 // Emits an op warning message if the calibrated range is larger than 10.0 and
274 // the storage type is less than or equal to 8 bits.
TensorRangeSanityCheckConvertStatsToQDQs275 void TensorRangeSanityCheck(quantfork::StatisticsOp op, double& min,
276 double& max) const {
277 double range = std::fabs(max - min);
278 if (num_bits <= 8 && range >= 10.0) {
279 op.emitWarning()
280 << "Tensor range is too wide to be quantized. Use tf.clip_by_value "
281 "or tf.relu6 to narrow the tensor range. Range: "
282 << range << ", bit width: " << num_bits;
283 }
284 if (std::abs(max - min) < kNearZeroTolerance) {
285 op.emitWarning() << "Tensor range (" << min << ", " << max
286 << ") is too narrow and it might cause overflow. "
287 "Expanding range symmetrically by "
288 << kNearZeroTolerance;
289 min -= kNearZeroTolerance;
290 max += kNearZeroTolerance;
291 }
292 }
293 };
294
295 template <typename VerifierT>
UsedBy(Operation * op)296 bool UsedBy(Operation* op) {
297 for (Operation* user : op->getUsers()) {
298 if (llvm::isa_and_nonnull<VerifierT>(user)) return true;
299 }
300 return false;
301 }
302
303 template <typename VerifierT>
CreateVerifier(Operation * quantizing_op,Operation * quantized_op,PatternRewriter & rewriter,int result_idx,const QuantPassSpec & quant_params)304 void CreateVerifier(Operation* quantizing_op, Operation* quantized_op,
305 PatternRewriter& rewriter, int result_idx,
306 const QuantPassSpec& quant_params) {
307 rewriter.setInsertionPointAfter(quantized_op);
308 FloatAttr tolerance = rewriter.getF32FloatAttr(
309 quant_params.numeric_verify_spec.error_tolerance);
310 BoolAttr log =
311 rewriter.getBoolAttr(quant_params.numeric_verify_spec.log_if_failed_flag);
312 // Verify the quantized value by sending the result to the verifier.
313 rewriter.create<VerifierT>(
314 quantizing_op->getLoc(), quantized_op->getResult(result_idx).getType(),
315 quantized_op->getResult(result_idx), quantizing_op->getResult(result_idx),
316 tolerance, log);
317 }
318
319 template <>
320 inline bool UsedBy<void>(Operation* op) {
321 return false;
322 }
323
324 // This specialization is not going to be called, but needed for compilation.
325 template <>
326 inline void CreateVerifier<void>(Operation* quantizing_op,
327 Operation* quantized_op,
328 PatternRewriter& rewriter, int result_idx,
329 const QuantPassSpec& quant_params) {}
330
331 // A base rewrite pattern which matches any N-in-M-out operations with
332 // quantization parameters propagated to at least one of its operands. The
333 // quantization parameters are annotated by the Q/DQ op pairs. Each
334 // matched pattern are rewritten by its quantized alternatives.
335 //
336 // The concrete pattern, extends from this base pattern, can specify whether it
337 // allows dynamic range quantized operands and results for the operations in the
338 // current context. These "DynamicRangeQuantized" operands and results don't
339 // have quantization parameters propagated to, so will be in float in the
340 // quantized results. The concrete pattern should define the following two
341 // functions:
342 //
343 // bool AllowDynamicRangeQuantizedOperand(Operation *) const
344 // bool AllowDynamicRangeQuantizedResult(Operation *) const
345 //
346 // Full integer quantization disallows "DynamicRangeQuantized" operands or
347 // results. Dynamic range quantization allows "DynamicRangeQuantized" operands
348 // and results.
349 template <typename ConcretTy, typename Q, typename DQ, typename VERIFIER,
350 typename RootOp = DQ>
351 class QuantizationPattern : public RewritePattern {
352 public:
353 using BaseType = QuantizationPattern<ConcretTy, Q, DQ, VERIFIER, RootOp>;
354
QuantizationPattern(MLIRContext * context,const QuantPassSpec & quant_params)355 explicit QuantizationPattern(MLIRContext* context,
356 const QuantPassSpec& quant_params)
357 // Set the score to a large number so it is always preferred.
358 : RewritePattern(RootOp::getOperationName(), 300, context),
359 quant_params_(quant_params) {}
360
matchAndRewrite(Operation * op,PatternRewriter & rewriter)361 LogicalResult matchAndRewrite(Operation* op,
362 PatternRewriter& rewriter) const override {
363 llvm::SmallVector<Operation*, 4> quantizing_ops;
364
365 // Collect all the ops to quantize, as the user / producer of the root op.
366 if (std::is_same<RootOp, DQ>::value) {
367 if (op->getNumResults() != 1) {
368 return failure();
369 }
370 auto users = op->getResult(0).getUsers();
371 quantizing_ops.append(users.begin(), users.end());
372 } else if (std::is_same<RootOp, Q>::value) {
373 if (op->getNumOperands() != 1) {
374 return failure();
375 }
376 Value quantize_operand = op->getOperand(0);
377 if (QuantizedType::getQuantizedElementType(quantize_operand.getType())) {
378 // The input of this Q op has already been quantized, i.e. rescale.
379 return failure();
380 }
381 DenseFPElementsAttr attr;
382 if (matchPattern(quantize_operand, m_Constant(&attr))) {
383 // Const->Q pattern will be handled separately.
384 return failure();
385 }
386 if (Operation* quantizing_op = quantize_operand.getDefiningOp()) {
387 quantizing_ops.push_back(quantizing_op);
388 }
389 }
390
391 tensorflow::DataType inference_type =
392 quant_params_.quant_spec.inference_type;
393 bool weight_only_quantization =
394 quant_params_.quant_spec.weight_only_quantization;
395 bool enable_verify = quant_params_.numeric_verify_spec.verify_numeric;
396 bool enable_whole_model_verify =
397 quant_params_.numeric_verify_spec.whole_model_verify;
398 StringSet ops_blocklist = quant_params_.quant_spec.ops_blocklist;
399 StringSet nodes_blocklist = quant_params_.quant_spec.nodes_blocklist;
400 CustomMap custom_map = quant_params_.quant_spec.custom_map;
401
402 // Rewrite the floating-point ops to the quantized version, by fusing
403 // preceding dequantize ops and succeding quantize ops.
404 for (Operation* quantizing_op : quantizing_ops) {
405 // If it is requantize op, we shouldn't rewrite this op.
406 if (llvm::isa<Q, DQ>(quantizing_op)) {
407 return failure();
408 }
409
410 // If the op is terminator, not quantizable or any ops from the mlir quant
411 // ops dialect, we shouldn't rewrite. In case of whole-model verify debug
412 // mode, not-quantizable ops should be duplicated to keep parallel
413 // float/quant model execution.
414 if (quantizing_op->hasTrait<OpTrait::IsTerminator>()) {
415 return failure();
416 }
417
418 if (IsOpNotQuantizable(quantizing_op) &&
419 !static_cast<const ConcretTy*>(this)->IsQuantizableCustomOp(
420 quantizing_op, custom_map)) {
421 if (!(enable_verify && enable_whole_model_verify)) {
422 return failure();
423 }
424 if (quantizing_op->hasAttr(kDebugModeOpQuantAttrName) ||
425 quantizing_op->hasAttr(kDebugModeOpFloatAttrName)) {
426 return failure();
427 }
428
429 rewriter.setInsertionPoint(quantizing_op);
430 Operation* float_op = rewriter.clone(*quantizing_op);
431 quantizing_op->setAttr(kDebugModeOpQuantAttrName,
432 rewriter.getUnitAttr());
433 float_op->setAttr(kDebugModeOpFloatAttrName, rewriter.getUnitAttr());
434 RewireFloatModelBackbone(quantizing_op, float_op);
435 return success();
436 }
437
438 // Blocklist op is checked in advance for non-dynamic range quantization
439 // case.
440 if (!quant_params_.quant_spec.weight_quantization &&
441 (ops_blocklist.find(quantizing_op->getName().getStringRef().str()) !=
442 ops_blocklist.end())) {
443 return failure();
444 }
445
446 if (!nodes_blocklist.empty()) {
447 if (auto name_loc = quantizing_op->getLoc().dyn_cast<NameLoc>()) {
448 std::string sloc = name_loc.getName().str();
449 if (!sloc.empty() &&
450 (nodes_blocklist.find(sloc) != nodes_blocklist.end())) {
451 return failure();
452 }
453 }
454 }
455
456 // An op with float inputs and outputs are expected when it's used by a
457 // NumericVerify op. Skip this op.
458 if (enable_verify && UsedBy<VERIFIER>(quantizing_op)) {
459 continue;
460 }
461
462 // Collect all the quantized inputs and "clone" the matched op by these
463 // inputs.
464 SmallVector<Value, 4> inputs;
465 inputs.reserve(quantizing_op->getNumOperands());
466 for (auto operand : quantizing_op->getOperands()) {
467 Type operand_type = operand.getType();
468 if (operand_type.isa<NoneType>()) {
469 inputs.push_back(operand);
470 continue;
471 }
472
473 auto ele_type = operand.getType().cast<TensorType>().getElementType();
474 if (static_cast<const ConcretTy*>(this)
475 ->AllowDynamicRangeQuantizedOperand(quantizing_op,
476 custom_map)) {
477 auto dq_op = dyn_cast_or_null<DQ>(operand.getDefiningOp());
478
479 if (dq_op && inference_type == tensorflow::DT_QINT8 &&
480 !static_cast<const ConcretTy*>(this)->IsWeightOnlyOp(
481 quantizing_op, ops_blocklist, weight_only_quantization,
482 custom_map)) {
483 // Dynamic range quantization is applied by having Q as an input.
484 // Only int8 weight is supported for now.
485 inputs.push_back(dq_op.getOperand());
486 } else {
487 // Otherwise, it's the case where the operand is activations or the
488 // quantizing_op is non-supported/weight-only.
489 inputs.push_back(operand);
490 }
491 } else {
492 if (auto dq_op = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
493 inputs.push_back(dq_op.getOperand());
494 } else if (!ele_type.isF32()) {
495 // If the operand is an integer tensor, then it doesn't require the
496 // DQ op in the pattern.
497 inputs.push_back(operand);
498 } else {
499 return failure();
500 }
501 }
502 }
503
504 // Collect all the quantized outputs and replace them by the results of
505 // the new quantized op.
506 llvm::SmallDenseMap<Value, int> outputs_replaced;
507 SmallVector<Type, 4> output_types;
508 output_types.reserve(quantizing_op->getNumResults());
509 for (const auto& enumerated_result :
510 llvm::enumerate(quantizing_op->getResults())) {
511 Value result = enumerated_result.value();
512 Type result_type = result.getType();
513 // Add this to the test coverage once we create test ops with none type
514 // results.
515 if (result_type.isa<NoneType>()) {
516 outputs_replaced.insert({result, enumerated_result.index()});
517 output_types.push_back(result_type);
518 continue;
519 }
520 Type result_ele_type =
521 result.getType().cast<TensorType>().getElementType();
522 // If the user is the Quantize op, it must be the only user.
523 if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) {
524 auto user = llvm::cast<Q>(*result.user_begin());
525 outputs_replaced.insert(
526 {user.getResult(), enumerated_result.index()});
527 output_types.push_back(user.getType());
528 } else if (!result_ele_type.isF32()) {
529 // If the result is an integer tensor, then it doesn't require the
530 // D op in the pattern.
531 outputs_replaced.insert({result, enumerated_result.index()});
532 output_types.push_back(result.getType());
533 } else if (static_cast<const ConcretTy*>(this)
534 ->AllowDynamicRangeQuantizedResult(quantizing_op,
535 custom_map)) {
536 outputs_replaced.insert({result, enumerated_result.index()});
537 output_types.push_back(result.getType());
538 } else {
539 return failure();
540 }
541 }
542
543 rewriter.setInsertionPointAfter(quantizing_op);
544 OperationState new_state(quantizing_op->getLoc(),
545 quantizing_op->getName().getStringRef(), inputs,
546 output_types, quantizing_op->getAttrs());
547 for (int i = 0; i < quantizing_op->getNumRegions(); ++i) {
548 new_state.addRegion();
549 }
550 Operation* quantized_op = rewriter.create(new_state);
551 if (quantizing_op->getNumRegions() != 0) {
552 for (const auto& indexed_regions :
553 llvm::enumerate(quantizing_op->getRegions())) {
554 Region& target_region =
555 quantized_op->getRegion(indexed_regions.index());
556 BlockAndValueMapping mapping;
557 indexed_regions.value().cloneInto(&target_region, mapping);
558 }
559 }
560 for (auto output : outputs_replaced) {
561 output.getFirst().replaceAllUsesWith(
562 quantized_op->getResult(output.getSecond()));
563 }
564
565 // To verify the numericals, the original floating-point ops are
566 // preserved in the graph. The result of these floating-point ops are sent
567 // to a numeric verifier op as the reference.
568 if (enable_verify && !std::is_same<VERIFIER, void>()) {
569 // For constant operands, the floating-point constant is duplicated in
570 // case it is quantized.
571 for (int i = 0, e = quantized_op->getNumOperands(); i < e; ++i) {
572 auto def = quantized_op->getOperand(i).getDefiningOp();
573 if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
574 DenseFPElementsAttr attr;
575 if (!matchPattern(q.getOperand(), m_Constant(&attr))) {
576 continue;
577 }
578 auto cst = rewriter.create<arith::ConstantOp>(
579 quantized_op->getLoc(), attr);
580 quantizing_op->setOperand(i, cst.getResult());
581 }
582 }
583
584 for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) {
585 if (!quantizing_op->getResult(i)
586 .getType()
587 .cast<ShapedType>()
588 .getElementType()
589 .isa<FloatType>()) {
590 continue;
591 }
592 CreateVerifier<VERIFIER>(quantizing_op, quantized_op, rewriter, i,
593 quant_params_);
594
595 if (enable_whole_model_verify) {
596 RewireFloatModelBackbone(quantized_op, quantizing_op);
597 }
598 }
599 }
600 }
601 return success();
602 }
603
604 private:
605 // Reconnects float ops in the whole-model verify mode. Works for both
606 // Quantizable ops and Unquantizable ops
RewireFloatModelBackbone(Operation * quantized_op,Operation * float_op)607 void RewireFloatModelBackbone(Operation* quantized_op,
608 Operation* float_op) const {
609 for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) {
610 if (!float_op->getResult(i)
611 .getType()
612 .cast<ShapedType>()
613 .getElementType()
614 .isF32()) {
615 continue;
616 }
617 // Find the Quantize/Dequantize users of the new op results, and replace
618 // the usage. Then all the floating-point ops are connected, forming a
619 // separate float "backbone" model that the quantized model can be
620 // compared against in parallel.
621 // N.B. the return op will use this floating-point result.
622 Value result;
623 if (IsOpNotQuantizable(float_op)) {
624 // For not quantizable ops, search for dequantize attached to the
625 // quantized op of the output.
626 if (Operation* quantize_op = dyn_cast_or_null<Q>(
627 *quantized_op->getResult(i).getUsers().begin())) {
628 result = quantize_op->getResult(0);
629 } else {
630 quantize_op->emitError()
631 << "Output[" << i
632 << "] is expected to have only one user [QUANTIZE]";
633 return;
634 }
635 } else {
636 result = quantized_op->getResult(i);
637 }
638 for (auto user : result.getUsers()) {
639 // Skip the Requantize op and set the user to the following dequantize
640 // op. This happens when the quantizer tries to match the scale conflict
641 // with Q - Q(requant) - DQ op triples. The correct float op should be
642 // the user of the last DQ op.
643 if (llvm::isa<Q>(user)) {
644 user = *user->getResult(0).getUsers().begin();
645 }
646 if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
647 // Replace all uses, except not quantizable ops that are being used in
648 // the float backbone.
649 dequantize.getResult().replaceUsesWithIf(
650 float_op->getResult(i), [&](OpOperand& use) {
651 return !use.getOwner()->hasAttr(kDebugModeOpQuantAttrName);
652 });
653 }
654 }
655 }
656 }
657
658 QuantPassSpec quant_params_;
659 };
660
661 // A pattern that removes debug attributes that are annotated to ops during
662 // the debug model creation.
663 class RemoveDebugAttrPattern : public RewritePattern {
664 public:
RemoveDebugAttrPattern(MLIRContext * context)665 explicit RemoveDebugAttrPattern(MLIRContext* context)
666 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
667 LogicalResult matchAndRewrite(Operation* op,
668 PatternRewriter& rewriter) const override;
669 };
670
671 // Converts quantized tensor type with signed integer type to quantized tensor
672 // type with unsigned integer type.
673 Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc);
674
675 // Converts quantize ops with unsigned quantized types to these with signed
676 // quantized types and preserves the scales.
677 template <typename Q>
678 struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
679 using BaseType = ConvertUnsignedToSigned<Q>;
680 using QType = quant::QuantizedType;
681
ConvertUnsignedToSignedConvertUnsignedToSigned682 explicit ConvertUnsignedToSigned(MLIRContext* context)
683 : OpRewritePattern<Q>(context, 1) {}
684
matchAndRewriteConvertUnsignedToSigned685 LogicalResult matchAndRewrite(Q op,
686 PatternRewriter& rewriter) const override {
687 Type output_type = op.getResult().getType();
688 auto qtype = QType::getQuantizedElementType(output_type);
689 if (!qtype || qtype.isSigned()) return failure();
690
691 int num_bits = qtype.getStorageTypeIntegralWidth();
692 if (num_bits == 8) {
693 // If storage is 8-bit, trained num bits may be less than 8 so check here.
694 num_bits =
695 static_cast<int>(std::ceil(std::log2(qtype.getStorageTypeMax())));
696 }
697 // This is a positive value, and will be applied on zero points and fixed
698 // point ranges.
699 int64_t offset =
700 QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits) -
701 QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits);
702
703 auto flags = quant::QuantizationFlags::Signed;
704 QType new_qtype;
705 if (auto uqtype = qtype.template dyn_cast<quant::UniformQuantizedType>()) {
706 new_qtype = quant::UniformQuantizedType::getChecked(
707 op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(),
708 uqtype.getScale(), uqtype.getZeroPoint() - offset,
709 uqtype.getStorageTypeMin() - offset,
710 uqtype.getStorageTypeMax() - offset);
711 } else if (auto aqtype = qtype.template dyn_cast<
712 quant::UniformQuantizedPerAxisType>()) {
713 auto zero_points = aqtype.getZeroPoints();
714 llvm::SmallVector<int64_t, 4> new_zero_points(zero_points.begin(),
715 zero_points.end());
716 for (int i = 0, e = new_zero_points.size(); i < e; ++i) {
717 new_zero_points[i] -= offset;
718 }
719 new_qtype = quant::UniformQuantizedPerAxisType::getChecked(
720 op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(),
721 aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(),
722 aqtype.getStorageTypeMin() - offset,
723 aqtype.getStorageTypeMax() - offset);
724 } else {
725 return failure();
726 }
727
728 if (!new_qtype) return failure();
729 Type new_output_type = new_qtype.castFromExpressedType(
730 QType::castToExpressedType(output_type));
731 rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.getArg());
732 return success();
733 }
734 };
735
736 // Fold Extra Requantize ops if the preceding ops has free scale requirement.
737 template <typename RQ>
738 struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
FoldTrivalRequantizeOpFoldTrivalRequantizeOp739 explicit FoldTrivalRequantizeOp(MLIRContext* context)
740 : OpRewritePattern<RQ>(context, 1) {}
741
matchAndRewriteFoldTrivalRequantizeOp742 LogicalResult matchAndRewrite(RQ op,
743 PatternRewriter& rewriter) const override {
744 Value pre_quantized = op->getOperand(0);
745 auto pre_quantized_type =
746 quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
747 if (!pre_quantized_type) return failure();
748
749 Operation* def = pre_quantized.getDefiningOp();
750 if (!def) return failure();
751 if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) ||
752 !def->hasTrait<OpTrait::quant::QuantizableResult>()) {
753 return failure();
754 }
755
756 // This op should not clobber def, if more than one requant of this value.
757 if (!pre_quantized.hasOneUse()) {
758 return failure();
759 }
760
761 op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
762
763 llvm::SmallVector<Type, 4> new_output_types;
764 for (auto result : def->getResults()) {
765 if (result.hasOneUse() && *result.getUsers().begin() == op) {
766 new_output_types.push_back(op.getResult().getType());
767 } else {
768 new_output_types.push_back(result.getType());
769 }
770 }
771
772 // Remove this rescale op.
773 rewriter.replaceOp(op, {pre_quantized});
774
775 // Replace the output scale of the preceding op.
776 rewriter.setInsertionPointAfter(def);
777 OperationState new_state(def->getLoc(), def->getName().getStringRef(),
778 def->getOperands(), new_output_types,
779 def->getAttrs());
780 Operation* new_op = rewriter.create(new_state);
781
782 rewriter.replaceOp(def, new_op->getResults());
783 return success();
784 }
785 };
786
787 // Given a quantized type `input`, magnifying its scales by the factor stored in
788 // `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
789 // dimension size of `input` or isn't floating-point, nullptr will be returned.
790 TypeAttr RescaleQuantizedType(Type input, Attribute factor);
791
792 // Converts the min/max/num_bits/narrow_range information to a
793 // QuantizedType, and then returns the attribute containing the QuantizedType.
794 // The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and
795 // returns UniformQuantizedType or UniformQuantizedPerAxisType respectively.
796 // `narrow_range` is set to true for weights and `is_signed` is set to true
797 // if it is using signed int symmetric quantization.
798 //
799 // Note that this method may broadcast min and max to match the dimension length
800 // of `input_type`, if the `quant_dim` is valid. On the other hand, the
801 // symmetry of min and max is not adjusted by this method. The QAT workflow
802 // should set min/max correctly (and use `narrow_range`=true, `is_signed`=true)
803 // if symmetric quantization is required.
804 TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
805 Attribute max, int quant_dim,
806 IntegerAttr num_bits, BoolAttr narrow_range,
807 bool is_signed, bool legacy_float_scale = false,
808 bool use_fake_quant_num_bits = false);
809
810 // Casts the `target` type to a quantized type by using the quantization
811 // parameters from the type in the `source` type attribute.
812 // Examples:
813 // f32 -> !quant.uniform<i8:f32, 1.0>
814 // tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
815 // The result is wrapped by a type attribute. Returns nullptr if the cast
816 // isn't valid.
817 //
818 // `axis` is to specify the quantization dimension in the `target` and only
819 // used if the element type of `source` is a per-channel quantized type. During
820 // the casting, the quantization dimension of the result type needs to be set
821 // this new `axis` value.
822 TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
823 TypeAttr source, Type target,
824 int axis);
825
826 // Quantizes the elements in the attribute `real_value` by the quantization
827 // parameters in `tensor_type`. Returns empty Attribute if the
828 // `tensor_type` is not a QuantizedType or the quantization fails.
829 ElementsAttr Quantize(Attribute real_value, Type tensor_type);
830
831 // Quantizes the elements in "legacy mode", where it calls TOCO's methods to
832 // to quantize values with float scale.
833 ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type);
834
835 // Returns the quantized type for an element attribute. The quantization
836 // parameters in this type is based on the min and max element of the
837 // attribute. When the elements in the `attr` are not in floating-point, or
838 // the value range isn't straddling zero, an empty type is returned. The min/max
839 // are adjusted to be symmetric if `symmetric` flag is set to True. And
840 // `symmetric` can only be set to true when it is signed and narrow_range.
841 Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,
842 unsigned num_bits, bool is_signed,
843 bool narrow_range,
844 bool legacy_float_scale = false,
845 bool use_fake_quant_num_bits = false);
846
847 // Returns the per channel quantized type for an element attribute.
848 // `quant_dim` defines the quantization axis. The channel min/max are adjusted
849 // to be symmetric if `symmetric` flag is set to True. And `symmetric` can only
850 // be set to true when it is signed and narrow_range.
851 Type GetUniformQuantizedPerAxisTypeForWeight(
852 ElementsAttr attr, int quant_dim, bool symmetric, unsigned num_bits,
853 bool is_signed, bool narrow_range, bool legacy_float_scale = false,
854 bool use_fake_quant_num_bits = false);
855
856 // Returns the quantized type of a bias input, given the quantized types of
857 // other operands which are multiply-accumulated (the bias is added to the
858 // accumulated value).
859 quant::QuantizedType GetUniformQuantizedTypeForBias(
860 const std::vector<quant::QuantizedType>& op_types,
861 bool legacy_float_scale = false);
862
863 // Propagates quantization parameters across ops in this function and satisfy
864 // the quantization specification of the ops. This methods assumes the initial
865 // quantization parameters are stored as adjacent quantize and dequantize ops
866 // and the propagation results are materialized by inserting pairs of quantize
867 // and dequantize ops to this function. Set `disable_per_channel` to true to not
868 // use per channel quantization even the op supports it.
869 // Setting `infer_tensor_range` to true, to infer quantization parameters from
870 // the activation ops and weight constants. This is only used for post-training
871 // quantization.
872 void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed,
873 bool disable_per_channel,
874 OpQuantSpecGetter op_quant_spec_getter,
875 bool infer_tensor_ranges,
876 bool legacy_float_scale = false);
877
878 void ApplyQuantizationParamsPropagation(
879 mlir::func::FuncOp func, bool is_signed, bool disable_per_channel,
880 OpQuantSpecGetter op_quant_spec_getter,
881 OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges,
882 bool legacy_float_scale = false);
883
884 // Gets quantization scale specs (e.g. fixed output range, same result and
885 // operand scales) from the default quantization interfaces. The op should
886 // outlive returned spec for its interface methods to be properly referenced.
887 std::unique_ptr<OpQuantScaleSpec> GetDefaultQuantScaleSpec(Operation* op);
888
889 // The function might contain more stats ops than required, and it will
890 // introduce requantize if the calibration stats have conflicts. This method
891 // tries to remove all the redundant stats ops.
892 bool RemoveRedundantStatsOps(mlir::func::FuncOp func,
893 OpQuantSpecGetter op_quant_spec_getter,
894 OpQuantScaleSpecGetter op_quant_scale_spec_getter =
895 GetDefaultQuantScaleSpec);
896
897 // Given quantization parameters for int8, compute the quantization parameters
898 // for uint if it is required, and wrap the result in an UniformQuantizedType.
899 quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
900 Type tensor_type, double scale,
901 int64_t zero_point,
902 int64_t storage_min = -128,
903 int64_t storage_max = 127);
904
905 // Extrace min and max values from the DenseFPElementsAttr, and stores them into
906 // `mins` and `maxs`. When mins and maxs are extracted per-channel, `dim_size`
907 // is number of channels and `slice_size` is the size of slice per each channel.
908 // When `symmetric` is true, the range is expanded to [-M, M].
909 void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
910 int slice_size, bool symmetric,
911 SmallVectorImpl<double>& mins,
912 SmallVectorImpl<double>& maxs);
913
914 // Returns the quantized type for the
915 // input_type/min/max/storag_type_width/narrow_range.
916 Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min,
917 ArrayRef<double> max, int quant_dim,
918 int storage_type_width, bool narrow_range, bool is_signed,
919 bool legacy_float_scale = false,
920 bool use_fake_quant_num_bits = false);
921 } // namespace quant
922 } // namespace mlir
923
924 #endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
925