1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <set>
17 #include <string>
18 #include <utility>
19 
20 #include "absl/base/attributes.h"
21 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
28 
29 namespace mlir {
30 namespace TF {
31 namespace internal {
32 
33 // The name prefix of Flex ops.
34 constexpr absl::string_view kFlexOpNamePrefix = "Flex";
35 // Don't fallback to Flex op if this attribute is set. This attribute is
36 // transient and is only used inside this pass. First, the pass looks for
37 // predefined patterns and set this attribute to ops in the patterns. Then,
38 // when parsing the function, if find ops with this attribute, the pass
39 // remove the attribute and skip further processing on those ops.
40 constexpr char kNoFallbackAttr[] = "no_fallback";
41 // TF Quantization modes. These constants are defined as char arrays so they
42 // can parsed by the pass option.
43 constexpr char kDefaultMode[] = "DEFAULT";
44 constexpr char kLegacyIntegerMode[] = "LEGACY_INTEGER";
45 
46 // Checks if the operation is TF FakeQuant ops.
IsTfFakeQuantOp(Operation * op)47 bool IsTfFakeQuantOp(Operation *op) {
48   return llvm::isa<
49       // clang-format off
50       TF::FakeQuantWithMinMaxArgsOp,
51       TF::FakeQuantWithMinMaxVarsOp,
52       TF::FakeQuantWithMinMaxVarsPerChannelOp
53       // clang-format on
54       >(op);
55 }
56 
57 // Checks if the operation is allowlisted in both modes. These ops are not
58 // quantizable but is necessary to make the conversion successful.
IsAlwaysAllowlistedOp(Operation * op)59 bool IsAlwaysAllowlistedOp(Operation *op) {
60   return llvm::isa<
61       // clang-format off
62       // go/keep-sorted start
63       TF::ConstOp,
64       TF::IdentityOp,
65       TF::PartitionedCallOp,
66       TF::StatefulPartitionedCallOp
67       // go/keep-sorted end
68       // clang-format on
69       >(op);
70 }
71 
72 // LINT.IfChange
73 // The list of quantizable ops in the Legacy Integer mode.
74 ABSL_ATTRIBUTE_NOINLINE const std::set<std::string>
QuantizableOpsInLegacyMode()75     &QuantizableOpsInLegacyMode() {
76   static const std::set<std::string> *legacy_op_list =
77       new std::set<std::string>({
78           // clang-format off
79           // go/keep-sorted start
80           TF::AbsOp::getOperationName().str(),
81           TF::AddOp::getOperationName().str(),
82           TF::AddV2Op::getOperationName().str(),
83           TF::ArgMaxOp::getOperationName().str(),
84           TF::AvgPoolOp::getOperationName().str(),
85           TF::BiasAddOp::getOperationName().str(),
86           TF::BucketizeOp::getOperationName().str(),
87           TF::ConcatV2Op::getOperationName().str(),
88           TF::Conv2DBackpropInputOp::getOperationName().str(),
89           TF::Conv2DOp::getOperationName().str(),
90           TF::DepthwiseConv2dNativeOp::getOperationName().str(),
91           TF::FusedBatchNormV3Op::getOperationName().str(),
92           TF::GatherV2Op::getOperationName().str(),
93           TF::MatMulOp::getOperationName().str(),
94           TF::MaxPoolOp::getOperationName().str(),
95           TF::MaximumOp::getOperationName().str(),
96           TF::MeanOp::getOperationName().str(),
97           TF::MinimumOp::getOperationName().str(),
98           TF::MulOp::getOperationName().str(),
99           TF::PadOp::getOperationName().str(),
100           TF::PadV2Op::getOperationName().str(),
101           TF::Relu6Op::getOperationName().str(),
102           TF::ReluOp::getOperationName().str(),
103           TF::ReshapeOp::getOperationName().str(),
104           TF::SoftmaxOp::getOperationName().str(),
105           TF::SubOp::getOperationName().str(),
106           TF::TransposeOp::getOperationName().str(),
107           // go/keep-sorted end
108           // clang-format on
109       });
110   return *legacy_op_list;
111 }
112 
113 // The list of quantizable ops in the Default mode.
114 ABSL_ATTRIBUTE_NOINLINE const std::set<std::string>
QuantizableOpsInDefaultMode()115     &QuantizableOpsInDefaultMode() {
116   static const std::set<std::string> *default_op_list =
117       new std::set<std::string>({
118           // clang-format off
119           // go/keep-sorted start
120           TF::BiasAddOp::getOperationName().str(),
121           TF::Conv2DBackpropInputOp::getOperationName().str(),
122           TF::Conv2DOp::getOperationName().str(),
123           TF::DepthwiseConv2dNativeOp::getOperationName().str(),
124           TF::FusedBatchNormV3Op::getOperationName().str(),
125           TF::MatMulOp::getOperationName().str(),
126           TF::Relu6Op::getOperationName().str(),
127           TF::ReluOp::getOperationName().str(),
128           // go/keep-sorted end
129           // clang-format on
130       });
131   return *default_op_list;
132 }
133 // LINT.ThenChange(Google-internal path)
134 
135 // Checks if the operation can be fused with bias.
IsFusibleWithBiasOp(Operation * op)136 inline bool IsFusibleWithBiasOp(Operation *op) {
137   return llvm::isa<
138       // clang-format off
139       TF::MatMulOp,
140       TF::Conv2DOp,
141       TF::DepthwiseConv2dNativeOp,
142       TF::Conv2DBackpropInputOp,
143       TF::Conv3DOp,
144       TF::Conv3DBackpropInputV2Op
145       // clang-format on
146       >(op);
147 }
148 
149 // Creates the custom option of the Flex ops.
CreateFlexOpCustomOptions(const std::string & op_name,const std::string & node_def_str,std::string & custom_option_buffer)150 inline void CreateFlexOpCustomOptions(const std::string &op_name,
151                                       const std::string &node_def_str,
152                                       std::string &custom_option_buffer) {
153   auto flex_builder = std::make_unique<flexbuffers::Builder>();
154   flex_builder->Vector([&]() {
155     flex_builder->String(op_name);
156     flex_builder->String(node_def_str);
157   });
158   flex_builder->Finish();
159   custom_option_buffer.assign(flex_builder->GetBuffer().begin(),
160                               flex_builder->GetBuffer().end());
161 }
162 
163 // Creates ElementsAttr for custom option.
CustomOptionForFlexOp(OpBuilder * builder,const std::string & content)164 inline TFL::ConstBytesAttr CustomOptionForFlexOp(OpBuilder *builder,
165                                                  const std::string &content) {
166   return TFL::ConstBytesAttr::get(builder->getContext(),
167                                   StringRef(content.data(), content.size()));
168 }
169 
170 // Fallbacks ops that are not supported by TF Quantization to TFLite Flex ops.
171 class FallbackToFlexOps
172     : public PassWrapper<FallbackToFlexOps, OperationPass<func::FuncOp>> {
173  public:
FallbackToFlexOps()174   FallbackToFlexOps() {}
FallbackToFlexOps(const std::string & mode)175   explicit FallbackToFlexOps(const std::string &mode) { mode_ = mode; }
FallbackToFlexOps(const FallbackToFlexOps & other)176   FallbackToFlexOps(const FallbackToFlexOps &other) { mode_ = other.mode_; }
177 
178   void runOnOperation() override;
179 
getArgument() const180   StringRef getArgument() const final { return "quant-raise-flex-fallback"; }
181 
getDescription() const182   StringRef getDescription() const final {
183     return "Fallback TF-Quantization-unsupported ops to TFLite Flex ops.";
184   }
185 
getDependentDialects(DialectRegistry & registry) const186   void getDependentDialects(DialectRegistry &registry) const override {
187     registry.insert<TFL::TensorFlowLiteDialect>();
188   }
189 
190  private:
191   // The mode of TF Quantization, might indicate different users/devices.
192   Option<std::string> mode_{*this, "mode",
193                             llvm::cl::desc("The mode of TF Quantization."),
194                             llvm::cl::init("")};
195 
196   // Checks if the operation is allowlisted in the current mode.
IsAllowListedOp(Operation * op)197   bool IsAllowListedOp(Operation *op) {
198     std::string op_name = op->getName().getStringRef().str();
199     if (IsAlwaysAllowlistedOp(op) || IsTfFakeQuantOp(op)) {
200       return true;
201     } else if (mode_ == kDefaultMode) {
202       return QuantizableOpsInDefaultMode().count(op_name) > 0;
203     } else if (mode_ == kLegacyIntegerMode) {
204       return QuantizableOpsInLegacyMode().count(op_name) > 0;
205     } else {
206       mlir::emitError(getOperation().getLoc(), "Unregconized mode: " + mode_);
207       signalPassFailure();
208       return true;
209     }
210   }
211 
212   // Converts the operation to a TFLite Flex op.
213   bool ConvertToFlexOp(Operation *op);
214 };
215 
ConvertToFlexOp(Operation * op)216 bool FallbackToFlexOps::ConvertToFlexOp(Operation *op) {
217   tensorflow::StatusOr<std::unique_ptr<tensorflow::NodeDef>> node_def =
218       tensorflow::ConvertTFDialectOpToNodeDef(
219           op, /*name=*/"", /*ignore_unregistered_attrs=*/true);
220   if (!node_def.ok()) {
221     op->emitError("Failed to obtain TensorFlow NodeDef: " +
222                   node_def.status().ToString());
223     return false;
224   }
225   std::string node_def_str;
226   if (!(*node_def)->SerializeToString(&node_def_str)) {
227     op->emitError("Failed to serialize tensorflow NodeDef");
228     return false;
229   }
230   std::string op_name = (*node_def)->op();
231 
232   OpBuilder builder(op);
233   std::string flex_op_name = std::string(kFlexOpNamePrefix) + op_name;
234   std::string custom_option_buffer;
235   CreateFlexOpCustomOptions(op_name, node_def_str, custom_option_buffer);
236   auto flex_op = builder.create<TFL::CustomOp>(
237       op->getLoc(), op->getResultTypes(), op->getOperands(), flex_op_name,
238       CustomOptionForFlexOp(&builder, custom_option_buffer));
239   op->replaceAllUsesWith(flex_op);
240   op->erase();
241   return true;
242 }
243 
244 // Sets the "no_fallback" attribute.
SetNoFallbackAttr(PatternRewriter & rewriter,Value val)245 Value SetNoFallbackAttr(PatternRewriter &rewriter, Value val) {
246   val.getDefiningOp()->setAttr(kNoFallbackAttr, rewriter.getUnitAttr());
247   return val;
248 }
249 
250 // Returns true if the attr is a float attribute and be equal to value.
FloatValueEquals(const Attribute & attr,double value)251 static bool FloatValueEquals(const Attribute &attr, double value) {
252   auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
253   if (fp_attr == nullptr) return false;
254 
255   if (fp_attr.isSplat()) {
256     return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
257   }
258   return llvm::all_of(fp_attr.getValues<APFloat>(), [value](const APFloat &f) {
259     return f.isExactlyValue(value);
260   });
261 }
262 
263 // Returns true if the rank of the value equals to the given rank.
RankEquals(Value value,int rank)264 bool RankEquals(Value value, int rank) {
265   auto rank_type = value.getType().template dyn_cast<RankedTensorType>();
266   return (rank_type && rank_type.getRank() == rank);
267 }
268 
269 #include "tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_patterns.inc"
270 
runOnOperation()271 void FallbackToFlexOps::runOnOperation() {
272   if (mode_.empty()) return;
273 
274   func::FuncOp func = getOperation();
275   MLIRContext *ctx = &getContext();
276 
277   // Convert binary ops to BiasAdd ops if possible.
278   RewritePatternSet patterns(ctx);
279   populateWithGenerated(patterns);
280   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
281 
282   // Convert unsupported ops to Flex ops.
283   auto tf_dialect = ctx->getLoadedDialect<TF::TensorFlowDialect>();
284   func.walk([&](Operation *op) {
285     if (op->getDialect() != tf_dialect) return;
286     if (IsAllowListedOp(op)) return;
287     if (op->hasAttr(kNoFallbackAttr)) {
288       op->removeAttr(kNoFallbackAttr);
289       return;
290     }
291     if (!ConvertToFlexOp(op)) signalPassFailure();
292   });
293 }
294 }  // namespace internal
295 
CreateFallbackToFlexOpsPass(const std::string & mode)296 std::unique_ptr<OperationPass<func::FuncOp>> CreateFallbackToFlexOpsPass(
297     const std::string &mode) {
298   return std::make_unique<internal::FallbackToFlexOps>(mode);
299 }
300 
__anon1a49f1520402null301 static PassRegistration<internal::FallbackToFlexOps> pass([] {
302   return CreateFallbackToFlexOpsPass(/*mode=*/internal::kDefaultMode);
303 });
304 
305 }  // namespace TF
306 }  // namespace mlir
307