1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
17
18 #include "llvm/Support/Casting.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
21 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
25
26 namespace mlir {
27 namespace TF {
28 namespace {
IsResourceType(Type type)29 bool IsResourceType(Type type) {
30 if (auto tensor_type = type.dyn_cast<TensorType>()) {
31 return tensor_type.getElementType().isa<TF::ResourceType>();
32 }
33 return false;
34 }
35
IsResource(Value value)36 bool IsResource(Value value) { return IsResourceType(value.getType()); }
37
38 // Helper that returns the FuncOp that is the SessionInit function which
39 // will be called to initialize all resources.
40 // Returns nullptr if no function is found.
GetSessionInitializerFunc(ModuleOp module)41 func::FuncOp GetSessionInitializerFunc(ModuleOp module) {
42 auto session_init_op = tf_saved_model::GetSessionInitializerOp(module);
43 if (session_init_op && !session_init_op.initializers().empty()) {
44 SymbolTable symbol_table(module);
45 func::FuncOp init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
46 session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
47 return init_func_op;
48 }
49 return nullptr;
50 }
51
52 // Returns ID for identifying a resource.
GetResourceKey(Operation * op)53 std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef> GetResourceKey(
54 Operation* op) {
55 llvm::StringRef device;
56 if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
57 device = attr.getValue();
58 }
59
60 llvm::StringRef container;
61 if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
62 container = attr.getValue();
63 }
64
65 llvm::StringRef shared_name;
66 if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
67 shared_name = attr.getValue();
68 }
69
70 return std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>{
71 device, container, shared_name};
72 }
73 } // namespace
ResourceAnalyzer(ModuleOp module,bool skip_session_init)74 ResourceAnalyzer::ResourceAnalyzer(ModuleOp module, bool skip_session_init) {
75 auto session_init_func = GetSessionInitializerFunc(module);
76 for (auto func : module.getOps<func::FuncOp>()) {
77 if (skip_session_init && func == session_init_func) continue;
78 (void)AnalyzeRegion(func.getRegion());
79 }
80 }
81
SetPotentiallyWritten(Value resource)82 void ResourceAnalyzer::SetPotentiallyWritten(Value resource) {
83 assert(IsResource(resource));
84 resource_infos_[resource].potentially_written = true;
85 auto* operation = resource.getDefiningOp();
86 if (operation && llvm::isa<TF::VarHandleOp>(operation)) {
87 mutable_variables_.insert(GetResourceKey(operation));
88 }
89 }
90
IsPotentiallyWritten(Value resource) const91 bool ResourceAnalyzer::IsPotentiallyWritten(Value resource) const {
92 assert(IsResource(resource));
93 auto* operation = resource.getDefiningOp();
94 if (operation && llvm::isa<TF::VarHandleOp>(operation))
95 return mutable_variables_.contains(GetResourceKey(operation));
96 auto it = resource_infos_.find(resource);
97 if (it == resource_infos_.end()) {
98 return false;
99 }
100 return it->second.potentially_written;
101 }
102
103 // Analyze the specified region for resource mutating operations, namely
104 // TF::AssignVariableOp, if so, set the resource associated as "potentially
105 // written". Do this recursively across the chain of regions via call or
106 // control flow ops.
107 // TODO(ashwinm): Move to iterative traversal.
AnalyzeRegion(Region & region)108 LogicalResult ResourceAnalyzer::AnalyzeRegion(Region& region) {
109 // Avoid infinite recursion.
110 if (!discovered_.insert(®ion).second) {
111 return success();
112 }
113
114 region.walk([&](Operation* op) {
115 if (isa<TF::ReadVariableOp, func::ReturnOp, YieldOp>(op)) {
116 return;
117 }
118 if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) {
119 SetPotentiallyWritten(assign_variable.resource());
120 return;
121 }
122 if (auto call = dyn_cast<CallOpInterface>(op)) {
123 if (auto func = dyn_cast<func::FuncOp>(call.resolveCallable())) {
124 PropagatePotentiallyWrittenUpFromCallee(func.getRegion(),
125 call.getArgOperands());
126 }
127 return;
128 }
129 if (auto if_op = dyn_cast<TF::IfOp>(op)) {
130 for (auto callee : {if_op.then_function(), if_op.else_function()}) {
131 PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
132 if_op.input());
133 }
134 return;
135 }
136 if (auto if_op = dyn_cast<TF::IfRegionOp>(op)) {
137 PropagatePotentiallyWrittenUpFromCallee(if_op.then_branch(),
138 if_op.getODSOperands(1));
139 PropagatePotentiallyWrittenUpFromCallee(if_op.else_branch(),
140 if_op.getODSOperands(1));
141 return;
142 }
143 if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
144 for (auto callee : {while_op.cond_function(), while_op.body_function()}) {
145 PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
146 while_op.input());
147 }
148 return;
149 }
150 if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
151 PropagatePotentiallyWrittenUpFromCallee(while_op.cond(),
152 while_op.input());
153 PropagatePotentiallyWrittenUpFromCallee(while_op.body(),
154 while_op.input());
155 return;
156 }
157 // For all other ops, we assume it mutates all resources it uses, so
158 // this errs on the side of being conservative. We should improve
159 // this by using either a property or a trait that clearly
160 // identifies ops with resource mutating behavior.
161 PropagatePotentiallyWrittenWithinUnhandledOp(op);
162 });
163 return success();
164 }
165
PropagatePotentiallyWrittenWithinUnhandledOp(Operation * op)166 void ResourceAnalyzer::PropagatePotentiallyWrittenWithinUnhandledOp(
167 Operation* op) {
168 for (auto operand : op->getOperands()) {
169 if (IsResource(operand)) {
170 SetPotentiallyWritten(operand);
171 }
172 }
173 visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
174 if (IsResource(operand->get())) {
175 SetPotentiallyWritten(operand->get());
176 }
177 });
178 }
179
PropagatePotentiallyWrittenUpFromCallee(Region & region,Operation::operand_range propagate_to)180 void ResourceAnalyzer::PropagatePotentiallyWrittenUpFromCallee(
181 Region& region, Operation::operand_range propagate_to) {
182 (void)AnalyzeRegion(region);
183 for (auto t : llvm::zip(region.getArguments(), propagate_to)) {
184 if (!IsResource(std::get<0>(t))) {
185 continue;
186 }
187 if (IsPotentiallyWritten(std::get<0>(t))) {
188 SetPotentiallyWritten(std::get<1>(t));
189 }
190 }
191 }
192 } // namespace TF
193 } // namespace mlir
194