1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
16 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
17 #include "mlir/Pass/Pass.h" // from @llvm-project
18 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
19 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
21
22 namespace mlir {
23 namespace TFL {
24 namespace {
25
26 #define GEN_PASS_CLASSES
27 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
28
29 // Attribute name to be added on the module to identify whether
30 // variables should be legalized to TFLite or not.
31 const char kLegalizeTflVariables[] = "tfl._legalize_tfl_variables";
32
33 // Returns true if 'op' is TF op that accepts resource type, but is
34 // supported by TFLite.
IsSupportedTFLiteResourceOp(Operation * op)35 bool IsSupportedTFLiteResourceOp(Operation* op) {
36 return llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp, TF::VarHandleOp,
37 TF::LookupTableFindV2Op, TF::LookupTableImportV2Op,
38 TF::LookupTableSizeV2Op>(op);
39 }
40
41 // Returns true if 'op' is TF/TFLite control flow op that can accept resource
42 // type. Usually these ops are just pass through, they call another subgraph and
43 // pass the operands to.
IsSupportedTFLiteControlFlow(Operation * op)44 bool IsSupportedTFLiteControlFlow(Operation* op) {
45 return llvm::isa<TFL::WhileOp, TFL::IfOp, TFL::CallOnceOp>(op);
46 }
47
48 // Returns true if the 'op' is one of the supported TF control flow ops or
49 // dataset ops. Those ops just forward the operands to other subgraphs.
IsSupportedTFDataForwardingOp(Operation * op)50 bool IsSupportedTFDataForwardingOp(Operation* op) {
51 return llvm::isa<TF::MapDatasetOp, TF::ReduceDatasetOp,
52 TF::TakeWhileDatasetOp, TF::IfOp, TF::WhileOp>(op);
53 }
54
55 class AnalyzeVariablesPass
56 : public AnalyzeVariablesPassBase<AnalyzeVariablesPass> {
57 public:
AnalyzeVariablesPass()58 explicit AnalyzeVariablesPass() {}
59 void runOnOperation() override;
60 };
61
runOnOperation()62 void AnalyzeVariablesPass::runOnOperation() {
63 auto* context = &getContext();
64 auto module = getOperation();
65 bool legalize_to_tfl = true;
66
67 module.walk([&](Operation* op) {
68 // Skip ops that are supported natively by TFLite.
69 if (IsSupportedTFLiteResourceOp(op)) return WalkResult::advance();
70 if (IsSupportedTFLiteControlFlow(op)) return WalkResult::advance();
71
72 // Check for ops that are legalized to TFLite.
73 if (op->getDialect()->getNamespace() == "tfl") {
74 return WalkResult::advance();
75 }
76 // Check for ops that are not legalized to TFLite.
77 if (IsSupportedTFDataForwardingOp(op)) {
78 return WalkResult::advance();
79 }
80
81 // If any of the operands is a resource type, then we break
82 // and mark the module as not valid for TFLite legalization.
83 // Note: this might disable native variables in more than needed cases.
84 // TODO(b/189370197): Enhance variable analysis.
85 for (auto operand : op->getOperands()) {
86 if (getElementTypeOrSelf(operand.getType()).isa<TF::ResourceType>()) {
87 legalize_to_tfl = false;
88 return WalkResult::interrupt();
89 }
90 }
91 return WalkResult::advance();
92 });
93 module->setAttr(kLegalizeTflVariables,
94 BoolAttr::get(context, legalize_to_tfl));
95 }
96
97 } // namespace
98
CreateAnalyzeVariablesPass()99 std::unique_ptr<OperationPass<ModuleOp>> CreateAnalyzeVariablesPass() {
100 return std::make_unique<AnalyzeVariablesPass>();
101 }
102
103 } // namespace TFL
104 } // namespace mlir
105