1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
27 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_while.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_optimizer.h"
35 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
36 #include "tensorflow/core/framework/graph_to_functiondef.h"
37 #include "tensorflow/core/framework/node_def_builder.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/control_flow.h"
40 #include "tensorflow/core/graph/node_builder.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/cleanup.h"
43 #include "tensorflow/core/public/session_options.h"
44 #include "tensorflow/core/public/version.h"
45 #include "tensorflow/core/util/dump_graph.h"
46
47 namespace tensorflow {
48
49 // Helper functions for functionalizing control flow in functions.
50
51 // Maps function name to
52 // - new function name, if the function body was functionalized
53 // - std::nullopt, if not
54 using FuncMap = std::map<string, std::optional<string>>;
55 using FuncMapIter = std::map<string, std::optional<string>>::const_iterator;
56
57 // Returns whether function has been processed before.
FunctionHasBeenProcessed(FuncMapIter func_iter,const FuncMap * func_map)58 bool FunctionHasBeenProcessed(FuncMapIter func_iter, const FuncMap* func_map) {
59 return func_iter != func_map->end();
60 }
61
62 // Returns whether function has been modified (i.e., functionalized) before.
FunctionHasBeenModified(FuncMapIter func_iter)63 bool FunctionHasBeenModified(FuncMapIter func_iter) {
64 return func_iter->second.has_value();
65 }
66
67 // Returns a name for the new functionalized version of a function.
GetNewFunctionName(const string & func_name,Node * n,AssociatedFunctionInfo::AssociatedFunctionType func_type,FunctionLibraryDefinition * fld)68 string GetNewFunctionName(
69 const string& func_name, Node* n,
70 AssociatedFunctionInfo::AssociatedFunctionType func_type,
71 FunctionLibraryDefinition* fld) {
72 // For SymbolicGradient, `func_name` is always "SymbolicGradient" which
73 // is not very informative. Use node name instead.
74 return (
75 func_type ==
76 AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient
77 ? fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_"))
78 : fld->UniqueFunctionName(absl::StrCat(func_name, "_f15n_")));
79 }
80
81 // Returns name to which a modified function has been mapped.
GetMappedFunctionName(FuncMapIter func_iter)82 const string& GetMappedFunctionName(FuncMapIter func_iter) {
83 DCHECK(func_iter->second.has_value());
84 return func_iter->second.value();
85 }
86
87 // Updates `func_map` with function given by `canonicalized_name`.
UpdateFunctionMap(FuncMap * func_map,const string & canonicalized_name,const string & new_func_name,bool function_modified)88 void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name,
89 const string& new_func_name, bool function_modified) {
90 // If function was modified store its new name, otherwise add empty entry to
91 // record that function has been processed and does not need to be rewritten.
92 (*func_map)[canonicalized_name] =
93 function_modified ? absl::make_optional(new_func_name) : std::nullopt;
94 }
95
96 // Adds new function def to graph's function library if necessary.
AddFunctionDefToGraphLibrary(const string & func_name,const AssociatedFunctionInfo & associated_function,Graph * graph,FunctionLibraryDefinition * fld)97 Status AddFunctionDefToGraphLibrary(
98 const string& func_name, const AssociatedFunctionInfo& associated_function,
99 Graph* graph, FunctionLibraryDefinition* fld) {
100 const OpRegistrationData* op_reg_data;
101 // We have to be careful with adding the function def since there are three
102 // different `OpRegistryInterface`s involved here:
103 // `fld`, `graph->flib_def()` and `graph->flib_def().default_registry()`.
104 // We have already added the function def to `fld` before calling this
105 // function but for the subsequent `RewriteAssociatedFunction` call we need
106 // the function def to be in one of the other two registries, otherwise
107 // `RewriteAssociatedFunction` will fail for the `kFunctionCallNode` case
108 // because it cannot find the associated function def.
109 // On the other hand, we should not add the function def if it is already
110 // contained in one of the last two registries, this would lead to errors when
111 // the function def is already in one registry and we try to add it to the
112 // other one (if we try to add it to the same it's fine). This can happen in
113 // cases where one of the last two registries is identical to `fld` (which we
114 // already updated).
115 // Therefore, before adding the function def we have to check if it's already
116 // contained in either `graph->flib_def()` or
117 // `graph->flib_def().default_registry()` which is done in the following line
118 // (we have to use `LookUp` instead of `Contains` or `Find` because the latter
119 // both don't check the default registry).
120 if (graph->flib_def().LookUp(func_name, &op_reg_data).ok()) return OkStatus();
121
122 const FunctionDef* new_fdef = fld->Find(func_name);
123 DCHECK(new_fdef != nullptr);
124 FunctionDefLibrary fdef_lib;
125 *(fdef_lib.add_function()) = *new_fdef;
126 return graph->AddFunctionLibrary(fdef_lib);
127 }
128
129 // Functionalizes function given by `func_name`. Update `func_map` accordingly.
130 Status FunctionalizeControlFlowForFunction(
131 const string& func_name, const string& new_func_name,
132 const protobuf::Map<string, tensorflow::AttrValue>& attrs,
133 FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
134 FuncMap* func_map, bool* function_modified,
135 const NodeFilter& node_filter = {});
136
137 // Functionalizes all functions that are (directly or indirectly) associated to
138 // any node in `graph`. Adds processed functions to `func_map`.
FunctionalizeControlFlowForNodeAssociatedFunctions(FuncMap * func_map,Graph * graph,FunctionLibraryDefinition * fld,FunctionLibraryRuntime * flr,bool * any_function_modified,const NodeFilter & node_filter)139 Status FunctionalizeControlFlowForNodeAssociatedFunctions(
140 FuncMap* func_map, Graph* graph, FunctionLibraryDefinition* fld,
141 FunctionLibraryRuntime* flr, bool* any_function_modified,
142 const NodeFilter& node_filter) {
143 std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
144 nodes_to_associated_functions;
145 for (auto* n : graph->nodes()) {
146 auto associated_functions = GetAssociatedFunctions(*n, fld);
147 if (!associated_functions.empty()) {
148 nodes_to_associated_functions.push_back({n, associated_functions});
149 }
150 }
151 for (const auto& pair : nodes_to_associated_functions) {
152 Node* n = pair.first;
153 auto associated_functions = pair.second;
154 for (auto& associated_function : associated_functions) {
155 // Note that if `n` is a function call node, then potential calls of
156 // `RewriteAssociatedFunction` below might delete `n` and create a new
157 // node instead, making `n` an invalid pointer. That's fine because in
158 // that case `n` only has one associated function, so this loop has only
159 // one iteration and we don't use `n` again after the rewrite.
160 // The invariant is guaranteed by `GetAssociatedFunctions` and confirmed
161 // below.
162 DCHECK(associated_function.type() !=
163 AssociatedFunctionInfo::kFunctionCallNode ||
164 associated_functions.size() == 1);
165
166 // Process one node-function-pair.
167 string func_name = associated_function.func_name();
168 string canonicalized_name =
169 Canonicalize(func_name, AttrSlice(&associated_function.attrs()));
170 auto func_iter = func_map->find(canonicalized_name);
171 string new_func_name;
172 if (FunctionHasBeenProcessed(func_iter, func_map)) {
173 if (FunctionHasBeenModified(func_iter)) {
174 *any_function_modified = true;
175 new_func_name = GetMappedFunctionName(func_iter);
176 TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
177 graph, n, fld, associated_function, new_func_name));
178 }
179 continue;
180 }
181 // Function is processed for the first time.
182 bool function_modified = false;
183 new_func_name =
184 GetNewFunctionName(func_name, n, associated_function.type(), fld);
185 // Perform functionalization for current function.
186 TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
187 func_name, new_func_name, associated_function.attrs(), fld, flr,
188 func_map, &function_modified, node_filter));
189 UpdateFunctionMap(func_map, canonicalized_name, new_func_name,
190 function_modified);
191 if (function_modified) {
192 *any_function_modified = true;
193 TF_RETURN_IF_ERROR(AddFunctionDefToGraphLibrary(
194 new_func_name, associated_function, graph, fld));
195 TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
196 graph, n, fld, associated_function, new_func_name));
197 }
198 }
199 }
200 return OkStatus();
201 }
202
FunctionalizeControlFlowForFunction(const string & func_name,const string & new_func_name,const protobuf::Map<string,tensorflow::AttrValue> & attrs,FunctionLibraryDefinition * fld,FunctionLibraryRuntime * flr,FuncMap * func_map,bool * function_modified,const NodeFilter & node_filter)203 Status FunctionalizeControlFlowForFunction(
204 const string& func_name, const string& new_func_name,
205 const protobuf::Map<string, tensorflow::AttrValue>& attrs,
206 FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
207 FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter) {
208 *function_modified = false;
209
210 // Convert the function to a graph.
211 FunctionLibraryRuntime::Handle handle;
212 TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
213 Status ret_status = OkStatus();
214 auto cleanup_handle = gtl::MakeCleanup([&]() {
215 auto s = flr->ReleaseHandle(handle);
216 if (!s.ok()) {
217 ret_status.Update(s);
218 }
219 });
220 const FunctionBody* body = flr->GetFunctionBody(handle);
221 Graph* g = body->graph;
222
223 // Check if the graph has Switch or Merge node.
224 bool has_switch_or_merge = false;
225 for (Node* n : body->graph->nodes()) {
226 // Skip nodes that are filtered out.
227 if (node_filter && !node_filter(n)) continue;
228 if (n->type_string() == "Switch" || n->type_string() == "Merge") {
229 has_switch_or_merge = true;
230 break;
231 }
232 }
233 // Before functionalizing control flow in `g` we functionalize control flow
234 // in functions (directly or indirectly) associated with nodes in `g`.
235 TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions(
236 func_map, g, fld, flr, function_modified, node_filter));
237
238 if (has_switch_or_merge) {
239 *function_modified = true;
240
241 // Functionalize the function body.
242 if (VLOG_IS_ON(4)) {
243 DumpGraphToFile(
244 absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
245 *g, fld);
246 }
247 TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld, node_filter));
248 if (VLOG_IS_ON(4)) {
249 DumpGraphToFile(
250 absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
251 fld);
252 }
253 }
254 if (*function_modified) {
255 // Add rewritten FunctionDef into library.
256 FunctionDef functionalized_fdef;
257 TF_RETURN_IF_ERROR(
258 GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
259 if (func_name == new_func_name) {
260 VLOG(2) << "Replacing function " << func_name;
261 TF_RETURN_IF_ERROR(
262 fld->ReplaceFunction(new_func_name, functionalized_fdef));
263 } else {
264 VLOG(2) << "Adding function " << new_func_name;
265 TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
266 }
267 }
268
269 return ret_status;
270 }
271
FunctionalizeControlFlow(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter,bool include_functions)272 Status FunctionalizeControlFlow(Graph* graph,
273 FunctionLibraryDefinition* library,
274 const NodeFilter& node_filter,
275 bool include_functions) {
276 VLOG(2) << "FunctionalizeControlFlow (initial): "
277 << DumpGraphToFile("functionalize_initial", *graph, library);
278
279 if (include_functions) {
280 // Functionalize control flow in functions that are (directly or indirectly)
281 // associated with a node in `graph`.
282 auto pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
283 /*device_mgr=*/nullptr, tensorflow::Env::Default(),
284 /*config=*/nullptr, TF_GRAPH_DEF_VERSION, library,
285 tensorflow::OptimizerOptions());
286 // `pflr` has only one `FunctionLibraryRuntime`, for `kDefaultFLRDevice`
287 // (because we constructed it with `device_mgr = nullptr`).
288 FunctionLibraryRuntime* flr =
289 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
290
291 FuncMap func_map;
292 bool modified = false;
293 TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions(
294 &func_map, graph, library, flr, &modified, node_filter));
295 }
296 // Functionalize and remove while loops from graph.
297 TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library, node_filter));
298
299 // FunctionalizeControlFlow is invoked for every function, so the loops's
300 // bodies and conditionals that were extracted into functions will be handled
301 // in successive invocations.
302 TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library, node_filter));
303
304 VLOG(2) << "FunctionalizeControlFlow (final): "
305 << DumpGraphToFile("functionalize_final", *graph, library);
306
307 return OkStatus();
308 }
309
FunctionalizeControlFlowForGraphDef(GraphDef * graph_def,FunctionLibraryDefinition * library,const NodeFilter & node_filter,bool include_functions)310 Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
311 FunctionLibraryDefinition* library,
312 const NodeFilter& node_filter,
313 bool include_functions) {
314 FunctionDefLibrary function_lib = graph_def->library();
315 Graph graph(OpRegistry::Global());
316
317 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph));
318 TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter,
319 include_functions));
320 graph.ToGraphDef(graph_def);
321 std::swap(*graph_def->mutable_library(), function_lib);
322 return OkStatus();
323 }
324
Run(const GraphOptimizationPassOptions & options)325 Status FunctionalizeControlFlowForXlaPass::Run(
326 const GraphOptimizationPassOptions& options) {
327 Graph* graph = options.graph->get();
328 if (VLOG_IS_ON(4)) {
329 DumpGraphToFile("functionalize_control_flow_before", *graph,
330 options.flib_def);
331 }
332 const auto* config = &options.session_options->config;
333 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
334 new ProcessFunctionLibraryRuntime(
335 /*device_mgr=*/nullptr, options.session_options->env, config,
336 TF_GRAPH_DEF_VERSION, options.flib_def,
337 config->graph_options().optimizer_options()));
338 FunctionLibraryRuntime* flr =
339 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
340
341 // Find XLA compile ops and its corresponding FunctionDef.
342 // TPUCompile op is not in the map because graph rewriting might happen
343 // multiple times, and we want to avoid functionalize it again.
344 static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
345 new std::map<string, string>{
346 // _TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
347 {"_TPUReplicate", "computation"},
348 // XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
349 {"XlaLaunch", "function"},
350 };
351 FuncMap func_map;
352 bool fld_modified = false;
353 for (Node* n : graph->nodes()) {
354 auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
355 if (it == kNodeTypeToFunctionAttrMapping->end()) {
356 continue;
357 }
358 const string func_attr = it->second;
359 NameAttrList func;
360 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
361 VLOG(2) << "Graph has node " << n->type_string()
362 << ". Corresponding function: " << func.name();
363 string new_func_name = options.flib_def->UniqueFunctionName(
364 absl::StrCat(func.name(), "_f15n_"));
365 bool modified;
366 TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
367 func.name(), new_func_name, func.attr(), options.flib_def, flr,
368 &func_map, &modified));
369 if (modified) {
370 n->ClearAttr(func_attr);
371 func.set_name(new_func_name);
372 n->AddAttr(func_attr, func);
373 fld_modified = true;
374 }
375 }
376
377 // TODO(ylc, endlessroad): Change this to "if (fld_modified")"
378 if (false) {
379 if (VLOG_IS_ON(4)) {
380 DumpGraphToFile("functionalize_control_flow_before_prune", *graph,
381 options.flib_def);
382 }
383 TF_RETURN_IF_ERROR(
384 PruneUnreachableFunctionsFromGraph(*graph, options.flib_def));
385 }
386
387 if (VLOG_IS_ON(4)) {
388 DumpGraphToFile("functionalize_control_flow_after", *graph,
389 options.flib_def);
390 }
391 return OkStatus();
392 }
393
394 } // namespace tensorflow
395