1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_split.h"
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringMap.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "llvm/Support/Regex.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
29 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
30 #include "mlir/IR/AffineExpr.h" // from @llvm-project
31 #include "mlir/IR/AffineMap.h" // from @llvm-project
32 #include "mlir/IR/Attributes.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/Location.h" // from @llvm-project
35 #include "mlir/IR/PatternMatch.h" // from @llvm-project
36 #include "mlir/Pass/Pass.h" // from @llvm-project
37 #include "mlir/Support/LLVM.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
39 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/location_utils.h"
44
45 // NOLINTNEXTLINE
46 static llvm::cl::opt<std::string> quantize_stats(
47 "quant-test-stats", llvm::cl::value_desc("string"),
48 llvm::cl::desc("serialized quant info string. Only used in tests"),
49 llvm::cl::init(""));
50
51 //===----------------------------------------------------------------------===//
52 // The Pass to import quantization stats to the ops in a function. This requires
53 // a custom method to retrieve the unique name of the operation.
54
55 namespace mlir {
56 namespace quant {
57
58 using QuantParamsEntry = QuantizationInfo::QuantParams;
59
60 namespace {
61 class ImportQuantStatsPass
62 : public PassWrapper<ImportQuantStatsPass, OperationPass<func::FuncOp>> {
63 public:
64 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ImportQuantStatsPass)
65
ImportQuantStatsPass(OperationToName op_to_name)66 explicit ImportQuantStatsPass(OperationToName op_to_name)
67 : op_to_name_(op_to_name) {}
68
getArgument() const69 StringRef getArgument() const final {
70 // This is the argument used to refer to the pass in
71 // the textual format (on the commandline for example).
72 return "quant-import-stats";
73 }
getDescription() const74 StringRef getDescription() const final {
75 // This is a brief description of the pass.
76 return "Import quantization stats to the model";
77 }
78
79 void runOnOperation() override;
80
getDependentDialects(DialectRegistry & registry) const81 void getDependentDialects(DialectRegistry ®istry) const override {
82 registry.insert<quant::QuantizationDialect,
83 quantfork::QuantizationForkDialect>();
84 }
85
86 // Parses the serialized quant stats protobuf and initialize the internal
87 // data structure. This method must be called after the pass is created.
88 bool ParseQuantStats(const std::string &stats_str);
89
90 private:
91 void ImportAsStatsOps(OpBuilder b, Operation *op, int index,
92 const QuantParamsEntry &info);
93
94 void InsertStatsOpAtResult(OpBuilder b, Value res, ElementsAttr layer_stats,
95 ElementsAttr axis_stats, IntegerAttr axis);
96
97 // If the index is out of range, this method returns false. Otherwise it
98 // returns true if the value is a float tensor.
IsQuantizableResult(Operation * op,int index)99 bool IsQuantizableResult(Operation *op, int index) {
100 if (index < 0 || index >= static_cast<int>(op->getNumResults()))
101 return false;
102 Value res = op->getResult(index);
103 return res.getType().isa<ShapedType>() &&
104 res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
105 }
106
107 // A method to retrieve the name for the given op.
108 OperationToName op_to_name_;
109
110 // We split the normal names and regex names, since the former can use hash
111 // map to lookup and the latter needs to iterate all the regex to find the
112 // match.
113 // The `int` in the following two containers are to specify the result index
114 // of the given op. -1 indicates all the floating-point results.
115 llvm::StringMap<std::pair<int, const QuantParamsEntry>> name_to_info_;
116 llvm::StringMap<std::pair<int, const QuantParamsEntry>> regex_to_info_;
117 };
118 } // namespace
119
ParseQuantStats(const std::string & stats_str)120 bool ImportQuantStatsPass::ParseQuantStats(const std::string &stats_str) {
121 QuantizationInfo quant_stats;
122 if (!tensorflow::LoadProtoFromBuffer(stats_str, &quant_stats).ok()) {
123 return true;
124 }
125
126 for (const auto &entry : quant_stats.entries()) {
127 if (!entry.name().empty()) {
128 std::vector<std::string> name_and_port =
129 absl::StrSplit(entry.name(), ':');
130 int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
131 name_to_info_.insert({name_and_port[0], {port, entry}});
132 } else if (!entry.name_regex().empty()) {
133 std::vector<std::string> name_and_port =
134 absl::StrSplit(entry.name_regex(), ':');
135 int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
136 regex_to_info_.insert({name_and_port[0], {port, entry}});
137 }
138 }
139 return false;
140 }
141
InsertStatsOpAtResult(OpBuilder b,Value res,ElementsAttr layer_stats,ElementsAttr axis_stats,IntegerAttr axis)142 void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
143 ElementsAttr layer_stats,
144 ElementsAttr axis_stats,
145 IntegerAttr axis) {
146 auto stats_op = b.create<quantfork::StatisticsOp>(
147 b.getUnknownLoc(), res, layer_stats, axis_stats, axis);
148 res.replaceAllUsesWith(stats_op);
149 stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
150 }
151
ImportAsStatsOps(OpBuilder b,Operation * op,int index,const QuantParamsEntry & info)152 void ImportQuantStatsPass::ImportAsStatsOps(OpBuilder b, Operation *op,
153 int index,
154 const QuantParamsEntry &info) {
155 if (info.params_size() == 0) return;
156
157 SmallVector<APFloat, 4> min_maxs;
158 min_maxs.reserve(info.params_size() * 2);
159 for (const auto ¶m : info.params()) {
160 llvm::APFloat min(param.min_max().min());
161 llvm::APFloat max(param.min_max().max());
162 min_maxs.push_back(min);
163 min_maxs.push_back(max);
164 }
165 // The layer stats contain only the first min/max pairs.
166 ElementsAttr layer_stats = DenseFPElementsAttr::get(
167 RankedTensorType::get({2}, b.getF32Type()), {min_maxs[0], min_maxs[1]});
168 ElementsAttr axis_stats;
169 IntegerAttr axis;
170
171 if (info.params_size() > 1) {
172 SmallVector<int64_t, 4> axis_stats_shape{info.params_size(), 2};
173 axis_stats = DenseFPElementsAttr::get(
174 RankedTensorType::get(axis_stats_shape, b.getF32Type()), min_maxs);
175 axis = b.getI64IntegerAttr(info.meta().quantize_axis());
176 }
177
178 b.setInsertionPointAfter(op);
179 if (IsQuantizableResult(op, index)) {
180 InsertStatsOpAtResult(b, op->getResult(index), layer_stats, axis_stats,
181 axis);
182 } else {
183 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
184 if (IsQuantizableResult(op, i)) {
185 InsertStatsOpAtResult(b, op->getResult(i), layer_stats, axis_stats,
186 axis);
187 }
188 }
189 }
190 }
191
runOnOperation()192 void ImportQuantStatsPass::runOnOperation() {
193 func::FuncOp func = getOperation();
194 OpBuilder builder(func);
195
196 func.walk([&](Operation *op) {
197 if (op->hasTrait<OpTrait::IsTerminator>()) return;
198 auto op_name = op_to_name_(op);
199
200 // Check the named info collection first.
201 auto it = name_to_info_.find(op_name);
202 if (it != name_to_info_.end()) {
203 ImportAsStatsOps(builder, op, it->second.first, it->second.second);
204 return;
205 }
206
207 // Iterate all the regex names and matches the first one.
208 for (auto ®ex : regex_to_info_) {
209 if (llvm::Regex(regex.first()).match(op_name)) {
210 ImportAsStatsOps(builder, op, regex.second.first, regex.second.second);
211 break;
212 }
213 }
214 });
215 }
216
217 // Creates an instance of the default quant parameters pass.
CreateImportQuantStatsPass(OperationToName op_to_name,const std::string & stats_str)218 std::unique_ptr<OperationPass<func::FuncOp>> CreateImportQuantStatsPass(
219 OperationToName op_to_name, const std::string &stats_str) {
220 auto pass = std::make_unique<ImportQuantStatsPass>(op_to_name);
221 if (pass->ParseQuantStats(stats_str)) return nullptr;
222 return pass;
223 }
224
225 // Creates an instance pass to import quantization stats to the operations in
226 // the function. A custom method to get the name from the op is used because
227 // different dialect ops might have different ways to assign the name.
228 std::unique_ptr<OperationPass<func::FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string & stats_str)229 CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
230 auto get_name_func = [](Operation *op) {
231 Location loc = tensorflow::GetLocationWithoutOpType(op->getLoc());
232 if (auto name = loc.dyn_cast<NameLoc>()) {
233 return name.getName().strref();
234 } else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
235 for (auto sub_loc : fused_name.getLocations()) {
236 if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
237 return named_sub_loc.getName().strref();
238 }
239 }
240 }
241 return llvm::StringRef("");
242 };
243
244 return CreateImportQuantStatsPass(get_name_func, stats_str);
245 }
246
247 // Registers this pass with default values, only for test
__anon3f358b260402null248 static PassRegistration<ImportQuantStatsPass> pass([] {
249 return CreateImportQuantStatsPassForTFControlDialect(quantize_stats);
250 });
251
252 } // namespace quant
253 } // namespace mlir
254