1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <utility>
18 #include <vector>
19 
20 #include "absl/strings/string_view.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
28 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
34 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
35 #include "mlir/IR/Verifier.h"  // from @llvm-project
36 #include "mlir/Pass/Pass.h"  // from @llvm-project
37 #include "mlir/Pass/PassManager.h"  // from @llvm-project
38 #include "mlir/Support/LLVM.h"  // from @llvm-project
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
42 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
44 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h"
45 #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
49 #include "tensorflow/core/ir/importexport/convert_tensor.h"
50 
51 namespace mlir {
52 namespace quant {
53 namespace {
54 
55 constexpr char kQuantizeFuncName[] = "quantize_i8";
56 constexpr char kDequantizeFuncName[] = "dequantize_i8";
57 constexpr char kAttrMapAttribute[] = "attr_map";
58 
59 class QuantizeCompositeFunctionsPass
60     : public mlir::PassWrapper<QuantizeCompositeFunctionsPass,
61                                OperationPass<ModuleOp>> {
62  public:
63   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeCompositeFunctionsPass)
64 
QuantizeCompositeFunctionsPass()65   explicit QuantizeCompositeFunctionsPass() {}
66 
QuantizeCompositeFunctionsPass(QuantizationMethod quantization_method,OpSet target_opset)67   explicit QuantizeCompositeFunctionsPass(
68       QuantizationMethod quantization_method, OpSet target_opset) {
69     quantization_method_ = quantization_method;
70     target_opset_ = target_opset;
71   }
72 
QuantizeCompositeFunctionsPass(const QuantizeCompositeFunctionsPass & other)73   QuantizeCompositeFunctionsPass(const QuantizeCompositeFunctionsPass& other) {
74     quantization_method_ = other.quantization_method_;
75     target_opset_ = other.target_opset_;
76   }
77 
getArgument() const78   StringRef getArgument() const final {
79     // This is the argument used to refer to the pass in
80     // the textual format (on the commandline for example).
81     return "quant-quantize-composite-functions";
82   }
83 
getDescription() const84   StringRef getDescription() const final {
85     // This is a brief description of the pass.
86     return "Quantize composite functions with QDQ input/outputs.";
87   }
88 
getDependentDialects(DialectRegistry & registry) const89   void getDependentDialects(DialectRegistry& registry) const override {
90     registry.insert<TF::TensorFlowDialect, quant::QuantizationDialect,
91                     quantfork::QuantizationForkDialect>();
92   }
93 
94  private:
95   void runOnOperation() override;
96 
97   // These flags are only used for testing purpose.
98   Option<QuantizationMethod> quantization_method_{
99       *this, "quantization-method",
100       llvm::cl::init(QuantizationMethod::kPostTrainingQuantization),
101       llvm::cl::desc("Choose quantization method."),
102       llvm::cl::values(
103           clEnumValN(QuantizationMethod::kPostTrainingQuantization, "ptq",
104                      "Post-training static-range quantization"),
105           clEnumValN(QuantizationMethod::kDynamicRangeQuantization, "drq",
106                      "Post-training dynamic-range quantizaiton"))};
107   Option<OpSet> target_opset_{
108       *this, "target-opset", llvm::cl::init(OpSet::TF),
109       llvm::cl::desc("Choose target opset."),
110       llvm::cl::values(
111           clEnumValN(OpSet::TF, "TF",
112                      "Uses TF ops that mimic quantization behavior"),
113           clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"),
114           clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED",
115                      "Uses TF Uniform Quantized ops"))};
116 };
117 
CreateUniformQuantizedTypeParams(UniformQuantizedType qtype,Location loc,PatternRewriter & rewriter,Value & scale,Value & zero_point)118 LogicalResult CreateUniformQuantizedTypeParams(UniformQuantizedType qtype,
119                                                Location loc,
120                                                PatternRewriter& rewriter,
121                                                Value& scale,
122                                                Value& zero_point) {
123   TensorType scale_type = RankedTensorType::get({}, rewriter.getF32Type());
124   TensorType zero_point_type = scale_type.clone(rewriter.getI32Type());
125   scale = rewriter.create<TF::ConstOp>(
126       loc, scale_type,
127       DenseFPElementsAttr::get(scale_type,
128                                {static_cast<float>(qtype.getScale())}));
129   zero_point = rewriter.create<TF::ConstOp>(
130       loc, zero_point_type,
131       DenseIntElementsAttr::get(zero_point_type,
132                                 {static_cast<int32_t>(qtype.getZeroPoint())}));
133   return success(scale && zero_point);
134 }
135 
CreateUniformQuantizedPerAxisTypeParams(quant::UniformQuantizedPerAxisType qtype,Location loc,PatternRewriter & rewriter,Value & scale,Value & zero_point)136 LogicalResult CreateUniformQuantizedPerAxisTypeParams(
137     quant::UniformQuantizedPerAxisType qtype, Location loc,
138     PatternRewriter& rewriter, Value& scale, Value& zero_point) {
139   // Consuming op should already know about Quantized channel information,
140   // so not passing it during conversion. This design might change if needed.
141   ArrayRef<double> scales = qtype.getScales();
142   ArrayRef<int64_t> zero_points = qtype.getZeroPoints();
143   const int num_channels = scales.size();
144   TensorType scale_type = RankedTensorType::get(
145       {static_cast<int64_t>(num_channels)}, rewriter.getF32Type());
146   TensorType zero_point_type = scale_type.clone(rewriter.getI32Type());
147 
148   llvm::SmallVector<float, 4> float_scales;
149   llvm::SmallVector<int32_t, 4> int32_zero_points;
150   float_scales.reserve(num_channels);
151   int32_zero_points.reserve(num_channels);
152   for (int i = 0; i < num_channels; ++i) {
153     float_scales.push_back(scales[i]);
154     int32_zero_points.push_back(zero_points[i]);
155   }
156   scale = rewriter.create<TF::ConstOp>(
157       loc, scale_type, DenseFPElementsAttr::get(scale_type, float_scales));
158   zero_point = rewriter.create<TF::ConstOp>(
159       loc, zero_point_type,
160       DenseIntElementsAttr::get(zero_point_type, int32_zero_points));
161   return success(scale && zero_point);
162 }
163 
CreateQuantizationParams(QuantizedType elem_type,Location loc,PatternRewriter & rewriter,Value & scale,Value & zero_point)164 LogicalResult CreateQuantizationParams(QuantizedType elem_type, Location loc,
165                                        PatternRewriter& rewriter, Value& scale,
166                                        Value& zero_point) {
167   if (!elem_type) {
168     return failure();
169   }
170   if (auto qtype = elem_type.dyn_cast<UniformQuantizedType>()) {
171     return CreateUniformQuantizedTypeParams(qtype, loc, rewriter, scale,
172                                             zero_point);
173   } else if (auto qtype =
174                  elem_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
175     return CreateUniformQuantizedPerAxisTypeParams(qtype, loc, rewriter, scale,
176                                                    zero_point);
177   }
178   return failure();
179 }
180 
181 // Replaces quant.qcast op to composite quantize_i8 function.
182 class ReplaceQuantizePattern
183     : public mlir::OpRewritePattern<quantfork::QuantizeCastOp> {
184  public:
ReplaceQuantizePattern(MLIRContext * context)185   explicit ReplaceQuantizePattern(MLIRContext* context)
186       : OpRewritePattern<quantfork::QuantizeCastOp>(context) {}
187 
188  private:
matchAndRewrite(quantfork::QuantizeCastOp q_op,PatternRewriter & rewriter) const189   LogicalResult matchAndRewrite(quantfork::QuantizeCastOp q_op,
190                                 PatternRewriter& rewriter) const override {
191     auto output_type = q_op.getType().cast<TensorType>();
192     auto elem_type = output_type.getElementType().dyn_cast<QuantizedType>();
193     const Location loc = q_op->getLoc();
194     Value scale, zero_point;
195 
196     if (failed(CreateQuantizationParams(elem_type, loc, rewriter, scale,
197                                         zero_point))) {
198       return failure();
199     }
200 
201     SmallVector<Type> output_types = {
202         output_type.clone(elem_type.getStorageType())};
203     SmallVector<Value> args = {q_op.getArg(), scale, zero_point};
204     FlatSymbolRefAttr func_name =
205         FlatSymbolRefAttr::get(rewriter.getStringAttr(kQuantizeFuncName));
206 
207     auto quantize_call = rewriter.create<TF::PartitionedCallOp>(
208         loc, output_types, args, func_name,
209         /*config=*/"", /*config_proto=*/"", /*executor_type=*/"");
210     auto scast_op = rewriter.create<quantfork::StorageCastOp>(
211         loc, output_type, quantize_call->getResult(0));
212     q_op->replaceAllUsesWith(scast_op);
213     return success();
214   }
215 };
216 
217 // Replaces quant.dcast op to composite dequantize_i8 function.
218 class ReplaceDequantizePattern
219     : public mlir::OpRewritePattern<quantfork::DequantizeCastOp> {
220  public:
ReplaceDequantizePattern(MLIRContext * context)221   explicit ReplaceDequantizePattern(MLIRContext* context)
222       : OpRewritePattern<quantfork::DequantizeCastOp>(context) {}
223 
224  private:
matchAndRewrite(quantfork::DequantizeCastOp dq_op,PatternRewriter & rewriter) const225   LogicalResult matchAndRewrite(quantfork::DequantizeCastOp dq_op,
226                                 PatternRewriter& rewriter) const override {
227     auto input_type = dq_op.getArg().getType().cast<TensorType>();
228     auto elem_type = input_type.getElementType().dyn_cast<QuantizedType>();
229     const Location loc = dq_op->getLoc();
230 
231     Value scale, zero_point;
232     if (failed(CreateQuantizationParams(elem_type, loc, rewriter, scale,
233                                         zero_point))) {
234       return failure();
235     }
236 
237     TensorType output_type = input_type.clone(elem_type.getStorageType());
238     auto scast_op = rewriter.create<quantfork::StorageCastOp>(loc, output_type,
239                                                               dq_op.getArg());
240 
241     FlatSymbolRefAttr func_name =
242         FlatSymbolRefAttr::get(rewriter.getStringAttr(kDequantizeFuncName));
243     SmallVector<Value> args = {scast_op->getResult(0), scale, zero_point};
244     auto dequantize_call = rewriter.create<TF::PartitionedCallOp>(
245         loc, dq_op.getResult().getType(), args, func_name,
246         /*config=*/"", /*config_proto=*/"", /*executor_type=*/"");
247     dq_op->replaceAllUsesWith(dequantize_call);
248     return success();
249   }
250 };
251 
252 // Checks if input weights are quantized only. For now, weight index is only at
253 // the first index(rhs). Later this can be replaced to use a map that has weight
254 // index information for each op.
IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op)255 bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) {
256   bool has_quantized_types_for_weights = false;
257   for (int32_t cur_idx = 0; cur_idx < call_op.args().size(); cur_idx++) {
258     // Check if the only the weight index has QuantizeCastOp.
259     auto cur_op = dyn_cast_or_null<quantfork::QuantizeCastOp>(
260         call_op.args()[cur_idx].getDefiningOp());
261     if ((!cur_op && cur_idx == 1) || (cur_op && cur_idx != 1)) {
262       return false;
263     } else if (cur_op) {
264       // Check if the QuantizeCastOp has element type of quantized type.
265       if (!getElementTypeOrSelf(cur_op.getResult().getType())
266                .isa<QuantizedType>()) {
267         return false;
268       }
269       // Satisfies the input condition.
270       has_quantized_types_for_weights = true;
271     }
272   }
273   for (Value output : call_op.output()) {
274     if (auto type = output.getType().dyn_cast<TensorType>()) {
275       if (type.getElementType().isa<QuantizedType>()) {
276         return false;
277       }
278     }
279   }
280   return has_quantized_types_for_weights;
281 }
282 
283 // Checks if all the inputs are quantized.
IsQuantizedCallforStaticRange(TF::PartitionedCallOp call_op)284 bool IsQuantizedCallforStaticRange(TF::PartitionedCallOp call_op) {
285   bool has_quantized_types = false;
286   for (Value input : call_op.args()) {
287     if (auto type = input.getType().dyn_cast<TensorType>()) {
288       if (type.getElementType().isa<FloatType>()) {
289         return false;
290       }
291       if (type.getElementType().isa<QuantizedType>()) {
292         has_quantized_types = true;
293       }
294     }
295   }
296   for (Value output : call_op.output()) {
297     if (auto type = output.getType().dyn_cast<TensorType>()) {
298       if (type.getElementType().isa<FloatType>()) {
299         return false;
300       }
301       if (type.getElementType().isa<QuantizedType>()) {
302         has_quantized_types = true;
303       }
304     }
305   }
306   return has_quantized_types;
307 }
308 
309 // Converts the element type of the input tensor to the corresponding quantized
310 // version. Supports only int8 for now and returns nullptr if the input type is
311 // not supported.
ConvertIntToQint(ShapedType input_type,MLIRContext * ctx)312 ShapedType ConvertIntToQint(ShapedType input_type, MLIRContext* ctx) {
313   int bit_width;
314   bool is_signed;
315 
316   Type ele_type = input_type.getElementType();
317   if (ele_type.isIntOrFloat()) {
318     bit_width = ele_type.getIntOrFloatBitWidth();
319     is_signed = ele_type.isSignlessIntOrFloat() || ele_type.isSignedInteger();
320   } else if (QuantizedType qtype = ele_type.dyn_cast<QuantizedType>()) {
321     bit_width = qtype.getStorageTypeIntegralWidth();
322     is_signed = qtype.isSigned();
323   } else {
324     return input_type;
325   }
326 
327   Type new_storage_type;
328   if (is_signed) {
329     switch (bit_width) {
330       case 8:
331         new_storage_type = mlir::TF::Qint8Type::get(ctx);
332         break;
333       default:
334         return nullptr;  // Not yet supported
335     }
336   } else {
337     return nullptr;  // Not yet supported
338   }
339 
340   input_type = input_type.clone(new_storage_type);
341   return input_type;
342 }
343 
344 // Transfers the attributes of the corresponding ops from the float function to
345 // the quantized function using the attr_map attribute. In the quantized
346 // function, this map (map1) is in {attr_name_1: attr_identifier} format; and in
347 // the float function, this map (map2) is in {attr_identifier: attr_name_2}
348 // format. Where, the attribute identifiers should match between two maps,
349 // attr_name_1 is the name of the of the attribute needs to be set in the
350 // quantized function, attr_name_2 is the name of the attribute corresponding to
351 // the attribute identifier in the float function.
TransferAttributes(func::FuncOp float_func,func::FuncOp quantized_func)352 LogicalResult TransferAttributes(func::FuncOp float_func,
353                                  func::FuncOp quantized_func) {
354   // A map to find an attribute from its identifier.
355   llvm::StringMap<Attribute> identifier_to_attr;
356   for (Operation& inner_op : float_func.getBody().front().getOperations()) {
357     if (!inner_op.hasAttr(kAttrMapAttribute)) continue;
358     std::string attr_map_str =
359         inner_op.getAttrOfType<StringAttr>(kAttrMapAttribute).str();
360     for (absl::string_view element_str : absl::StrSplit(attr_map_str, ',')) {
361       std::vector<absl::string_view> key_and_value_pair =
362           absl::StrSplit(element_str, ':');
363       if (key_and_value_pair.size() != 2) {
364         float_func.emitError("The attr_map attribute is malformed");
365         return failure();
366       }
367       identifier_to_attr.insert(
368           {llvm::StringRef(std::string(key_and_value_pair[0])),
369            inner_op.getAttr(
370                llvm::StringRef(std::string(key_and_value_pair[1])))});
371     }
372   }
373 
374   // Set the attributes for ops with the attr_map attribute.
375   for (Operation& inner_op : quantized_func.getBody().front().getOperations()) {
376     if (!inner_op.hasAttr(kAttrMapAttribute)) continue;
377 
378     std::string attr_map_str =
379         inner_op.getAttrOfType<StringAttr>(kAttrMapAttribute).str();
380     for (absl::string_view element_str : absl::StrSplit(attr_map_str, ',')) {
381       std::vector<absl::string_view> key_and_value_pair =
382           absl::StrSplit(element_str, ':');
383       if (key_and_value_pair.size() != 2) {
384         float_func.emitError("The attr_map attribute is malformed");
385         return failure();
386       }
387       if (identifier_to_attr.count(
388               llvm::StringRef(std::string(key_and_value_pair[1]))) == 0) {
389         float_func.emitWarning(absl::StrCat("Using the default value for the '",
390                                             key_and_value_pair[0],
391                                             "' attribute"));
392         continue;
393       }
394       inner_op.setAttr(llvm::StringRef(std::string(key_and_value_pair[0])),
395                        identifier_to_attr[llvm::StringRef(
396                            std::string(key_and_value_pair[1]))]);
397     }
398     inner_op.removeAttr(kAttrMapAttribute);
399   }
400   return success();
401 }
402 
403 // Unwraps quantization parameters of PartitionedCall ops with quantized
404 // input/outputs that are created from QuantizePass.
405 class QuantizeFunctionPattern
406     : public mlir::OpRewritePattern<TF::PartitionedCallOp> {
407  public:
QuantizeFunctionPattern(MLIRContext * context,QuantizationMethod quantization_method,OpSet target_opset)408   explicit QuantizeFunctionPattern(MLIRContext* context,
409                                    QuantizationMethod quantization_method,
410                                    OpSet target_opset)
411       : OpRewritePattern<TF::PartitionedCallOp>(context),
412         quantization_method_(quantization_method),
413         target_opset_(target_opset) {}
414 
415  private:
416   QuantizationMethod quantization_method_ =
417       QuantizationMethod::kPostTrainingQuantization;
418   OpSet target_opset_ = OpSet::TF;
419 
matchAndRewrite(TF::PartitionedCallOp call_op,PatternRewriter & rewriter) const420   LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op,
421                                 PatternRewriter& rewriter) const override {
422     const auto f_attr = call_op.fAttr().dyn_cast<FlatSymbolRefAttr>();
423     // removeAttr will return nullptr if no attribute was removed.
424     if (!call_op->removeAttr(kQuantTraitAttrName) || !f_attr) {
425       return failure();
426     }
427 
428     // Determines if all required float input/outputs are now quantized.
429     bool has_quantized_types = false;
430     if (quantization_method_ == QuantizationMethod::kDynamicRangeQuantization) {
431       has_quantized_types = IsQuantizedCallforDynamicRange(call_op);
432       if (f_attr.getValue().startswith("composite_") && !has_quantized_types) {
433         call_op->emitError(
434             "Only quantizable ops need to be in composite function for dynamic"
435             "-range PTQ case.");
436         return failure();
437       }
438     } else {
439       has_quantized_types = IsQuantizedCallforStaticRange(call_op);
440     }
441 
442     if (!f_attr.getValue().startswith("composite_") || !has_quantized_types) {
443       return failure();
444     }
445 
446     SmallVector<Value, 4> args;
447     SmallVector<Value, 4> qparam_args;
448     for (Value arg : call_op.args()) {
449       if (const auto arg_type = arg.getType().dyn_cast<TensorType>()) {
450         QuantizedType qtype =
451             arg_type.getElementType().dyn_cast<QuantizedType>();
452         if (!qtype) continue;
453         if (!qtype.isa<UniformQuantizedType,
454                        quant::UniformQuantizedPerAxisType>()) {
455           return failure();
456         }
457         Value scale, zero_point;
458         if (failed(CreateQuantizationParams(qtype, arg.getLoc(), rewriter,
459                                             scale, zero_point))) {
460           // As the quantized types are already checked, this is unexpected.
461           call_op->emitError(
462               "Failed to create quantization parameter for an argument.");
463           return failure();
464         }
465         qparam_args.push_back(scale);
466         qparam_args.push_back(zero_point);
467       }
468     }
469 
470     for (Value result : call_op->getResults()) {
471       if (auto result_type = result.getType().dyn_cast<TensorType>()) {
472         QuantizedType qtype =
473             result_type.getElementType().dyn_cast<QuantizedType>();
474         if (!qtype) continue;
475         if (!qtype.isa<UniformQuantizedType,
476                        quant::UniformQuantizedPerAxisType>()) {
477           return failure();
478         }
479         Value scale, zero_point;
480         if (failed(CreateQuantizationParams(qtype, result.getLoc(), rewriter,
481                                             scale, zero_point))) {
482           // As the quantized types are already checked, this is unexpected.
483           call_op->emitError(
484               "Failed to create quantization parameter for a result.");
485           return failure();
486         }
487         qparam_args.push_back(scale);
488         qparam_args.push_back(zero_point);
489       }
490     }
491 
492     rewriter.setInsertionPoint(call_op);
493 
494     for (Value arg : call_op.args()) {
495       TensorType arg_type = arg.getType().dyn_cast<TensorType>();
496       if (!arg_type) {
497         args.push_back(arg);
498         continue;
499       }
500       QuantizedType qtype = arg_type.getElementType().dyn_cast<QuantizedType>();
501       if (!qtype) {
502         args.push_back(arg);
503         continue;
504       }
505 
506       quantfork::StorageCastOp scast_op;
507       if (quantization_method_ ==
508           QuantizationMethod::kDynamicRangeQuantization) {
509         ShapedType new_arg_type = ConvertIntToQint(arg_type.cast<ShapedType>(),
510                                                    rewriter.getContext());
511         if (!new_arg_type) {
512           call_op->emitError(
513               "Failed to convert the type to the corresponding qtype.");
514           return failure();
515         }
516         scast_op = rewriter.create<quantfork::StorageCastOp>(
517             arg.getLoc(), new_arg_type.cast<TensorType>(), arg);
518       } else {
519         scast_op = rewriter.create<quantfork::StorageCastOp>(
520             arg.getLoc(), arg_type.clone(qtype.getStorageType()), arg);
521       }
522       args.push_back(scast_op.getResult());
523     }
524     args.insert(args.end(), qparam_args.begin(), qparam_args.end());
525     // For XLA opset, try to merge quantized functions with following Dequantize
526     // for optimization.
527     if (target_opset_ == OpSet::XLA) {
528       if (failed(mergeDequantizeOpFollowingQuantizedFunction(call_op, args,
529                                                              rewriter))) {
530         return failure();
531       }
532     }
533     if (call_op->use_empty()) return success();
534 
535     DenseMap<Value, quantfork::StorageCastOp> replace_map;
536     rewriter.setInsertionPointAfter(call_op);
537 
538     SmallVector<Type, 4> result_types;
539     for (Value result : call_op->getResults()) {
540       TensorType result_type = result.getType().dyn_cast<TensorType>();
541       if (!result_type) {
542         result_types.push_back(result.getType());
543         continue;
544       }
545       QuantizedType qtype =
546           result_type.getElementType().dyn_cast<QuantizedType>();
547       if (!qtype) {
548         result_types.push_back(result_type);
549         continue;
550       }
551       auto scast_op = rewriter.create<quantfork::StorageCastOp>(
552           call_op.getLoc(), result_type, result);
553       replace_map.insert(std::make_pair(result, scast_op));
554 
555       result_types.push_back(result_type.clone(qtype.getStorageType()));
556     }
557 
558     for (auto replace_pair : replace_map) {
559       Value result = replace_pair.first;
560       quantfork::StorageCastOp scast_op = replace_pair.second;
561       result.replaceAllUsesExcept(scast_op, scast_op);
562     }
563 
564     // Make a copy of the quantized function.
565     auto module = call_op->getParentOfType<ModuleOp>();
566     SymbolTable symbol_table(module);
567 
568     mlir::func::FuncOp float_func =
569         dyn_cast<func::FuncOp>(symbol_table.lookup(f_attr.getValue()));
570     rewriter.setInsertionPointAfter(float_func);
571 
572     // substr(10) == strip the "composite_" prefix.
573     const llvm::Twine quantized_function_name = llvm::Twine(
574         "quantized_", f_attr.getValue().substr(10).rsplit('_').first);
575     const mlir::func::FuncOp quantized_func = dyn_cast<func::FuncOp>(
576         symbol_table.lookup(quantized_function_name.str()));
577     mlir::func::FuncOp new_quantized_func =
578         dyn_cast<func::FuncOp>(quantized_func->clone());
579     if (new_quantized_func == nullptr) {
580       return failure();
581     }
582     new_quantized_func.setType(
583         FunctionType::get(getContext(), TypeRange{ValueRange{args}},
584                           new_quantized_func.getResultTypes()));
585     for (auto [partitioned_call_arg, new_quantized_func_arg] :
586          llvm::zip_first(args, new_quantized_func.getArguments())) {
587       new_quantized_func_arg.setType(partitioned_call_arg.getType());
588     }
589 
590     // Set the attributes for ops with the attr_map attribute.
591     if (failed(TransferAttributes(float_func, new_quantized_func))) {
592       return failure();
593     }
594 
595     rewriter.setInsertionPoint(call_op);
596 
597     const StringAttr new_quant_func_name =
598         symbol_table.insert(new_quantized_func);
599     rewriter.replaceOpWithNewOp<TF::PartitionedCallOp>(
600         call_op, result_types, args,
601         FlatSymbolRefAttr::get(new_quant_func_name));
602 
603     return success();
604   }
605 
606   // For composite functions followed by Dequantize ops, merges the Dequantize
607   // op into the functions by creating quantized functions with float output.
mergeDequantizeOpFollowingQuantizedFunction(TF::PartitionedCallOp call_op,const SmallVector<Value,4> & args,PatternRewriter & rewriter) const608   LogicalResult mergeDequantizeOpFollowingQuantizedFunction(
609       TF::PartitionedCallOp call_op, const SmallVector<Value, 4>& args,
610       PatternRewriter& rewriter) const {
611     bool followed_by_dequantize = false;
612     for (Operation* user : call_op->getUsers()) {
613       if (llvm::isa<quantfork::DequantizeCastOp>(user)) {
614         followed_by_dequantize = true;
615         break;
616       }
617     }
618     if (!followed_by_dequantize) return success();
619 
620     rewriter.setInsertionPointAfter(call_op);
621     SmallVector<Type, 4> result_types;
622     for (Value result : call_op->getResults()) {
623       TensorType result_type = result.getType().dyn_cast<TensorType>();
624       if (!result_type) {
625         result_types.push_back(result.getType());
626         continue;
627       }
628       QuantizedType qtype =
629           result_type.getElementType().dyn_cast<QuantizedType>();
630       if (!qtype) {
631         result_types.push_back(result_type);
632         continue;
633       }
634 
635       result_types.push_back(result_type.clone(qtype.getExpressedType()));
636     }
637 
638     // Make a copy of the quantized function.
639     auto module = call_op->getParentOfType<ModuleOp>();
640     SymbolTable symbol_table(module);
641 
642     const auto f_attr = call_op.fAttr().dyn_cast<FlatSymbolRefAttr>();
643     const auto float_func =
644         dyn_cast<func::FuncOp>(symbol_table.lookup(f_attr.getValue()));
645     rewriter.setInsertionPointAfter(float_func);
646 
647     // substr(10) == strip the "composite_" prefix.
648     const std::string quantized_function_name =
649         "quantized_" + f_attr.getValue().substr(10).rsplit("_fn_").first.str() +
650         "_float_output_fn";
651     const auto quantized_func =
652         dyn_cast<func::FuncOp>(symbol_table.lookup(quantized_function_name));
653     auto new_quantized_func = dyn_cast<func::FuncOp>(quantized_func->clone());
654     if (new_quantized_func == nullptr) {
655       return failure();
656     }
657     new_quantized_func.setType(
658         FunctionType::get(getContext(), TypeRange{ValueRange{args}},
659                           new_quantized_func.getResultTypes()));
660     for (auto [partitioned_call_arg, new_quantized_func_arg] :
661          llvm::zip_first(args, new_quantized_func.getArguments())) {
662       new_quantized_func_arg.setType(partitioned_call_arg.getType());
663     }
664 
665     // Set the attributes for ops with the attr_map attribute.
666     if (failed(TransferAttributes(float_func, new_quantized_func))) {
667       return failure();
668     }
669 
670     rewriter.setInsertionPoint(call_op);
671     const StringAttr new_quant_func_name =
672         symbol_table.insert(new_quantized_func);
673     auto quantized_call_op = rewriter.create<TF::PartitionedCallOp>(
674         call_op.getLoc(), result_types, args,
675         FlatSymbolRefAttr::get(new_quant_func_name));
676 
677     for (int result_idx : llvm::seq<int>(0, call_op->getNumResults())) {
678       Value result = call_op->getResult(result_idx);
679       for (Operation* user : result.getUsers()) {
680         if (auto dequant_op =
681                 llvm::dyn_cast<quantfork::DequantizeCastOp>(user)) {
682           dequant_op.getResult().replaceAllUsesWith(
683               quantized_call_op->getResult(result_idx));
684         }
685       }
686     }
687 
688     return success();
689   }
690 };
691 
692 // Converts const -> quant.qcast pattern to quantized constant, after
693 // quantization parameters are safely included to each quantize composite
694 // functions.
695 class QuantizeConstPattern
696     : public OpRewritePattern<quantfork::QuantizeCastOp> {
697  public:
698   // This pattern should have larger benefit than ReplaceQuantizePattern
QuantizeConstPattern(MLIRContext * context,QuantizationMethod quantization_method)699   explicit QuantizeConstPattern(MLIRContext* context,
700                                 QuantizationMethod quantization_method)
701       : OpRewritePattern<quantfork::QuantizeCastOp>(context, /*benefit=*/10),
702         quantization_method_(quantization_method) {}
703 
704  private:
705   QuantizationMethod quantization_method_ =
706       QuantizationMethod::kPostTrainingQuantization;
matchAndRewrite(quantfork::QuantizeCastOp q_op,PatternRewriter & rewriter) const707   LogicalResult matchAndRewrite(quantfork::QuantizeCastOp q_op,
708                                 PatternRewriter& rewriter) const override {
709     DenseFPElementsAttr attr;
710     if (!matchPattern(q_op.getArg(), m_Constant(&attr))) {
711       return failure();
712     }
713 
714     ShapedType tensor_qtype = q_op.getResult().getType().cast<ShapedType>();
715     Attribute tensor_proto_attr = Quantize(attr, tensor_qtype);
716     if (!tensor_proto_attr) {
717       return failure();
718     }
719 
720     Type storage_type =
721         tensor_qtype.getElementType().cast<QuantizedType>().getStorageType();
722     ShapedType new_type = tensor_qtype.clone(storage_type);
723     Location loc = q_op.getArg().getLoc();
724     // Convert integer to quantized integer type. Currently only applied for
725     // dynamic range quantization case.
726     if (quantization_method_ == QuantizationMethod::kDynamicRangeQuantization) {
727       new_type = ConvertIntToQint(new_type, rewriter.getContext());
728       tensor_qtype = ConvertIntToQint(tensor_qtype, rewriter.getContext());
729 
730       // TODO(b/225793355): It adds TensorProtoAttr to the constant as a
731       // workaround.
732       tensorflow::TensorProto tensor_proto;
733       if (!mlir::tfg::ConvertToTensorProto(tensor_proto_attr, &tensor_proto)
734                .ok()) {
735         return failure();
736       }
737 
738       tensor_proto.set_dtype(tensorflow::DT_QINT8);
739 
740       tensor_proto_attr = ElementsAttr(TF::TensorProtoAttr::get(
741           new_type, tensorflow::mangling_util::MangleTensor(tensor_proto)));
742     }
743     auto const_op =
744         rewriter.create<TF::ConstOp>(loc, new_type, tensor_proto_attr);
745     // Add scast op to match quantize -> composition pattern. The added scast
746     // is then removed by canonicalization. ([scast - scast] -> [])
747     auto scast_op = rewriter.create<quantfork::StorageCastOp>(
748         loc, tensor_qtype, const_op.output());
749     q_op->replaceAllUsesWith(scast_op);
750     return success();
751   }
752 };
753 
754 static PassRegistration<QuantizeCompositeFunctionsPass> pass;
755 
756 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.inc"
757 
runOnOperation()758 void QuantizeCompositeFunctionsPass::runOnOperation() {
759   MLIRContext* ctx = &getContext();
760   ModuleOp module = getOperation();
761 
762   PassManager pm(ctx);
763   // Intermediate output from QuantizePass will have PartitionedCall ops with
764   // quantized input and output types, which are not allowed in TF dialect.
765   // This can be removed when the composite call supports quantized types.
766   pm.enableVerifier(false);
767 
768   QuantizationSpecs quant_specs;
769   if (quantization_method_ == QuantizationMethod::kDynamicRangeQuantization) {
770     quant_specs.weight_quantization = true;
771     quant_specs.inference_type = tensorflow::DT_QINT8;
772     pm.addNestedPass<func::FuncOp>(CreatePrepareQuantizeDRQPass());
773   } else {
774     pm.addNestedPass<func::FuncOp>(
775         CreatePrepareQuantizePass(quantization_method_));
776   }
777   pm.addNestedPass<func::FuncOp>(CreateQuantizePass(quant_specs));
778 
779   pm.addNestedPass<func::FuncOp>(CreatePostQuantizePass());
780   if (failed(pm.run(module))) {
781     signalPassFailure();
782   }
783 
784   RewritePatternSet patterns(ctx);
785   patterns.add<QuantizeFunctionPattern>(ctx, quantization_method_,
786                                         target_opset_);
787 
788   if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
789     signalPassFailure();
790   }
791 
792   // Constant quantization is a lossy transformation, so they are applied only
793   // after all the other patterns have been aplied.
794   RewritePatternSet patterns_2(ctx);
795   populateWithGenerated(patterns_2);
796   patterns_2.add<ReplaceQuantizePattern, ReplaceDequantizePattern>(ctx);
797   patterns_2.add<QuantizeConstPattern>(ctx, quantization_method_);
798   if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns_2))) ||
799       failed(verify(module))) {
800     signalPassFailure();
801   }
802 }
803 
804 }  // namespace
805 
CreateQuantizeCompositeFunctionsPass(QuantizationMethod quantization_method,OpSet target_opset)806 std::unique_ptr<OperationPass<ModuleOp>> CreateQuantizeCompositeFunctionsPass(
807     QuantizationMethod quantization_method, OpSet target_opset) {
808   return std::make_unique<QuantizeCompositeFunctionsPass>(quantization_method,
809                                                           target_opset);
810 }
811 
812 }  // namespace quant
813 }  // namespace mlir
814