1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <functional>
17 #include <memory>
18
19 #include "llvm/ADT/STLExtras.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/Block.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Support/LLVM.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
29 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32
33 namespace mlir {
34 namespace TFL {
35 namespace {
36 #define GEN_PASS_CLASSES
37 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
38
IsResourceTensor(Value value)39 bool IsResourceTensor(Value value) {
40 const auto tensor_type = value.getType().dyn_cast<TensorType>();
41 return tensor_type &&
42 tensor_type.getElementType().isa<mlir::TF::ResourceType>();
43 }
44
45 // The default criterion for operations being considered as causing or being
46 // dependent on side effects. Reflects the current runtime logic; see below.
OpHasSideEffects(Operation * op)47 bool OpHasSideEffects(Operation *op) {
48 // Note that TFL::IfOp are only ever instantiated in flatbuffer_export; until
49 // then, they are represented as mlir::TF::IfOp. We add them here, anyway, to
50 // be future-proof.
51 if (llvm::isa<TF::IfOp, TFL::IfOp, TFL::CallOnceOp, TFL::WhileOp>(op))
52 return true;
53 for (auto operand : op->getOperands()) {
54 if (IsResourceTensor(operand)) return true;
55 }
56 for (auto result : op->getResults()) {
57 if (IsResourceTensor(result)) return true;
58 }
59 return false;
60 }
61
62 // This transformation pass takes an operation that has or depends on side
63 // effects and wraps it in a TFL::ControlNodeOp, which is made to depend on the
64 // control token generated by the most recent preceding such operation, if
65 // any. This copies the logic that is currently executed at runtime (in
66 // tensorflow/lite/core/subgraph). That runtime logic will now be a no-op for
67 // models that were generated with this pass.
68 //
69 // For purposes of this pass, an operator is considered to have/depend on side
70 // effects if
71 // - it involves calling a different function
72 // - it involves accessing resource variables
73 //
74 // Note that these criteria are more restrictive than necessary:
75 // - they will force a fixed order on operations that read from/write to
76 // *different* variables
77 // - they make the blanket assumption that any functions called cause or depend
78 // on side effects (i.e., access resource variables.)
79 //
80 // By moving the logic to compile time, we will be able to do a finer-grained
81 // data flow analysis in the future, which will enable more optimizations.
82 // This could happen in two steps:
83 // (1) build multiple dependency chains (one per variable), still treating
84 // function/subgraph calls as black boxes (i.e., all variables would
85 // be assumed to be read and modified within control operations)
86 // (2) Extend the variable dependency analysis across function boundaries.
87 class PinOpsWithSideEffectsPass
88 : public PinOpsWithSideEffectsPassBase<PinOpsWithSideEffectsPass> {
89 public:
90 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PinOpsWithSideEffectsPass)
91
PinOpsWithSideEffectsPass(const std::function<bool (Operation *)> & op_has_side_effects=OpHasSideEffects)92 explicit PinOpsWithSideEffectsPass(const std::function<bool(Operation *)> &
93 op_has_side_effects = OpHasSideEffects)
94 : op_has_side_effects_(op_has_side_effects) {}
95
96 void runOnOperation() override;
97
98 private:
99 // This will be used recursively.
100 const std::function<bool(Operation *)> op_has_side_effects_;
101 };
102
runOnOperation()103 void PinOpsWithSideEffectsPass::runOnOperation() {
104 auto fn = getOperation();
105 // We're assuming (and checking) that there are no tfl::ControlNodeOps present
106 // before this pass. We could relax this requirement by defining folding logic
107 // for them.
108 if (fn.walk([&](TFL::ControlNodeOp) {
109 return WalkResult::interrupt();
110 }).wasInterrupted()) {
111 fn.emitOpError("Can't have control ops in this pass.");
112 signalPassFailure();
113 }
114
115 llvm::SmallVector<Operation *, 4> ops_with_side_effects;
116
117 // We're iterating over all operations at the top block level, excluding
118 // the return operation (which otherwise would be recognized as being
119 // susceptible to side effects when returning a resource variable.)
120 // We only need to consider functions with single-block bodies, as
121 // this is an assumption flatbuffer_export makes, and this pass is
122 // operating on IRs ready for exporting.
123 for (Operation &op : fn.getBody().front().without_terminator()) {
124 // We have to recurse, since we might have wrapped a side-effectful operator
125 // in a tfl::CustomTfOp.
126 if (op.walk([&](Operation *inner_op) {
127 return op_has_side_effects_(inner_op) ? WalkResult::interrupt()
128 : WalkResult::advance();
129 }).wasInterrupted()) {
130 ops_with_side_effects.push_back(&op);
131 }
132 }
133
134 OpBuilder builder(fn.getContext());
135 // The control tokens generated by the last ControlNodeOp wrapping. Will be
136 // empty until the first ControlNodeOp was generated, then have constant size
137 // 1.
138 llvm::SmallVector<Value, 1> control_tokens;
139 for (auto *op : ops_with_side_effects) {
140 // Wrap all side-effect producing/dependent operations in a ControlNodeOp.
141 builder.setInsertionPoint(op);
142 Location loc = op->getLoc();
143 auto outer_op = builder.create<ControlNodeOp>(
144 loc, op->getResultTypes(), ControlType::get(op->getContext()),
145 control_tokens);
146 Region region;
147 Block *new_block = new Block;
148 region.push_back(new_block);
149 builder.setInsertionPointToEnd(®ion.front());
150 Operation *inner_op = builder.clone(*op);
151 builder.create<YieldOp>(loc, inner_op->getResults());
152 outer_op.body().takeBody(region);
153 // Careful: We can't use outer_op.getResults(), because that also includes
154 // the control token.
155 op->replaceAllUsesWith(outer_op.outputs());
156 op->erase();
157 // Control token is last result of outer_op.
158 control_tokens.assign(1, outer_op.getResults().back());
159 }
160 }
161 } // namespace
162
CreatePinOpsWithSideEffectsPass()163 std::unique_ptr<OperationPass<func::FuncOp>> CreatePinOpsWithSideEffectsPass() {
164 return std::make_unique<PinOpsWithSideEffectsPass>();
165 }
166
167 static PassRegistration<PinOpsWithSideEffectsPass> pass;
168
169 } // namespace TFL
170 } // namespace mlir
171