1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <vector>
18 
19 #include "llvm/ADT/BitVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/Debug.h"
23 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"  // from @llvm-project
24 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"  // from @llvm-project
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
28 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
29 #include "mlir/IR/Value.h"  // from @llvm-project
30 #include "mlir/Pass/Pass.h"  // from @llvm-project
31 #include "mlir/Support/LLVM.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/savedmodel_passes_detail.h"
36 
37 #define DEBUG_TYPE "freeze-global-tensor"
38 
39 namespace mlir {
40 namespace tf_saved_model {
41 
42 // The value of our lattice represents the GlobalTensorOp matching the value.
43 struct ResourceLatticeValue {
ResourceLatticeValuemlir::tf_saved_model::ResourceLatticeValue44   explicit ResourceLatticeValue(GlobalTensorOp op = nullptr) {
45     if (op) ops.insert(op);
46   }
47 
getPessimisticValueStatemlir::tf_saved_model::ResourceLatticeValue48   static ResourceLatticeValue getPessimisticValueState(MLIRContext *context) {
49     return ResourceLatticeValue();
50   }
getPessimisticValueStatemlir::tf_saved_model::ResourceLatticeValue51   static ResourceLatticeValue getPessimisticValueState(Value value) {
52     if (auto barg = value.dyn_cast<BlockArgument>()) {
53       if (func::FuncOp func =
54               dyn_cast<func::FuncOp>(barg.getOwner()->getParentOp())) {
55         SymbolTable symbol_table(func->getParentOfType<ModuleOp>());
56         auto global_tensor = LookupBoundInputOfType<GlobalTensorOp>(
57             func, barg.getArgNumber(), symbol_table);
58         return ResourceLatticeValue(global_tensor);
59       }
60     }
61     return ResourceLatticeValue();
62   }
63 
operator ==mlir::tf_saved_model::ResourceLatticeValue64   bool operator==(const ResourceLatticeValue &rhs) const {
65     return ops == rhs.ops;
66   }
67 
joinmlir::tf_saved_model::ResourceLatticeValue68   static ResourceLatticeValue join(const ResourceLatticeValue &lhs,
69                                    const ResourceLatticeValue &rhs) {
70     // Take union of both sets of possible GlobalTensorOp values that can be
71     // referenced here.
72     ResourceLatticeValue ret;
73     ret.ops.insert(lhs.ops.begin(), lhs.ops.end());
74     ret.ops.insert(rhs.ops.begin(), rhs.ops.end());
75     return ret;
76   }
77 
printmlir::tf_saved_model::ResourceLatticeValue78   void print(raw_ostream &os) const {
79     llvm::interleaveComma(ops, os << "["), os << "]";
80   }
81 
82   // The location which originated the int value.
83   // IR constructs (i.e., GlobalTensorOp) are not const-correct.
84   mutable DenseSet<GlobalTensorOp> ops;
85 };
86 
87 namespace {
88 class ResourceAnalysis : public dataflow::SparseDataFlowAnalysis<
89                              dataflow::Lattice<ResourceLatticeValue>> {
90  public:
91   using StateT = dataflow::Lattice<ResourceLatticeValue>;
92   using dataflow::SparseDataFlowAnalysis<StateT>::SparseDataFlowAnalysis;
93   ~ResourceAnalysis() override = default;
94 
visitOperation(Operation * op,ArrayRef<const StateT * > operands,ArrayRef<StateT * > results)95   void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
96                       ArrayRef<StateT *> results) override {
97     LLVM_DEBUG(llvm::dbgs() << "ResAn: Visiting operation: " << *op << "\n");
98     markAllPessimisticFixpoint(results);
99   }
100 };
101 
102 struct FreezeGlobalTensorsPass
103     : public FreezeGlobalTensorsPassBase<FreezeGlobalTensorsPass> {
FreezeGlobalTensorsPassmlir::tf_saved_model::__anon21b2196a0111::FreezeGlobalTensorsPass104   explicit FreezeGlobalTensorsPass(bool allow_mutable_tensors) {
105     this->allow_mutable_tensors = allow_mutable_tensors;
106   }
107   void runOnOperation() override;
108 };
109 
runOnOperation()110 void FreezeGlobalTensorsPass::runOnOperation() {
111   auto module = getOperation();
112   if (!tf_saved_model::HasTfSavedModelSemantics(module)) return;
113 
114   DataFlowSolver solver;
115   solver.load<dataflow::DeadCodeAnalysis>();
116   solver.load<ResourceAnalysis>();
117   if (failed(solver.initializeAndRun(module))) return signalPassFailure();
118 
119   DenseSet<GlobalTensorOp> remaining_global_tensor_ops;
120   {
121     auto ops = module.getOps<GlobalTensorOp>();
122     remaining_global_tensor_ops.insert(ops.begin(), ops.end());
123   }
124 
125   for (auto global_tensor : remaining_global_tensor_ops) {
126     // This pass assumes that all global tensors as immutable (e.g. by a
127     // previous optimize global tensors pass). If not, this pass has to fail
128     // since it cannot perform one of its goals.
129     if (global_tensor.is_mutable()) {
130       if (allow_mutable_tensors) continue;
131       global_tensor.emitError()
132           << "is not immutable, try removing mutable variables in your model "
133              "since mutable variables are currently not supported through "
134              "this converter";
135       return signalPassFailure();
136     }
137   }
138 
139   // Collect all those freezable. This is an extra scan but allows for the
140   // partial behavior from `allow_mutable_tensor`.
141   DenseMap<BlockArgument, bool> freezeable;
142   for (auto func : module.getOps<func::FuncOp>()) {
143     for (BlockArgument val : func.getArguments()) {
144       if (!getElementTypeOrSelf(val.getType()).isa<TF::ResourceType>())
145         continue;
146 
147       // Check that there is only a single global tensor associated with arg.
148       const ResourceAnalysis::StateT *latticeElement =
149           solver.lookupState<ResourceAnalysis::StateT>(val);
150       if (!latticeElement || latticeElement->getValue().ops.size() != 1)
151         continue;
152 
153       // Don't freeze mutable tensors.
154       if (latticeElement->getValue().ops.begin()->is_mutable()) {
155         freezeable[val] = false;
156         continue;
157       }
158 
159       freezeable[val] = true;
160 
161       // Verify users are supported kind.
162       for (Operation *user : val.getUsers()) {
163         if (!(isa<TF::ReadVariableOp>(user) || isa<CallOpInterface>(user))) {
164           freezeable[val] = false;
165           // Error out early if possible.
166           if (!allow_mutable_tensors) {
167             user->emitError()
168                 << "could not rewrite use of immutable bound input";
169             return signalPassFailure();
170           }
171         }
172       }
173     }
174   }
175 
176   DenseSet<GlobalTensorOp> frozen_global_tensors;
177   for (auto func : module.getOps<func::FuncOp>()) {
178     llvm::BitVector args_to_erase(func.getNumArguments());
179     DenseMap<Operation *, llvm::BitVector> remove_operands;
180     OpBuilder builder(func.getBody());
181 
182     for (BlockArgument val : func.getArguments()) {
183       if (!freezeable[val]) continue;
184 
185       const ResourceAnalysis::StateT *latticeElement =
186           solver.lookupState<ResourceAnalysis::StateT>(val);
187       GlobalTensorOp global_tensor = *latticeElement->getValue().ops.begin();
188 
189       SmallVector<TF::ReadVariableOp, 4> read_variable_ops_to_erase;
190       frozen_global_tensors.insert(global_tensor);
191 
192       for (Operation *user : val.getUsers()) {
193         if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
194           // Collect all read variable ops so that all its uses can be replaced
195           // with the tf.constant corresponding to the global tensor op.
196           read_variable_ops_to_erase.push_back(read_op);
197         } else {
198           llvm::BitVector &bvector = remove_operands[user];
199           bvector.resize(user->getNumOperands());
200           for (OpOperand &use : user->getOpOperands())
201             bvector.set(use.getOperandNumber());
202         }
203       }
204 
205       // Replace the arg with a tf.Const op in the function body.
206       builder.setInsertionPointToStart(&func.getBody().front());
207       auto const_op = builder.create<TF::ConstOp>(global_tensor.getLoc(),
208                                                   global_tensor.value());
209       args_to_erase.set(val.getArgNumber());
210       for (auto read_op : read_variable_ops_to_erase) {
211         read_op.getResult().replaceAllUsesWith(const_op.getResult());
212         read_op.erase();
213       }
214     }
215     // As the other uses are call operations, we simply remove the arguments
216     // as the function arguments will be removed below once that function is
217     // processed.
218     for (auto it : remove_operands) {
219       it.first->eraseOperands(it.second);
220     }
221 
222     func.eraseArguments(args_to_erase);
223   }
224 
225   // Erase all global tensors that were frozen.
226   for (auto global_tensor : frozen_global_tensors) {
227     remaining_global_tensor_ops.erase(global_tensor);
228     global_tensor->erase();
229   }
230 
231   // Verify that there are no remaining global tensors.
232   if (!allow_mutable_tensors && !remaining_global_tensor_ops.empty()) {
233     module.emitError() << "could not freeze all global tensors in the module";
234     return signalPassFailure();
235   }
236 }
237 
238 }  // namespace
239 
CreateFreezeGlobalTensorsPass(bool allow_mutable_tensors)240 std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass(
241     bool allow_mutable_tensors) {
242   return std::make_unique<FreezeGlobalTensorsPass>(allow_mutable_tensors);
243 }
244 
245 }  // namespace tf_saved_model
246 }  // namespace mlir
247