1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <vector>
18
19 #include "llvm/ADT/BitVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/Debug.h"
23 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project
24 #include "mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
28 #include "mlir/IR/SymbolTable.h" // from @llvm-project
29 #include "mlir/IR/Value.h" // from @llvm-project
30 #include "mlir/Pass/Pass.h" // from @llvm-project
31 #include "mlir/Support/LLVM.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/savedmodel_passes_detail.h"
36
37 #define DEBUG_TYPE "freeze-global-tensor"
38
39 namespace mlir {
40 namespace tf_saved_model {
41
42 // The value of our lattice represents the GlobalTensorOp matching the value.
43 struct ResourceLatticeValue {
ResourceLatticeValuemlir::tf_saved_model::ResourceLatticeValue44 explicit ResourceLatticeValue(GlobalTensorOp op = nullptr) {
45 if (op) ops.insert(op);
46 }
47
getPessimisticValueStatemlir::tf_saved_model::ResourceLatticeValue48 static ResourceLatticeValue getPessimisticValueState(MLIRContext *context) {
49 return ResourceLatticeValue();
50 }
getPessimisticValueStatemlir::tf_saved_model::ResourceLatticeValue51 static ResourceLatticeValue getPessimisticValueState(Value value) {
52 if (auto barg = value.dyn_cast<BlockArgument>()) {
53 if (func::FuncOp func =
54 dyn_cast<func::FuncOp>(barg.getOwner()->getParentOp())) {
55 SymbolTable symbol_table(func->getParentOfType<ModuleOp>());
56 auto global_tensor = LookupBoundInputOfType<GlobalTensorOp>(
57 func, barg.getArgNumber(), symbol_table);
58 return ResourceLatticeValue(global_tensor);
59 }
60 }
61 return ResourceLatticeValue();
62 }
63
operator ==mlir::tf_saved_model::ResourceLatticeValue64 bool operator==(const ResourceLatticeValue &rhs) const {
65 return ops == rhs.ops;
66 }
67
joinmlir::tf_saved_model::ResourceLatticeValue68 static ResourceLatticeValue join(const ResourceLatticeValue &lhs,
69 const ResourceLatticeValue &rhs) {
70 // Take union of both sets of possible GlobalTensorOp values that can be
71 // referenced here.
72 ResourceLatticeValue ret;
73 ret.ops.insert(lhs.ops.begin(), lhs.ops.end());
74 ret.ops.insert(rhs.ops.begin(), rhs.ops.end());
75 return ret;
76 }
77
printmlir::tf_saved_model::ResourceLatticeValue78 void print(raw_ostream &os) const {
79 llvm::interleaveComma(ops, os << "["), os << "]";
80 }
81
82 // The location which originated the int value.
83 // IR constructs (i.e., GlobalTensorOp) are not const-correct.
84 mutable DenseSet<GlobalTensorOp> ops;
85 };
86
87 namespace {
88 class ResourceAnalysis : public dataflow::SparseDataFlowAnalysis<
89 dataflow::Lattice<ResourceLatticeValue>> {
90 public:
91 using StateT = dataflow::Lattice<ResourceLatticeValue>;
92 using dataflow::SparseDataFlowAnalysis<StateT>::SparseDataFlowAnalysis;
93 ~ResourceAnalysis() override = default;
94
visitOperation(Operation * op,ArrayRef<const StateT * > operands,ArrayRef<StateT * > results)95 void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
96 ArrayRef<StateT *> results) override {
97 LLVM_DEBUG(llvm::dbgs() << "ResAn: Visiting operation: " << *op << "\n");
98 markAllPessimisticFixpoint(results);
99 }
100 };
101
102 struct FreezeGlobalTensorsPass
103 : public FreezeGlobalTensorsPassBase<FreezeGlobalTensorsPass> {
FreezeGlobalTensorsPassmlir::tf_saved_model::__anon21b2196a0111::FreezeGlobalTensorsPass104 explicit FreezeGlobalTensorsPass(bool allow_mutable_tensors) {
105 this->allow_mutable_tensors = allow_mutable_tensors;
106 }
107 void runOnOperation() override;
108 };
109
runOnOperation()110 void FreezeGlobalTensorsPass::runOnOperation() {
111 auto module = getOperation();
112 if (!tf_saved_model::HasTfSavedModelSemantics(module)) return;
113
114 DataFlowSolver solver;
115 solver.load<dataflow::DeadCodeAnalysis>();
116 solver.load<ResourceAnalysis>();
117 if (failed(solver.initializeAndRun(module))) return signalPassFailure();
118
119 DenseSet<GlobalTensorOp> remaining_global_tensor_ops;
120 {
121 auto ops = module.getOps<GlobalTensorOp>();
122 remaining_global_tensor_ops.insert(ops.begin(), ops.end());
123 }
124
125 for (auto global_tensor : remaining_global_tensor_ops) {
126 // This pass assumes that all global tensors as immutable (e.g. by a
127 // previous optimize global tensors pass). If not, this pass has to fail
128 // since it cannot perform one of its goals.
129 if (global_tensor.is_mutable()) {
130 if (allow_mutable_tensors) continue;
131 global_tensor.emitError()
132 << "is not immutable, try removing mutable variables in your model "
133 "since mutable variables are currently not supported through "
134 "this converter";
135 return signalPassFailure();
136 }
137 }
138
139 // Collect all those freezable. This is an extra scan but allows for the
140 // partial behavior from `allow_mutable_tensor`.
141 DenseMap<BlockArgument, bool> freezeable;
142 for (auto func : module.getOps<func::FuncOp>()) {
143 for (BlockArgument val : func.getArguments()) {
144 if (!getElementTypeOrSelf(val.getType()).isa<TF::ResourceType>())
145 continue;
146
147 // Check that there is only a single global tensor associated with arg.
148 const ResourceAnalysis::StateT *latticeElement =
149 solver.lookupState<ResourceAnalysis::StateT>(val);
150 if (!latticeElement || latticeElement->getValue().ops.size() != 1)
151 continue;
152
153 // Don't freeze mutable tensors.
154 if (latticeElement->getValue().ops.begin()->is_mutable()) {
155 freezeable[val] = false;
156 continue;
157 }
158
159 freezeable[val] = true;
160
161 // Verify users are supported kind.
162 for (Operation *user : val.getUsers()) {
163 if (!(isa<TF::ReadVariableOp>(user) || isa<CallOpInterface>(user))) {
164 freezeable[val] = false;
165 // Error out early if possible.
166 if (!allow_mutable_tensors) {
167 user->emitError()
168 << "could not rewrite use of immutable bound input";
169 return signalPassFailure();
170 }
171 }
172 }
173 }
174 }
175
176 DenseSet<GlobalTensorOp> frozen_global_tensors;
177 for (auto func : module.getOps<func::FuncOp>()) {
178 llvm::BitVector args_to_erase(func.getNumArguments());
179 DenseMap<Operation *, llvm::BitVector> remove_operands;
180 OpBuilder builder(func.getBody());
181
182 for (BlockArgument val : func.getArguments()) {
183 if (!freezeable[val]) continue;
184
185 const ResourceAnalysis::StateT *latticeElement =
186 solver.lookupState<ResourceAnalysis::StateT>(val);
187 GlobalTensorOp global_tensor = *latticeElement->getValue().ops.begin();
188
189 SmallVector<TF::ReadVariableOp, 4> read_variable_ops_to_erase;
190 frozen_global_tensors.insert(global_tensor);
191
192 for (Operation *user : val.getUsers()) {
193 if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
194 // Collect all read variable ops so that all its uses can be replaced
195 // with the tf.constant corresponding to the global tensor op.
196 read_variable_ops_to_erase.push_back(read_op);
197 } else {
198 llvm::BitVector &bvector = remove_operands[user];
199 bvector.resize(user->getNumOperands());
200 for (OpOperand &use : user->getOpOperands())
201 bvector.set(use.getOperandNumber());
202 }
203 }
204
205 // Replace the arg with a tf.Const op in the function body.
206 builder.setInsertionPointToStart(&func.getBody().front());
207 auto const_op = builder.create<TF::ConstOp>(global_tensor.getLoc(),
208 global_tensor.value());
209 args_to_erase.set(val.getArgNumber());
210 for (auto read_op : read_variable_ops_to_erase) {
211 read_op.getResult().replaceAllUsesWith(const_op.getResult());
212 read_op.erase();
213 }
214 }
215 // As the other uses are call operations, we simply remove the arguments
216 // as the function arguments will be removed below once that function is
217 // processed.
218 for (auto it : remove_operands) {
219 it.first->eraseOperands(it.second);
220 }
221
222 func.eraseArguments(args_to_erase);
223 }
224
225 // Erase all global tensors that were frozen.
226 for (auto global_tensor : frozen_global_tensors) {
227 remaining_global_tensor_ops.erase(global_tensor);
228 global_tensor->erase();
229 }
230
231 // Verify that there are no remaining global tensors.
232 if (!allow_mutable_tensors && !remaining_global_tensor_ops.empty()) {
233 module.emitError() << "could not freeze all global tensors in the module";
234 return signalPassFailure();
235 }
236 }
237
238 } // namespace
239
CreateFreezeGlobalTensorsPass(bool allow_mutable_tensors)240 std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass(
241 bool allow_mutable_tensors) {
242 return std::make_unique<FreezeGlobalTensorsPass>(allow_mutable_tensors);
243 }
244
245 } // namespace tf_saved_model
246 } // namespace mlir
247