1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <utility>
18 #include <vector>
19
20 #include "absl/strings/string_view.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
28 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
29 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/OperationSupport.h" // from @llvm-project
34 #include "mlir/IR/PatternMatch.h" // from @llvm-project
35 #include "mlir/IR/Verifier.h" // from @llvm-project
36 #include "mlir/Pass/Pass.h" // from @llvm-project
37 #include "mlir/Pass/PassManager.h" // from @llvm-project
38 #include "mlir/Support/LLVM.h" // from @llvm-project
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
42 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
44 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h"
45 #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
49 #include "tensorflow/core/ir/importexport/convert_tensor.h"
50
51 namespace mlir {
52 namespace quant {
53 namespace {
54
55 constexpr char kQuantizeFuncName[] = "quantize_i8";
56 constexpr char kDequantizeFuncName[] = "dequantize_i8";
57 constexpr char kAttrMapAttribute[] = "attr_map";
58
59 class QuantizeCompositeFunctionsPass
60 : public mlir::PassWrapper<QuantizeCompositeFunctionsPass,
61 OperationPass<ModuleOp>> {
62 public:
63 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeCompositeFunctionsPass)
64
QuantizeCompositeFunctionsPass()65 explicit QuantizeCompositeFunctionsPass() {}
66
QuantizeCompositeFunctionsPass(QuantizationMethod quantization_method,OpSet target_opset)67 explicit QuantizeCompositeFunctionsPass(
68 QuantizationMethod quantization_method, OpSet target_opset) {
69 quantization_method_ = quantization_method;
70 target_opset_ = target_opset;
71 }
72
QuantizeCompositeFunctionsPass(const QuantizeCompositeFunctionsPass & other)73 QuantizeCompositeFunctionsPass(const QuantizeCompositeFunctionsPass& other) {
74 quantization_method_ = other.quantization_method_;
75 target_opset_ = other.target_opset_;
76 }
77
getArgument() const78 StringRef getArgument() const final {
79 // This is the argument used to refer to the pass in
80 // the textual format (on the commandline for example).
81 return "quant-quantize-composite-functions";
82 }
83
getDescription() const84 StringRef getDescription() const final {
85 // This is a brief description of the pass.
86 return "Quantize composite functions with QDQ input/outputs.";
87 }
88
getDependentDialects(DialectRegistry & registry) const89 void getDependentDialects(DialectRegistry& registry) const override {
90 registry.insert<TF::TensorFlowDialect, quant::QuantizationDialect,
91 quantfork::QuantizationForkDialect>();
92 }
93
94 private:
95 void runOnOperation() override;
96
97 // These flags are only used for testing purpose.
98 Option<QuantizationMethod> quantization_method_{
99 *this, "quantization-method",
100 llvm::cl::init(QuantizationMethod::kPostTrainingQuantization),
101 llvm::cl::desc("Choose quantization method."),
102 llvm::cl::values(
103 clEnumValN(QuantizationMethod::kPostTrainingQuantization, "ptq",
104 "Post-training static-range quantization"),
105 clEnumValN(QuantizationMethod::kDynamicRangeQuantization, "drq",
106 "Post-training dynamic-range quantizaiton"))};
107 Option<OpSet> target_opset_{
108 *this, "target-opset", llvm::cl::init(OpSet::TF),
109 llvm::cl::desc("Choose target opset."),
110 llvm::cl::values(
111 clEnumValN(OpSet::TF, "TF",
112 "Uses TF ops that mimic quantization behavior"),
113 clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"),
114 clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED",
115 "Uses TF Uniform Quantized ops"))};
116 };
117
CreateUniformQuantizedTypeParams(UniformQuantizedType qtype,Location loc,PatternRewriter & rewriter,Value & scale,Value & zero_point)118 LogicalResult CreateUniformQuantizedTypeParams(UniformQuantizedType qtype,
119 Location loc,
120 PatternRewriter& rewriter,
121 Value& scale,
122 Value& zero_point) {
123 TensorType scale_type = RankedTensorType::get({}, rewriter.getF32Type());
124 TensorType zero_point_type = scale_type.clone(rewriter.getI32Type());
125 scale = rewriter.create<TF::ConstOp>(
126 loc, scale_type,
127 DenseFPElementsAttr::get(scale_type,
128 {static_cast<float>(qtype.getScale())}));
129 zero_point = rewriter.create<TF::ConstOp>(
130 loc, zero_point_type,
131 DenseIntElementsAttr::get(zero_point_type,
132 {static_cast<int32_t>(qtype.getZeroPoint())}));
133 return success(scale && zero_point);
134 }
135
CreateUniformQuantizedPerAxisTypeParams(quant::UniformQuantizedPerAxisType qtype,Location loc,PatternRewriter & rewriter,Value & scale,Value & zero_point)136 LogicalResult CreateUniformQuantizedPerAxisTypeParams(
137 quant::UniformQuantizedPerAxisType qtype, Location loc,
138 PatternRewriter& rewriter, Value& scale, Value& zero_point) {
139 // Consuming op should already know about Quantized channel information,
140 // so not passing it during conversion. This design might change if needed.
141 ArrayRef<double> scales = qtype.getScales();
142 ArrayRef<int64_t> zero_points = qtype.getZeroPoints();
143 const int num_channels = scales.size();
144 TensorType scale_type = RankedTensorType::get(
145 {static_cast<int64_t>(num_channels)}, rewriter.getF32Type());
146 TensorType zero_point_type = scale_type.clone(rewriter.getI32Type());
147
148 llvm::SmallVector<float, 4> float_scales;
149 llvm::SmallVector<int32_t, 4> int32_zero_points;
150 float_scales.reserve(num_channels);
151 int32_zero_points.reserve(num_channels);
152 for (int i = 0; i < num_channels; ++i) {
153 float_scales.push_back(scales[i]);
154 int32_zero_points.push_back(zero_points[i]);
155 }
156 scale = rewriter.create<TF::ConstOp>(
157 loc, scale_type, DenseFPElementsAttr::get(scale_type, float_scales));
158 zero_point = rewriter.create<TF::ConstOp>(
159 loc, zero_point_type,
160 DenseIntElementsAttr::get(zero_point_type, int32_zero_points));
161 return success(scale && zero_point);
162 }
163
CreateQuantizationParams(QuantizedType elem_type,Location loc,PatternRewriter & rewriter,Value & scale,Value & zero_point)164 LogicalResult CreateQuantizationParams(QuantizedType elem_type, Location loc,
165 PatternRewriter& rewriter, Value& scale,
166 Value& zero_point) {
167 if (!elem_type) {
168 return failure();
169 }
170 if (auto qtype = elem_type.dyn_cast<UniformQuantizedType>()) {
171 return CreateUniformQuantizedTypeParams(qtype, loc, rewriter, scale,
172 zero_point);
173 } else if (auto qtype =
174 elem_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
175 return CreateUniformQuantizedPerAxisTypeParams(qtype, loc, rewriter, scale,
176 zero_point);
177 }
178 return failure();
179 }
180
181 // Replaces quant.qcast op to composite quantize_i8 function.
182 class ReplaceQuantizePattern
183 : public mlir::OpRewritePattern<quantfork::QuantizeCastOp> {
184 public:
ReplaceQuantizePattern(MLIRContext * context)185 explicit ReplaceQuantizePattern(MLIRContext* context)
186 : OpRewritePattern<quantfork::QuantizeCastOp>(context) {}
187
188 private:
matchAndRewrite(quantfork::QuantizeCastOp q_op,PatternRewriter & rewriter) const189 LogicalResult matchAndRewrite(quantfork::QuantizeCastOp q_op,
190 PatternRewriter& rewriter) const override {
191 auto output_type = q_op.getType().cast<TensorType>();
192 auto elem_type = output_type.getElementType().dyn_cast<QuantizedType>();
193 const Location loc = q_op->getLoc();
194 Value scale, zero_point;
195
196 if (failed(CreateQuantizationParams(elem_type, loc, rewriter, scale,
197 zero_point))) {
198 return failure();
199 }
200
201 SmallVector<Type> output_types = {
202 output_type.clone(elem_type.getStorageType())};
203 SmallVector<Value> args = {q_op.getArg(), scale, zero_point};
204 FlatSymbolRefAttr func_name =
205 FlatSymbolRefAttr::get(rewriter.getStringAttr(kQuantizeFuncName));
206
207 auto quantize_call = rewriter.create<TF::PartitionedCallOp>(
208 loc, output_types, args, func_name,
209 /*config=*/"", /*config_proto=*/"", /*executor_type=*/"");
210 auto scast_op = rewriter.create<quantfork::StorageCastOp>(
211 loc, output_type, quantize_call->getResult(0));
212 q_op->replaceAllUsesWith(scast_op);
213 return success();
214 }
215 };
216
217 // Replaces quant.dcast op to composite dequantize_i8 function.
218 class ReplaceDequantizePattern
219 : public mlir::OpRewritePattern<quantfork::DequantizeCastOp> {
220 public:
ReplaceDequantizePattern(MLIRContext * context)221 explicit ReplaceDequantizePattern(MLIRContext* context)
222 : OpRewritePattern<quantfork::DequantizeCastOp>(context) {}
223
224 private:
matchAndRewrite(quantfork::DequantizeCastOp dq_op,PatternRewriter & rewriter) const225 LogicalResult matchAndRewrite(quantfork::DequantizeCastOp dq_op,
226 PatternRewriter& rewriter) const override {
227 auto input_type = dq_op.getArg().getType().cast<TensorType>();
228 auto elem_type = input_type.getElementType().dyn_cast<QuantizedType>();
229 const Location loc = dq_op->getLoc();
230
231 Value scale, zero_point;
232 if (failed(CreateQuantizationParams(elem_type, loc, rewriter, scale,
233 zero_point))) {
234 return failure();
235 }
236
237 TensorType output_type = input_type.clone(elem_type.getStorageType());
238 auto scast_op = rewriter.create<quantfork::StorageCastOp>(loc, output_type,
239 dq_op.getArg());
240
241 FlatSymbolRefAttr func_name =
242 FlatSymbolRefAttr::get(rewriter.getStringAttr(kDequantizeFuncName));
243 SmallVector<Value> args = {scast_op->getResult(0), scale, zero_point};
244 auto dequantize_call = rewriter.create<TF::PartitionedCallOp>(
245 loc, dq_op.getResult().getType(), args, func_name,
246 /*config=*/"", /*config_proto=*/"", /*executor_type=*/"");
247 dq_op->replaceAllUsesWith(dequantize_call);
248 return success();
249 }
250 };
251
252 // Checks if input weights are quantized only. For now, weight index is only at
253 // the first index(rhs). Later this can be replaced to use a map that has weight
254 // index information for each op.
IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op)255 bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) {
256 bool has_quantized_types_for_weights = false;
257 for (int32_t cur_idx = 0; cur_idx < call_op.args().size(); cur_idx++) {
258 // Check if the only the weight index has QuantizeCastOp.
259 auto cur_op = dyn_cast_or_null<quantfork::QuantizeCastOp>(
260 call_op.args()[cur_idx].getDefiningOp());
261 if ((!cur_op && cur_idx == 1) || (cur_op && cur_idx != 1)) {
262 return false;
263 } else if (cur_op) {
264 // Check if the QuantizeCastOp has element type of quantized type.
265 if (!getElementTypeOrSelf(cur_op.getResult().getType())
266 .isa<QuantizedType>()) {
267 return false;
268 }
269 // Satisfies the input condition.
270 has_quantized_types_for_weights = true;
271 }
272 }
273 for (Value output : call_op.output()) {
274 if (auto type = output.getType().dyn_cast<TensorType>()) {
275 if (type.getElementType().isa<QuantizedType>()) {
276 return false;
277 }
278 }
279 }
280 return has_quantized_types_for_weights;
281 }
282
283 // Checks if all the inputs are quantized.
IsQuantizedCallforStaticRange(TF::PartitionedCallOp call_op)284 bool IsQuantizedCallforStaticRange(TF::PartitionedCallOp call_op) {
285 bool has_quantized_types = false;
286 for (Value input : call_op.args()) {
287 if (auto type = input.getType().dyn_cast<TensorType>()) {
288 if (type.getElementType().isa<FloatType>()) {
289 return false;
290 }
291 if (type.getElementType().isa<QuantizedType>()) {
292 has_quantized_types = true;
293 }
294 }
295 }
296 for (Value output : call_op.output()) {
297 if (auto type = output.getType().dyn_cast<TensorType>()) {
298 if (type.getElementType().isa<FloatType>()) {
299 return false;
300 }
301 if (type.getElementType().isa<QuantizedType>()) {
302 has_quantized_types = true;
303 }
304 }
305 }
306 return has_quantized_types;
307 }
308
309 // Converts the element type of the input tensor to the corresponding quantized
310 // version. Supports only int8 for now and returns nullptr if the input type is
311 // not supported.
ConvertIntToQint(ShapedType input_type,MLIRContext * ctx)312 ShapedType ConvertIntToQint(ShapedType input_type, MLIRContext* ctx) {
313 int bit_width;
314 bool is_signed;
315
316 Type ele_type = input_type.getElementType();
317 if (ele_type.isIntOrFloat()) {
318 bit_width = ele_type.getIntOrFloatBitWidth();
319 is_signed = ele_type.isSignlessIntOrFloat() || ele_type.isSignedInteger();
320 } else if (QuantizedType qtype = ele_type.dyn_cast<QuantizedType>()) {
321 bit_width = qtype.getStorageTypeIntegralWidth();
322 is_signed = qtype.isSigned();
323 } else {
324 return input_type;
325 }
326
327 Type new_storage_type;
328 if (is_signed) {
329 switch (bit_width) {
330 case 8:
331 new_storage_type = mlir::TF::Qint8Type::get(ctx);
332 break;
333 default:
334 return nullptr; // Not yet supported
335 }
336 } else {
337 return nullptr; // Not yet supported
338 }
339
340 input_type = input_type.clone(new_storage_type);
341 return input_type;
342 }
343
344 // Transfers the attributes of the corresponding ops from the float function to
345 // the quantized function using the attr_map attribute. In the quantized
346 // function, this map (map1) is in {attr_name_1: attr_identifier} format; and in
347 // the float function, this map (map2) is in {attr_identifier: attr_name_2}
348 // format. Where, the attribute identifiers should match between two maps,
349 // attr_name_1 is the name of the of the attribute needs to be set in the
350 // quantized function, attr_name_2 is the name of the attribute corresponding to
351 // the attribute identifier in the float function.
TransferAttributes(func::FuncOp float_func,func::FuncOp quantized_func)352 LogicalResult TransferAttributes(func::FuncOp float_func,
353 func::FuncOp quantized_func) {
354 // A map to find an attribute from its identifier.
355 llvm::StringMap<Attribute> identifier_to_attr;
356 for (Operation& inner_op : float_func.getBody().front().getOperations()) {
357 if (!inner_op.hasAttr(kAttrMapAttribute)) continue;
358 std::string attr_map_str =
359 inner_op.getAttrOfType<StringAttr>(kAttrMapAttribute).str();
360 for (absl::string_view element_str : absl::StrSplit(attr_map_str, ',')) {
361 std::vector<absl::string_view> key_and_value_pair =
362 absl::StrSplit(element_str, ':');
363 if (key_and_value_pair.size() != 2) {
364 float_func.emitError("The attr_map attribute is malformed");
365 return failure();
366 }
367 identifier_to_attr.insert(
368 {llvm::StringRef(std::string(key_and_value_pair[0])),
369 inner_op.getAttr(
370 llvm::StringRef(std::string(key_and_value_pair[1])))});
371 }
372 }
373
374 // Set the attributes for ops with the attr_map attribute.
375 for (Operation& inner_op : quantized_func.getBody().front().getOperations()) {
376 if (!inner_op.hasAttr(kAttrMapAttribute)) continue;
377
378 std::string attr_map_str =
379 inner_op.getAttrOfType<StringAttr>(kAttrMapAttribute).str();
380 for (absl::string_view element_str : absl::StrSplit(attr_map_str, ',')) {
381 std::vector<absl::string_view> key_and_value_pair =
382 absl::StrSplit(element_str, ':');
383 if (key_and_value_pair.size() != 2) {
384 float_func.emitError("The attr_map attribute is malformed");
385 return failure();
386 }
387 if (identifier_to_attr.count(
388 llvm::StringRef(std::string(key_and_value_pair[1]))) == 0) {
389 float_func.emitWarning(absl::StrCat("Using the default value for the '",
390 key_and_value_pair[0],
391 "' attribute"));
392 continue;
393 }
394 inner_op.setAttr(llvm::StringRef(std::string(key_and_value_pair[0])),
395 identifier_to_attr[llvm::StringRef(
396 std::string(key_and_value_pair[1]))]);
397 }
398 inner_op.removeAttr(kAttrMapAttribute);
399 }
400 return success();
401 }
402
403 // Unwraps quantization parameters of PartitionedCall ops with quantized
404 // input/outputs that are created from QuantizePass.
405 class QuantizeFunctionPattern
406 : public mlir::OpRewritePattern<TF::PartitionedCallOp> {
407 public:
QuantizeFunctionPattern(MLIRContext * context,QuantizationMethod quantization_method,OpSet target_opset)408 explicit QuantizeFunctionPattern(MLIRContext* context,
409 QuantizationMethod quantization_method,
410 OpSet target_opset)
411 : OpRewritePattern<TF::PartitionedCallOp>(context),
412 quantization_method_(quantization_method),
413 target_opset_(target_opset) {}
414
415 private:
416 QuantizationMethod quantization_method_ =
417 QuantizationMethod::kPostTrainingQuantization;
418 OpSet target_opset_ = OpSet::TF;
419
matchAndRewrite(TF::PartitionedCallOp call_op,PatternRewriter & rewriter) const420 LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op,
421 PatternRewriter& rewriter) const override {
422 const auto f_attr = call_op.fAttr().dyn_cast<FlatSymbolRefAttr>();
423 // removeAttr will return nullptr if no attribute was removed.
424 if (!call_op->removeAttr(kQuantTraitAttrName) || !f_attr) {
425 return failure();
426 }
427
428 // Determines if all required float input/outputs are now quantized.
429 bool has_quantized_types = false;
430 if (quantization_method_ == QuantizationMethod::kDynamicRangeQuantization) {
431 has_quantized_types = IsQuantizedCallforDynamicRange(call_op);
432 if (f_attr.getValue().startswith("composite_") && !has_quantized_types) {
433 call_op->emitError(
434 "Only quantizable ops need to be in composite function for dynamic"
435 "-range PTQ case.");
436 return failure();
437 }
438 } else {
439 has_quantized_types = IsQuantizedCallforStaticRange(call_op);
440 }
441
442 if (!f_attr.getValue().startswith("composite_") || !has_quantized_types) {
443 return failure();
444 }
445
446 SmallVector<Value, 4> args;
447 SmallVector<Value, 4> qparam_args;
448 for (Value arg : call_op.args()) {
449 if (const auto arg_type = arg.getType().dyn_cast<TensorType>()) {
450 QuantizedType qtype =
451 arg_type.getElementType().dyn_cast<QuantizedType>();
452 if (!qtype) continue;
453 if (!qtype.isa<UniformQuantizedType,
454 quant::UniformQuantizedPerAxisType>()) {
455 return failure();
456 }
457 Value scale, zero_point;
458 if (failed(CreateQuantizationParams(qtype, arg.getLoc(), rewriter,
459 scale, zero_point))) {
460 // As the quantized types are already checked, this is unexpected.
461 call_op->emitError(
462 "Failed to create quantization parameter for an argument.");
463 return failure();
464 }
465 qparam_args.push_back(scale);
466 qparam_args.push_back(zero_point);
467 }
468 }
469
470 for (Value result : call_op->getResults()) {
471 if (auto result_type = result.getType().dyn_cast<TensorType>()) {
472 QuantizedType qtype =
473 result_type.getElementType().dyn_cast<QuantizedType>();
474 if (!qtype) continue;
475 if (!qtype.isa<UniformQuantizedType,
476 quant::UniformQuantizedPerAxisType>()) {
477 return failure();
478 }
479 Value scale, zero_point;
480 if (failed(CreateQuantizationParams(qtype, result.getLoc(), rewriter,
481 scale, zero_point))) {
482 // As the quantized types are already checked, this is unexpected.
483 call_op->emitError(
484 "Failed to create quantization parameter for a result.");
485 return failure();
486 }
487 qparam_args.push_back(scale);
488 qparam_args.push_back(zero_point);
489 }
490 }
491
492 rewriter.setInsertionPoint(call_op);
493
494 for (Value arg : call_op.args()) {
495 TensorType arg_type = arg.getType().dyn_cast<TensorType>();
496 if (!arg_type) {
497 args.push_back(arg);
498 continue;
499 }
500 QuantizedType qtype = arg_type.getElementType().dyn_cast<QuantizedType>();
501 if (!qtype) {
502 args.push_back(arg);
503 continue;
504 }
505
506 quantfork::StorageCastOp scast_op;
507 if (quantization_method_ ==
508 QuantizationMethod::kDynamicRangeQuantization) {
509 ShapedType new_arg_type = ConvertIntToQint(arg_type.cast<ShapedType>(),
510 rewriter.getContext());
511 if (!new_arg_type) {
512 call_op->emitError(
513 "Failed to convert the type to the corresponding qtype.");
514 return failure();
515 }
516 scast_op = rewriter.create<quantfork::StorageCastOp>(
517 arg.getLoc(), new_arg_type.cast<TensorType>(), arg);
518 } else {
519 scast_op = rewriter.create<quantfork::StorageCastOp>(
520 arg.getLoc(), arg_type.clone(qtype.getStorageType()), arg);
521 }
522 args.push_back(scast_op.getResult());
523 }
524 args.insert(args.end(), qparam_args.begin(), qparam_args.end());
525 // For XLA opset, try to merge quantized functions with following Dequantize
526 // for optimization.
527 if (target_opset_ == OpSet::XLA) {
528 if (failed(mergeDequantizeOpFollowingQuantizedFunction(call_op, args,
529 rewriter))) {
530 return failure();
531 }
532 }
533 if (call_op->use_empty()) return success();
534
535 DenseMap<Value, quantfork::StorageCastOp> replace_map;
536 rewriter.setInsertionPointAfter(call_op);
537
538 SmallVector<Type, 4> result_types;
539 for (Value result : call_op->getResults()) {
540 TensorType result_type = result.getType().dyn_cast<TensorType>();
541 if (!result_type) {
542 result_types.push_back(result.getType());
543 continue;
544 }
545 QuantizedType qtype =
546 result_type.getElementType().dyn_cast<QuantizedType>();
547 if (!qtype) {
548 result_types.push_back(result_type);
549 continue;
550 }
551 auto scast_op = rewriter.create<quantfork::StorageCastOp>(
552 call_op.getLoc(), result_type, result);
553 replace_map.insert(std::make_pair(result, scast_op));
554
555 result_types.push_back(result_type.clone(qtype.getStorageType()));
556 }
557
558 for (auto replace_pair : replace_map) {
559 Value result = replace_pair.first;
560 quantfork::StorageCastOp scast_op = replace_pair.second;
561 result.replaceAllUsesExcept(scast_op, scast_op);
562 }
563
564 // Make a copy of the quantized function.
565 auto module = call_op->getParentOfType<ModuleOp>();
566 SymbolTable symbol_table(module);
567
568 mlir::func::FuncOp float_func =
569 dyn_cast<func::FuncOp>(symbol_table.lookup(f_attr.getValue()));
570 rewriter.setInsertionPointAfter(float_func);
571
572 // substr(10) == strip the "composite_" prefix.
573 const llvm::Twine quantized_function_name = llvm::Twine(
574 "quantized_", f_attr.getValue().substr(10).rsplit('_').first);
575 const mlir::func::FuncOp quantized_func = dyn_cast<func::FuncOp>(
576 symbol_table.lookup(quantized_function_name.str()));
577 mlir::func::FuncOp new_quantized_func =
578 dyn_cast<func::FuncOp>(quantized_func->clone());
579 if (new_quantized_func == nullptr) {
580 return failure();
581 }
582 new_quantized_func.setType(
583 FunctionType::get(getContext(), TypeRange{ValueRange{args}},
584 new_quantized_func.getResultTypes()));
585 for (auto [partitioned_call_arg, new_quantized_func_arg] :
586 llvm::zip_first(args, new_quantized_func.getArguments())) {
587 new_quantized_func_arg.setType(partitioned_call_arg.getType());
588 }
589
590 // Set the attributes for ops with the attr_map attribute.
591 if (failed(TransferAttributes(float_func, new_quantized_func))) {
592 return failure();
593 }
594
595 rewriter.setInsertionPoint(call_op);
596
597 const StringAttr new_quant_func_name =
598 symbol_table.insert(new_quantized_func);
599 rewriter.replaceOpWithNewOp<TF::PartitionedCallOp>(
600 call_op, result_types, args,
601 FlatSymbolRefAttr::get(new_quant_func_name));
602
603 return success();
604 }
605
606 // For composite functions followed by Dequantize ops, merges the Dequantize
607 // op into the functions by creating quantized functions with float output.
mergeDequantizeOpFollowingQuantizedFunction(TF::PartitionedCallOp call_op,const SmallVector<Value,4> & args,PatternRewriter & rewriter) const608 LogicalResult mergeDequantizeOpFollowingQuantizedFunction(
609 TF::PartitionedCallOp call_op, const SmallVector<Value, 4>& args,
610 PatternRewriter& rewriter) const {
611 bool followed_by_dequantize = false;
612 for (Operation* user : call_op->getUsers()) {
613 if (llvm::isa<quantfork::DequantizeCastOp>(user)) {
614 followed_by_dequantize = true;
615 break;
616 }
617 }
618 if (!followed_by_dequantize) return success();
619
620 rewriter.setInsertionPointAfter(call_op);
621 SmallVector<Type, 4> result_types;
622 for (Value result : call_op->getResults()) {
623 TensorType result_type = result.getType().dyn_cast<TensorType>();
624 if (!result_type) {
625 result_types.push_back(result.getType());
626 continue;
627 }
628 QuantizedType qtype =
629 result_type.getElementType().dyn_cast<QuantizedType>();
630 if (!qtype) {
631 result_types.push_back(result_type);
632 continue;
633 }
634
635 result_types.push_back(result_type.clone(qtype.getExpressedType()));
636 }
637
638 // Make a copy of the quantized function.
639 auto module = call_op->getParentOfType<ModuleOp>();
640 SymbolTable symbol_table(module);
641
642 const auto f_attr = call_op.fAttr().dyn_cast<FlatSymbolRefAttr>();
643 const auto float_func =
644 dyn_cast<func::FuncOp>(symbol_table.lookup(f_attr.getValue()));
645 rewriter.setInsertionPointAfter(float_func);
646
647 // substr(10) == strip the "composite_" prefix.
648 const std::string quantized_function_name =
649 "quantized_" + f_attr.getValue().substr(10).rsplit("_fn_").first.str() +
650 "_float_output_fn";
651 const auto quantized_func =
652 dyn_cast<func::FuncOp>(symbol_table.lookup(quantized_function_name));
653 auto new_quantized_func = dyn_cast<func::FuncOp>(quantized_func->clone());
654 if (new_quantized_func == nullptr) {
655 return failure();
656 }
657 new_quantized_func.setType(
658 FunctionType::get(getContext(), TypeRange{ValueRange{args}},
659 new_quantized_func.getResultTypes()));
660 for (auto [partitioned_call_arg, new_quantized_func_arg] :
661 llvm::zip_first(args, new_quantized_func.getArguments())) {
662 new_quantized_func_arg.setType(partitioned_call_arg.getType());
663 }
664
665 // Set the attributes for ops with the attr_map attribute.
666 if (failed(TransferAttributes(float_func, new_quantized_func))) {
667 return failure();
668 }
669
670 rewriter.setInsertionPoint(call_op);
671 const StringAttr new_quant_func_name =
672 symbol_table.insert(new_quantized_func);
673 auto quantized_call_op = rewriter.create<TF::PartitionedCallOp>(
674 call_op.getLoc(), result_types, args,
675 FlatSymbolRefAttr::get(new_quant_func_name));
676
677 for (int result_idx : llvm::seq<int>(0, call_op->getNumResults())) {
678 Value result = call_op->getResult(result_idx);
679 for (Operation* user : result.getUsers()) {
680 if (auto dequant_op =
681 llvm::dyn_cast<quantfork::DequantizeCastOp>(user)) {
682 dequant_op.getResult().replaceAllUsesWith(
683 quantized_call_op->getResult(result_idx));
684 }
685 }
686 }
687
688 return success();
689 }
690 };
691
692 // Converts const -> quant.qcast pattern to quantized constant, after
693 // quantization parameters are safely included to each quantize composite
694 // functions.
695 class QuantizeConstPattern
696 : public OpRewritePattern<quantfork::QuantizeCastOp> {
697 public:
698 // This pattern should have larger benefit than ReplaceQuantizePattern
QuantizeConstPattern(MLIRContext * context,QuantizationMethod quantization_method)699 explicit QuantizeConstPattern(MLIRContext* context,
700 QuantizationMethod quantization_method)
701 : OpRewritePattern<quantfork::QuantizeCastOp>(context, /*benefit=*/10),
702 quantization_method_(quantization_method) {}
703
704 private:
705 QuantizationMethod quantization_method_ =
706 QuantizationMethod::kPostTrainingQuantization;
matchAndRewrite(quantfork::QuantizeCastOp q_op,PatternRewriter & rewriter) const707 LogicalResult matchAndRewrite(quantfork::QuantizeCastOp q_op,
708 PatternRewriter& rewriter) const override {
709 DenseFPElementsAttr attr;
710 if (!matchPattern(q_op.getArg(), m_Constant(&attr))) {
711 return failure();
712 }
713
714 ShapedType tensor_qtype = q_op.getResult().getType().cast<ShapedType>();
715 Attribute tensor_proto_attr = Quantize(attr, tensor_qtype);
716 if (!tensor_proto_attr) {
717 return failure();
718 }
719
720 Type storage_type =
721 tensor_qtype.getElementType().cast<QuantizedType>().getStorageType();
722 ShapedType new_type = tensor_qtype.clone(storage_type);
723 Location loc = q_op.getArg().getLoc();
724 // Convert integer to quantized integer type. Currently only applied for
725 // dynamic range quantization case.
726 if (quantization_method_ == QuantizationMethod::kDynamicRangeQuantization) {
727 new_type = ConvertIntToQint(new_type, rewriter.getContext());
728 tensor_qtype = ConvertIntToQint(tensor_qtype, rewriter.getContext());
729
730 // TODO(b/225793355): It adds TensorProtoAttr to the constant as a
731 // workaround.
732 tensorflow::TensorProto tensor_proto;
733 if (!mlir::tfg::ConvertToTensorProto(tensor_proto_attr, &tensor_proto)
734 .ok()) {
735 return failure();
736 }
737
738 tensor_proto.set_dtype(tensorflow::DT_QINT8);
739
740 tensor_proto_attr = ElementsAttr(TF::TensorProtoAttr::get(
741 new_type, tensorflow::mangling_util::MangleTensor(tensor_proto)));
742 }
743 auto const_op =
744 rewriter.create<TF::ConstOp>(loc, new_type, tensor_proto_attr);
745 // Add scast op to match quantize -> composition pattern. The added scast
746 // is then removed by canonicalization. ([scast - scast] -> [])
747 auto scast_op = rewriter.create<quantfork::StorageCastOp>(
748 loc, tensor_qtype, const_op.output());
749 q_op->replaceAllUsesWith(scast_op);
750 return success();
751 }
752 };
753
754 static PassRegistration<QuantizeCompositeFunctionsPass> pass;
755
756 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.inc"
757
runOnOperation()758 void QuantizeCompositeFunctionsPass::runOnOperation() {
759 MLIRContext* ctx = &getContext();
760 ModuleOp module = getOperation();
761
762 PassManager pm(ctx);
763 // Intermediate output from QuantizePass will have PartitionedCall ops with
764 // quantized input and output types, which are not allowed in TF dialect.
765 // This can be removed when the composite call supports quantized types.
766 pm.enableVerifier(false);
767
768 QuantizationSpecs quant_specs;
769 if (quantization_method_ == QuantizationMethod::kDynamicRangeQuantization) {
770 quant_specs.weight_quantization = true;
771 quant_specs.inference_type = tensorflow::DT_QINT8;
772 pm.addNestedPass<func::FuncOp>(CreatePrepareQuantizeDRQPass());
773 } else {
774 pm.addNestedPass<func::FuncOp>(
775 CreatePrepareQuantizePass(quantization_method_));
776 }
777 pm.addNestedPass<func::FuncOp>(CreateQuantizePass(quant_specs));
778
779 pm.addNestedPass<func::FuncOp>(CreatePostQuantizePass());
780 if (failed(pm.run(module))) {
781 signalPassFailure();
782 }
783
784 RewritePatternSet patterns(ctx);
785 patterns.add<QuantizeFunctionPattern>(ctx, quantization_method_,
786 target_opset_);
787
788 if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
789 signalPassFailure();
790 }
791
792 // Constant quantization is a lossy transformation, so they are applied only
793 // after all the other patterns have been aplied.
794 RewritePatternSet patterns_2(ctx);
795 populateWithGenerated(patterns_2);
796 patterns_2.add<ReplaceQuantizePattern, ReplaceDequantizePattern>(ctx);
797 patterns_2.add<QuantizeConstPattern>(ctx, quantization_method_);
798 if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns_2))) ||
799 failed(verify(module))) {
800 signalPassFailure();
801 }
802 }
803
804 } // namespace
805
CreateQuantizeCompositeFunctionsPass(QuantizationMethod quantization_method,OpSet target_opset)806 std::unique_ptr<OperationPass<ModuleOp>> CreateQuantizeCompositeFunctionsPass(
807 QuantizationMethod quantization_method, OpSet target_opset) {
808 return std::make_unique<QuantizeCompositeFunctionsPass>(quantization_method,
809 target_opset);
810 }
811
812 } // namespace quant
813 } // namespace mlir
814