1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
17 
18 #include "llvm/Support/Casting.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
21 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
25 
26 namespace mlir {
27 namespace TF {
28 namespace {
IsResourceType(Type type)29 bool IsResourceType(Type type) {
30   if (auto tensor_type = type.dyn_cast<TensorType>()) {
31     return tensor_type.getElementType().isa<TF::ResourceType>();
32   }
33   return false;
34 }
35 
IsResource(Value value)36 bool IsResource(Value value) { return IsResourceType(value.getType()); }
37 
38 // Helper that returns the FuncOp that is the SessionInit function which
39 // will be called to initialize all resources.
40 // Returns nullptr if no function is found.
GetSessionInitializerFunc(ModuleOp module)41 func::FuncOp GetSessionInitializerFunc(ModuleOp module) {
42   auto session_init_op = tf_saved_model::GetSessionInitializerOp(module);
43   if (session_init_op && !session_init_op.initializers().empty()) {
44     SymbolTable symbol_table(module);
45     func::FuncOp init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
46         session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
47     return init_func_op;
48   }
49   return nullptr;
50 }
51 
52 // Returns ID for identifying a resource.
GetResourceKey(Operation * op)53 std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef> GetResourceKey(
54     Operation* op) {
55   llvm::StringRef device;
56   if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
57     device = attr.getValue();
58   }
59 
60   llvm::StringRef container;
61   if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
62     container = attr.getValue();
63   }
64 
65   llvm::StringRef shared_name;
66   if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
67     shared_name = attr.getValue();
68   }
69 
70   return std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>{
71       device, container, shared_name};
72 }
73 }  // namespace
ResourceAnalyzer(ModuleOp module,bool skip_session_init)74 ResourceAnalyzer::ResourceAnalyzer(ModuleOp module, bool skip_session_init) {
75   auto session_init_func = GetSessionInitializerFunc(module);
76   for (auto func : module.getOps<func::FuncOp>()) {
77     if (skip_session_init && func == session_init_func) continue;
78     (void)AnalyzeRegion(func.getRegion());
79   }
80 }
81 
SetPotentiallyWritten(Value resource)82 void ResourceAnalyzer::SetPotentiallyWritten(Value resource) {
83   assert(IsResource(resource));
84   resource_infos_[resource].potentially_written = true;
85   auto* operation = resource.getDefiningOp();
86   if (operation && llvm::isa<TF::VarHandleOp>(operation)) {
87     mutable_variables_.insert(GetResourceKey(operation));
88   }
89 }
90 
IsPotentiallyWritten(Value resource) const91 bool ResourceAnalyzer::IsPotentiallyWritten(Value resource) const {
92   assert(IsResource(resource));
93   auto* operation = resource.getDefiningOp();
94   if (operation && llvm::isa<TF::VarHandleOp>(operation))
95     return mutable_variables_.contains(GetResourceKey(operation));
96   auto it = resource_infos_.find(resource);
97   if (it == resource_infos_.end()) {
98     return false;
99   }
100   return it->second.potentially_written;
101 }
102 
103 // Analyze the specified region for resource mutating operations, namely
104 // TF::AssignVariableOp, if so, set the resource associated as "potentially
105 // written". Do this recursively across the chain of regions via call or
106 // control flow ops.
107 // TODO(ashwinm): Move to iterative traversal.
AnalyzeRegion(Region & region)108 LogicalResult ResourceAnalyzer::AnalyzeRegion(Region& region) {
109   // Avoid infinite recursion.
110   if (!discovered_.insert(&region).second) {
111     return success();
112   }
113 
114   region.walk([&](Operation* op) {
115     if (isa<TF::ReadVariableOp, func::ReturnOp, YieldOp>(op)) {
116       return;
117     }
118     if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) {
119       SetPotentiallyWritten(assign_variable.resource());
120       return;
121     }
122     if (auto call = dyn_cast<CallOpInterface>(op)) {
123       if (auto func = dyn_cast<func::FuncOp>(call.resolveCallable())) {
124         PropagatePotentiallyWrittenUpFromCallee(func.getRegion(),
125                                                 call.getArgOperands());
126       }
127       return;
128     }
129     if (auto if_op = dyn_cast<TF::IfOp>(op)) {
130       for (auto callee : {if_op.then_function(), if_op.else_function()}) {
131         PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
132                                                 if_op.input());
133       }
134       return;
135     }
136     if (auto if_op = dyn_cast<TF::IfRegionOp>(op)) {
137       PropagatePotentiallyWrittenUpFromCallee(if_op.then_branch(),
138                                               if_op.getODSOperands(1));
139       PropagatePotentiallyWrittenUpFromCallee(if_op.else_branch(),
140                                               if_op.getODSOperands(1));
141       return;
142     }
143     if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
144       for (auto callee : {while_op.cond_function(), while_op.body_function()}) {
145         PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
146                                                 while_op.input());
147       }
148       return;
149     }
150     if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
151       PropagatePotentiallyWrittenUpFromCallee(while_op.cond(),
152                                               while_op.input());
153       PropagatePotentiallyWrittenUpFromCallee(while_op.body(),
154                                               while_op.input());
155       return;
156     }
157     // For all other ops, we assume it mutates all resources it uses, so
158     // this errs on the side of being conservative. We should improve
159     // this by using either a property or a trait that clearly
160     // identifies ops with resource mutating behavior.
161     PropagatePotentiallyWrittenWithinUnhandledOp(op);
162   });
163   return success();
164 }
165 
PropagatePotentiallyWrittenWithinUnhandledOp(Operation * op)166 void ResourceAnalyzer::PropagatePotentiallyWrittenWithinUnhandledOp(
167     Operation* op) {
168   for (auto operand : op->getOperands()) {
169     if (IsResource(operand)) {
170       SetPotentiallyWritten(operand);
171     }
172   }
173   visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
174     if (IsResource(operand->get())) {
175       SetPotentiallyWritten(operand->get());
176     }
177   });
178 }
179 
PropagatePotentiallyWrittenUpFromCallee(Region & region,Operation::operand_range propagate_to)180 void ResourceAnalyzer::PropagatePotentiallyWrittenUpFromCallee(
181     Region& region, Operation::operand_range propagate_to) {
182   (void)AnalyzeRegion(region);
183   for (auto t : llvm::zip(region.getArguments(), propagate_to)) {
184     if (!IsResource(std::get<0>(t))) {
185       continue;
186     }
187     if (IsPotentiallyWritten(std::get<0>(t))) {
188       SetPotentiallyWritten(std::get<1>(t));
189     }
190   }
191 }
192 }  // namespace TF
193 }  // namespace mlir
194