1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass optimizes tf_saved_model.global_tensor ops.
17
18 #include <cstddef>
19 #include <map>
20 #include <set>
21
22 #include "llvm/ADT/DenseMap.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/IR/SymbolTable.h" // from @llvm-project
29 #include "mlir/IR/Types.h" // from @llvm-project
30 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Support/LLVM.h" // from @llvm-project
33 #include "mlir/Support/LogicalResult.h" // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/savedmodel_passes_detail.h"
40
41 namespace mlir {
42 namespace tf_saved_model {
43 namespace {
44 struct OptimizeGlobalTensorsPass
45 : public OptimizeGlobalTensorsPassBase<OptimizeGlobalTensorsPass> {
46 void runOnOperation() override;
47 };
48
49 // A global tensor is bound to arguments of multiple funcs.
50 // This struct tracks which funcs (and which argument to that func) the global
51 // tensor is bound to.
52 struct GlobalTensorUse {
53 mutable func::FuncOp func;
54 size_t arg_index;
55 };
56
57 using GlobalTensorUsesMap =
58 std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
59
IsImmutable(GlobalTensorOp global_tensor,ArrayRef<GlobalTensorUse> global_tensor_uses,const TF::ResourceAnalyzer & resource_analyzer)60 bool IsImmutable(GlobalTensorOp global_tensor,
61 ArrayRef<GlobalTensorUse> global_tensor_uses,
62 const TF::ResourceAnalyzer& resource_analyzer) {
63 // Global tensor is already known to be immutable.
64 if (!global_tensor.is_mutable()) {
65 return false;
66 }
67 // An exported global tensor that is not already known to be immutable might
68 // be externally mutated.
69 if (IsExported(global_tensor)) {
70 return false;
71 }
72
73 // A global tensor is immutable if the resource analyzer deems it so.
74 for (auto& global_tensor_use : global_tensor_uses) {
75 auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
76 if (resource_analyzer.IsPotentiallyWritten(arg)) {
77 return false;
78 }
79 }
80 return true;
81 }
82
CreateGlobalTensorUsesMap(ModuleOp module)83 GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
84 GlobalTensorUsesMap global_tensor_uses;
85
86 SymbolTable symbol_table(module);
87 for (auto func : module.getOps<func::FuncOp>()) {
88 for (size_t i = 0, e = func.getNumArguments(); i < e; i++) {
89 auto sym =
90 func.getArgAttrOfType<SymbolRefAttr>(i, "tf_saved_model.bound_input");
91 if (!sym) {
92 continue;
93 }
94 auto global_tensor = symbol_table.lookup<GlobalTensorOp>(
95 sym.cast<FlatSymbolRefAttr>().getValue());
96 if (!global_tensor) {
97 continue;
98 }
99 global_tensor_uses[global_tensor].push_back({func, i});
100 }
101 }
102
103 return global_tensor_uses;
104 }
105
106 // Removes `is_mutable` attribute from tf_saved_model.global_tensor ops where we
107 // can prove it is safe to do so.
MarkGlobalTensorsImmutable(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses_map,const TF::ResourceAnalyzer & resource_analyzer)108 void MarkGlobalTensorsImmutable(
109 ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map,
110 const TF::ResourceAnalyzer& resource_analyzer) {
111 for (const auto& kv : global_tensor_uses_map) {
112 auto global_tensor = kv.first;
113 const auto& global_tensor_uses = kv.second;
114 if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
115 global_tensor->removeAttr("is_mutable");
116 }
117 }
118 }
119
EraseUnusedGlobalTensors(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses)120 void EraseUnusedGlobalTensors(ModuleOp module,
121 const GlobalTensorUsesMap& global_tensor_uses) {
122 for (auto global_tensor :
123 llvm::make_early_inc_range(module.getOps<GlobalTensorOp>())) {
124 // If the tensor is exported, then it is used.
125 if (IsExported(global_tensor)) {
126 continue;
127 }
128 // If the tensor is bound to an argument, then it is used.
129 if (global_tensor_uses.find(global_tensor) != global_tensor_uses.end()) {
130 continue;
131 }
132 // Erase it.
133 global_tensor.erase();
134 }
135 }
136
EraseUnusedBoundInputs(ModuleOp module)137 void EraseUnusedBoundInputs(ModuleOp module) {
138 for (auto func : module.getOps<func::FuncOp>()) {
139 llvm::BitVector args_to_erase(func.getNumArguments());
140 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
141 if (func.getArgAttr(i, "tf_saved_model.bound_input") &&
142 func.getArgument(i).use_empty()) {
143 args_to_erase.set(i);
144 }
145 }
146 func.eraseArguments(args_to_erase);
147 }
148 }
149
runOnOperation()150 void OptimizeGlobalTensorsPass::runOnOperation() {
151 auto module = getOperation();
152 if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
153 return;
154 }
155
156 EraseUnusedBoundInputs(module);
157
158 TF::ResourceAnalyzer resource_analyzer(module);
159
160 GlobalTensorUsesMap global_tensor_uses = CreateGlobalTensorUsesMap(module);
161
162 MarkGlobalTensorsImmutable(module, global_tensor_uses, resource_analyzer);
163
164 EraseUnusedGlobalTensors(module, global_tensor_uses);
165 }
166
167 } // namespace
168
CreateOptimizeGlobalTensorsPass()169 std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
170 return std::make_unique<OptimizeGlobalTensorsPass>();
171 }
172
173 } // namespace tf_saved_model
174 } // namespace mlir
175