1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <set>
17 #include <string>
18 #include <utility>
19
20 #include "absl/base/attributes.h"
21 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
22 #include "mlir/Pass/Pass.h" // from @llvm-project
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
28
29 namespace mlir {
30 namespace TF {
31 namespace internal {
32
33 // The name prefix of Flex ops.
34 constexpr absl::string_view kFlexOpNamePrefix = "Flex";
35 // Don't fallback to Flex op if this attribute is set. This attribute is
36 // transient and is only used inside this pass. First, the pass looks for
37 // predefined patterns and set this attribute to ops in the patterns. Then,
38 // when parsing the function, if find ops with this attribute, the pass
39 // remove the attribute and skip further processing on those ops.
40 constexpr char kNoFallbackAttr[] = "no_fallback";
41 // TF Quantization modes. These constants are defined as char arrays so they
42 // can parsed by the pass option.
43 constexpr char kDefaultMode[] = "DEFAULT";
44 constexpr char kLegacyIntegerMode[] = "LEGACY_INTEGER";
45
46 // Checks if the operation is TF FakeQuant ops.
IsTfFakeQuantOp(Operation * op)47 bool IsTfFakeQuantOp(Operation *op) {
48 return llvm::isa<
49 // clang-format off
50 TF::FakeQuantWithMinMaxArgsOp,
51 TF::FakeQuantWithMinMaxVarsOp,
52 TF::FakeQuantWithMinMaxVarsPerChannelOp
53 // clang-format on
54 >(op);
55 }
56
57 // Checks if the operation is allowlisted in both modes. These ops are not
58 // quantizable but is necessary to make the conversion successful.
IsAlwaysAllowlistedOp(Operation * op)59 bool IsAlwaysAllowlistedOp(Operation *op) {
60 return llvm::isa<
61 // clang-format off
62 // go/keep-sorted start
63 TF::ConstOp,
64 TF::IdentityOp,
65 TF::PartitionedCallOp,
66 TF::StatefulPartitionedCallOp
67 // go/keep-sorted end
68 // clang-format on
69 >(op);
70 }
71
72 // LINT.IfChange
73 // The list of quantizable ops in the Legacy Integer mode.
74 ABSL_ATTRIBUTE_NOINLINE const std::set<std::string>
QuantizableOpsInLegacyMode()75 &QuantizableOpsInLegacyMode() {
76 static const std::set<std::string> *legacy_op_list =
77 new std::set<std::string>({
78 // clang-format off
79 // go/keep-sorted start
80 TF::AbsOp::getOperationName().str(),
81 TF::AddOp::getOperationName().str(),
82 TF::AddV2Op::getOperationName().str(),
83 TF::ArgMaxOp::getOperationName().str(),
84 TF::AvgPoolOp::getOperationName().str(),
85 TF::BiasAddOp::getOperationName().str(),
86 TF::BucketizeOp::getOperationName().str(),
87 TF::ConcatV2Op::getOperationName().str(),
88 TF::Conv2DBackpropInputOp::getOperationName().str(),
89 TF::Conv2DOp::getOperationName().str(),
90 TF::DepthwiseConv2dNativeOp::getOperationName().str(),
91 TF::FusedBatchNormV3Op::getOperationName().str(),
92 TF::GatherV2Op::getOperationName().str(),
93 TF::MatMulOp::getOperationName().str(),
94 TF::MaxPoolOp::getOperationName().str(),
95 TF::MaximumOp::getOperationName().str(),
96 TF::MeanOp::getOperationName().str(),
97 TF::MinimumOp::getOperationName().str(),
98 TF::MulOp::getOperationName().str(),
99 TF::PadOp::getOperationName().str(),
100 TF::PadV2Op::getOperationName().str(),
101 TF::Relu6Op::getOperationName().str(),
102 TF::ReluOp::getOperationName().str(),
103 TF::ReshapeOp::getOperationName().str(),
104 TF::SoftmaxOp::getOperationName().str(),
105 TF::SubOp::getOperationName().str(),
106 TF::TransposeOp::getOperationName().str(),
107 // go/keep-sorted end
108 // clang-format on
109 });
110 return *legacy_op_list;
111 }
112
113 // The list of quantizable ops in the Default mode.
114 ABSL_ATTRIBUTE_NOINLINE const std::set<std::string>
QuantizableOpsInDefaultMode()115 &QuantizableOpsInDefaultMode() {
116 static const std::set<std::string> *default_op_list =
117 new std::set<std::string>({
118 // clang-format off
119 // go/keep-sorted start
120 TF::BiasAddOp::getOperationName().str(),
121 TF::Conv2DBackpropInputOp::getOperationName().str(),
122 TF::Conv2DOp::getOperationName().str(),
123 TF::DepthwiseConv2dNativeOp::getOperationName().str(),
124 TF::FusedBatchNormV3Op::getOperationName().str(),
125 TF::MatMulOp::getOperationName().str(),
126 TF::Relu6Op::getOperationName().str(),
127 TF::ReluOp::getOperationName().str(),
128 // go/keep-sorted end
129 // clang-format on
130 });
131 return *default_op_list;
132 }
133 // LINT.ThenChange(Google-internal path)
134
135 // Checks if the operation can be fused with bias.
IsFusibleWithBiasOp(Operation * op)136 inline bool IsFusibleWithBiasOp(Operation *op) {
137 return llvm::isa<
138 // clang-format off
139 TF::MatMulOp,
140 TF::Conv2DOp,
141 TF::DepthwiseConv2dNativeOp,
142 TF::Conv2DBackpropInputOp,
143 TF::Conv3DOp,
144 TF::Conv3DBackpropInputV2Op
145 // clang-format on
146 >(op);
147 }
148
149 // Creates the custom option of the Flex ops.
CreateFlexOpCustomOptions(const std::string & op_name,const std::string & node_def_str,std::string & custom_option_buffer)150 inline void CreateFlexOpCustomOptions(const std::string &op_name,
151 const std::string &node_def_str,
152 std::string &custom_option_buffer) {
153 auto flex_builder = std::make_unique<flexbuffers::Builder>();
154 flex_builder->Vector([&]() {
155 flex_builder->String(op_name);
156 flex_builder->String(node_def_str);
157 });
158 flex_builder->Finish();
159 custom_option_buffer.assign(flex_builder->GetBuffer().begin(),
160 flex_builder->GetBuffer().end());
161 }
162
163 // Creates ElementsAttr for custom option.
CustomOptionForFlexOp(OpBuilder * builder,const std::string & content)164 inline TFL::ConstBytesAttr CustomOptionForFlexOp(OpBuilder *builder,
165 const std::string &content) {
166 return TFL::ConstBytesAttr::get(builder->getContext(),
167 StringRef(content.data(), content.size()));
168 }
169
170 // Fallbacks ops that are not supported by TF Quantization to TFLite Flex ops.
171 class FallbackToFlexOps
172 : public PassWrapper<FallbackToFlexOps, OperationPass<func::FuncOp>> {
173 public:
FallbackToFlexOps()174 FallbackToFlexOps() {}
FallbackToFlexOps(const std::string & mode)175 explicit FallbackToFlexOps(const std::string &mode) { mode_ = mode; }
FallbackToFlexOps(const FallbackToFlexOps & other)176 FallbackToFlexOps(const FallbackToFlexOps &other) { mode_ = other.mode_; }
177
178 void runOnOperation() override;
179
getArgument() const180 StringRef getArgument() const final { return "quant-raise-flex-fallback"; }
181
getDescription() const182 StringRef getDescription() const final {
183 return "Fallback TF-Quantization-unsupported ops to TFLite Flex ops.";
184 }
185
getDependentDialects(DialectRegistry & registry) const186 void getDependentDialects(DialectRegistry ®istry) const override {
187 registry.insert<TFL::TensorFlowLiteDialect>();
188 }
189
190 private:
191 // The mode of TF Quantization, might indicate different users/devices.
192 Option<std::string> mode_{*this, "mode",
193 llvm::cl::desc("The mode of TF Quantization."),
194 llvm::cl::init("")};
195
196 // Checks if the operation is allowlisted in the current mode.
IsAllowListedOp(Operation * op)197 bool IsAllowListedOp(Operation *op) {
198 std::string op_name = op->getName().getStringRef().str();
199 if (IsAlwaysAllowlistedOp(op) || IsTfFakeQuantOp(op)) {
200 return true;
201 } else if (mode_ == kDefaultMode) {
202 return QuantizableOpsInDefaultMode().count(op_name) > 0;
203 } else if (mode_ == kLegacyIntegerMode) {
204 return QuantizableOpsInLegacyMode().count(op_name) > 0;
205 } else {
206 mlir::emitError(getOperation().getLoc(), "Unregconized mode: " + mode_);
207 signalPassFailure();
208 return true;
209 }
210 }
211
212 // Converts the operation to a TFLite Flex op.
213 bool ConvertToFlexOp(Operation *op);
214 };
215
ConvertToFlexOp(Operation * op)216 bool FallbackToFlexOps::ConvertToFlexOp(Operation *op) {
217 tensorflow::StatusOr<std::unique_ptr<tensorflow::NodeDef>> node_def =
218 tensorflow::ConvertTFDialectOpToNodeDef(
219 op, /*name=*/"", /*ignore_unregistered_attrs=*/true);
220 if (!node_def.ok()) {
221 op->emitError("Failed to obtain TensorFlow NodeDef: " +
222 node_def.status().ToString());
223 return false;
224 }
225 std::string node_def_str;
226 if (!(*node_def)->SerializeToString(&node_def_str)) {
227 op->emitError("Failed to serialize tensorflow NodeDef");
228 return false;
229 }
230 std::string op_name = (*node_def)->op();
231
232 OpBuilder builder(op);
233 std::string flex_op_name = std::string(kFlexOpNamePrefix) + op_name;
234 std::string custom_option_buffer;
235 CreateFlexOpCustomOptions(op_name, node_def_str, custom_option_buffer);
236 auto flex_op = builder.create<TFL::CustomOp>(
237 op->getLoc(), op->getResultTypes(), op->getOperands(), flex_op_name,
238 CustomOptionForFlexOp(&builder, custom_option_buffer));
239 op->replaceAllUsesWith(flex_op);
240 op->erase();
241 return true;
242 }
243
244 // Sets the "no_fallback" attribute.
SetNoFallbackAttr(PatternRewriter & rewriter,Value val)245 Value SetNoFallbackAttr(PatternRewriter &rewriter, Value val) {
246 val.getDefiningOp()->setAttr(kNoFallbackAttr, rewriter.getUnitAttr());
247 return val;
248 }
249
250 // Returns true if the attr is a float attribute and be equal to value.
FloatValueEquals(const Attribute & attr,double value)251 static bool FloatValueEquals(const Attribute &attr, double value) {
252 auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
253 if (fp_attr == nullptr) return false;
254
255 if (fp_attr.isSplat()) {
256 return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
257 }
258 return llvm::all_of(fp_attr.getValues<APFloat>(), [value](const APFloat &f) {
259 return f.isExactlyValue(value);
260 });
261 }
262
263 // Returns true if the rank of the value equals to the given rank.
RankEquals(Value value,int rank)264 bool RankEquals(Value value, int rank) {
265 auto rank_type = value.getType().template dyn_cast<RankedTensorType>();
266 return (rank_type && rank_type.getRank() == rank);
267 }
268
269 #include "tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_patterns.inc"
270
runOnOperation()271 void FallbackToFlexOps::runOnOperation() {
272 if (mode_.empty()) return;
273
274 func::FuncOp func = getOperation();
275 MLIRContext *ctx = &getContext();
276
277 // Convert binary ops to BiasAdd ops if possible.
278 RewritePatternSet patterns(ctx);
279 populateWithGenerated(patterns);
280 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
281
282 // Convert unsupported ops to Flex ops.
283 auto tf_dialect = ctx->getLoadedDialect<TF::TensorFlowDialect>();
284 func.walk([&](Operation *op) {
285 if (op->getDialect() != tf_dialect) return;
286 if (IsAllowListedOp(op)) return;
287 if (op->hasAttr(kNoFallbackAttr)) {
288 op->removeAttr(kNoFallbackAttr);
289 return;
290 }
291 if (!ConvertToFlexOp(op)) signalPassFailure();
292 });
293 }
294 } // namespace internal
295
CreateFallbackToFlexOpsPass(const std::string & mode)296 std::unique_ptr<OperationPass<func::FuncOp>> CreateFallbackToFlexOpsPass(
297 const std::string &mode) {
298 return std::make_unique<internal::FallbackToFlexOps>(mode);
299 }
300
__anon1a49f1520402null301 static PassRegistration<internal::FallbackToFlexOps> pass([] {
302 return CreateFallbackToFlexOpsPass(/*mode=*/internal::kDefaultMode);
303 });
304
305 } // namespace TF
306 } // namespace mlir
307