xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass applies quantization propagation on TFLite dialect.
17 #include <iterator>
18 #include <string>
19 #include <utility>
20 
21 #include "absl/memory/memory.h"
22 #include "llvm/ADT/Optional.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/MathExtras.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
31 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
32 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
36 #include "mlir/IR/Operation.h"  // from @llvm-project
37 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
43 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
44 #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
45 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
46 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
47 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
48 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
49 #include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/lib/monitoring/counter.h"
52 
53 //===----------------------------------------------------------------------===//
54 // The prepare-quantize Pass.
55 //
56 namespace mlir {
57 namespace TFL {
58 
59 namespace {
60 #define GEN_PASS_CLASSES
61 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
62 
63 auto* tflite_quantizer_usage_stats = tensorflow::monitoring::Counter<1>::New(
64     "/tensorflow/lite/quantization/transforms/stats",
65     "The number of quantization pass invocations.", "path");
66 
67 // Applies prepare quantization on the model in TFL dialect. This pass runs
68 // before the quantization pass and propagate the quantization parameters
69 // across ops. This step is necessary for post-training quantization and also
70 // making the quantization rule for some operations in the quantization-aware
71 // training quantization simpler.
72 class PrepareQuantizePass
73     : public PrepareQuantizePassBase<PrepareQuantizePass> {
74  public:
75   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareQuantizePass)
76 
77   // Constructor used by the PassRegistration and enforce uint8 quantization.
78   // This is only used by test.
PrepareQuantizePass()79   explicit PrepareQuantizePass() : use_quantization_flags_(true) {}
80 
81   // Constructor used by manually creating the pass.
PrepareQuantizePass(const quant::QuantizationSpecs & quant_specs)82   explicit PrepareQuantizePass(const quant::QuantizationSpecs& quant_specs)
83       : use_quantization_flags_(false), quant_specs_(quant_specs) {}
84 
85   void runOnOperation() override;
86 
87  private:
88   // Set the quantization parameters of the input nodes. These parameters are
89   // converted from the user specified input value ranges. The input nodes with
90   // non-float tensor types will be skipped because they are not quantizable.
91   // Return true if number of input nodes doesn't equal to that of the input
92   // ranges.
93   bool SetInputNodesQuantizationParams(func::FuncOp func);
94 
95   // The function might contain more stats ops than required, and it will
96   // introduce requantize if the calibration stats have conflicts. This method
97   // tries to remove all the redundant stats ops.
98   bool RemoveRedundantStats(func::FuncOp func);
99 
100   // Verify the quantization specification is expected for quantizing the
101   // current function.
IsLegalQuantSpecs(func::FuncOp func)102   bool IsLegalQuantSpecs(func::FuncOp func) {
103     if (func.getName() == quant_specs_.target_func) {
104       return (quant_specs_.disable_set_input_nodes_quantization_params ||
105               func.getNumArguments() == quant_specs_.input_ranges.size());
106     }
107     return true;
108   }
109 
110   // Get the min and max values from the quantization specification for the
111   // current function and argument index. Uses default values if the function
112   // is specified in the `quantize_allowlist`.
113   std::pair<llvm::Optional<double>, llvm::Optional<double>>
GetMinMaxValuesForArgument(llvm::StringRef func_name,int index)114   GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
115     if (func_name == quant_specs_.target_func) {
116       return quant_specs_.input_ranges[index];
117     } else {
118       return {0.0, 255.0};
119     }
120   }
121 
122   // Apply some sanity check and report some warnings for those who don't follow
123   // the best quantization practice. This also fixes some simple violations.
124   void SanityCheckAndAdjustment(func::FuncOp func);
125 
126   // Whether the func contains Quantize ops. This is used to determine whether
127   // to use the quantization parameters from the fixed output range property.
128   bool ContainsQuantizeOps(func::FuncOp func);
129 
130   bool use_quantization_flags_;
131   quant::QuantizationSpecs quant_specs_;
132 };
133 
SetInputNodesQuantizationParams(func::FuncOp func)134 bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) {
135   if (quant_specs_.disable_set_input_nodes_quantization_params) {
136     return false;
137   }
138 
139   StringRef func_name = func.getName();
140   auto& target_func = quant_specs_.target_func;
141   // Skip this function because it isn't the target function from the spec or
142   // in the function while list.
143   if (target_func != func_name &&
144       !llvm::is_contained(quantize_allowlist_, func_name)) {
145     return false;
146   }
147   auto has_quantize_op = [&](const Value arg) {
148     return (arg.hasOneUse() &&
149             llvm::isa<quantfork::QuantizeCastOp>(*arg.user_begin()));
150   };
151 
152   bool need_to_set_input_nodes_quantization_params = false;
153   for (const BlockArgument arg : func.getArguments()) {
154     auto shaped = arg.getType().dyn_cast<ShapedType>();
155     if (shaped && shaped.getElementType().isa<FloatType>() &&
156         !has_quantize_op(arg)) {
157       need_to_set_input_nodes_quantization_params = true;
158       break;
159     }
160   }
161 
162   if (!need_to_set_input_nodes_quantization_params) {
163     return false;
164   }
165 
166   // If the validation fails, the pass should stop immediately.
167   if (!IsLegalQuantSpecs(func)) {
168     return true;
169   }
170 
171   OpBuilder builder(func);
172   bool is_signed = quant_specs_.IsSignedInferenceType();
173   IntegerAttr num_bits =
174       builder.getI32IntegerAttr(quant_specs_.GetQuantizationTypeWidth());
175   BoolAttr narrow_range = builder.getBoolAttr(false);
176 
177   auto add_quantize_op = [&](Location loc, Type input_type, Block* block,
178                              Block::iterator insertion_point, Value arg,
179                              int i) {
180     if (auto shaped = input_type.dyn_cast<ShapedType>()) {
181       if (shaped.getElementType().isa<FloatType>()) {
182         // If there are existing quantize ops, they are from training and we
183         // should respect them.
184         if (has_quantize_op(arg)) {
185           return;
186         }
187 
188         auto min_max = GetMinMaxValuesForArgument(func_name, i);
189         // The input min/max or mean/std are not specified, then skip.
190         if (!min_max.first.has_value() || !min_max.second.has_value()) return;
191 
192         TypeAttr params = quant::GetQuantizedTypeAttr(
193             builder, input_type,
194             builder.getF64FloatAttr(min_max.first.getValue()),
195             builder.getF64FloatAttr(min_max.second.getValue()),
196             /*quant_dim=*/-1, num_bits, narrow_range, is_signed);
197         builder.setInsertionPoint(block, insertion_point);
198         auto q_op = builder.create<quantfork::QuantizeCastOp>(
199             loc, params.getValue(), arg);
200         auto dq_op = builder.create<quantfork::DequantizeCastOp>(
201             loc, input_type, q_op.getResult());
202         arg.replaceAllUsesWith(dq_op.getResult());
203         q_op.setOperand(arg);
204       }
205     }
206   };
207 
208   for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
209     BlockArgument arg = func.getArgument(i);
210     auto* arg_block = arg.getOwner();
211     add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
212                     std::next(arg_block->begin(), i), arg, i);
213   }
214 
215   return false;
216 }
217 
218 #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
219 
RemoveRedundantStats(func::FuncOp func)220 bool PrepareQuantizePass::RemoveRedundantStats(func::FuncOp func) {
221   return RemoveRedundantStatsOps(func, GetOpQuantSpec);
222 }
223 
Quantized(Operation * user)224 static Value Quantized(Operation* user) {
225   if (auto q = llvm::dyn_cast_or_null<quantfork::QuantizeCastOp>(user)) {
226     if (auto dq = llvm::dyn_cast_or_null<quantfork::DequantizeCastOp>(
227             *q.getResult().user_begin())) {
228       return dq.getResult();
229     }
230   }
231   return {};
232 }
233 
SanityCheckAndAdjustment(func::FuncOp func)234 void PrepareQuantizePass::SanityCheckAndAdjustment(func::FuncOp func) {
235   // If an op output has two users: one of them is a quantize op and another
236   // one is returned directly, we decide to return the quantized result instead,
237   // so this op can be quantized. This is only applied on the returned result
238   // because the error will not be accumulated.
239 
240   func.walk([&](func::ReturnOp ret) {
241     int i = 0;
242     for (Value returned : ret.getOperands()) {
243       llvm::SmallVector<Value, 4> quantized;
244       for (auto user : returned.getUsers()) {
245         if (auto q = Quantized(user)) {
246           quantized.push_back(q);
247         }
248       }
249       if (quantized.size() == 1) {
250         ret.setOperand(i, quantized.front());
251       }
252       i++;
253     }
254   });
255 
256   // We prefer to placing quantization emulation ops on the results of the
257   // concat ops.
258   func.walk([&](ConcatenationOp concat) {
259     if (concat.output().hasOneUse() &&
260         Quantized(*concat.output().user_begin())) {
261       return;
262     }
263     concat.emitWarning(
264         "Missing quantization parameter on the output might introduce "
265         "quantization error!");
266   });
267 
268   // Check for  (Quant (Dequant $in), $qA) "qdq" pairs that couldn't be
269   // eliminated at this point.  This only occurs for the pattern
270   //      (Quant (Dequant (Quant $in, $qB)), $qA)   $qB != $qA
271   // where the  qdq pair denotes a non-trivial requantization of an
272   // already quantized value. Since this makes little sense (directly quantizing
273   // (Quant $in, $qA) would introduce less quantization noise) the likely cause
274   // is an minor error in constructing the original network model that
275   // introduced back-to-back Fake Quantization operations. Hence: emit a
276   // warning. N.b. at this point we're (teporarility) in the quantization
277   // dialect (presumably enable re-use in xla etc) quantfork::*QuantizeCastOp
278   // we're matching here.
279   //
280   func.walk([&](quantfork::QuantizeCastOp q_op) {
281     // If up with end up with
282     auto dq_op = dyn_cast_or_null<quantfork::DequantizeCastOp>(
283         q_op.getOperand().getDefiningOp());
284     if (!dq_op) {
285       return;
286     }
287     auto dq_arg = dq_op.getOperand();
288 
289     if (!dq_arg.hasOneUse()) {
290       // The initial quantization is used someplace else ... so it might be
291       // reasonable for it to requantized for another purpose.
292       // Ideally would want to still check whether requantization narrows
293       // rather than widens the representation.
294       return;
295     }
296 
297     // Invariant:
298     // isa<quantfork::QuantizeCastOp>(dq_arg.getDefiningOp()) -->
299     // getdq_arg.getType() != q_op.getResult().getType()
300     //
301     // as otherwise qdq pair would have been optimized away.
302     auto qd_arg_def_q_op =
303         dyn_cast_or_null<quantfork::QuantizeCastOp>(dq_arg.getDefiningOp());
304     if (!qd_arg_def_q_op) {
305       return;
306     }
307 
308     qd_arg_def_q_op.emitWarning()
309         << " quantizer's output has another quantizer (" << q_op.getLoc()
310         << ") as consumer - intentional?";
311   });
312 }
313 
ContainsQuantizeOps(func::FuncOp func)314 bool PrepareQuantizePass::ContainsQuantizeOps(func::FuncOp func) {
315   for (const auto& op : func.getOps()) {
316     if (llvm::isa<quantfork::DequantizeCastOp>(op)) return true;
317   }
318   return false;
319 }
320 
321 using PrepareQuantStats =
322     quant::ConvertStatsToQDQs<quantfork::QuantizeCastOp,
323                               quantfork::DequantizeCastOp>;
324 
runOnOperation()325 void PrepareQuantizePass::runOnOperation() {
326   func::FuncOp func = getOperation();
327   MLIRContext* ctx = func.getContext();
328   ScopedTFLQuantOpsToMlirQuantOpsConverter converter(func);
329   if (use_quantization_flags_) {
330     quant_specs_.inference_type =
331         this->quantize_signed_ ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
332     quant_specs_.post_training_quantization = post_training_quantize_;
333     quant_specs_.legacy_float_scale = legacy_float_scale_;
334     quant_specs_.disable_set_input_nodes_quantization_params =
335         disable_set_input_nodes_quantization_params_;
336   }
337 
338   if (quant_specs_.post_training_quantization) {
339     tflite_quantizer_usage_stats->GetCell("post_training")->IncrementBy(1);
340     RemoveRedundantStats(func);
341   } else {
342     tflite_quantizer_usage_stats->GetCell("during_training")->IncrementBy(1);
343     // Set the quantization parameters for the quantizable input nodes. If this
344     // failed, return the function immediately. This is only required for
345     // quantization aware training model conversion.
346     if (SetInputNodesQuantizationParams(func)) {
347       return;
348     }
349   }
350 
351   bool is_signed = quant_specs_.IsSignedInferenceType();
352   int bit_width = quant_specs_.GetQuantizationTypeWidth();
353   // When this is true, the quantizer will try its best to extract the
354   // quantization parameters from the op quantization property and constant
355   // content. This is also set to true when the `quantize_allowlist` and
356   // `quantize_signed` test flags are enabled.
357   bool eager_quantize = ContainsQuantizeOps(func) ||
358                         (!quantize_allowlist_.empty() || quantize_signed_);
359   // Infer the tensor range for the activation ops and weight constants unless
360   // it is disabled explicitly.
361   bool infer_tensor_range =
362       (quant_specs_.post_training_quantization || eager_quantize) &&
363       !quant_specs_.disable_infer_tensor_range;
364 
365   // LSTM's restrict_scale requirement should be handled before converting stats
366   // to Q-DQ ops. The pattern is applied for non-PTQ case to make op ordering
367   // consistent. Otherwise some FileCheck tests would fail.
368   RewritePatternSet patterns_1(&getContext());
369   if (quant_specs_.post_training_quantization) {
370     patterns_1.add<PrepareLstmOutputScale<LSTMOp>>(ctx);
371     patterns_1.add<PrepareLstmOutputScale<UnidirectionalSequenceLSTMOp>>(ctx);
372   }
373   (void)applyPatternsAndFoldGreedily(func, std::move(patterns_1));
374 
375   // During the legalization, unsigned quantized type is used, so we have to
376   // convert all of them to signed.
377   RewritePatternSet patterns_2(&getContext());
378   if (is_signed) {
379     patterns_2.add<quant::ConvertUnsignedToSigned<quantfork::QuantizeCastOp>>(
380         ctx);
381     // Convert quant stats to int8 quantization parameters.
382     // Currently, only activation stats are imported, so narrow_range = false.
383     patterns_2.add<PrepareQuantStats>(bit_width, false, true,
384                                       quant_specs_.legacy_float_scale, ctx);
385   } else {
386     // Convert quant stats to uint8 quantization parameters.
387     // Currently, only activation stats are imported, so narrow_range = false.
388     patterns_2.add<PrepareQuantStats>(bit_width, false, false,
389                                       quant_specs_.legacy_float_scale, ctx);
390   }
391 
392   if (quant_specs_.post_training_quantization) {
393     patterns_2.add<ConvertLstmStatsToQDQs<LSTMOp>>(ctx, quant_specs_);
394     patterns_2.add<ConvertLstmStatsToQDQs<UnidirectionalSequenceLSTMOp>>(
395         ctx, quant_specs_);
396     patterns_2.add<ConvertSvdfStatsToQDQs>(ctx, quant_specs_);
397   }
398   (void)applyPatternsAndFoldGreedily(func, std::move(patterns_2));
399 
400   SanityCheckAndAdjustment(func);
401 
402   // Finally, the quantization parameters can be propagated to the rest of the
403   // values (tensors).
404   ApplyQuantizationParamsPropagation(
405       func, is_signed, disable_per_channel_ || quant_specs_.disable_per_channel,
406       GetOpQuantSpec, infer_tensor_range, quant_specs_.legacy_float_scale);
407 }
408 
409 }  // namespace
410 
411 // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
CreatePrepareQuantizePass(const quant::QuantizationSpecs & quant_specs)412 std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareQuantizePass(
413     const quant::QuantizationSpecs& quant_specs) {
414   return std::make_unique<PrepareQuantizePass>(quant_specs);
415 }
416 
CreatePrepareQuantizePass()417 std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareQuantizePass() {
418   return std::make_unique<PrepareQuantizePass>();
419 }
420 
421 }  // namespace TFL
422 }  // namespace mlir
423