xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/pin_ops_with_side_effects.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <functional>
17 #include <memory>
18 
19 #include "llvm/ADT/STLExtras.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/Block.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Support/LLVM.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
29 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 
33 namespace mlir {
34 namespace TFL {
35 namespace {
36 #define GEN_PASS_CLASSES
37 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
38 
IsResourceTensor(Value value)39 bool IsResourceTensor(Value value) {
40   const auto tensor_type = value.getType().dyn_cast<TensorType>();
41   return tensor_type &&
42          tensor_type.getElementType().isa<mlir::TF::ResourceType>();
43 }
44 
45 // The default criterion for operations being considered as causing or being
46 // dependent on side effects. Reflects the current runtime logic; see below.
OpHasSideEffects(Operation * op)47 bool OpHasSideEffects(Operation *op) {
48   // Note that TFL::IfOp are only ever instantiated in flatbuffer_export; until
49   // then, they are represented as mlir::TF::IfOp. We add them here, anyway, to
50   // be future-proof.
51   if (llvm::isa<TF::IfOp, TFL::IfOp, TFL::CallOnceOp, TFL::WhileOp>(op))
52     return true;
53   for (auto operand : op->getOperands()) {
54     if (IsResourceTensor(operand)) return true;
55   }
56   for (auto result : op->getResults()) {
57     if (IsResourceTensor(result)) return true;
58   }
59   return false;
60 }
61 
62 // This transformation pass takes an operation that has or depends on side
63 // effects and wraps it in a TFL::ControlNodeOp, which is made to depend on the
64 // control token generated by the most recent preceding such operation, if
65 // any. This copies the logic that is currently executed at runtime (in
66 // tensorflow/lite/core/subgraph). That runtime logic will now be a no-op for
67 // models that were generated with this pass.
68 //
69 // For purposes of this pass, an operator is considered to have/depend on side
70 // effects if
71 //  - it involves calling a different function
72 //  - it involves accessing resource variables
73 //
74 // Note that these criteria are more restrictive than necessary:
75 //  - they will force a fixed order on operations that read from/write to
76 //    *different* variables
77 //  - they make the blanket assumption that any functions called cause or depend
78 //    on side effects (i.e., access resource variables.)
79 //
80 // By moving the logic to compile time, we will be able to do a finer-grained
81 // data flow analysis in the future, which will enable more optimizations.
82 // This could happen in two steps:
83 // (1) build multiple dependency chains (one per variable), still treating
84 //     function/subgraph calls as black boxes (i.e., all variables would
85 //     be assumed to be read and modified within control operations)
86 // (2) Extend the variable dependency analysis across function boundaries.
87 class PinOpsWithSideEffectsPass
88     : public PinOpsWithSideEffectsPassBase<PinOpsWithSideEffectsPass> {
89  public:
90   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PinOpsWithSideEffectsPass)
91 
PinOpsWithSideEffectsPass(const std::function<bool (Operation *)> & op_has_side_effects=OpHasSideEffects)92   explicit PinOpsWithSideEffectsPass(const std::function<bool(Operation *)> &
93                                          op_has_side_effects = OpHasSideEffects)
94       : op_has_side_effects_(op_has_side_effects) {}
95 
96   void runOnOperation() override;
97 
98  private:
99   // This will be used recursively.
100   const std::function<bool(Operation *)> op_has_side_effects_;
101 };
102 
runOnOperation()103 void PinOpsWithSideEffectsPass::runOnOperation() {
104   auto fn = getOperation();
105   // We're assuming (and checking) that there are no tfl::ControlNodeOps present
106   // before this pass. We could relax this requirement by defining folding logic
107   // for them.
108   if (fn.walk([&](TFL::ControlNodeOp) {
109           return WalkResult::interrupt();
110         }).wasInterrupted()) {
111     fn.emitOpError("Can't have control ops in this pass.");
112     signalPassFailure();
113   }
114 
115   llvm::SmallVector<Operation *, 4> ops_with_side_effects;
116 
117   // We're iterating over all operations at the top block level, excluding
118   // the return operation (which otherwise would be recognized as being
119   // susceptible to side effects when returning a resource variable.)
120   // We only need to consider functions with single-block bodies, as
121   // this is an assumption flatbuffer_export makes, and this pass is
122   // operating on IRs ready for exporting.
123   for (Operation &op : fn.getBody().front().without_terminator()) {
124     // We have to recurse, since we might have wrapped a side-effectful operator
125     // in a tfl::CustomTfOp.
126     if (op.walk([&](Operation *inner_op) {
127             return op_has_side_effects_(inner_op) ? WalkResult::interrupt()
128                                                   : WalkResult::advance();
129           }).wasInterrupted()) {
130       ops_with_side_effects.push_back(&op);
131     }
132   }
133 
134   OpBuilder builder(fn.getContext());
135   // The control tokens generated by the last ControlNodeOp wrapping.  Will be
136   // empty until the first ControlNodeOp was generated, then have constant size
137   // 1.
138   llvm::SmallVector<Value, 1> control_tokens;
139   for (auto *op : ops_with_side_effects) {
140     // Wrap all side-effect producing/dependent operations in a ControlNodeOp.
141     builder.setInsertionPoint(op);
142     Location loc = op->getLoc();
143     auto outer_op = builder.create<ControlNodeOp>(
144         loc, op->getResultTypes(), ControlType::get(op->getContext()),
145         control_tokens);
146     Region region;
147     Block *new_block = new Block;
148     region.push_back(new_block);
149     builder.setInsertionPointToEnd(&region.front());
150     Operation *inner_op = builder.clone(*op);
151     builder.create<YieldOp>(loc, inner_op->getResults());
152     outer_op.body().takeBody(region);
153     // Careful: We can't use outer_op.getResults(), because that also includes
154     // the control token.
155     op->replaceAllUsesWith(outer_op.outputs());
156     op->erase();
157     // Control token is last result of outer_op.
158     control_tokens.assign(1, outer_op.getResults().back());
159   }
160 }
161 }  // namespace
162 
CreatePinOpsWithSideEffectsPass()163 std::unique_ptr<OperationPass<func::FuncOp>> CreatePinOpsWithSideEffectsPass() {
164   return std::make_unique<PinOpsWithSideEffectsPass>();
165 }
166 
167 static PassRegistration<PinOpsWithSideEffectsPass> pass;
168 
169 }  // namespace TFL
170 }  // namespace mlir
171