xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/functionalize_control_flow.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
27 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_while.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_optimizer.h"
35 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
36 #include "tensorflow/core/framework/graph_to_functiondef.h"
37 #include "tensorflow/core/framework/node_def_builder.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/control_flow.h"
40 #include "tensorflow/core/graph/node_builder.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/cleanup.h"
43 #include "tensorflow/core/public/session_options.h"
44 #include "tensorflow/core/public/version.h"
45 #include "tensorflow/core/util/dump_graph.h"
46 
47 namespace tensorflow {
48 
49 // Helper functions for functionalizing control flow in functions.
50 
51 // Maps function name to
52 // - new function name, if the function body was functionalized
53 // - std::nullopt, if not
54 using FuncMap = std::map<string, std::optional<string>>;
55 using FuncMapIter = std::map<string, std::optional<string>>::const_iterator;
56 
57 // Returns whether function has been processed before.
FunctionHasBeenProcessed(FuncMapIter func_iter,const FuncMap * func_map)58 bool FunctionHasBeenProcessed(FuncMapIter func_iter, const FuncMap* func_map) {
59   return func_iter != func_map->end();
60 }
61 
62 // Returns whether function has been modified (i.e., functionalized) before.
FunctionHasBeenModified(FuncMapIter func_iter)63 bool FunctionHasBeenModified(FuncMapIter func_iter) {
64   return func_iter->second.has_value();
65 }
66 
67 // Returns a name for the new functionalized version of a function.
GetNewFunctionName(const string & func_name,Node * n,AssociatedFunctionInfo::AssociatedFunctionType func_type,FunctionLibraryDefinition * fld)68 string GetNewFunctionName(
69     const string& func_name, Node* n,
70     AssociatedFunctionInfo::AssociatedFunctionType func_type,
71     FunctionLibraryDefinition* fld) {
72   // For SymbolicGradient, `func_name` is always "SymbolicGradient" which
73   // is not very informative. Use node name instead.
74   return (
75       func_type ==
76               AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient
77           ? fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_"))
78           : fld->UniqueFunctionName(absl::StrCat(func_name, "_f15n_")));
79 }
80 
81 // Returns name to which a modified function has been mapped.
GetMappedFunctionName(FuncMapIter func_iter)82 const string& GetMappedFunctionName(FuncMapIter func_iter) {
83   DCHECK(func_iter->second.has_value());
84   return func_iter->second.value();
85 }
86 
87 // Updates `func_map` with function given by `canonicalized_name`.
UpdateFunctionMap(FuncMap * func_map,const string & canonicalized_name,const string & new_func_name,bool function_modified)88 void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name,
89                        const string& new_func_name, bool function_modified) {
90   // If function was modified store its new name, otherwise add empty entry to
91   // record that function has been processed and does not need to be rewritten.
92   (*func_map)[canonicalized_name] =
93       function_modified ? absl::make_optional(new_func_name) : std::nullopt;
94 }
95 
96 // Adds new function def to graph's function library if necessary.
AddFunctionDefToGraphLibrary(const string & func_name,const AssociatedFunctionInfo & associated_function,Graph * graph,FunctionLibraryDefinition * fld)97 Status AddFunctionDefToGraphLibrary(
98     const string& func_name, const AssociatedFunctionInfo& associated_function,
99     Graph* graph, FunctionLibraryDefinition* fld) {
100   const OpRegistrationData* op_reg_data;
101   // We have to be careful with adding the function def since there are three
102   // different `OpRegistryInterface`s involved here:
103   // `fld`, `graph->flib_def()` and `graph->flib_def().default_registry()`.
104   // We have already added the function def to `fld` before calling this
105   // function but for the subsequent `RewriteAssociatedFunction` call we need
106   // the function def to be in one of the other two registries, otherwise
107   // `RewriteAssociatedFunction` will fail for the `kFunctionCallNode` case
108   // because it cannot find the associated function def.
109   // On the other hand, we should not add the function def if it is already
110   // contained in one of the last two registries, this would lead to errors when
111   // the function def is already in one registry and we try to add it to the
112   // other one (if we try to add it to the same it's fine). This can happen in
113   // cases where one of the last two registries is identical to `fld` (which we
114   // already updated).
115   // Therefore, before adding the function def we have to check if it's already
116   // contained in either `graph->flib_def()` or
117   // `graph->flib_def().default_registry()` which is done in the following line
118   // (we have to use `LookUp` instead of `Contains` or `Find` because the latter
119   // both don't check the default registry).
120   if (graph->flib_def().LookUp(func_name, &op_reg_data).ok()) return OkStatus();
121 
122   const FunctionDef* new_fdef = fld->Find(func_name);
123   DCHECK(new_fdef != nullptr);
124   FunctionDefLibrary fdef_lib;
125   *(fdef_lib.add_function()) = *new_fdef;
126   return graph->AddFunctionLibrary(fdef_lib);
127 }
128 
129 // Functionalizes function given by `func_name`. Update `func_map` accordingly.
130 Status FunctionalizeControlFlowForFunction(
131     const string& func_name, const string& new_func_name,
132     const protobuf::Map<string, tensorflow::AttrValue>& attrs,
133     FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
134     FuncMap* func_map, bool* function_modified,
135     const NodeFilter& node_filter = {});
136 
137 // Functionalizes all functions that are (directly or indirectly) associated to
138 // any node in `graph`. Adds processed functions to `func_map`.
FunctionalizeControlFlowForNodeAssociatedFunctions(FuncMap * func_map,Graph * graph,FunctionLibraryDefinition * fld,FunctionLibraryRuntime * flr,bool * any_function_modified,const NodeFilter & node_filter)139 Status FunctionalizeControlFlowForNodeAssociatedFunctions(
140     FuncMap* func_map, Graph* graph, FunctionLibraryDefinition* fld,
141     FunctionLibraryRuntime* flr, bool* any_function_modified,
142     const NodeFilter& node_filter) {
143   std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
144       nodes_to_associated_functions;
145   for (auto* n : graph->nodes()) {
146     auto associated_functions = GetAssociatedFunctions(*n, fld);
147     if (!associated_functions.empty()) {
148       nodes_to_associated_functions.push_back({n, associated_functions});
149     }
150   }
151   for (const auto& pair : nodes_to_associated_functions) {
152     Node* n = pair.first;
153     auto associated_functions = pair.second;
154     for (auto& associated_function : associated_functions) {
155       // Note that if `n` is a function call node, then potential calls of
156       // `RewriteAssociatedFunction` below might delete `n` and create a new
157       // node instead, making `n` an invalid pointer. That's fine because in
158       // that case `n` only has one associated function, so this loop has only
159       // one iteration and we don't use `n` again after the rewrite.
160       // The invariant is guaranteed by `GetAssociatedFunctions` and confirmed
161       // below.
162       DCHECK(associated_function.type() !=
163                  AssociatedFunctionInfo::kFunctionCallNode ||
164              associated_functions.size() == 1);
165 
166       // Process one node-function-pair.
167       string func_name = associated_function.func_name();
168       string canonicalized_name =
169           Canonicalize(func_name, AttrSlice(&associated_function.attrs()));
170       auto func_iter = func_map->find(canonicalized_name);
171       string new_func_name;
172       if (FunctionHasBeenProcessed(func_iter, func_map)) {
173         if (FunctionHasBeenModified(func_iter)) {
174           *any_function_modified = true;
175           new_func_name = GetMappedFunctionName(func_iter);
176           TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
177               graph, n, fld, associated_function, new_func_name));
178         }
179         continue;
180       }
181       // Function is processed for the first time.
182       bool function_modified = false;
183       new_func_name =
184           GetNewFunctionName(func_name, n, associated_function.type(), fld);
185       // Perform functionalization for current function.
186       TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
187           func_name, new_func_name, associated_function.attrs(), fld, flr,
188           func_map, &function_modified, node_filter));
189       UpdateFunctionMap(func_map, canonicalized_name, new_func_name,
190                         function_modified);
191       if (function_modified) {
192         *any_function_modified = true;
193         TF_RETURN_IF_ERROR(AddFunctionDefToGraphLibrary(
194             new_func_name, associated_function, graph, fld));
195         TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
196             graph, n, fld, associated_function, new_func_name));
197       }
198     }
199   }
200   return OkStatus();
201 }
202 
FunctionalizeControlFlowForFunction(const string & func_name,const string & new_func_name,const protobuf::Map<string,tensorflow::AttrValue> & attrs,FunctionLibraryDefinition * fld,FunctionLibraryRuntime * flr,FuncMap * func_map,bool * function_modified,const NodeFilter & node_filter)203 Status FunctionalizeControlFlowForFunction(
204     const string& func_name, const string& new_func_name,
205     const protobuf::Map<string, tensorflow::AttrValue>& attrs,
206     FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
207     FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter) {
208   *function_modified = false;
209 
210   // Convert the function to a graph.
211   FunctionLibraryRuntime::Handle handle;
212   TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
213   Status ret_status = OkStatus();
214   auto cleanup_handle = gtl::MakeCleanup([&]() {
215     auto s = flr->ReleaseHandle(handle);
216     if (!s.ok()) {
217       ret_status.Update(s);
218     }
219   });
220   const FunctionBody* body = flr->GetFunctionBody(handle);
221   Graph* g = body->graph;
222 
223   // Check if the graph has Switch or Merge node.
224   bool has_switch_or_merge = false;
225   for (Node* n : body->graph->nodes()) {
226     // Skip nodes that are filtered out.
227     if (node_filter && !node_filter(n)) continue;
228     if (n->type_string() == "Switch" || n->type_string() == "Merge") {
229       has_switch_or_merge = true;
230       break;
231     }
232   }
233   // Before functionalizing control flow in `g` we functionalize control flow
234   // in functions (directly or indirectly) associated with nodes in `g`.
235   TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions(
236       func_map, g, fld, flr, function_modified, node_filter));
237 
238   if (has_switch_or_merge) {
239     *function_modified = true;
240 
241     // Functionalize the function body.
242     if (VLOG_IS_ON(4)) {
243       DumpGraphToFile(
244           absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
245           *g, fld);
246     }
247     TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld, node_filter));
248     if (VLOG_IS_ON(4)) {
249       DumpGraphToFile(
250           absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
251           fld);
252     }
253   }
254   if (*function_modified) {
255     // Add rewritten FunctionDef into library.
256     FunctionDef functionalized_fdef;
257     TF_RETURN_IF_ERROR(
258         GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
259     if (func_name == new_func_name) {
260       VLOG(2) << "Replacing function " << func_name;
261       TF_RETURN_IF_ERROR(
262           fld->ReplaceFunction(new_func_name, functionalized_fdef));
263     } else {
264       VLOG(2) << "Adding function " << new_func_name;
265       TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
266     }
267   }
268 
269   return ret_status;
270 }
271 
FunctionalizeControlFlow(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter,bool include_functions)272 Status FunctionalizeControlFlow(Graph* graph,
273                                 FunctionLibraryDefinition* library,
274                                 const NodeFilter& node_filter,
275                                 bool include_functions) {
276   VLOG(2) << "FunctionalizeControlFlow (initial): "
277           << DumpGraphToFile("functionalize_initial", *graph, library);
278 
279   if (include_functions) {
280     // Functionalize control flow in functions that are (directly or indirectly)
281     // associated with a node in `graph`.
282     auto pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
283         /*device_mgr=*/nullptr, tensorflow::Env::Default(),
284         /*config=*/nullptr, TF_GRAPH_DEF_VERSION, library,
285         tensorflow::OptimizerOptions());
286     // `pflr` has only one `FunctionLibraryRuntime`, for `kDefaultFLRDevice`
287     // (because we constructed it with `device_mgr = nullptr`).
288     FunctionLibraryRuntime* flr =
289         pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
290 
291     FuncMap func_map;
292     bool modified = false;
293     TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions(
294         &func_map, graph, library, flr, &modified, node_filter));
295   }
296   // Functionalize and remove while loops from graph.
297   TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library, node_filter));
298 
299   // FunctionalizeControlFlow is invoked for every function, so the loops's
300   // bodies and conditionals that were extracted into functions will be handled
301   // in successive invocations.
302   TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library, node_filter));
303 
304   VLOG(2) << "FunctionalizeControlFlow (final): "
305           << DumpGraphToFile("functionalize_final", *graph, library);
306 
307   return OkStatus();
308 }
309 
FunctionalizeControlFlowForGraphDef(GraphDef * graph_def,FunctionLibraryDefinition * library,const NodeFilter & node_filter,bool include_functions)310 Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
311                                            FunctionLibraryDefinition* library,
312                                            const NodeFilter& node_filter,
313                                            bool include_functions) {
314   FunctionDefLibrary function_lib = graph_def->library();
315   Graph graph(OpRegistry::Global());
316 
317   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph));
318   TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter,
319                                               include_functions));
320   graph.ToGraphDef(graph_def);
321   std::swap(*graph_def->mutable_library(), function_lib);
322   return OkStatus();
323 }
324 
Run(const GraphOptimizationPassOptions & options)325 Status FunctionalizeControlFlowForXlaPass::Run(
326     const GraphOptimizationPassOptions& options) {
327   Graph* graph = options.graph->get();
328   if (VLOG_IS_ON(4)) {
329     DumpGraphToFile("functionalize_control_flow_before", *graph,
330                     options.flib_def);
331   }
332   const auto* config = &options.session_options->config;
333   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
334       new ProcessFunctionLibraryRuntime(
335           /*device_mgr=*/nullptr, options.session_options->env, config,
336           TF_GRAPH_DEF_VERSION, options.flib_def,
337           config->graph_options().optimizer_options()));
338   FunctionLibraryRuntime* flr =
339       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
340 
341   // Find XLA compile ops and its corresponding FunctionDef.
342   // TPUCompile op is not in the map because graph rewriting might happen
343   // multiple times, and we want to avoid functionalize it again.
344   static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
345       new std::map<string, string>{
346           // _TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
347           {"_TPUReplicate", "computation"},
348           // XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
349           {"XlaLaunch", "function"},
350       };
351   FuncMap func_map;
352   bool fld_modified = false;
353   for (Node* n : graph->nodes()) {
354     auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
355     if (it == kNodeTypeToFunctionAttrMapping->end()) {
356       continue;
357     }
358     const string func_attr = it->second;
359     NameAttrList func;
360     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
361     VLOG(2) << "Graph has node " << n->type_string()
362             << ". Corresponding function: " << func.name();
363     string new_func_name = options.flib_def->UniqueFunctionName(
364         absl::StrCat(func.name(), "_f15n_"));
365     bool modified;
366     TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
367         func.name(), new_func_name, func.attr(), options.flib_def, flr,
368         &func_map, &modified));
369     if (modified) {
370       n->ClearAttr(func_attr);
371       func.set_name(new_func_name);
372       n->AddAttr(func_attr, func);
373       fld_modified = true;
374     }
375   }
376 
377   // TODO(ylc, endlessroad): Change this to "if (fld_modified")"
378   if (false) {
379     if (VLOG_IS_ON(4)) {
380       DumpGraphToFile("functionalize_control_flow_before_prune", *graph,
381                       options.flib_def);
382     }
383     TF_RETURN_IF_ERROR(
384         PruneUnreachableFunctionsFromGraph(*graph, options.flib_def));
385   }
386 
387   if (VLOG_IS_ON(4)) {
388     DumpGraphToFile("functionalize_control_flow_after", *graph,
389                     options.flib_def);
390   }
391   return OkStatus();
392 }
393 
394 }  // namespace tensorflow
395