1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <unordered_map>
17 #include <unordered_set>
18 #include <utility>
19
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
28 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/IR/Matchers.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "mlir/Support/LLVM.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
39 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
41 #include "tensorflow/core/platform/logging.h"
42
43 #define DEBUG_TYPE "quantization-driver"
44
45 namespace mlir {
46 namespace quant {
47 namespace {
EmptyParams(QuantParams p)48 static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
49
50 // The state for each op result during the quantization parameters propagation.
51 struct QuantState {
52 // Quantization parameters propagated to an op result.
53 QuantParams params;
54 // A flag indicates this state (the params) shouldn't be changed after it is
55 // initialized. This flag will be set to true if the quantization parameters
56 // are from the quantization-aware training.
57 const bool immutable;
58
IsEmptymlir::quant::__anon95f8f9910111::QuantState59 bool IsEmpty() { return EmptyParams(params); }
60 };
61
62 // The state for rescaling the propagated quantization parameters. This can be
63 // on the input side to satisfy the constraint of previous operation, or on the
64 // output side to satisfy the constraint of the next operation.
65 struct RequantizeState {
66 // Sometimes, we have to "requantize" the quantization result to satisfy all
67 // the constraints. The "requantize" can happen either on the input or output
68 // of the quantization result.
69 enum RequantizePosition {
70 NO_REQUANTIZE,
71 ON_INPUT,
72 ON_OUTPUT
73 } pos = NO_REQUANTIZE;
74
75 // Quantization parameters will be used to add the requantize ops.
76 QuantParams params;
77
78 // Avoid clobbering all uses of the value, limit to just these ops.
79 SmallVector<std::pair<Operation *, int>> users;
80 };
81
82 using RequantizeStates = SmallVector<RequantizeState>;
83
84 // This is a worklist-driven driver for propagating quantization parameters
85 // across operations.
86 //
87 // The initial quantization parameters are extracted from the quantized type
88 // between adjacent tfl.quantize and tfl.dequantize ops. All these initial
89 // parameters are marked as immutable because they are from quantization-aware
90 // training.
91 //
92 // The algorithm traverses each op and sets the quantization parameters of its
93 // operands and results, according to its quantization specification, and then
94 // adds the operands and results to the worklist. If there are any conflicts
95 // (for example, there are quantization parameters propagated from the previous
96 // iteration), this process stops if the existing parameters are the immutable,
97 // or adding `requantize` op to resolve the conflicts.
98 //
99 // After the algorithm is converged, pairs of tfl.quantize and tfl.dequantize
100 // are inserted to the right position to materialize the propagation and
101 // requantize results.
102 //
103 class QuantizationDriver {
104 public:
QuantizationDriver(func::FuncOp fn,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,OpQuantScaleSpecGetter op_quant_scale_spec_getter,bool infer_tensor_range,bool legacy_float_scale)105 explicit QuantizationDriver(func::FuncOp fn, bool is_signed,
106 bool disable_per_channel,
107 OpQuantSpecGetter op_quant_spec_getter,
108 OpQuantScaleSpecGetter op_quant_scale_spec_getter,
109 bool infer_tensor_range, bool legacy_float_scale)
110 : fn_(fn),
111 builder_(fn.getBody()),
112 is_signed_(is_signed),
113 disable_per_channel_(disable_per_channel),
114 op_quant_spec_getter_(op_quant_spec_getter),
115 op_quant_scale_spec_getter_(op_quant_scale_spec_getter),
116 infer_tensor_range_(infer_tensor_range),
117 legacy_float_scale_(legacy_float_scale) {}
118
119 // The entry point of the quantization parameters propagation.
120 void Run();
121
122 private:
123 // This is used to identify an operand or result of an op. The second element
124 // of this pair is the index of the operand or result.
125 using OpValue = std::pair<mlir::Operation *, int>;
126
127 // Sets up the states for all the op results in the function.
128 void Initialize();
129
130 // Propagates the quantization parameters across all the ops.
131 bool PropagateParams();
132
133 // Duplicates the constant op if it has multiple uses, and replaces
134 // target_op->operand[operand_index] with the newly created op. This also
135 // replaces corresponsing quantization states.
136 arith::ConstantOp DuplicateConstantOpIfNeeded(arith::ConstantOp op,
137 Operation *target_op,
138 int operand_index);
139
140 // Adjusts bias scale that is derived from other scales (fc, conv ops) to
141 // prevent overflow of quantized bias values. This also changes quantization
142 // state of other inputs when needed.
143 bool SetBiasParamsWithAdjustments(Operation *op, int bias_index,
144 const std::vector<int> &input_indices,
145 QuantParams params);
146
147 // Helper for checking preconditions to adjust bias scale.
148 bool ShouldCheckBiasScale(Operation *op, int bias_index,
149 const std::vector<int> &input_indices,
150 QuantParams params, int &input_index,
151 int &filter_index);
152
153 // Inserts the Quantize and Dequantize ops according to the propagation
154 // result.
155 void Finalize();
156
157 // The quantization parameters of bias operand are usually determined by
158 // other operands, so if a constant is used by different ops as bias, it needs
159 // to be duplicated, thus each op can assign its own quantization parameter
160 // for this bias. Also this method adds all the non-bias constants (weights)
161 // to a set for looking up later. This method also adds all the per-channel
162 // weights to a set for looking up later.
163 void PreprocessConstantOps();
164
165 // Setup all the data structures for quantization propagation.
166 void SetupAllStates();
167
168 // Whether the constant is a weight, which shouldn't be shared by different
169 // ops.
IsWeight(Operation * cst)170 bool IsWeight(Operation *cst) { return llvm::is_contained(weights_, cst); }
171
172 // Returns all the related quantization constraints of the op.
173 std::unique_ptr<OpQuantSpec> GetQuantSpec(Operation *op);
174 std::unique_ptr<OpQuantScaleSpec> GetQuantScaleSpec(Operation *op);
175
176 // Whether Quantization parameters have been propagated to the results of this
177 // op.
178 bool IsQuantized(Operation *op);
179
180 // Adds all the users of index-th result of op to the work list.
AddUserToList(Operation * op,int index)181 void AddUserToList(Operation *op, int index) {
182 for (auto *user : op->getResult(index).getUsers()) {
183 work_list_.push_back(user);
184 }
185 }
186
187 // Adds the defining op of index-th operand of op to the work list.
AddOperandToList(Operation * op,int index)188 void AddOperandToList(Operation *op, int index) {
189 if (auto *inst = op->getOperand(index).getDefiningOp()) {
190 work_list_.push_back(inst);
191 }
192 }
193
194 // Returns the quantization params for the bias input from the non-bias
195 // operands which have their indexes in the `non_biases` vector. The returned
196 // parameters are calculated by `func`.
197 QuantParams GetBiasParams(Operation *op, int bias,
198 const std::vector<int> &non_biases,
199 AccumulatorScaleFunc func);
200
201 // Sets the quantization parameters of the result to a fixed value. If any
202 // quantization parameters have been propagated, a `requantize` will happen on
203 // the input of propagated quantization.
204 bool SetResultParams(Operation *op, int index, QuantParams params);
205
206 // Sets the quantization parameters of the operand to a fixed value. If any
207 // quantization parameters have been propagated, a `requantize` will happen on
208 // the output of propagated quantization. When `override` is set, quantization
209 // state of the value is replaced instead of adding requantization.
210 bool SetOperandParams(Operation *op, int index, QuantParams params,
211 bool override = false);
212
213 // Sets the quantization parameters of the constant result according to its
214 // content.
215 bool SetConstantResultParams(Operation *op);
216
217 // Inserts the Quantize and Dequantize ops for quantizing the index-th result
218 // of the op.
219 void QuantizeOpResult(Operation *op, int index, QuantParams params);
220
221 void QuantizeArg(BlockArgument arg, QuantParams params);
222
223 // Inserts the Quantize and Dequantize ops to quantize the value and returns
224 // the Quantize op.
225 void QuantizeValue(Value value, QuantParams params, Location loc);
226
227 // Inserts the Quantize ops for requantizing the index-th result of the op.
228 void RequantizeOpResult(Operation *op, int index, RequantizeStates *states);
229
230 // Inserts the Quantize ops for requantizing a block argument.
231 void RequantizeArg(BlockArgument arg, RequantizeStates *states);
232
233 // Inserts the Quantize and Dequantize ops to quantize the value and returns
234 // the Quantize op.
235 void RequantizeValue(Value value, RequantizeStates *states, Location loc);
236
237 // A heuristic to get the quantization parameter satisfies the same scale
238 // constraints for the op. Returns an empty option if this quantization
239 // parameter doesn't exist.
240 QuantParams GetQuantParamsForSameScaleConstraint(Operation *op);
241
242 // Returns the state of the index-th operand of the op.
GetOperandQuantState(Operation * op,int index)243 QuantState &GetOperandQuantState(Operation *op, int index) {
244 return states_[operand_states_[{op, index}]];
245 }
246
247 // Returns the state of the index-th result of the op.
GetResultQuantState(Operation * op,int index)248 QuantState &GetResultQuantState(Operation *op, int index) {
249 return states_[result_states_[{op, index}]];
250 }
251
252 // Returns the state of the block argument.
GetArgQuantState(BlockArgument arg)253 QuantState &GetArgQuantState(BlockArgument arg) {
254 return states_[arg_states_[arg]];
255 }
256
257 // Returns the states of the index-th operand of the op.
GetOperandRequantizeStates(Operation * op,int index)258 RequantizeStates &GetOperandRequantizeStates(Operation *op, int index) {
259 return rescale_states_[operand_states_[{op, index}]];
260 }
261
262 // Returns the states of the index-th result of the op.
GetResultRequantizeStates(Operation * op,int index)263 RequantizeStates &GetResultRequantizeStates(Operation *op, int index) {
264 return rescale_states_[result_states_[{op, index}]];
265 }
266
267 // Returns the states of the arg.
GetArgRequantizeStates(BlockArgument arg)268 RequantizeStates &GetArgRequantizeStates(BlockArgument arg) {
269 return rescale_states_[arg_states_[arg]];
270 }
271
272 // Uses the type of `val` to set the initial state of the index-th result if
273 // `as_result` is true or index-th operand if `as_result` is false. The state
274 // is immutable if the type is a quantized type. Returns the index of this
275 // new state in the state vector.
276 int InitializeState(Operation *op, int index, Value val, bool as_result);
277
278 // Sets the state of an argument. If this value is cached, uses the cached
279 // result without creating new entry in the state vector. Otherwise, allocate
280 // a new entry in the state vector.
InitializeArgState(BlockArgument arg,Value in)281 void InitializeArgState(BlockArgument arg, Value in) {
282 auto cached = value_to_state_.insert({in, 0});
283 if (!cached.second) {
284 arg_states_[arg] = cached.first->second;
285 return;
286 }
287 QuantParams params =
288 quant::QuantizedType::getQuantizedElementType(in.getType());
289 bool immutable = !EmptyParams(params);
290 int next_state_index = states_.size();
291 states_.push_back({params, immutable});
292 arg_states_[arg] = next_state_index;
293 cached.first->second = next_state_index;
294 }
295
296 // Sets the state of the index-th operand of the op. If this operand is
297 // cached, uses the cached result without creating new entry in the state
298 // vector. Otherwise, allocate a new entry in the state vector.
InitializeOperandState(Operation * op,int index,Value in)299 void InitializeOperandState(Operation *op, int index, Value in) {
300 auto cached = value_to_state_.insert({in, 0});
301 if (!cached.second) {
302 operand_states_[{op, index}] = cached.first->second;
303 return;
304 }
305 cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
306 }
307
308 // Sets the state of the index-th result of the op. If this result is cached,
309 // uses the cached result without creating new entry in the state vector.
310 // Otherwise, allocate a new entry in the state vector.
InitializeResultState(Operation * op,int index,Value res)311 void InitializeResultState(Operation *op, int index, Value res) {
312 auto cached = value_to_state_.insert({res, 0});
313 if (!cached.second) {
314 result_states_[{op, index}] = cached.first->second;
315 return;
316 }
317 cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
318 }
319
320 // Utility function for debug output for requantize states.
DumpRequantizeStates(const RequantizeStates & requantize_states)321 void DumpRequantizeStates(const RequantizeStates &requantize_states) {
322 for (auto &requantize_state : requantize_states) {
323 if (requantize_state.pos != RequantizeState::NO_REQUANTIZE) {
324 llvm::dbgs() << "+";
325 requantize_state.params.print(llvm::dbgs());
326 }
327 }
328 }
329
DumpStates(Operation * current_op)330 void DumpStates(Operation *current_op) {
331 if (current_op) {
332 llvm::dbgs() << "\n\n\n" << current_op->getName() << "\n";
333 }
334 fn_.walk([&](Operation *op) {
335 std::unique_ptr<OpQuantScaleSpec> scale_spec = GetQuantScaleSpec(op);
336 if (op->hasTrait<OpTrait::IsTerminator>() ||
337 (IsOpNotQuantizable(op) && !scale_spec->has_same_scale_requirement) ||
338 llvm::isa<quantfork::QuantizeCastOp, quantfork::DequantizeCastOp,
339 func::ConstantOp, arith::ConstantOp>(op)) {
340 return;
341 }
342 if (current_op == op) llvm::dbgs() << "===>>>";
343 llvm::dbgs() << op->getName() << " : (";
344 if (llvm::isa<func::FuncOp>(op)) {
345 for (auto &arg : fn_.getArguments()) {
346 if (auto params = GetArgQuantState(arg).params) {
347 params.print(llvm::dbgs());
348 DumpRequantizeStates(GetArgRequantizeStates(arg));
349 }
350 llvm::dbgs() << ",";
351 }
352 }
353 for (int i = 0, e = op->getNumOperands(); i < e; ++i) {
354 if (auto params = GetOperandQuantState(op, i).params) {
355 params.print(llvm::dbgs());
356 DumpRequantizeStates(GetOperandRequantizeStates(op, i));
357 } else {
358 op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
359 llvm::dbgs());
360 }
361 llvm::dbgs() << ",";
362 }
363 llvm::dbgs() << ") -> (";
364 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
365 if (auto params = GetResultQuantState(op, i).params) {
366 params.print(llvm::dbgs());
367 DumpRequantizeStates(GetResultRequantizeStates(op, i));
368 } else {
369 op->getResult(i).getType().cast<ShapedType>().getElementType().print(
370 llvm::dbgs());
371 }
372 llvm::dbgs() << ",";
373 }
374 llvm::dbgs() << ")\n";
375 });
376 }
377
378 func::FuncOp fn_;
379 OpBuilder builder_;
380 bool is_signed_;
381 bool disable_per_channel_;
382
383 // We should distinguish weights and bias constants. Biases are specified by
384 // the quantization spec or are the operands of ops with same scale spec. The
385 // rest are weights.
386 llvm::DenseSet<Operation *> weights_;
387
388 // The weights require narrow_range quantization. This map collects all the
389 // weight operands defined by the op quant spec. If the value of the entry is
390 // positive, per-channel quantization is required.
391 llvm::DenseMap<Operation *, int> optimized_weights_;
392
393 // All the ops needs to propagate the quantization parameters to.
394 std::vector<Operation *> work_list_;
395 std::unordered_set<Operation *> quantized_;
396
397 // The vector contains all the quantization parameters propagated from the
398 // defining operations of the value, or from the quantization aware training.
399 std::vector<QuantState> states_;
400
401 // The map contains all the quantization parameters which are required to
402 // satisfy the same operands and results constraint. The keys of this map are
403 // the values from `operand_states_` and `result_state_`.
404 std::unordered_map<int, RequantizeStates> rescale_states_;
405
406 // Maps of indexes to the propagation state vector from the ops operands,
407 // results and arguments.
408 llvm::DenseMap<OpValue, int> operand_states_;
409 llvm::DenseMap<OpValue, int> result_states_;
410 llvm::DenseMap<BlockArgument, int> arg_states_;
411 llvm::DenseMap<Value, int> value_to_state_;
412
413 // This vector is to preserve the arguments order, so the newly inserted
414 // quantized ops for the arguments are deterministically ordered.
415 llvm::SmallVector<BlockArgument, 4> args_;
416
417 OpQuantSpecGetter op_quant_spec_getter_;
418 OpQuantScaleSpecGetter op_quant_scale_spec_getter_;
419
420 // Infer output ranges for activation ops and constants. This is usually
421 // required for post-training quantization.
422 bool infer_tensor_range_;
423
424 // Calculate scales in float instead of double, so that the scales and
425 // quantized values are exactly the same with the TOCO quantizer.
426 bool legacy_float_scale_;
427 };
428 } // namespace
429
GetQuantSpec(Operation * op)430 std::unique_ptr<OpQuantSpec> QuantizationDriver::GetQuantSpec(Operation *op) {
431 return op_quant_spec_getter_(op);
432 }
433
GetQuantScaleSpec(Operation * op)434 std::unique_ptr<OpQuantScaleSpec> QuantizationDriver::GetQuantScaleSpec(
435 Operation *op) {
436 return op_quant_scale_spec_getter_(op);
437 }
438
IsQuantized(Operation * op)439 bool QuantizationDriver::IsQuantized(Operation *op) {
440 for (int i = 0, e = op->getNumResults(); i != e; ++i) {
441 if (GetResultQuantState(op, i).IsEmpty()) return false;
442 }
443 return true;
444 }
445
InitializeState(Operation * op,int index,Value val,bool as_result)446 int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
447 bool as_result) {
448 QuantParams params =
449 quant::QuantizedType::getQuantizedElementType(val.getType());
450 bool immutable = !EmptyParams(params);
451 int next_state_index = states_.size();
452 states_.push_back({params, immutable});
453 if (as_result)
454 result_states_[{op, index}] = next_state_index;
455 else
456 operand_states_[{op, index}] = next_state_index;
457
458 return next_state_index;
459 }
460
SetConstantResultParams(Operation * op)461 bool QuantizationDriver::SetConstantResultParams(Operation *op) {
462 DenseFPElementsAttr attr;
463 Value res = op->getResult(0);
464 if (!matchPattern(res, m_Constant(&attr))) {
465 return false;
466 }
467 // TODO(fengliuai): make storage_type_width and narrow_range configurable.
468 Type final_type;
469 auto it = optimized_weights_.find(op);
470 bool is_weight = it != optimized_weights_.end();
471 bool is_weight_with_per_channel_support =
472 is_weight && it->second != -1 && is_signed_;
473
474 if (is_weight_with_per_channel_support && !disable_per_channel_) {
475 // When `disable_per_channel_` is false, per-channel symmetric quantization
476 // parameters are created from the weights when the ops support per-channel
477 // quantization. Otherwise, uses per-tensor asymmetric quantization with
478 // narrow range.
479
480 // per-axis quantization weight, with symmetric min/max enforced.
481 final_type = GetUniformQuantizedPerAxisTypeForWeight(
482 attr, it->second, /*symmetric=*/true, /*num_bits=*/8, is_signed_,
483 /*narrow_range=*/true, legacy_float_scale_);
484 } else {
485 // per-tensor quantization weight
486 final_type = GetUniformQuantizedTypeForWeight(
487 attr, /*symmetric=*/is_weight && is_signed_,
488 /*num_bits=*/8, is_signed_,
489 /*narrow_range_=*/is_weight, legacy_float_scale_);
490 }
491 if (auto quant_type = final_type.dyn_cast_or_null<quant::QuantizedType>()) {
492 return SetResultParams(op, 0, quant_type);
493 }
494 return false;
495 }
496
SetResultParams(Operation * op,int res_index,QuantParams params)497 bool QuantizationDriver::SetResultParams(Operation *op, int res_index,
498 QuantParams params) {
499 auto &state = GetResultQuantState(op, res_index);
500 if (state.params == params) {
501 return false;
502 }
503 if (!state.IsEmpty()) {
504 auto &rescales = GetResultRequantizeStates(op, res_index);
505 RequantizeState &rescale = rescales.emplace_back();
506 rescale.pos = RequantizeState::ON_INPUT;
507 rescale.params = params;
508 return true;
509 }
510 state.params = params;
511 AddUserToList(op, res_index);
512 return true;
513 }
514
GetBiasParams(Operation * op,int bias,const std::vector<int> & non_biases,AccumulatorScaleFunc func)515 QuantParams QuantizationDriver::GetBiasParams(
516 Operation *op, int bias, const std::vector<int> &non_biases,
517 AccumulatorScaleFunc func) {
518 auto &bias_state = GetOperandQuantState(op, bias);
519 if (!bias_state.IsEmpty()) {
520 return bias_state.params;
521 }
522 std::vector<QuantParams> op_types;
523 op_types.reserve(non_biases.size());
524 for (auto non_bias : non_biases) {
525 auto &non_bias_type = GetOperandQuantState(op, non_bias);
526 op_types.push_back(non_bias_type.params);
527 }
528 if (op_types.empty()) return {};
529 return func(op_types, legacy_float_scale_);
530 }
531
SetOperandParams(Operation * op,int index,QuantParams params,bool override)532 bool QuantizationDriver::SetOperandParams(Operation *op, int index,
533 QuantParams params, bool override) {
534 auto &state = GetOperandQuantState(op, index);
535 if (state.params == params) {
536 return false;
537 }
538
539 if (!state.IsEmpty() && !override) {
540 auto &rescales = GetOperandRequantizeStates(op, index);
541 for (RequantizeState &rescale : rescales) {
542 if (rescale.params == params) {
543 rescale.users.emplace_back(op, index);
544 return true;
545 }
546 }
547 RequantizeState &rescale = rescales.emplace_back();
548 rescale.pos = RequantizeState::ON_OUTPUT;
549 rescale.params = params;
550 rescale.users.emplace_back(op, index);
551 return true;
552 }
553
554 state.params = params;
555 AddOperandToList(op, index);
556 return true;
557 }
558
QuantizeOpResult(Operation * op,int index,QuantParams params)559 void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
560 QuantParams params) {
561 builder_.setInsertionPointAfter(op);
562 Value original_result = op->getResult(index);
563 QuantizeValue(original_result, params, op->getLoc());
564 }
565
QuantizeArg(BlockArgument arg,QuantParams params)566 void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
567 builder_.setInsertionPointToStart(arg.getOwner());
568 QuantizeValue(arg, params, builder_.getUnknownLoc());
569 }
570
QuantizeValue(Value value,QuantParams params,Location loc)571 void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
572 Location loc) {
573 Type expressed_type = value.getType();
574 Type new_type = params.castFromExpressedType(expressed_type);
575 // This value isn't an expressed type (float), skip.
576 if (!new_type) return;
577 auto quantize =
578 builder_.create<quantfork::QuantizeCastOp>(loc, new_type, value);
579 auto dequantize = builder_.create<quantfork::DequantizeCastOp>(
580 loc, expressed_type, quantize.getResult());
581
582 // This attribute is set to distinguish the quantize ops being added by the
583 // quantization pass. These ops can be removed without losing original
584 // program accuracy.
585 // TODO(fengliuai): make the attribute being part of op definition.
586 quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
587
588 // `original_result` has a use to `quantize`, so this will replace that use
589 // by the result of `dequantize`. Remember to reset that use afterwards
590 value.replaceAllUsesWith(dequantize);
591 quantize.getOperation()->replaceUsesOfWith(dequantize, value);
592 }
593
RequantizeOpResult(Operation * op,int index,RequantizeStates * states)594 void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
595 RequantizeStates *states) {
596 if (states->empty()) return;
597
598 builder_.setInsertionPointAfter(op);
599 Value value = op->getResult(index);
600 RequantizeState::RequantizePosition pos = states->front().pos;
601 if (pos == RequantizeState::NO_REQUANTIZE) {
602 return;
603 }
604 for (auto &state : *states) {
605 // Check that all requantization positions are the same for each state.
606 // Unsure if this check is required.
607 if (state.pos != pos) {
608 return;
609 }
610 }
611 if (pos == RequantizeState::ON_OUTPUT) {
612 Operation *user = value.getUses().begin().getUser();
613 if (llvm::isa<quantfork::QuantizeCastOp>(user)) {
614 // The requantize op is inserted between `quantize` and `dequantize` ops.
615 value = user->getResult(0);
616 builder_.setInsertionPointAfter(user);
617 }
618 }
619 RequantizeValue(value, states, op->getLoc());
620 }
621
RequantizeArg(BlockArgument arg,RequantizeStates * states)622 void QuantizationDriver::RequantizeArg(BlockArgument arg,
623 RequantizeStates *states) {
624 Value value = arg;
625 builder_.setInsertionPointToStart(arg.getOwner());
626 if (value.hasOneUse()) {
627 auto user = value.use_begin().getUser();
628 if (auto q = llvm::dyn_cast<quantfork::QuantizeCastOp>(user)) {
629 value = q.getResult();
630 builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
631 }
632 }
633 RequantizeValue(value, states, builder_.getUnknownLoc());
634 }
635
RequantizeValue(Value value,RequantizeStates * states,Location loc)636 void QuantizationDriver::RequantizeValue(Value value, RequantizeStates *states,
637 Location loc) {
638 if (states->empty() ||
639 states->front().pos == RequantizeState::NO_REQUANTIZE) {
640 return;
641 }
642 if (states->front().pos == RequantizeState::ON_INPUT) {
643 auto &state = states->front();
644 Type expressed_type = value.getType();
645 // The value needs to be requantized. A Quantize op will be created to use
646 // it as the operand and replace its uses.
647 Type new_type = state.params.castFromExpressedType(expressed_type);
648 if (!new_type) return;
649 auto requantize_op =
650 builder_.create<quantfork::QuantizeCastOp>(loc, new_type, value);
651 value.replaceAllUsesWith(requantize_op);
652 requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
653 // This requantization was defined as required for the result value, so
654 // there should be only one requant state.
655 return;
656 }
657
658 // If this is an operand that requires requantization, then the value should
659 // only have one DequantizeCastOp user which produces the operand value.
660 if (!value.hasOneUse()) {
661 return;
662 }
663 auto dequant_op = llvm::dyn_cast_or_null<quantfork::DequantizeCastOp>(
664 value.use_begin().getUser());
665 if (!dequant_op) {
666 return;
667 }
668 // It is possible that the dequant value is used by a op that doesn't require
669 // requant, so only overwrite the first if that is not the case.
670 const int num_uses = std::distance(dequant_op.getResult().use_begin(),
671 dequant_op.getResult().use_end());
672
673 // Whether to replace quantization params of the first dequantize op
674 // after the quantized value is produced.
675 // If there is a use other than the requantize states, then we can't clobber.
676 bool clobber_first = num_uses <= states->size();
677 for (auto &state : *states) {
678 Type expressed_type =
679 quant::QuantizedType::castToExpressedType(value.getType());
680 if (!expressed_type) continue;
681 // The value needs to be requantized. A Quantize op will be created to use
682 // it as the operand and replace its uses.
683 Type new_type = state.params.castFromExpressedType(expressed_type);
684 // This value isn't an expressed type (float), skip.
685 if (!new_type) continue;
686
687 auto requantize_op =
688 builder_.create<quantfork::QuantizeCastOp>(loc, new_type, value);
689
690 if (clobber_first) {
691 dequant_op.setOperand(requantize_op.getResult());
692 // All ops requiring this value already use the result of dequant.
693 clobber_first = false;
694 } else {
695 auto new_dequant_op = builder_.create<quantfork::DequantizeCastOp>(
696 loc, dequant_op.getResult().getType(), requantize_op.getResult());
697 for (auto &op_index : state.users) {
698 op_index.first->setOperand(op_index.second, new_dequant_op.getResult());
699 }
700 }
701 }
702 }
703
704 // A heuristic to get quantization parameters satisfies the same scale
705 // constraints:
706 // - If there are immutable states,
707 // - use the single input, or,
708 // - use the single output, or,
709 // - use the first one in the collection,
710 // - use the single input if it is ready, or,
711 // - use the single output if it is ready, or,
712 // - use the first ready one in the collection.
GetQuantParamsForSameScaleConstraint(Operation * op)713 QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint(
714 Operation *op) {
715 // Two vector to collect Non-empty operands and results states.
716 std::vector<QuantState *> mutable_states, immutable_states;
717 for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
718 auto &state = GetOperandQuantState(op, i);
719 if (state.immutable) {
720 immutable_states.push_back(&state);
721 } else if (!state.IsEmpty()) {
722 mutable_states.push_back(&state);
723 }
724 }
725
726 int immutable_operands_num = immutable_states.size();
727 int mutable_operands_num = mutable_states.size();
728 // Use the operand's state if it is immutable and it is the only one
729 // operand.
730 if (op->getNumOperands() == 1 && immutable_operands_num == 1) {
731 return immutable_states.front()->params;
732 }
733
734 for (int i = 0, e = op->getNumResults(); i != e; ++i) {
735 auto &state = GetResultQuantState(op, i);
736 if (state.immutable) {
737 immutable_states.push_back(&state);
738 } else if (!state.IsEmpty()) {
739 mutable_states.push_back(&state);
740 }
741 }
742
743 int immutable_results_num = immutable_states.size() - immutable_operands_num;
744 int mutable_results_num = mutable_states.size() - mutable_operands_num;
745 // Use the result's state if it is immutable and it is the only one result.
746 if (op->getNumResults() == 1 && immutable_results_num == 1) {
747 return immutable_states.back()->params;
748 }
749
750 // Use the first immutable state to quantize the rest operands and results.
751 if (!immutable_states.empty()) return immutable_states.front()->params;
752
753 // If there are no immutable states, use the operand's state if it is the
754 // only one operand and has parameters propagated.
755 if (op->getNumOperands() == 1 && mutable_operands_num == 1) {
756 return mutable_states.front()->params;
757 }
758
759 // If there are no immutable states, use the result's state if it is the
760 // only one result and has parameters propagated.
761 if (op->getNumResults() == 1 && mutable_results_num == 1) {
762 return mutable_states.back()->params;
763 }
764
765 // Use the first propagated state to quantize the rest operands and results.
766 if (!mutable_states.empty()) return mutable_states.front()->params;
767
768 // None operands/results have parameters propagated, skip this node for now.
769 return {};
770 }
771
PreprocessConstantOps()772 void QuantizationDriver::PreprocessConstantOps() {
773 fn_.walk([&](arith::ConstantOp cst) {
774 // Non-float tensors are neither weights nor require quantization.
775 auto type = cst.getType().dyn_cast<ShapedType>();
776 if (!type || !type.getElementType().isa<FloatType>()) return;
777
778 Value value = cst.getResult();
779 builder_.setInsertionPoint(cst);
780
781 // The following loop will change the value uses, thus we cache all the uses
782 // needs to be changed.
783 llvm::SmallVector<std::pair<Operation *, int>, 4> uses;
784 for (auto &use : value.getUses()) {
785 uses.push_back({use.getOwner(), use.getOperandNumber()});
786 }
787 for (const auto &indexed_use : llvm::enumerate(uses)) {
788 Operation *user = indexed_use.value().first;
789 int operand_num = indexed_use.value().second;
790
791 std::unique_ptr<OpQuantSpec> spec = GetQuantSpec(user);
792 std::unique_ptr<OpQuantScaleSpec> scale_spec = GetQuantScaleSpec(user);
793 BiasParamsMap biases = spec->biases_params;
794
795 // The quantization parameters of a `weight` shouldn't be determined by
796 // other values. So any constants which are not bias, an operand of an
797 // op with same scale requirements, and haven't been quantized are
798 // weights.
799 if (biases.find(operand_num) == biases.end() &&
800 !scale_spec->has_same_scale_requirement &&
801 !llvm::dyn_cast<quantfork::QuantizeCastOp>(user)) {
802 // Needs to scan the content of weights to get the quantization
803 // parameters if there are no quantization parameters (FakeQuant ops).
804 // For this case, the weight will not be duplicated.
805 weights_.insert(cst);
806 if (spec->coeff_op_quant_dim.find(operand_num) !=
807 spec->coeff_op_quant_dim.end()) {
808 optimized_weights_.insert(
809 {cst, spec->coeff_op_quant_dim[operand_num]});
810 }
811 } else {
812 // This is a bias or an operand of an op with same scale requirements,
813 // so the quantization parameter are propagated from or determined by
814 // other values. Duplicate this constant in case it is shared by
815 // different users.
816 if (uses.size() > 1) {
817 auto new_cst =
818 builder_.create<arith::ConstantOp>(cst.getLoc(), cst.getValue());
819 user->setOperand(operand_num, new_cst);
820 }
821 }
822 }
823 });
824 }
825
SetupAllStates()826 void QuantizationDriver::SetupAllStates() {
827 for (auto arg : fn_.getArguments()) {
828 args_.push_back(arg);
829 Value value = arg;
830 // If the argument is quantized, it should only has one user.
831 if (arg.hasOneUse()) {
832 auto user = value.use_begin().getUser();
833 if (auto q = llvm::dyn_cast<quantfork::QuantizeCastOp>(user)) {
834 value = q.getResult();
835 }
836 }
837 InitializeArgState(arg, value);
838 }
839
840 fn_.walk([&](Operation *op) {
841 std::unique_ptr<OpQuantScaleSpec> scale_spec = GetQuantScaleSpec(op);
842 if (IsOpNotQuantizable(op) && !scale_spec->has_same_scale_requirement) {
843 return;
844 }
845 work_list_.push_back(op);
846
847 for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
848 auto operand = op->getOperand(i);
849 if (auto *inst = operand.getDefiningOp()) {
850 // If the operand comes from a tfl.dequantize op, we use the quantized
851 // input of this tfl.dequantize op to set the state.
852 if (auto dq = llvm::dyn_cast<quantfork::DequantizeCastOp>(inst)) {
853 operand = dq.getArg();
854 }
855 }
856 InitializeOperandState(op, i, operand);
857 }
858
859 for (int res = 0, e = op->getNumResults(); res != e; ++res) {
860 Value result = op->getResult(res);
861 // If the result has been quantized, it should only be used by a
862 // tfl.quantize op. For this case, we uses the quantized result to
863 // create the state and mark it immutable.
864 if (result.hasOneUse()) {
865 auto user = result.use_begin().getUser();
866 if (auto q = llvm::dyn_cast<quantfork::QuantizeCastOp>(user)) {
867 result = q.getResult();
868 }
869 }
870 InitializeResultState(op, res, result);
871 }
872 });
873 }
874
875 // This method scans the operations in the function to setup the initial
876 // states for quantization parameter propagation.
877 // TODO(fengliuai): This algorithm assumes there are only one pair of
878 // tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
879 // check should be applied.
Initialize()880 void QuantizationDriver::Initialize() {
881 // Duplicate the bias constant, so the states can be setup correctly.
882 // TODO(fengliuai): Function definition should also be duplicated if there
883 // are multiple call sites.
884 PreprocessConstantOps();
885
886 // Setup all the internal states.
887 SetupAllStates();
888 }
889
PropagateParams()890 bool QuantizationDriver::PropagateParams() {
891 // TODO(fengliuai): uses a typed indicator instead of a bool value.
892 bool changed = false;
893 while (!work_list_.empty()) {
894 Operation *op = work_list_.back();
895 work_list_.pop_back();
896
897 LLVM_DEBUG(DumpStates(op));
898
899 // This op has been quantized, so we should not consider it again.
900 if (llvm::is_contained(quantized_, op)) continue;
901 quantized_.insert(op);
902
903 if (auto cst = llvm::dyn_cast<arith::ConstantOp>(op)) {
904 // If the workflow requires inferring ranges from the content
905 // (post-training quantization) and it is weight (filter) and hasn't
906 // been quantized, we infer the quantization parameters from the content.
907 if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) {
908 // The quantization parameters are determined by the content of the
909 // constant.
910 changed |= SetConstantResultParams(op);
911 }
912 continue;
913 }
914
915 std::unique_ptr<OpQuantScaleSpec> scale_spec = GetQuantScaleSpec(op);
916
917 if (scale_spec->has_same_scale_requirement) {
918 auto params = GetQuantParamsForSameScaleConstraint(op);
919 // The quantization parameters haven't been propagated to any operands
920 // or results. Skip this node for now.
921 if (!params) {
922 quantized_.erase(op);
923 continue;
924 }
925
926 // Use the final state to set all the operands' parameters.
927 for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
928 if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
929 // Without this check, it will accidentally propagate the quantization
930 // information by the shared non-float tensors.
931 if (type.getElementType().isa<FloatType>())
932 changed |= SetOperandParams(op, i, params);
933 }
934 }
935
936 // Use the final state to set all the results' parameters.
937 for (int res = 0, e = op->getNumResults(); res != e; ++res)
938 if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
939 // Without this check, it will accidentally propagate the quantization
940 // information by the shared non-float-tensors.
941 if (type.getElementType().isa<FloatType>())
942 changed |= SetResultParams(op, res, params);
943 }
944 }
945
946 // TODO(fengliuai): make the bit width configurable.
947 if (scale_spec->has_fixed_output_range && infer_tensor_range_) {
948 // Infer ranges from the activation ops. This is usually required for
949 // the post-training quantization workflow.
950 // TODO(fengliuai): different result can have different fixed range.
951 auto params = scale_spec->fixed_output_range_func(is_signed_,
952 /*bit_width=*/8);
953 for (auto i = 0; i < op->getNumResults(); ++i) {
954 // The range is null if the result has been quantized.
955 if (params) {
956 changed |= SetResultParams(op, i, params);
957 }
958 }
959 }
960
961 auto spec = GetQuantSpec(op);
962 for (auto &it : spec->biases_params) {
963 auto params =
964 GetBiasParams(op, it.first, it.second.first, it.second.second);
965 if (!params) {
966 quantized_.erase(op);
967 continue;
968 }
969 changed |=
970 SetBiasParamsWithAdjustments(op, it.first, it.second.first, params);
971 }
972 }
973
974 LLVM_DEBUG(llvm::dbgs() << "\n\n\n");
975 LLVM_DEBUG(DumpStates(nullptr));
976
977 return changed;
978 }
979
DuplicateConstantOpIfNeeded(arith::ConstantOp op,Operation * target_op,int operand_index)980 arith::ConstantOp QuantizationDriver::DuplicateConstantOpIfNeeded(
981 arith::ConstantOp op, Operation *target_op, int operand_index) {
982 if (op.getResult().hasOneUse()) {
983 return op;
984 }
985 OpBuilder builder(op->getContext());
986 builder.setInsertionPointAfter(op);
987 arith::ConstantOp new_op = llvm::cast<arith::ConstantOp>(builder.clone(*op));
988 target_op->getOpOperand(operand_index).set(new_op.getResult());
989 InitializeOperandState(target_op, operand_index, new_op.getResult());
990 InitializeResultState(new_op, 0, new_op.getResult());
991 return new_op;
992 }
993
ShouldCheckBiasScale(Operation * op,int bias_index,const std::vector<int> & input_indices,QuantParams params,int & input_index,int & filter_index)994 bool QuantizationDriver::ShouldCheckBiasScale(
995 Operation *op, int bias_index, const std::vector<int> &input_indices,
996 QuantParams params, int &input_index, int &filter_index) {
997 // For now, restrict scale adjustment to ops with affine quantized weights,
998 // and having weights and biases as constants. This currently only applies to
999 // FC and Conv* ops. Restriction for the weight can be relaxed if there are
1000 // needs for adjusting scale of variable weights.
1001 auto affine_op = llvm::dyn_cast<AffineQuantizedOpInterface>(op);
1002 auto bias_op = op->getOperand(bias_index).getDefiningOp<arith::ConstantOp>();
1003 if (!affine_op || !bias_op || input_indices.size() != 2) return false;
1004 if (!bias_op.getValue().isa<DenseFPElementsAttr>()) return false;
1005 filter_index = affine_op.GetAffineOperandIndex();
1006 if (!op->getOperand(filter_index).getDefiningOp<arith::ConstantOp>()) {
1007 return false;
1008 }
1009 if (filter_index == input_indices[0]) {
1010 input_index = input_indices[1];
1011 } else if (filter_index == input_indices[1]) {
1012 input_index = input_indices[0];
1013 } else {
1014 return false;
1015 }
1016
1017 auto input_state = GetOperandQuantState(op, input_index);
1018 auto filter_state = GetOperandQuantState(op, filter_index);
1019 // If quantization paramater for the filter is fixed, should return it as-is.
1020 // Only checks ops with 8-bit input and weights, and 32-bit biases.
1021 if (!(input_state.params.getStorageTypeIntegralWidth() == 8 &&
1022 filter_state.params.getStorageTypeIntegralWidth() == 8 &&
1023 params.getStorageTypeIntegralWidth() == 32)) {
1024 return false;
1025 }
1026 return true;
1027 }
1028
SetBiasParamsWithAdjustments(Operation * op,int bias_index,const std::vector<int> & input_indices,QuantParams params)1029 bool QuantizationDriver::SetBiasParamsWithAdjustments(
1030 Operation *op, int bias_index, const std::vector<int> &input_indices,
1031 QuantParams params) {
1032 bool changed = false;
1033 int input_index;
1034 int filter_index;
1035 if (!ShouldCheckBiasScale(op, bias_index, input_indices, params, input_index,
1036 filter_index)) {
1037 return SetOperandParams(op, bias_index, params);
1038 }
1039
1040 quant::QuantState input_state = GetOperandQuantState(op, input_index);
1041 quant::QuantState filter_state = GetOperandQuantState(op, filter_index);
1042 auto bias_op = op->getOperand(bias_index).getDefiningOp<arith::ConstantOp>();
1043 const double input_scale =
1044 input_state.params.cast<UniformQuantizedType>().getScale();
1045
1046 auto bias_values = bias_op.getValue().cast<DenseFPElementsAttr>();
1047 // Restrict maximum absolute value of bias within INT_MAX / 2, to make some
1048 // room for accumulator.
1049 const int32_t kBiasMax = std::numeric_limits<int32_t>::max() / 2;
1050 if (auto bias_params = params.dyn_cast<UniformQuantizedType>()) {
1051 double bias_half_range = 0.0f;
1052 for (auto bias : bias_values.getValues<APFloat>()) {
1053 if (bias_half_range < std::abs(bias.convertToFloat())) {
1054 bias_half_range = std::abs(bias.convertToFloat());
1055 }
1056 }
1057 if (bias_half_range / bias_params.getScale() < kBiasMax) {
1058 return SetOperandParams(op, bias_index, params);
1059 }
1060 double new_bias_scale = static_cast<double>(bias_half_range) / kBiasMax;
1061
1062 changed |= SetOperandParams(
1063 op, bias_index,
1064 UniformQuantizedType::getChecked(
1065 bias_op->getLoc(), params.getFlags(), params.getStorageType(),
1066 params.getExpressedType(), new_bias_scale, 0,
1067 params.getStorageTypeMin(), params.getStorageTypeMax()));
1068 auto filter_op = DuplicateConstantOpIfNeeded(
1069 op->getOperand(filter_index).getDefiningOp<arith::ConstantOp>(), op,
1070 filter_index);
1071 if (!filter_op) {
1072 return SetOperandParams(op, bias_index, params);
1073 }
1074
1075 auto filter_param = filter_state.params.cast<UniformQuantizedType>();
1076 changed |= SetOperandParams(
1077 op, filter_index,
1078 UniformQuantizedType::getChecked(
1079 filter_op->getLoc(), filter_param.getFlags(),
1080 filter_param.getStorageType(), filter_param.getExpressedType(),
1081 new_bias_scale / input_scale, 0, filter_param.getStorageTypeMin(),
1082 filter_param.getStorageTypeMax()),
1083 /*override=*/true);
1084 } else if (auto bias_params =
1085 params.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
1086 auto filter_params =
1087 filter_state.params.cast<quant::UniformQuantizedPerAxisType>();
1088 std::vector<double> new_bias_scales = bias_params.getScales().vec();
1089 std::vector<double> new_filter_scales = filter_params.getScales().vec();
1090 bool needs_adjustment = false;
1091 for (int i = 0; i < bias_params.getScales().size(); ++i) {
1092 float abs_bias = std::abs(bias_values.getValues<float>()[i]);
1093 if (abs_bias / new_bias_scales[i] > kBiasMax) {
1094 new_bias_scales[i] = static_cast<double>(abs_bias) / kBiasMax;
1095 new_filter_scales[i] = new_bias_scales[i] / input_scale;
1096 needs_adjustment = true;
1097 }
1098 }
1099 if (!needs_adjustment) {
1100 return SetOperandParams(op, bias_index, params);
1101 }
1102 changed |= SetOperandParams(
1103 op, bias_index,
1104 quant::UniformQuantizedPerAxisType::getChecked(
1105 bias_op->getLoc(), params.getFlags(), params.getStorageType(),
1106 params.getExpressedType(), new_bias_scales,
1107 bias_params.getZeroPoints(), bias_params.getQuantizedDimension(),
1108 params.getStorageTypeMin(), params.getStorageTypeMax()));
1109
1110 auto filter_op = DuplicateConstantOpIfNeeded(
1111 op->getOperand(filter_index).getDefiningOp<arith::ConstantOp>(), op,
1112 filter_index);
1113 changed |= SetOperandParams(
1114 op, filter_index,
1115 quant::UniformQuantizedPerAxisType::getChecked(
1116 filter_op->getLoc(), filter_params.getFlags(),
1117 filter_params.getStorageType(), filter_params.getExpressedType(),
1118 new_filter_scales, filter_params.getZeroPoints(),
1119 filter_params.getQuantizedDimension(),
1120 filter_params.getStorageTypeMin(),
1121 filter_params.getStorageTypeMax()),
1122 /*override=*/true);
1123 }
1124 return changed;
1125 }
1126
Finalize()1127 void QuantizationDriver::Finalize() {
1128 for (auto arg : args_) {
1129 auto &state = GetArgQuantState(arg);
1130 auto &requantizes = GetArgRequantizeStates(arg);
1131 if (state.IsEmpty() || (state.immutable && requantizes.empty())) {
1132 continue;
1133 }
1134
1135 if (!state.immutable) {
1136 QuantizeArg(arg, state.params);
1137 }
1138
1139 if (!requantizes.empty()) {
1140 RequantizeArg(arg, &requantizes);
1141 }
1142 }
1143
1144 for (auto it : result_states_) {
1145 Operation *op = it.first.first;
1146 int res_index = it.first.second;
1147 auto &state = GetResultQuantState(op, res_index);
1148 auto &requantizes = GetResultRequantizeStates(op, res_index);
1149 if (state.IsEmpty() || (state.immutable && requantizes.empty())) {
1150 continue;
1151 }
1152
1153 if (!state.immutable) {
1154 QuantizeOpResult(op, res_index, state.params);
1155 }
1156
1157 if (!requantizes.empty()) {
1158 RequantizeOpResult(op, res_index, &requantizes);
1159 }
1160 }
1161 }
1162
Run()1163 void QuantizationDriver::Run() {
1164 Initialize();
1165 if (PropagateParams()) {
1166 Finalize();
1167 }
1168 }
1169
ApplyQuantizationParamsPropagation(mlir::func::FuncOp func,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,bool infer_tensor_ranges,bool legacy_float_scale)1170 void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed,
1171 bool disable_per_channel,
1172 OpQuantSpecGetter op_quant_spec_getter,
1173 bool infer_tensor_ranges,
1174 bool legacy_float_scale) {
1175 ApplyQuantizationParamsPropagation(
1176 func, is_signed, disable_per_channel, op_quant_spec_getter,
1177 GetDefaultQuantScaleSpec, infer_tensor_ranges, legacy_float_scale);
1178 }
1179
ApplyQuantizationParamsPropagation(mlir::func::FuncOp func,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,OpQuantScaleSpecGetter op_quant_scale_spec_getter,bool infer_tensor_ranges,bool legacy_float_scale)1180 void ApplyQuantizationParamsPropagation(
1181 mlir::func::FuncOp func, bool is_signed, bool disable_per_channel,
1182 OpQuantSpecGetter op_quant_spec_getter,
1183 OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges,
1184 bool legacy_float_scale) {
1185 QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
1186 op_quant_scale_spec_getter, infer_tensor_ranges,
1187 legacy_float_scale)
1188 .Run();
1189 }
1190
1191 } // namespace quant
1192 } // namespace mlir
1193