xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_split.h"
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringMap.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "llvm/Support/Regex.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
29 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
30 #include "mlir/IR/AffineExpr.h"  // from @llvm-project
31 #include "mlir/IR/AffineMap.h"  // from @llvm-project
32 #include "mlir/IR/Attributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/Location.h"  // from @llvm-project
35 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
36 #include "mlir/Pass/Pass.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
39 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/location_utils.h"
44 
45 // NOLINTNEXTLINE
46 static llvm::cl::opt<std::string> quantize_stats(
47     "quant-test-stats", llvm::cl::value_desc("string"),
48     llvm::cl::desc("serialized quant info string. Only used in tests"),
49     llvm::cl::init(""));
50 
51 //===----------------------------------------------------------------------===//
52 // The Pass to import quantization stats to the ops in a function. This requires
53 // a custom method to retrieve the unique name of the operation.
54 
55 namespace mlir {
56 namespace quant {
57 
58 using QuantParamsEntry = QuantizationInfo::QuantParams;
59 
60 namespace {
61 class ImportQuantStatsPass
62     : public PassWrapper<ImportQuantStatsPass, OperationPass<func::FuncOp>> {
63  public:
64   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ImportQuantStatsPass)
65 
ImportQuantStatsPass(OperationToName op_to_name)66   explicit ImportQuantStatsPass(OperationToName op_to_name)
67       : op_to_name_(op_to_name) {}
68 
getArgument() const69   StringRef getArgument() const final {
70     // This is the argument used to refer to the pass in
71     // the textual format (on the commandline for example).
72     return "quant-import-stats";
73   }
getDescription() const74   StringRef getDescription() const final {
75     // This is a brief description of the pass.
76     return "Import quantization stats to the model";
77   }
78 
79   void runOnOperation() override;
80 
getDependentDialects(DialectRegistry & registry) const81   void getDependentDialects(DialectRegistry &registry) const override {
82     registry.insert<quant::QuantizationDialect,
83                     quantfork::QuantizationForkDialect>();
84   }
85 
86   // Parses the serialized quant stats protobuf and initialize the internal
87   // data structure. This method must be called after the pass is created.
88   bool ParseQuantStats(const std::string &stats_str);
89 
90  private:
91   void ImportAsStatsOps(OpBuilder b, Operation *op, int index,
92                         const QuantParamsEntry &info);
93 
94   void InsertStatsOpAtResult(OpBuilder b, Value res, ElementsAttr layer_stats,
95                              ElementsAttr axis_stats, IntegerAttr axis);
96 
97   // If the index is out of range, this method returns false. Otherwise it
98   // returns true if the value is a float tensor.
IsQuantizableResult(Operation * op,int index)99   bool IsQuantizableResult(Operation *op, int index) {
100     if (index < 0 || index >= static_cast<int>(op->getNumResults()))
101       return false;
102     Value res = op->getResult(index);
103     return res.getType().isa<ShapedType>() &&
104            res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
105   }
106 
107   // A method to retrieve the name for the given op.
108   OperationToName op_to_name_;
109 
110   // We split the normal names and regex names, since the former can use hash
111   // map to lookup and the latter needs to iterate all the regex to find the
112   // match.
113   // The `int` in the following two containers are to specify the result index
114   // of the given op. -1 indicates all the floating-point results.
115   llvm::StringMap<std::pair<int, const QuantParamsEntry>> name_to_info_;
116   llvm::StringMap<std::pair<int, const QuantParamsEntry>> regex_to_info_;
117 };
118 }  // namespace
119 
ParseQuantStats(const std::string & stats_str)120 bool ImportQuantStatsPass::ParseQuantStats(const std::string &stats_str) {
121   QuantizationInfo quant_stats;
122   if (!tensorflow::LoadProtoFromBuffer(stats_str, &quant_stats).ok()) {
123     return true;
124   }
125 
126   for (const auto &entry : quant_stats.entries()) {
127     if (!entry.name().empty()) {
128       std::vector<std::string> name_and_port =
129           absl::StrSplit(entry.name(), ':');
130       int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
131       name_to_info_.insert({name_and_port[0], {port, entry}});
132     } else if (!entry.name_regex().empty()) {
133       std::vector<std::string> name_and_port =
134           absl::StrSplit(entry.name_regex(), ':');
135       int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
136       regex_to_info_.insert({name_and_port[0], {port, entry}});
137     }
138   }
139   return false;
140 }
141 
InsertStatsOpAtResult(OpBuilder b,Value res,ElementsAttr layer_stats,ElementsAttr axis_stats,IntegerAttr axis)142 void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
143                                                  ElementsAttr layer_stats,
144                                                  ElementsAttr axis_stats,
145                                                  IntegerAttr axis) {
146   auto stats_op = b.create<quantfork::StatisticsOp>(
147       b.getUnknownLoc(), res, layer_stats, axis_stats, axis);
148   res.replaceAllUsesWith(stats_op);
149   stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
150 }
151 
ImportAsStatsOps(OpBuilder b,Operation * op,int index,const QuantParamsEntry & info)152 void ImportQuantStatsPass::ImportAsStatsOps(OpBuilder b, Operation *op,
153                                             int index,
154                                             const QuantParamsEntry &info) {
155   if (info.params_size() == 0) return;
156 
157   SmallVector<APFloat, 4> min_maxs;
158   min_maxs.reserve(info.params_size() * 2);
159   for (const auto &param : info.params()) {
160     llvm::APFloat min(param.min_max().min());
161     llvm::APFloat max(param.min_max().max());
162     min_maxs.push_back(min);
163     min_maxs.push_back(max);
164   }
165   // The layer stats contain only the first min/max pairs.
166   ElementsAttr layer_stats = DenseFPElementsAttr::get(
167       RankedTensorType::get({2}, b.getF32Type()), {min_maxs[0], min_maxs[1]});
168   ElementsAttr axis_stats;
169   IntegerAttr axis;
170 
171   if (info.params_size() > 1) {
172     SmallVector<int64_t, 4> axis_stats_shape{info.params_size(), 2};
173     axis_stats = DenseFPElementsAttr::get(
174         RankedTensorType::get(axis_stats_shape, b.getF32Type()), min_maxs);
175     axis = b.getI64IntegerAttr(info.meta().quantize_axis());
176   }
177 
178   b.setInsertionPointAfter(op);
179   if (IsQuantizableResult(op, index)) {
180     InsertStatsOpAtResult(b, op->getResult(index), layer_stats, axis_stats,
181                           axis);
182   } else {
183     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
184       if (IsQuantizableResult(op, i)) {
185         InsertStatsOpAtResult(b, op->getResult(i), layer_stats, axis_stats,
186                               axis);
187       }
188     }
189   }
190 }
191 
runOnOperation()192 void ImportQuantStatsPass::runOnOperation() {
193   func::FuncOp func = getOperation();
194   OpBuilder builder(func);
195 
196   func.walk([&](Operation *op) {
197     if (op->hasTrait<OpTrait::IsTerminator>()) return;
198     auto op_name = op_to_name_(op);
199 
200     // Check the named info collection first.
201     auto it = name_to_info_.find(op_name);
202     if (it != name_to_info_.end()) {
203       ImportAsStatsOps(builder, op, it->second.first, it->second.second);
204       return;
205     }
206 
207     // Iterate all the regex names and matches the first one.
208     for (auto &regex : regex_to_info_) {
209       if (llvm::Regex(regex.first()).match(op_name)) {
210         ImportAsStatsOps(builder, op, regex.second.first, regex.second.second);
211         break;
212       }
213     }
214   });
215 }
216 
217 // Creates an instance of the default quant parameters pass.
CreateImportQuantStatsPass(OperationToName op_to_name,const std::string & stats_str)218 std::unique_ptr<OperationPass<func::FuncOp>> CreateImportQuantStatsPass(
219     OperationToName op_to_name, const std::string &stats_str) {
220   auto pass = std::make_unique<ImportQuantStatsPass>(op_to_name);
221   if (pass->ParseQuantStats(stats_str)) return nullptr;
222   return pass;
223 }
224 
225 // Creates an instance pass to import quantization stats to the operations in
226 // the function. A custom method to get the name from the op is used because
227 // different dialect ops might have different ways to assign the name.
228 std::unique_ptr<OperationPass<func::FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string & stats_str)229 CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
230   auto get_name_func = [](Operation *op) {
231     Location loc = tensorflow::GetLocationWithoutOpType(op->getLoc());
232     if (auto name = loc.dyn_cast<NameLoc>()) {
233       return name.getName().strref();
234     } else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
235       for (auto sub_loc : fused_name.getLocations()) {
236         if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
237           return named_sub_loc.getName().strref();
238         }
239       }
240     }
241     return llvm::StringRef("");
242   };
243 
244   return CreateImportQuantStatsPass(get_name_func, stats_str);
245 }
246 
247 // Registers this pass with default values, only for test
__anon3f358b260402null248 static PassRegistration<ImportQuantStatsPass> pass([] {
249   return CreateImportQuantStatsPassForTFControlDialect(quantize_stats);
250 });
251 
252 }  // namespace quant
253 }  // namespace mlir
254