1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass applies quantization propagation on TFLite dialect.
17 #include <iterator>
18 #include <string>
19 #include <utility>
20
21 #include "absl/memory/memory.h"
22 #include "llvm/ADT/Optional.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/MathExtras.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
31 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
32 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/MLIRContext.h" // from @llvm-project
36 #include "mlir/IR/Operation.h" // from @llvm-project
37 #include "mlir/IR/PatternMatch.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
43 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
44 #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
45 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
46 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
47 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
48 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
49 #include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/lib/monitoring/counter.h"
52
53 //===----------------------------------------------------------------------===//
54 // The prepare-quantize Pass.
55 //
56 namespace mlir {
57 namespace TFL {
58
59 namespace {
60 #define GEN_PASS_CLASSES
61 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
62
63 auto* tflite_quantizer_usage_stats = tensorflow::monitoring::Counter<1>::New(
64 "/tensorflow/lite/quantization/transforms/stats",
65 "The number of quantization pass invocations.", "path");
66
67 // Applies prepare quantization on the model in TFL dialect. This pass runs
68 // before the quantization pass and propagate the quantization parameters
69 // across ops. This step is necessary for post-training quantization and also
70 // making the quantization rule for some operations in the quantization-aware
71 // training quantization simpler.
72 class PrepareQuantizePass
73 : public PrepareQuantizePassBase<PrepareQuantizePass> {
74 public:
75 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareQuantizePass)
76
77 // Constructor used by the PassRegistration and enforce uint8 quantization.
78 // This is only used by test.
PrepareQuantizePass()79 explicit PrepareQuantizePass() : use_quantization_flags_(true) {}
80
81 // Constructor used by manually creating the pass.
PrepareQuantizePass(const quant::QuantizationSpecs & quant_specs)82 explicit PrepareQuantizePass(const quant::QuantizationSpecs& quant_specs)
83 : use_quantization_flags_(false), quant_specs_(quant_specs) {}
84
85 void runOnOperation() override;
86
87 private:
88 // Set the quantization parameters of the input nodes. These parameters are
89 // converted from the user specified input value ranges. The input nodes with
90 // non-float tensor types will be skipped because they are not quantizable.
91 // Return true if number of input nodes doesn't equal to that of the input
92 // ranges.
93 bool SetInputNodesQuantizationParams(func::FuncOp func);
94
95 // The function might contain more stats ops than required, and it will
96 // introduce requantize if the calibration stats have conflicts. This method
97 // tries to remove all the redundant stats ops.
98 bool RemoveRedundantStats(func::FuncOp func);
99
100 // Verify the quantization specification is expected for quantizing the
101 // current function.
IsLegalQuantSpecs(func::FuncOp func)102 bool IsLegalQuantSpecs(func::FuncOp func) {
103 if (func.getName() == quant_specs_.target_func) {
104 return (quant_specs_.disable_set_input_nodes_quantization_params ||
105 func.getNumArguments() == quant_specs_.input_ranges.size());
106 }
107 return true;
108 }
109
110 // Get the min and max values from the quantization specification for the
111 // current function and argument index. Uses default values if the function
112 // is specified in the `quantize_allowlist`.
113 std::pair<llvm::Optional<double>, llvm::Optional<double>>
GetMinMaxValuesForArgument(llvm::StringRef func_name,int index)114 GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
115 if (func_name == quant_specs_.target_func) {
116 return quant_specs_.input_ranges[index];
117 } else {
118 return {0.0, 255.0};
119 }
120 }
121
122 // Apply some sanity check and report some warnings for those who don't follow
123 // the best quantization practice. This also fixes some simple violations.
124 void SanityCheckAndAdjustment(func::FuncOp func);
125
126 // Whether the func contains Quantize ops. This is used to determine whether
127 // to use the quantization parameters from the fixed output range property.
128 bool ContainsQuantizeOps(func::FuncOp func);
129
130 bool use_quantization_flags_;
131 quant::QuantizationSpecs quant_specs_;
132 };
133
SetInputNodesQuantizationParams(func::FuncOp func)134 bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) {
135 if (quant_specs_.disable_set_input_nodes_quantization_params) {
136 return false;
137 }
138
139 StringRef func_name = func.getName();
140 auto& target_func = quant_specs_.target_func;
141 // Skip this function because it isn't the target function from the spec or
142 // in the function while list.
143 if (target_func != func_name &&
144 !llvm::is_contained(quantize_allowlist_, func_name)) {
145 return false;
146 }
147 auto has_quantize_op = [&](const Value arg) {
148 return (arg.hasOneUse() &&
149 llvm::isa<quantfork::QuantizeCastOp>(*arg.user_begin()));
150 };
151
152 bool need_to_set_input_nodes_quantization_params = false;
153 for (const BlockArgument arg : func.getArguments()) {
154 auto shaped = arg.getType().dyn_cast<ShapedType>();
155 if (shaped && shaped.getElementType().isa<FloatType>() &&
156 !has_quantize_op(arg)) {
157 need_to_set_input_nodes_quantization_params = true;
158 break;
159 }
160 }
161
162 if (!need_to_set_input_nodes_quantization_params) {
163 return false;
164 }
165
166 // If the validation fails, the pass should stop immediately.
167 if (!IsLegalQuantSpecs(func)) {
168 return true;
169 }
170
171 OpBuilder builder(func);
172 bool is_signed = quant_specs_.IsSignedInferenceType();
173 IntegerAttr num_bits =
174 builder.getI32IntegerAttr(quant_specs_.GetQuantizationTypeWidth());
175 BoolAttr narrow_range = builder.getBoolAttr(false);
176
177 auto add_quantize_op = [&](Location loc, Type input_type, Block* block,
178 Block::iterator insertion_point, Value arg,
179 int i) {
180 if (auto shaped = input_type.dyn_cast<ShapedType>()) {
181 if (shaped.getElementType().isa<FloatType>()) {
182 // If there are existing quantize ops, they are from training and we
183 // should respect them.
184 if (has_quantize_op(arg)) {
185 return;
186 }
187
188 auto min_max = GetMinMaxValuesForArgument(func_name, i);
189 // The input min/max or mean/std are not specified, then skip.
190 if (!min_max.first.has_value() || !min_max.second.has_value()) return;
191
192 TypeAttr params = quant::GetQuantizedTypeAttr(
193 builder, input_type,
194 builder.getF64FloatAttr(min_max.first.getValue()),
195 builder.getF64FloatAttr(min_max.second.getValue()),
196 /*quant_dim=*/-1, num_bits, narrow_range, is_signed);
197 builder.setInsertionPoint(block, insertion_point);
198 auto q_op = builder.create<quantfork::QuantizeCastOp>(
199 loc, params.getValue(), arg);
200 auto dq_op = builder.create<quantfork::DequantizeCastOp>(
201 loc, input_type, q_op.getResult());
202 arg.replaceAllUsesWith(dq_op.getResult());
203 q_op.setOperand(arg);
204 }
205 }
206 };
207
208 for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
209 BlockArgument arg = func.getArgument(i);
210 auto* arg_block = arg.getOwner();
211 add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
212 std::next(arg_block->begin(), i), arg, i);
213 }
214
215 return false;
216 }
217
218 #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
219
RemoveRedundantStats(func::FuncOp func)220 bool PrepareQuantizePass::RemoveRedundantStats(func::FuncOp func) {
221 return RemoveRedundantStatsOps(func, GetOpQuantSpec);
222 }
223
Quantized(Operation * user)224 static Value Quantized(Operation* user) {
225 if (auto q = llvm::dyn_cast_or_null<quantfork::QuantizeCastOp>(user)) {
226 if (auto dq = llvm::dyn_cast_or_null<quantfork::DequantizeCastOp>(
227 *q.getResult().user_begin())) {
228 return dq.getResult();
229 }
230 }
231 return {};
232 }
233
SanityCheckAndAdjustment(func::FuncOp func)234 void PrepareQuantizePass::SanityCheckAndAdjustment(func::FuncOp func) {
235 // If an op output has two users: one of them is a quantize op and another
236 // one is returned directly, we decide to return the quantized result instead,
237 // so this op can be quantized. This is only applied on the returned result
238 // because the error will not be accumulated.
239
240 func.walk([&](func::ReturnOp ret) {
241 int i = 0;
242 for (Value returned : ret.getOperands()) {
243 llvm::SmallVector<Value, 4> quantized;
244 for (auto user : returned.getUsers()) {
245 if (auto q = Quantized(user)) {
246 quantized.push_back(q);
247 }
248 }
249 if (quantized.size() == 1) {
250 ret.setOperand(i, quantized.front());
251 }
252 i++;
253 }
254 });
255
256 // We prefer to placing quantization emulation ops on the results of the
257 // concat ops.
258 func.walk([&](ConcatenationOp concat) {
259 if (concat.output().hasOneUse() &&
260 Quantized(*concat.output().user_begin())) {
261 return;
262 }
263 concat.emitWarning(
264 "Missing quantization parameter on the output might introduce "
265 "quantization error!");
266 });
267
268 // Check for (Quant (Dequant $in), $qA) "qdq" pairs that couldn't be
269 // eliminated at this point. This only occurs for the pattern
270 // (Quant (Dequant (Quant $in, $qB)), $qA) $qB != $qA
271 // where the qdq pair denotes a non-trivial requantization of an
272 // already quantized value. Since this makes little sense (directly quantizing
273 // (Quant $in, $qA) would introduce less quantization noise) the likely cause
274 // is an minor error in constructing the original network model that
275 // introduced back-to-back Fake Quantization operations. Hence: emit a
276 // warning. N.b. at this point we're (teporarility) in the quantization
277 // dialect (presumably enable re-use in xla etc) quantfork::*QuantizeCastOp
278 // we're matching here.
279 //
280 func.walk([&](quantfork::QuantizeCastOp q_op) {
281 // If up with end up with
282 auto dq_op = dyn_cast_or_null<quantfork::DequantizeCastOp>(
283 q_op.getOperand().getDefiningOp());
284 if (!dq_op) {
285 return;
286 }
287 auto dq_arg = dq_op.getOperand();
288
289 if (!dq_arg.hasOneUse()) {
290 // The initial quantization is used someplace else ... so it might be
291 // reasonable for it to requantized for another purpose.
292 // Ideally would want to still check whether requantization narrows
293 // rather than widens the representation.
294 return;
295 }
296
297 // Invariant:
298 // isa<quantfork::QuantizeCastOp>(dq_arg.getDefiningOp()) -->
299 // getdq_arg.getType() != q_op.getResult().getType()
300 //
301 // as otherwise qdq pair would have been optimized away.
302 auto qd_arg_def_q_op =
303 dyn_cast_or_null<quantfork::QuantizeCastOp>(dq_arg.getDefiningOp());
304 if (!qd_arg_def_q_op) {
305 return;
306 }
307
308 qd_arg_def_q_op.emitWarning()
309 << " quantizer's output has another quantizer (" << q_op.getLoc()
310 << ") as consumer - intentional?";
311 });
312 }
313
ContainsQuantizeOps(func::FuncOp func)314 bool PrepareQuantizePass::ContainsQuantizeOps(func::FuncOp func) {
315 for (const auto& op : func.getOps()) {
316 if (llvm::isa<quantfork::DequantizeCastOp>(op)) return true;
317 }
318 return false;
319 }
320
321 using PrepareQuantStats =
322 quant::ConvertStatsToQDQs<quantfork::QuantizeCastOp,
323 quantfork::DequantizeCastOp>;
324
runOnOperation()325 void PrepareQuantizePass::runOnOperation() {
326 func::FuncOp func = getOperation();
327 MLIRContext* ctx = func.getContext();
328 ScopedTFLQuantOpsToMlirQuantOpsConverter converter(func);
329 if (use_quantization_flags_) {
330 quant_specs_.inference_type =
331 this->quantize_signed_ ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
332 quant_specs_.post_training_quantization = post_training_quantize_;
333 quant_specs_.legacy_float_scale = legacy_float_scale_;
334 quant_specs_.disable_set_input_nodes_quantization_params =
335 disable_set_input_nodes_quantization_params_;
336 }
337
338 if (quant_specs_.post_training_quantization) {
339 tflite_quantizer_usage_stats->GetCell("post_training")->IncrementBy(1);
340 RemoveRedundantStats(func);
341 } else {
342 tflite_quantizer_usage_stats->GetCell("during_training")->IncrementBy(1);
343 // Set the quantization parameters for the quantizable input nodes. If this
344 // failed, return the function immediately. This is only required for
345 // quantization aware training model conversion.
346 if (SetInputNodesQuantizationParams(func)) {
347 return;
348 }
349 }
350
351 bool is_signed = quant_specs_.IsSignedInferenceType();
352 int bit_width = quant_specs_.GetQuantizationTypeWidth();
353 // When this is true, the quantizer will try its best to extract the
354 // quantization parameters from the op quantization property and constant
355 // content. This is also set to true when the `quantize_allowlist` and
356 // `quantize_signed` test flags are enabled.
357 bool eager_quantize = ContainsQuantizeOps(func) ||
358 (!quantize_allowlist_.empty() || quantize_signed_);
359 // Infer the tensor range for the activation ops and weight constants unless
360 // it is disabled explicitly.
361 bool infer_tensor_range =
362 (quant_specs_.post_training_quantization || eager_quantize) &&
363 !quant_specs_.disable_infer_tensor_range;
364
365 // LSTM's restrict_scale requirement should be handled before converting stats
366 // to Q-DQ ops. The pattern is applied for non-PTQ case to make op ordering
367 // consistent. Otherwise some FileCheck tests would fail.
368 RewritePatternSet patterns_1(&getContext());
369 if (quant_specs_.post_training_quantization) {
370 patterns_1.add<PrepareLstmOutputScale<LSTMOp>>(ctx);
371 patterns_1.add<PrepareLstmOutputScale<UnidirectionalSequenceLSTMOp>>(ctx);
372 }
373 (void)applyPatternsAndFoldGreedily(func, std::move(patterns_1));
374
375 // During the legalization, unsigned quantized type is used, so we have to
376 // convert all of them to signed.
377 RewritePatternSet patterns_2(&getContext());
378 if (is_signed) {
379 patterns_2.add<quant::ConvertUnsignedToSigned<quantfork::QuantizeCastOp>>(
380 ctx);
381 // Convert quant stats to int8 quantization parameters.
382 // Currently, only activation stats are imported, so narrow_range = false.
383 patterns_2.add<PrepareQuantStats>(bit_width, false, true,
384 quant_specs_.legacy_float_scale, ctx);
385 } else {
386 // Convert quant stats to uint8 quantization parameters.
387 // Currently, only activation stats are imported, so narrow_range = false.
388 patterns_2.add<PrepareQuantStats>(bit_width, false, false,
389 quant_specs_.legacy_float_scale, ctx);
390 }
391
392 if (quant_specs_.post_training_quantization) {
393 patterns_2.add<ConvertLstmStatsToQDQs<LSTMOp>>(ctx, quant_specs_);
394 patterns_2.add<ConvertLstmStatsToQDQs<UnidirectionalSequenceLSTMOp>>(
395 ctx, quant_specs_);
396 patterns_2.add<ConvertSvdfStatsToQDQs>(ctx, quant_specs_);
397 }
398 (void)applyPatternsAndFoldGreedily(func, std::move(patterns_2));
399
400 SanityCheckAndAdjustment(func);
401
402 // Finally, the quantization parameters can be propagated to the rest of the
403 // values (tensors).
404 ApplyQuantizationParamsPropagation(
405 func, is_signed, disable_per_channel_ || quant_specs_.disable_per_channel,
406 GetOpQuantSpec, infer_tensor_range, quant_specs_.legacy_float_scale);
407 }
408
409 } // namespace
410
411 // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
CreatePrepareQuantizePass(const quant::QuantizationSpecs & quant_specs)412 std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareQuantizePass(
413 const quant::QuantizationSpecs& quant_specs) {
414 return std::make_unique<PrepareQuantizePass>(quant_specs);
415 }
416
CreatePrepareQuantizePass()417 std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareQuantizePass() {
418 return std::make_unique<PrepareQuantizePass>();
419 }
420
421 } // namespace TFL
422 } // namespace mlir
423