1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass optimizes tf_saved_model.global_tensor ops.
17 
18 #include <cstddef>
19 #include <map>
20 #include <set>
21 
22 #include "llvm/ADT/DenseMap.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
29 #include "mlir/IR/Types.h"  // from @llvm-project
30 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Support/LLVM.h"  // from @llvm-project
33 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/savedmodel_passes_detail.h"
40 
41 namespace mlir {
42 namespace tf_saved_model {
43 namespace {
44 struct OptimizeGlobalTensorsPass
45     : public OptimizeGlobalTensorsPassBase<OptimizeGlobalTensorsPass> {
46   void runOnOperation() override;
47 };
48 
49 // A global tensor is bound to arguments of multiple funcs.
50 // This struct tracks which funcs (and which argument to that func) the global
51 // tensor is bound to.
52 struct GlobalTensorUse {
53   mutable func::FuncOp func;
54   size_t arg_index;
55 };
56 
57 using GlobalTensorUsesMap =
58     std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
59 
IsImmutable(GlobalTensorOp global_tensor,ArrayRef<GlobalTensorUse> global_tensor_uses,const TF::ResourceAnalyzer & resource_analyzer)60 bool IsImmutable(GlobalTensorOp global_tensor,
61                  ArrayRef<GlobalTensorUse> global_tensor_uses,
62                  const TF::ResourceAnalyzer& resource_analyzer) {
63   // Global tensor is already known to be immutable.
64   if (!global_tensor.is_mutable()) {
65     return false;
66   }
67   // An exported global tensor that is not already known to be immutable might
68   // be externally mutated.
69   if (IsExported(global_tensor)) {
70     return false;
71   }
72 
73   // A global tensor is immutable if the resource analyzer deems it so.
74   for (auto& global_tensor_use : global_tensor_uses) {
75     auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
76     if (resource_analyzer.IsPotentiallyWritten(arg)) {
77       return false;
78     }
79   }
80   return true;
81 }
82 
CreateGlobalTensorUsesMap(ModuleOp module)83 GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
84   GlobalTensorUsesMap global_tensor_uses;
85 
86   SymbolTable symbol_table(module);
87   for (auto func : module.getOps<func::FuncOp>()) {
88     for (size_t i = 0, e = func.getNumArguments(); i < e; i++) {
89       auto sym =
90           func.getArgAttrOfType<SymbolRefAttr>(i, "tf_saved_model.bound_input");
91       if (!sym) {
92         continue;
93       }
94       auto global_tensor = symbol_table.lookup<GlobalTensorOp>(
95           sym.cast<FlatSymbolRefAttr>().getValue());
96       if (!global_tensor) {
97         continue;
98       }
99       global_tensor_uses[global_tensor].push_back({func, i});
100     }
101   }
102 
103   return global_tensor_uses;
104 }
105 
106 // Removes `is_mutable` attribute from tf_saved_model.global_tensor ops where we
107 // can prove it is safe to do so.
MarkGlobalTensorsImmutable(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses_map,const TF::ResourceAnalyzer & resource_analyzer)108 void MarkGlobalTensorsImmutable(
109     ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map,
110     const TF::ResourceAnalyzer& resource_analyzer) {
111   for (const auto& kv : global_tensor_uses_map) {
112     auto global_tensor = kv.first;
113     const auto& global_tensor_uses = kv.second;
114     if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
115       global_tensor->removeAttr("is_mutable");
116     }
117   }
118 }
119 
EraseUnusedGlobalTensors(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses)120 void EraseUnusedGlobalTensors(ModuleOp module,
121                               const GlobalTensorUsesMap& global_tensor_uses) {
122   for (auto global_tensor :
123        llvm::make_early_inc_range(module.getOps<GlobalTensorOp>())) {
124     // If the tensor is exported, then it is used.
125     if (IsExported(global_tensor)) {
126       continue;
127     }
128     // If the tensor is bound to an argument, then it is used.
129     if (global_tensor_uses.find(global_tensor) != global_tensor_uses.end()) {
130       continue;
131     }
132     // Erase it.
133     global_tensor.erase();
134   }
135 }
136 
EraseUnusedBoundInputs(ModuleOp module)137 void EraseUnusedBoundInputs(ModuleOp module) {
138   for (auto func : module.getOps<func::FuncOp>()) {
139     llvm::BitVector args_to_erase(func.getNumArguments());
140     for (int i = 0, e = func.getNumArguments(); i < e; i++) {
141       if (func.getArgAttr(i, "tf_saved_model.bound_input") &&
142           func.getArgument(i).use_empty()) {
143         args_to_erase.set(i);
144       }
145     }
146     func.eraseArguments(args_to_erase);
147   }
148 }
149 
runOnOperation()150 void OptimizeGlobalTensorsPass::runOnOperation() {
151   auto module = getOperation();
152   if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
153     return;
154   }
155 
156   EraseUnusedBoundInputs(module);
157 
158   TF::ResourceAnalyzer resource_analyzer(module);
159 
160   GlobalTensorUsesMap global_tensor_uses = CreateGlobalTensorUsesMap(module);
161 
162   MarkGlobalTensorsImmutable(module, global_tensor_uses, resource_analyzer);
163 
164   EraseUnusedGlobalTensors(module, global_tensor_uses);
165 }
166 
167 }  // namespace
168 
CreateOptimizeGlobalTensorsPass()169 std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
170   return std::make_unique<OptimizeGlobalTensorsPass>();
171 }
172 
173 }  // namespace tf_saved_model
174 }  // namespace mlir
175