xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <unordered_map>
17 #include <unordered_set>
18 #include <utility>
19 
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
28 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/Matchers.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
39 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
41 #include "tensorflow/core/platform/logging.h"
42 
43 #define DEBUG_TYPE "quantization-driver"
44 
45 namespace mlir {
46 namespace quant {
47 namespace {
EmptyParams(QuantParams p)48 static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
49 
50 // The state for each op result during the quantization parameters propagation.
51 struct QuantState {
52   // Quantization parameters propagated to an op result.
53   QuantParams params;
54   // A flag indicates this state (the params) shouldn't be changed after it is
55   // initialized. This flag will be set to true if the quantization parameters
56   // are from the quantization-aware training.
57   const bool immutable;
58 
IsEmptymlir::quant::__anon95f8f9910111::QuantState59   bool IsEmpty() { return EmptyParams(params); }
60 };
61 
62 // The state for rescaling the propagated quantization parameters. This can be
63 // on the input side to satisfy the constraint of previous operation, or on the
64 // output side to satisfy the constraint of the next operation.
65 struct RequantizeState {
66   // Sometimes, we have to "requantize" the quantization result to satisfy all
67   // the constraints. The "requantize" can happen either on the input or output
68   // of the quantization result.
69   enum RequantizePosition {
70     NO_REQUANTIZE,
71     ON_INPUT,
72     ON_OUTPUT
73   } pos = NO_REQUANTIZE;
74 
75   // Quantization parameters will be used to add the requantize ops.
76   QuantParams params;
77 
78   // Avoid clobbering all uses of the value, limit to just these ops.
79   SmallVector<std::pair<Operation *, int>> users;
80 };
81 
82 using RequantizeStates = SmallVector<RequantizeState>;
83 
84 // This is a worklist-driven driver for propagating quantization parameters
85 // across operations.
86 //
87 // The initial quantization parameters are extracted from the quantized type
88 // between adjacent tfl.quantize and tfl.dequantize ops. All these initial
89 // parameters are marked as immutable because they are from quantization-aware
90 // training.
91 //
92 // The algorithm traverses each op and sets the quantization parameters of its
93 // operands and results, according to its quantization specification, and then
94 // adds the operands and results to the worklist. If there are any conflicts
95 // (for example, there are quantization parameters propagated from the previous
96 // iteration), this process stops if the existing parameters are the immutable,
97 // or adding `requantize` op to resolve the conflicts.
98 //
99 // After the algorithm is converged, pairs of tfl.quantize and tfl.dequantize
100 // are inserted to the right position to materialize the propagation and
101 // requantize results.
102 //
103 class QuantizationDriver {
104  public:
QuantizationDriver(func::FuncOp fn,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,OpQuantScaleSpecGetter op_quant_scale_spec_getter,bool infer_tensor_range,bool legacy_float_scale)105   explicit QuantizationDriver(func::FuncOp fn, bool is_signed,
106                               bool disable_per_channel,
107                               OpQuantSpecGetter op_quant_spec_getter,
108                               OpQuantScaleSpecGetter op_quant_scale_spec_getter,
109                               bool infer_tensor_range, bool legacy_float_scale)
110       : fn_(fn),
111         builder_(fn.getBody()),
112         is_signed_(is_signed),
113         disable_per_channel_(disable_per_channel),
114         op_quant_spec_getter_(op_quant_spec_getter),
115         op_quant_scale_spec_getter_(op_quant_scale_spec_getter),
116         infer_tensor_range_(infer_tensor_range),
117         legacy_float_scale_(legacy_float_scale) {}
118 
119   // The entry point of the quantization parameters propagation.
120   void Run();
121 
122  private:
123   // This is used to identify an operand or result of an op. The second element
124   // of this pair is the index of the operand or result.
125   using OpValue = std::pair<mlir::Operation *, int>;
126 
127   // Sets up the states for all the op results in the function.
128   void Initialize();
129 
130   // Propagates the quantization parameters across all the ops.
131   bool PropagateParams();
132 
133   // Duplicates the constant op if it has multiple uses, and replaces
134   // target_op->operand[operand_index] with the newly created op. This also
135   // replaces corresponsing quantization states.
136   arith::ConstantOp DuplicateConstantOpIfNeeded(arith::ConstantOp op,
137                                                 Operation *target_op,
138                                                 int operand_index);
139 
140   // Adjusts bias scale that is derived from other scales (fc, conv ops) to
141   // prevent overflow of quantized bias values. This also changes quantization
142   // state of other inputs when needed.
143   bool SetBiasParamsWithAdjustments(Operation *op, int bias_index,
144                                     const std::vector<int> &input_indices,
145                                     QuantParams params);
146 
147   // Helper for checking preconditions to adjust bias scale.
148   bool ShouldCheckBiasScale(Operation *op, int bias_index,
149                             const std::vector<int> &input_indices,
150                             QuantParams params, int &input_index,
151                             int &filter_index);
152 
153   // Inserts the Quantize and Dequantize ops according to the propagation
154   // result.
155   void Finalize();
156 
157   // The quantization parameters of bias operand are usually determined by
158   // other operands, so if a constant is used by different ops as bias, it needs
159   // to be duplicated, thus each op can assign its own quantization parameter
160   // for this bias. Also this method adds all the non-bias constants (weights)
161   // to a set for looking up later. This method also adds all the per-channel
162   // weights to a set for looking up later.
163   void PreprocessConstantOps();
164 
165   // Setup all the data structures for quantization propagation.
166   void SetupAllStates();
167 
168   // Whether the constant is a weight, which shouldn't be shared by different
169   // ops.
IsWeight(Operation * cst)170   bool IsWeight(Operation *cst) { return llvm::is_contained(weights_, cst); }
171 
172   // Returns all the related quantization constraints of the op.
173   std::unique_ptr<OpQuantSpec> GetQuantSpec(Operation *op);
174   std::unique_ptr<OpQuantScaleSpec> GetQuantScaleSpec(Operation *op);
175 
176   // Whether Quantization parameters have been propagated to the results of this
177   // op.
178   bool IsQuantized(Operation *op);
179 
180   // Adds all the users of index-th result of op to the work list.
AddUserToList(Operation * op,int index)181   void AddUserToList(Operation *op, int index) {
182     for (auto *user : op->getResult(index).getUsers()) {
183       work_list_.push_back(user);
184     }
185   }
186 
187   // Adds the defining op of index-th operand of op to the work list.
AddOperandToList(Operation * op,int index)188   void AddOperandToList(Operation *op, int index) {
189     if (auto *inst = op->getOperand(index).getDefiningOp()) {
190       work_list_.push_back(inst);
191     }
192   }
193 
194   // Returns the quantization params for the bias input from the non-bias
195   // operands which have their indexes in the `non_biases` vector. The returned
196   // parameters are calculated by `func`.
197   QuantParams GetBiasParams(Operation *op, int bias,
198                             const std::vector<int> &non_biases,
199                             AccumulatorScaleFunc func);
200 
201   // Sets the quantization parameters of the result to a fixed value. If any
202   // quantization parameters have been propagated, a `requantize` will happen on
203   // the input of propagated quantization.
204   bool SetResultParams(Operation *op, int index, QuantParams params);
205 
206   // Sets the quantization parameters of the operand to a fixed value. If any
207   // quantization parameters have been propagated, a `requantize` will happen on
208   // the output of propagated quantization. When `override` is set, quantization
209   // state of the value is replaced instead of adding requantization.
210   bool SetOperandParams(Operation *op, int index, QuantParams params,
211                         bool override = false);
212 
213   // Sets the quantization parameters of the constant result according to its
214   // content.
215   bool SetConstantResultParams(Operation *op);
216 
217   // Inserts the Quantize and Dequantize ops for quantizing the index-th result
218   // of the op.
219   void QuantizeOpResult(Operation *op, int index, QuantParams params);
220 
221   void QuantizeArg(BlockArgument arg, QuantParams params);
222 
223   // Inserts the Quantize and Dequantize ops to quantize the value and returns
224   // the Quantize op.
225   void QuantizeValue(Value value, QuantParams params, Location loc);
226 
227   // Inserts the Quantize ops for requantizing the index-th result of the op.
228   void RequantizeOpResult(Operation *op, int index, RequantizeStates *states);
229 
230   // Inserts the Quantize ops for requantizing a block argument.
231   void RequantizeArg(BlockArgument arg, RequantizeStates *states);
232 
233   // Inserts the Quantize and Dequantize ops to quantize the value and returns
234   // the Quantize op.
235   void RequantizeValue(Value value, RequantizeStates *states, Location loc);
236 
237   // A heuristic to get the quantization parameter satisfies the same scale
238   // constraints for the op. Returns an empty option if this quantization
239   // parameter doesn't exist.
240   QuantParams GetQuantParamsForSameScaleConstraint(Operation *op);
241 
242   // Returns the state of the index-th operand of the op.
GetOperandQuantState(Operation * op,int index)243   QuantState &GetOperandQuantState(Operation *op, int index) {
244     return states_[operand_states_[{op, index}]];
245   }
246 
247   // Returns the state of the index-th result of the op.
GetResultQuantState(Operation * op,int index)248   QuantState &GetResultQuantState(Operation *op, int index) {
249     return states_[result_states_[{op, index}]];
250   }
251 
252   // Returns the state of the block argument.
GetArgQuantState(BlockArgument arg)253   QuantState &GetArgQuantState(BlockArgument arg) {
254     return states_[arg_states_[arg]];
255   }
256 
257   // Returns the states of the index-th operand of the op.
GetOperandRequantizeStates(Operation * op,int index)258   RequantizeStates &GetOperandRequantizeStates(Operation *op, int index) {
259     return rescale_states_[operand_states_[{op, index}]];
260   }
261 
262   // Returns the states of the index-th result of the op.
GetResultRequantizeStates(Operation * op,int index)263   RequantizeStates &GetResultRequantizeStates(Operation *op, int index) {
264     return rescale_states_[result_states_[{op, index}]];
265   }
266 
267   // Returns the states of the arg.
GetArgRequantizeStates(BlockArgument arg)268   RequantizeStates &GetArgRequantizeStates(BlockArgument arg) {
269     return rescale_states_[arg_states_[arg]];
270   }
271 
272   // Uses the type of `val` to set the initial state of the index-th result if
273   // `as_result` is true or index-th operand if `as_result` is false. The state
274   // is immutable if the type is a quantized type. Returns the index of this
275   // new state in the state vector.
276   int InitializeState(Operation *op, int index, Value val, bool as_result);
277 
278   // Sets the state of an argument. If this value is cached, uses the cached
279   // result without creating new entry in the state vector. Otherwise, allocate
280   // a new entry in the state vector.
InitializeArgState(BlockArgument arg,Value in)281   void InitializeArgState(BlockArgument arg, Value in) {
282     auto cached = value_to_state_.insert({in, 0});
283     if (!cached.second) {
284       arg_states_[arg] = cached.first->second;
285       return;
286     }
287     QuantParams params =
288         quant::QuantizedType::getQuantizedElementType(in.getType());
289     bool immutable = !EmptyParams(params);
290     int next_state_index = states_.size();
291     states_.push_back({params, immutable});
292     arg_states_[arg] = next_state_index;
293     cached.first->second = next_state_index;
294   }
295 
296   // Sets the state of the index-th operand of the op. If this operand is
297   // cached, uses the cached result without creating new entry in the state
298   // vector. Otherwise, allocate a new entry in the state vector.
InitializeOperandState(Operation * op,int index,Value in)299   void InitializeOperandState(Operation *op, int index, Value in) {
300     auto cached = value_to_state_.insert({in, 0});
301     if (!cached.second) {
302       operand_states_[{op, index}] = cached.first->second;
303       return;
304     }
305     cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
306   }
307 
308   // Sets the state of the index-th result of the op. If this result is cached,
309   // uses the cached result without creating new entry in the state vector.
310   // Otherwise, allocate a new entry in the state vector.
InitializeResultState(Operation * op,int index,Value res)311   void InitializeResultState(Operation *op, int index, Value res) {
312     auto cached = value_to_state_.insert({res, 0});
313     if (!cached.second) {
314       result_states_[{op, index}] = cached.first->second;
315       return;
316     }
317     cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
318   }
319 
320   // Utility function for debug output for requantize states.
DumpRequantizeStates(const RequantizeStates & requantize_states)321   void DumpRequantizeStates(const RequantizeStates &requantize_states) {
322     for (auto &requantize_state : requantize_states) {
323       if (requantize_state.pos != RequantizeState::NO_REQUANTIZE) {
324         llvm::dbgs() << "+";
325         requantize_state.params.print(llvm::dbgs());
326       }
327     }
328   }
329 
DumpStates(Operation * current_op)330   void DumpStates(Operation *current_op) {
331     if (current_op) {
332       llvm::dbgs() << "\n\n\n" << current_op->getName() << "\n";
333     }
334     fn_.walk([&](Operation *op) {
335       std::unique_ptr<OpQuantScaleSpec> scale_spec = GetQuantScaleSpec(op);
336       if (op->hasTrait<OpTrait::IsTerminator>() ||
337           (IsOpNotQuantizable(op) && !scale_spec->has_same_scale_requirement) ||
338           llvm::isa<quantfork::QuantizeCastOp, quantfork::DequantizeCastOp,
339                     func::ConstantOp, arith::ConstantOp>(op)) {
340         return;
341       }
342       if (current_op == op) llvm::dbgs() << "===>>>";
343       llvm::dbgs() << op->getName() << " : (";
344       if (llvm::isa<func::FuncOp>(op)) {
345         for (auto &arg : fn_.getArguments()) {
346           if (auto params = GetArgQuantState(arg).params) {
347             params.print(llvm::dbgs());
348             DumpRequantizeStates(GetArgRequantizeStates(arg));
349           }
350           llvm::dbgs() << ",";
351         }
352       }
353       for (int i = 0, e = op->getNumOperands(); i < e; ++i) {
354         if (auto params = GetOperandQuantState(op, i).params) {
355           params.print(llvm::dbgs());
356           DumpRequantizeStates(GetOperandRequantizeStates(op, i));
357         } else {
358           op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
359               llvm::dbgs());
360         }
361         llvm::dbgs() << ",";
362       }
363       llvm::dbgs() << ") -> (";
364       for (int i = 0, e = op->getNumResults(); i < e; ++i) {
365         if (auto params = GetResultQuantState(op, i).params) {
366           params.print(llvm::dbgs());
367           DumpRequantizeStates(GetResultRequantizeStates(op, i));
368         } else {
369           op->getResult(i).getType().cast<ShapedType>().getElementType().print(
370               llvm::dbgs());
371         }
372         llvm::dbgs() << ",";
373       }
374       llvm::dbgs() << ")\n";
375     });
376   }
377 
378   func::FuncOp fn_;
379   OpBuilder builder_;
380   bool is_signed_;
381   bool disable_per_channel_;
382 
383   // We should distinguish weights and bias constants. Biases are specified by
384   // the quantization spec or are the operands of ops with same scale spec. The
385   // rest are weights.
386   llvm::DenseSet<Operation *> weights_;
387 
388   // The weights require narrow_range quantization. This map collects all the
389   // weight operands defined by the op quant spec. If the value of the entry is
390   // positive, per-channel quantization is required.
391   llvm::DenseMap<Operation *, int> optimized_weights_;
392 
393   // All the ops needs to propagate the quantization parameters to.
394   std::vector<Operation *> work_list_;
395   std::unordered_set<Operation *> quantized_;
396 
397   // The vector contains all the quantization parameters propagated from the
398   // defining operations of the value, or from the quantization aware training.
399   std::vector<QuantState> states_;
400 
401   // The map contains all the quantization parameters which are required to
402   // satisfy the same operands and results constraint. The keys of this map are
403   // the values from `operand_states_` and `result_state_`.
404   std::unordered_map<int, RequantizeStates> rescale_states_;
405 
406   // Maps of indexes to the propagation state vector from the ops operands,
407   // results and arguments.
408   llvm::DenseMap<OpValue, int> operand_states_;
409   llvm::DenseMap<OpValue, int> result_states_;
410   llvm::DenseMap<BlockArgument, int> arg_states_;
411   llvm::DenseMap<Value, int> value_to_state_;
412 
413   // This vector is to preserve the arguments order, so the newly inserted
414   // quantized ops for the arguments are deterministically ordered.
415   llvm::SmallVector<BlockArgument, 4> args_;
416 
417   OpQuantSpecGetter op_quant_spec_getter_;
418   OpQuantScaleSpecGetter op_quant_scale_spec_getter_;
419 
420   // Infer output ranges for activation ops and constants. This is usually
421   // required for post-training quantization.
422   bool infer_tensor_range_;
423 
424   // Calculate scales in float instead of double, so that the scales and
425   // quantized values are exactly the same with the TOCO quantizer.
426   bool legacy_float_scale_;
427 };
428 }  // namespace
429 
GetQuantSpec(Operation * op)430 std::unique_ptr<OpQuantSpec> QuantizationDriver::GetQuantSpec(Operation *op) {
431   return op_quant_spec_getter_(op);
432 }
433 
GetQuantScaleSpec(Operation * op)434 std::unique_ptr<OpQuantScaleSpec> QuantizationDriver::GetQuantScaleSpec(
435     Operation *op) {
436   return op_quant_scale_spec_getter_(op);
437 }
438 
IsQuantized(Operation * op)439 bool QuantizationDriver::IsQuantized(Operation *op) {
440   for (int i = 0, e = op->getNumResults(); i != e; ++i) {
441     if (GetResultQuantState(op, i).IsEmpty()) return false;
442   }
443   return true;
444 }
445 
InitializeState(Operation * op,int index,Value val,bool as_result)446 int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
447                                         bool as_result) {
448   QuantParams params =
449       quant::QuantizedType::getQuantizedElementType(val.getType());
450   bool immutable = !EmptyParams(params);
451   int next_state_index = states_.size();
452   states_.push_back({params, immutable});
453   if (as_result)
454     result_states_[{op, index}] = next_state_index;
455   else
456     operand_states_[{op, index}] = next_state_index;
457 
458   return next_state_index;
459 }
460 
SetConstantResultParams(Operation * op)461 bool QuantizationDriver::SetConstantResultParams(Operation *op) {
462   DenseFPElementsAttr attr;
463   Value res = op->getResult(0);
464   if (!matchPattern(res, m_Constant(&attr))) {
465     return false;
466   }
467   // TODO(fengliuai): make storage_type_width and narrow_range configurable.
468   Type final_type;
469   auto it = optimized_weights_.find(op);
470   bool is_weight = it != optimized_weights_.end();
471   bool is_weight_with_per_channel_support =
472       is_weight && it->second != -1 && is_signed_;
473 
474   if (is_weight_with_per_channel_support && !disable_per_channel_) {
475     // When `disable_per_channel_` is false, per-channel symmetric quantization
476     // parameters are created from the weights when the ops support per-channel
477     // quantization. Otherwise, uses per-tensor asymmetric quantization with
478     // narrow range.
479 
480     // per-axis quantization weight, with symmetric min/max enforced.
481     final_type = GetUniformQuantizedPerAxisTypeForWeight(
482         attr, it->second, /*symmetric=*/true, /*num_bits=*/8, is_signed_,
483         /*narrow_range=*/true, legacy_float_scale_);
484   } else {
485     // per-tensor quantization weight
486     final_type = GetUniformQuantizedTypeForWeight(
487         attr, /*symmetric=*/is_weight && is_signed_,
488         /*num_bits=*/8, is_signed_,
489         /*narrow_range_=*/is_weight, legacy_float_scale_);
490   }
491   if (auto quant_type = final_type.dyn_cast_or_null<quant::QuantizedType>()) {
492     return SetResultParams(op, 0, quant_type);
493   }
494   return false;
495 }
496 
SetResultParams(Operation * op,int res_index,QuantParams params)497 bool QuantizationDriver::SetResultParams(Operation *op, int res_index,
498                                          QuantParams params) {
499   auto &state = GetResultQuantState(op, res_index);
500   if (state.params == params) {
501     return false;
502   }
503   if (!state.IsEmpty()) {
504     auto &rescales = GetResultRequantizeStates(op, res_index);
505     RequantizeState &rescale = rescales.emplace_back();
506     rescale.pos = RequantizeState::ON_INPUT;
507     rescale.params = params;
508     return true;
509   }
510   state.params = params;
511   AddUserToList(op, res_index);
512   return true;
513 }
514 
GetBiasParams(Operation * op,int bias,const std::vector<int> & non_biases,AccumulatorScaleFunc func)515 QuantParams QuantizationDriver::GetBiasParams(
516     Operation *op, int bias, const std::vector<int> &non_biases,
517     AccumulatorScaleFunc func) {
518   auto &bias_state = GetOperandQuantState(op, bias);
519   if (!bias_state.IsEmpty()) {
520     return bias_state.params;
521   }
522   std::vector<QuantParams> op_types;
523   op_types.reserve(non_biases.size());
524   for (auto non_bias : non_biases) {
525     auto &non_bias_type = GetOperandQuantState(op, non_bias);
526     op_types.push_back(non_bias_type.params);
527   }
528   if (op_types.empty()) return {};
529   return func(op_types, legacy_float_scale_);
530 }
531 
SetOperandParams(Operation * op,int index,QuantParams params,bool override)532 bool QuantizationDriver::SetOperandParams(Operation *op, int index,
533                                           QuantParams params, bool override) {
534   auto &state = GetOperandQuantState(op, index);
535   if (state.params == params) {
536     return false;
537   }
538 
539   if (!state.IsEmpty() && !override) {
540     auto &rescales = GetOperandRequantizeStates(op, index);
541     for (RequantizeState &rescale : rescales) {
542       if (rescale.params == params) {
543         rescale.users.emplace_back(op, index);
544         return true;
545       }
546     }
547     RequantizeState &rescale = rescales.emplace_back();
548     rescale.pos = RequantizeState::ON_OUTPUT;
549     rescale.params = params;
550     rescale.users.emplace_back(op, index);
551     return true;
552   }
553 
554   state.params = params;
555   AddOperandToList(op, index);
556   return true;
557 }
558 
QuantizeOpResult(Operation * op,int index,QuantParams params)559 void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
560                                           QuantParams params) {
561   builder_.setInsertionPointAfter(op);
562   Value original_result = op->getResult(index);
563   QuantizeValue(original_result, params, op->getLoc());
564 }
565 
QuantizeArg(BlockArgument arg,QuantParams params)566 void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
567   builder_.setInsertionPointToStart(arg.getOwner());
568   QuantizeValue(arg, params, builder_.getUnknownLoc());
569 }
570 
QuantizeValue(Value value,QuantParams params,Location loc)571 void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
572                                        Location loc) {
573   Type expressed_type = value.getType();
574   Type new_type = params.castFromExpressedType(expressed_type);
575   // This value isn't an expressed type (float), skip.
576   if (!new_type) return;
577   auto quantize =
578       builder_.create<quantfork::QuantizeCastOp>(loc, new_type, value);
579   auto dequantize = builder_.create<quantfork::DequantizeCastOp>(
580       loc, expressed_type, quantize.getResult());
581 
582   // This attribute is set to distinguish the quantize ops being added by the
583   // quantization pass. These ops can be removed without losing original
584   // program accuracy.
585   // TODO(fengliuai): make the attribute being part of op definition.
586   quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
587 
588   // `original_result` has a use to `quantize`, so this will replace that use
589   // by the result of `dequantize`. Remember to reset that use afterwards
590   value.replaceAllUsesWith(dequantize);
591   quantize.getOperation()->replaceUsesOfWith(dequantize, value);
592 }
593 
RequantizeOpResult(Operation * op,int index,RequantizeStates * states)594 void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
595                                             RequantizeStates *states) {
596   if (states->empty()) return;
597 
598   builder_.setInsertionPointAfter(op);
599   Value value = op->getResult(index);
600   RequantizeState::RequantizePosition pos = states->front().pos;
601   if (pos == RequantizeState::NO_REQUANTIZE) {
602     return;
603   }
604   for (auto &state : *states) {
605     // Check that all requantization positions are the same for each state.
606     // Unsure if this check is required.
607     if (state.pos != pos) {
608       return;
609     }
610   }
611   if (pos == RequantizeState::ON_OUTPUT) {
612     Operation *user = value.getUses().begin().getUser();
613     if (llvm::isa<quantfork::QuantizeCastOp>(user)) {
614       // The requantize op is inserted between `quantize` and `dequantize` ops.
615       value = user->getResult(0);
616       builder_.setInsertionPointAfter(user);
617     }
618   }
619   RequantizeValue(value, states, op->getLoc());
620 }
621 
RequantizeArg(BlockArgument arg,RequantizeStates * states)622 void QuantizationDriver::RequantizeArg(BlockArgument arg,
623                                        RequantizeStates *states) {
624   Value value = arg;
625   builder_.setInsertionPointToStart(arg.getOwner());
626   if (value.hasOneUse()) {
627     auto user = value.use_begin().getUser();
628     if (auto q = llvm::dyn_cast<quantfork::QuantizeCastOp>(user)) {
629       value = q.getResult();
630       builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
631     }
632   }
633   RequantizeValue(value, states, builder_.getUnknownLoc());
634 }
635 
RequantizeValue(Value value,RequantizeStates * states,Location loc)636 void QuantizationDriver::RequantizeValue(Value value, RequantizeStates *states,
637                                          Location loc) {
638   if (states->empty() ||
639       states->front().pos == RequantizeState::NO_REQUANTIZE) {
640     return;
641   }
642   if (states->front().pos == RequantizeState::ON_INPUT) {
643     auto &state = states->front();
644     Type expressed_type = value.getType();
645     // The value needs to be requantized. A Quantize op will be created to use
646     // it as the operand and replace its uses.
647     Type new_type = state.params.castFromExpressedType(expressed_type);
648     if (!new_type) return;
649     auto requantize_op =
650         builder_.create<quantfork::QuantizeCastOp>(loc, new_type, value);
651     value.replaceAllUsesWith(requantize_op);
652     requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
653     // This requantization was defined as required for the result value, so
654     // there should be only one requant state.
655     return;
656   }
657 
658   // If this is an operand that requires requantization, then the value should
659   // only have one DequantizeCastOp user which produces the operand value.
660   if (!value.hasOneUse()) {
661     return;
662   }
663   auto dequant_op = llvm::dyn_cast_or_null<quantfork::DequantizeCastOp>(
664       value.use_begin().getUser());
665   if (!dequant_op) {
666     return;
667   }
668   // It is possible that the dequant value is used by a op that doesn't require
669   // requant, so only overwrite the first if that is not the case.
670   const int num_uses = std::distance(dequant_op.getResult().use_begin(),
671                                      dequant_op.getResult().use_end());
672 
673   // Whether to replace quantization params of the first dequantize op
674   // after the quantized value is produced.
675   // If there is a use other than the requantize states, then we can't clobber.
676   bool clobber_first = num_uses <= states->size();
677   for (auto &state : *states) {
678     Type expressed_type =
679         quant::QuantizedType::castToExpressedType(value.getType());
680     if (!expressed_type) continue;
681     // The value needs to be requantized. A Quantize op will be created to use
682     // it as the operand and replace its uses.
683     Type new_type = state.params.castFromExpressedType(expressed_type);
684     // This value isn't an expressed type (float), skip.
685     if (!new_type) continue;
686 
687     auto requantize_op =
688         builder_.create<quantfork::QuantizeCastOp>(loc, new_type, value);
689 
690     if (clobber_first) {
691       dequant_op.setOperand(requantize_op.getResult());
692       // All ops requiring this value already use the result of dequant.
693       clobber_first = false;
694     } else {
695       auto new_dequant_op = builder_.create<quantfork::DequantizeCastOp>(
696           loc, dequant_op.getResult().getType(), requantize_op.getResult());
697       for (auto &op_index : state.users) {
698         op_index.first->setOperand(op_index.second, new_dequant_op.getResult());
699       }
700     }
701   }
702 }
703 
704 // A heuristic to get quantization parameters satisfies the same scale
705 // constraints:
706 // - If there are immutable states,
707 //   - use the single input, or,
708 //   - use the single output, or,
709 //   - use the first one in the collection,
710 // - use the single input if it is ready, or,
711 // - use the single output if it is ready, or,
712 // - use the first ready one in the collection.
GetQuantParamsForSameScaleConstraint(Operation * op)713 QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint(
714     Operation *op) {
715   // Two vector to collect Non-empty operands and results states.
716   std::vector<QuantState *> mutable_states, immutable_states;
717   for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
718     auto &state = GetOperandQuantState(op, i);
719     if (state.immutable) {
720       immutable_states.push_back(&state);
721     } else if (!state.IsEmpty()) {
722       mutable_states.push_back(&state);
723     }
724   }
725 
726   int immutable_operands_num = immutable_states.size();
727   int mutable_operands_num = mutable_states.size();
728   // Use the operand's state if it is immutable and it is the only one
729   // operand.
730   if (op->getNumOperands() == 1 && immutable_operands_num == 1) {
731     return immutable_states.front()->params;
732   }
733 
734   for (int i = 0, e = op->getNumResults(); i != e; ++i) {
735     auto &state = GetResultQuantState(op, i);
736     if (state.immutable) {
737       immutable_states.push_back(&state);
738     } else if (!state.IsEmpty()) {
739       mutable_states.push_back(&state);
740     }
741   }
742 
743   int immutable_results_num = immutable_states.size() - immutable_operands_num;
744   int mutable_results_num = mutable_states.size() - mutable_operands_num;
745   // Use the result's state if it is immutable and it is the only one result.
746   if (op->getNumResults() == 1 && immutable_results_num == 1) {
747     return immutable_states.back()->params;
748   }
749 
750   // Use the first immutable state to quantize the rest operands and results.
751   if (!immutable_states.empty()) return immutable_states.front()->params;
752 
753   // If there are no immutable states, use the operand's state if it is the
754   // only one operand and has parameters propagated.
755   if (op->getNumOperands() == 1 && mutable_operands_num == 1) {
756     return mutable_states.front()->params;
757   }
758 
759   // If there are no immutable states, use the result's state if it is the
760   // only one result and has parameters propagated.
761   if (op->getNumResults() == 1 && mutable_results_num == 1) {
762     return mutable_states.back()->params;
763   }
764 
765   // Use the first propagated state to quantize the rest operands and results.
766   if (!mutable_states.empty()) return mutable_states.front()->params;
767 
768   // None operands/results have parameters propagated, skip this node for now.
769   return {};
770 }
771 
PreprocessConstantOps()772 void QuantizationDriver::PreprocessConstantOps() {
773   fn_.walk([&](arith::ConstantOp cst) {
774     // Non-float tensors are neither weights nor require quantization.
775     auto type = cst.getType().dyn_cast<ShapedType>();
776     if (!type || !type.getElementType().isa<FloatType>()) return;
777 
778     Value value = cst.getResult();
779     builder_.setInsertionPoint(cst);
780 
781     // The following loop will change the value uses, thus we cache all the uses
782     // needs to be changed.
783     llvm::SmallVector<std::pair<Operation *, int>, 4> uses;
784     for (auto &use : value.getUses()) {
785       uses.push_back({use.getOwner(), use.getOperandNumber()});
786     }
787     for (const auto &indexed_use : llvm::enumerate(uses)) {
788       Operation *user = indexed_use.value().first;
789       int operand_num = indexed_use.value().second;
790 
791       std::unique_ptr<OpQuantSpec> spec = GetQuantSpec(user);
792       std::unique_ptr<OpQuantScaleSpec> scale_spec = GetQuantScaleSpec(user);
793       BiasParamsMap biases = spec->biases_params;
794 
795       // The quantization parameters of a `weight` shouldn't be determined by
796       // other values. So any constants which are not bias, an operand of an
797       // op with same scale requirements, and haven't been quantized are
798       // weights.
799       if (biases.find(operand_num) == biases.end() &&
800           !scale_spec->has_same_scale_requirement &&
801           !llvm::dyn_cast<quantfork::QuantizeCastOp>(user)) {
802         // Needs to scan the content of weights to get the quantization
803         // parameters if there are no quantization parameters (FakeQuant ops).
804         // For this case, the weight will not be duplicated.
805         weights_.insert(cst);
806         if (spec->coeff_op_quant_dim.find(operand_num) !=
807             spec->coeff_op_quant_dim.end()) {
808           optimized_weights_.insert(
809               {cst, spec->coeff_op_quant_dim[operand_num]});
810         }
811       } else {
812         // This is a bias or an operand of an op with same scale requirements,
813         // so the quantization parameter are propagated from or determined by
814         // other values. Duplicate this constant in case it is shared by
815         // different users.
816         if (uses.size() > 1) {
817           auto new_cst =
818               builder_.create<arith::ConstantOp>(cst.getLoc(), cst.getValue());
819           user->setOperand(operand_num, new_cst);
820         }
821       }
822     }
823   });
824 }
825 
SetupAllStates()826 void QuantizationDriver::SetupAllStates() {
827   for (auto arg : fn_.getArguments()) {
828     args_.push_back(arg);
829     Value value = arg;
830     // If the argument is quantized, it should only has one user.
831     if (arg.hasOneUse()) {
832       auto user = value.use_begin().getUser();
833       if (auto q = llvm::dyn_cast<quantfork::QuantizeCastOp>(user)) {
834         value = q.getResult();
835       }
836     }
837     InitializeArgState(arg, value);
838   }
839 
840   fn_.walk([&](Operation *op) {
841     std::unique_ptr<OpQuantScaleSpec> scale_spec = GetQuantScaleSpec(op);
842     if (IsOpNotQuantizable(op) && !scale_spec->has_same_scale_requirement) {
843       return;
844     }
845     work_list_.push_back(op);
846 
847     for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
848       auto operand = op->getOperand(i);
849       if (auto *inst = operand.getDefiningOp()) {
850         // If the operand comes from a tfl.dequantize op, we use the quantized
851         // input of this tfl.dequantize op to set the state.
852         if (auto dq = llvm::dyn_cast<quantfork::DequantizeCastOp>(inst)) {
853           operand = dq.getArg();
854         }
855       }
856       InitializeOperandState(op, i, operand);
857     }
858 
859     for (int res = 0, e = op->getNumResults(); res != e; ++res) {
860       Value result = op->getResult(res);
861       // If the result has been quantized, it should only be used by a
862       // tfl.quantize op. For this case, we uses the quantized result to
863       // create the state and mark it immutable.
864       if (result.hasOneUse()) {
865         auto user = result.use_begin().getUser();
866         if (auto q = llvm::dyn_cast<quantfork::QuantizeCastOp>(user)) {
867           result = q.getResult();
868         }
869       }
870       InitializeResultState(op, res, result);
871     }
872   });
873 }
874 
875 // This method scans the operations in the function to setup the initial
876 // states for quantization parameter propagation.
877 // TODO(fengliuai): This algorithm assumes there are only one pair of
878 // tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
879 // check should be applied.
Initialize()880 void QuantizationDriver::Initialize() {
881   // Duplicate the bias constant, so the states can be setup correctly.
882   // TODO(fengliuai): Function definition should also be duplicated if there
883   // are multiple call sites.
884   PreprocessConstantOps();
885 
886   // Setup all the internal states.
887   SetupAllStates();
888 }
889 
PropagateParams()890 bool QuantizationDriver::PropagateParams() {
891   // TODO(fengliuai): uses a typed indicator instead of a bool value.
892   bool changed = false;
893   while (!work_list_.empty()) {
894     Operation *op = work_list_.back();
895     work_list_.pop_back();
896 
897     LLVM_DEBUG(DumpStates(op));
898 
899     // This op has been quantized, so we should not consider it again.
900     if (llvm::is_contained(quantized_, op)) continue;
901     quantized_.insert(op);
902 
903     if (auto cst = llvm::dyn_cast<arith::ConstantOp>(op)) {
904       // If the workflow requires inferring ranges from the content
905       // (post-training quantization) and it is weight (filter) and hasn't
906       // been quantized, we infer the quantization parameters from the content.
907       if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) {
908         // The quantization parameters are determined by the content of the
909         // constant.
910         changed |= SetConstantResultParams(op);
911       }
912       continue;
913     }
914 
915     std::unique_ptr<OpQuantScaleSpec> scale_spec = GetQuantScaleSpec(op);
916 
917     if (scale_spec->has_same_scale_requirement) {
918       auto params = GetQuantParamsForSameScaleConstraint(op);
919       // The quantization parameters haven't been propagated to any operands
920       // or results. Skip this node for now.
921       if (!params) {
922         quantized_.erase(op);
923         continue;
924       }
925 
926       // Use the final state to set all the operands' parameters.
927       for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
928         if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
929           // Without this check, it will accidentally propagate the quantization
930           // information by the shared non-float tensors.
931           if (type.getElementType().isa<FloatType>())
932             changed |= SetOperandParams(op, i, params);
933         }
934       }
935 
936       // Use the final state to set all the results' parameters.
937       for (int res = 0, e = op->getNumResults(); res != e; ++res)
938         if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
939           // Without this check, it will accidentally propagate the quantization
940           // information by the shared non-float-tensors.
941           if (type.getElementType().isa<FloatType>())
942             changed |= SetResultParams(op, res, params);
943         }
944     }
945 
946     // TODO(fengliuai): make the bit width configurable.
947     if (scale_spec->has_fixed_output_range && infer_tensor_range_) {
948       // Infer ranges from the activation ops. This is usually required for
949       // the post-training quantization workflow.
950       // TODO(fengliuai): different result can have different fixed range.
951       auto params = scale_spec->fixed_output_range_func(is_signed_,
952                                                         /*bit_width=*/8);
953       for (auto i = 0; i < op->getNumResults(); ++i) {
954         // The range is null if the result has been quantized.
955         if (params) {
956           changed |= SetResultParams(op, i, params);
957         }
958       }
959     }
960 
961     auto spec = GetQuantSpec(op);
962     for (auto &it : spec->biases_params) {
963       auto params =
964           GetBiasParams(op, it.first, it.second.first, it.second.second);
965       if (!params) {
966         quantized_.erase(op);
967         continue;
968       }
969       changed |=
970           SetBiasParamsWithAdjustments(op, it.first, it.second.first, params);
971     }
972   }
973 
974   LLVM_DEBUG(llvm::dbgs() << "\n\n\n");
975   LLVM_DEBUG(DumpStates(nullptr));
976 
977   return changed;
978 }
979 
DuplicateConstantOpIfNeeded(arith::ConstantOp op,Operation * target_op,int operand_index)980 arith::ConstantOp QuantizationDriver::DuplicateConstantOpIfNeeded(
981     arith::ConstantOp op, Operation *target_op, int operand_index) {
982   if (op.getResult().hasOneUse()) {
983     return op;
984   }
985   OpBuilder builder(op->getContext());
986   builder.setInsertionPointAfter(op);
987   arith::ConstantOp new_op = llvm::cast<arith::ConstantOp>(builder.clone(*op));
988   target_op->getOpOperand(operand_index).set(new_op.getResult());
989   InitializeOperandState(target_op, operand_index, new_op.getResult());
990   InitializeResultState(new_op, 0, new_op.getResult());
991   return new_op;
992 }
993 
ShouldCheckBiasScale(Operation * op,int bias_index,const std::vector<int> & input_indices,QuantParams params,int & input_index,int & filter_index)994 bool QuantizationDriver::ShouldCheckBiasScale(
995     Operation *op, int bias_index, const std::vector<int> &input_indices,
996     QuantParams params, int &input_index, int &filter_index) {
997   // For now, restrict scale adjustment to ops with affine quantized weights,
998   // and having weights and biases as constants. This currently only applies to
999   // FC and Conv* ops. Restriction for the weight can be relaxed if there are
1000   // needs for adjusting scale of variable weights.
1001   auto affine_op = llvm::dyn_cast<AffineQuantizedOpInterface>(op);
1002   auto bias_op = op->getOperand(bias_index).getDefiningOp<arith::ConstantOp>();
1003   if (!affine_op || !bias_op || input_indices.size() != 2) return false;
1004   if (!bias_op.getValue().isa<DenseFPElementsAttr>()) return false;
1005   filter_index = affine_op.GetAffineOperandIndex();
1006   if (!op->getOperand(filter_index).getDefiningOp<arith::ConstantOp>()) {
1007     return false;
1008   }
1009   if (filter_index == input_indices[0]) {
1010     input_index = input_indices[1];
1011   } else if (filter_index == input_indices[1]) {
1012     input_index = input_indices[0];
1013   } else {
1014     return false;
1015   }
1016 
1017   auto input_state = GetOperandQuantState(op, input_index);
1018   auto filter_state = GetOperandQuantState(op, filter_index);
1019   // If quantization paramater for the filter is fixed, should return it as-is.
1020   // Only checks ops with 8-bit input and weights, and 32-bit biases.
1021   if (!(input_state.params.getStorageTypeIntegralWidth() == 8 &&
1022         filter_state.params.getStorageTypeIntegralWidth() == 8 &&
1023         params.getStorageTypeIntegralWidth() == 32)) {
1024     return false;
1025   }
1026   return true;
1027 }
1028 
SetBiasParamsWithAdjustments(Operation * op,int bias_index,const std::vector<int> & input_indices,QuantParams params)1029 bool QuantizationDriver::SetBiasParamsWithAdjustments(
1030     Operation *op, int bias_index, const std::vector<int> &input_indices,
1031     QuantParams params) {
1032   bool changed = false;
1033   int input_index;
1034   int filter_index;
1035   if (!ShouldCheckBiasScale(op, bias_index, input_indices, params, input_index,
1036                             filter_index)) {
1037     return SetOperandParams(op, bias_index, params);
1038   }
1039 
1040   quant::QuantState input_state = GetOperandQuantState(op, input_index);
1041   quant::QuantState filter_state = GetOperandQuantState(op, filter_index);
1042   auto bias_op = op->getOperand(bias_index).getDefiningOp<arith::ConstantOp>();
1043   const double input_scale =
1044       input_state.params.cast<UniformQuantizedType>().getScale();
1045 
1046   auto bias_values = bias_op.getValue().cast<DenseFPElementsAttr>();
1047   // Restrict maximum absolute value of bias within INT_MAX / 2, to make some
1048   // room for accumulator.
1049   const int32_t kBiasMax = std::numeric_limits<int32_t>::max() / 2;
1050   if (auto bias_params = params.dyn_cast<UniformQuantizedType>()) {
1051     double bias_half_range = 0.0f;
1052     for (auto bias : bias_values.getValues<APFloat>()) {
1053       if (bias_half_range < std::abs(bias.convertToFloat())) {
1054         bias_half_range = std::abs(bias.convertToFloat());
1055       }
1056     }
1057     if (bias_half_range / bias_params.getScale() < kBiasMax) {
1058       return SetOperandParams(op, bias_index, params);
1059     }
1060     double new_bias_scale = static_cast<double>(bias_half_range) / kBiasMax;
1061 
1062     changed |= SetOperandParams(
1063         op, bias_index,
1064         UniformQuantizedType::getChecked(
1065             bias_op->getLoc(), params.getFlags(), params.getStorageType(),
1066             params.getExpressedType(), new_bias_scale, 0,
1067             params.getStorageTypeMin(), params.getStorageTypeMax()));
1068     auto filter_op = DuplicateConstantOpIfNeeded(
1069         op->getOperand(filter_index).getDefiningOp<arith::ConstantOp>(), op,
1070         filter_index);
1071     if (!filter_op) {
1072       return SetOperandParams(op, bias_index, params);
1073     }
1074 
1075     auto filter_param = filter_state.params.cast<UniformQuantizedType>();
1076     changed |= SetOperandParams(
1077         op, filter_index,
1078         UniformQuantizedType::getChecked(
1079             filter_op->getLoc(), filter_param.getFlags(),
1080             filter_param.getStorageType(), filter_param.getExpressedType(),
1081             new_bias_scale / input_scale, 0, filter_param.getStorageTypeMin(),
1082             filter_param.getStorageTypeMax()),
1083         /*override=*/true);
1084   } else if (auto bias_params =
1085                  params.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
1086     auto filter_params =
1087         filter_state.params.cast<quant::UniformQuantizedPerAxisType>();
1088     std::vector<double> new_bias_scales = bias_params.getScales().vec();
1089     std::vector<double> new_filter_scales = filter_params.getScales().vec();
1090     bool needs_adjustment = false;
1091     for (int i = 0; i < bias_params.getScales().size(); ++i) {
1092       float abs_bias = std::abs(bias_values.getValues<float>()[i]);
1093       if (abs_bias / new_bias_scales[i] > kBiasMax) {
1094         new_bias_scales[i] = static_cast<double>(abs_bias) / kBiasMax;
1095         new_filter_scales[i] = new_bias_scales[i] / input_scale;
1096         needs_adjustment = true;
1097       }
1098     }
1099     if (!needs_adjustment) {
1100       return SetOperandParams(op, bias_index, params);
1101     }
1102     changed |= SetOperandParams(
1103         op, bias_index,
1104         quant::UniformQuantizedPerAxisType::getChecked(
1105             bias_op->getLoc(), params.getFlags(), params.getStorageType(),
1106             params.getExpressedType(), new_bias_scales,
1107             bias_params.getZeroPoints(), bias_params.getQuantizedDimension(),
1108             params.getStorageTypeMin(), params.getStorageTypeMax()));
1109 
1110     auto filter_op = DuplicateConstantOpIfNeeded(
1111         op->getOperand(filter_index).getDefiningOp<arith::ConstantOp>(), op,
1112         filter_index);
1113     changed |= SetOperandParams(
1114         op, filter_index,
1115         quant::UniformQuantizedPerAxisType::getChecked(
1116             filter_op->getLoc(), filter_params.getFlags(),
1117             filter_params.getStorageType(), filter_params.getExpressedType(),
1118             new_filter_scales, filter_params.getZeroPoints(),
1119             filter_params.getQuantizedDimension(),
1120             filter_params.getStorageTypeMin(),
1121             filter_params.getStorageTypeMax()),
1122         /*override=*/true);
1123   }
1124   return changed;
1125 }
1126 
Finalize()1127 void QuantizationDriver::Finalize() {
1128   for (auto arg : args_) {
1129     auto &state = GetArgQuantState(arg);
1130     auto &requantizes = GetArgRequantizeStates(arg);
1131     if (state.IsEmpty() || (state.immutable && requantizes.empty())) {
1132       continue;
1133     }
1134 
1135     if (!state.immutable) {
1136       QuantizeArg(arg, state.params);
1137     }
1138 
1139     if (!requantizes.empty()) {
1140       RequantizeArg(arg, &requantizes);
1141     }
1142   }
1143 
1144   for (auto it : result_states_) {
1145     Operation *op = it.first.first;
1146     int res_index = it.first.second;
1147     auto &state = GetResultQuantState(op, res_index);
1148     auto &requantizes = GetResultRequantizeStates(op, res_index);
1149     if (state.IsEmpty() || (state.immutable && requantizes.empty())) {
1150       continue;
1151     }
1152 
1153     if (!state.immutable) {
1154       QuantizeOpResult(op, res_index, state.params);
1155     }
1156 
1157     if (!requantizes.empty()) {
1158       RequantizeOpResult(op, res_index, &requantizes);
1159     }
1160   }
1161 }
1162 
Run()1163 void QuantizationDriver::Run() {
1164   Initialize();
1165   if (PropagateParams()) {
1166     Finalize();
1167   }
1168 }
1169 
ApplyQuantizationParamsPropagation(mlir::func::FuncOp func,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,bool infer_tensor_ranges,bool legacy_float_scale)1170 void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed,
1171                                         bool disable_per_channel,
1172                                         OpQuantSpecGetter op_quant_spec_getter,
1173                                         bool infer_tensor_ranges,
1174                                         bool legacy_float_scale) {
1175   ApplyQuantizationParamsPropagation(
1176       func, is_signed, disable_per_channel, op_quant_spec_getter,
1177       GetDefaultQuantScaleSpec, infer_tensor_ranges, legacy_float_scale);
1178 }
1179 
ApplyQuantizationParamsPropagation(mlir::func::FuncOp func,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,OpQuantScaleSpecGetter op_quant_scale_spec_getter,bool infer_tensor_ranges,bool legacy_float_scale)1180 void ApplyQuantizationParamsPropagation(
1181     mlir::func::FuncOp func, bool is_signed, bool disable_per_channel,
1182     OpQuantSpecGetter op_quant_spec_getter,
1183     OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges,
1184     bool legacy_float_scale) {
1185   QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
1186                      op_quant_scale_spec_getter, infer_tensor_ranges,
1187                      legacy_float_scale)
1188       .Run();
1189 }
1190 
1191 }  // namespace quant
1192 }  // namespace mlir
1193