xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/encapsulate_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/encapsulate_util.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/jit/shape_inference.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/graph/node_builder.h"
29 #include "tensorflow/core/protobuf/error_codes.pb.h"
30 #include "tensorflow/stream_executor/lib/statusor.h"
31 
32 using stream_executor::port::StatusOr;
33 
34 namespace tensorflow {
35 
36 namespace {
37 
38 // Returns string attribute value for the node if the attribute is present,
39 // otherwise returns empty optional value.
GetStringAttr(const Node & n,const string & attr_name)40 std::optional<string> GetStringAttr(const Node& n, const string& attr_name) {
41   auto attr = n.attrs().Find(attr_name);
42   if (!attr) {
43     return std::nullopt;
44   } else {
45     return attr->s();
46   }
47 }
48 
49 // Adds a value to the node's list attribute.
50 template <typename T>
AppendToListAttr(Node * n,const string & attr_name,const string & value)51 Status AppendToListAttr(Node* n, const string& attr_name, const string& value) {
52   std::vector<T> attr_value;
53   Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value);
54   if (!s.ok() && s.code() != error::NOT_FOUND) {
55     return s;
56   }
57 
58   n->ClearAttr(attr_name);
59   attr_value.push_back(value);
60   n->AddAttr(attr_name, attr_value);
61   return OkStatus();
62 }
63 
64 // Replaces attribute value.
65 template <typename T>
ReplaceAttr(Node * n,const string & attr_name,const T & value)66 void ReplaceAttr(Node* n, const string& attr_name, const T& value) {
67   n->ClearAttr(attr_name);
68   n->AddAttr(attr_name, value);
69 }
70 
71 // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
72 // `PreprocessEdgesBetweenOutsideCompilations` for details.
PreprocessControlEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)73 Status PreprocessControlEdgesBetweenOutsideCompilations(
74     Graph* g, const string& outside_compilation_attr_name) {
75   // Gather edges to remove. We should not remove the edge while iterating.
76   std::vector<const Edge*> edges_to_remove;
77   for (const Edge* e : g->edges()) {
78     if (!e->IsControlEdge()) {
79       continue;
80     }
81 
82     auto src_outside_compilation =
83         GetStringAttr(*e->src(), outside_compilation_attr_name);
84     auto dst_outside_compilation =
85         GetStringAttr(*e->dst(), outside_compilation_attr_name);
86 
87     if (src_outside_compilation && dst_outside_compilation) {
88       if (*src_outside_compilation != *dst_outside_compilation) {
89         // Case 1a: outside compilation to outside compilation control edge.
90         edges_to_remove.push_back(e);
91 
92         TF_RETURN_IF_ERROR(AppendToListAttr<string>(
93             e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName,
94             e->src()->name()));
95       }
96     } else if (src_outside_compilation && !dst_outside_compilation) {
97       // Case 1b: outside compilation to its XLA computation control edge.
98       ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
99     } else if (!src_outside_compilation && dst_outside_compilation) {
100       // Case 1b: XLA computation to outside compilation in it control edge.
101       ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
102     }
103   }
104 
105   for (auto e : edges_to_remove) {
106     g->RemoveEdge(e);
107   }
108   return OkStatus();
109 }
110 
111 // Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
112 // `PreprocessEdgesBetweenOutsideCompilations` for details.
PreprocessDataEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)113 Status PreprocessDataEdgesBetweenOutsideCompilations(
114     Graph* g, const string& outside_compilation_attr_name) {
115   // Gather edges between outside compilation and host computation. Notice that
116   // we do not store `Edge*` directly because we remove some nodes while adding
117   // Identity nodes, and those Edge pointers might be invalidated.
118   struct EdgeInfo {
119     int dst_input, dst_node_id;
120   };
121   std::vector<EdgeInfo> edges;
122   for (const Edge* e : g->edges()) {
123     if (e->IsControlEdge()) {
124       continue;
125     }
126 
127     auto src_outside_compilation =
128         GetStringAttr(*e->src(), outside_compilation_attr_name);
129     auto dst_outside_compilation =
130         GetStringAttr(*e->dst(), outside_compilation_attr_name);
131 
132     if (src_outside_compilation && dst_outside_compilation &&
133         *src_outside_compilation != *dst_outside_compilation) {
134       edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
135       VLOG(4) << "Oc -> oc edge: " << e->DebugString();
136     }
137   }
138 
139   // Remove the edge from host to outside compilation. Add a placeholder as
140   // outside compilation node input.
141   std::map<std::pair<string, int>, Node*> placeholders;
142   for (int i = 0, end = edges.size(); i < end; i++) {
143     Node* dst = g->FindNodeId(edges[i].dst_node_id);
144     const Edge* e;
145     TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
146     Node* src = e->src();
147     int src_output = e->src_output(), dst_input = e->dst_input();
148     g->RemoveEdge(e);
149 
150     // Find or create placeholder node.
151     string new_name =
152         absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
153     auto placeholder_index = std::make_pair(src->name(), src_output);
154     auto iter = placeholders.find(placeholder_index);
155     Node* placeholder_node;
156     if (iter == placeholders.end()) {
157       NodeDefBuilder placeholder_builder(new_name, "Placeholder");
158       placeholder_builder.Attr("dtype", src->output_type(src_output));
159       string outside_compilation_attr;
160       TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
161                                      outside_compilation_attr_name,
162                                      &outside_compilation_attr));
163       placeholder_builder.Attr(outside_compilation_attr_name,
164                                outside_compilation_attr);
165       placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName,
166                                src->name());
167       placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName,
168                                src_output);
169       NodeDef placeholder_def;
170       TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
171       TF_ASSIGN_OR_RETURN(placeholder_node, g->AddNode(placeholder_def));
172       placeholders[placeholder_index] = placeholder_node;
173     } else {
174       placeholder_node = iter->second;
175     }
176     g->AddEdge(placeholder_node, 0, dst, dst_input);
177 
178     // Replace `e->dst()` because its input node changed.
179     NodeDef new_def = dst->def();
180     *new_def.mutable_input(dst_input) = placeholder_node->name();
181     TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
182 
183     // Other edge in `edges` might have `e->dst()` as src or dst
184     // node. Before removing `e->dst()`, replace those edges with
185     // corresponding edges for `dst_replace_node`.
186     for (int j = i + 1, end = edges.size(); j < end; j++) {
187       if (edges[j].dst_node_id == edges[i].dst_node_id) {
188         edges[j].dst_node_id = dst_replace_node->id();
189       }
190     }
191   }
192   return OkStatus();
193 }
194 
195 // Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
196 // `PostprocessEdgesBetweenOutsideCompilations` for details.
PostprocessDataEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)197 Status PostprocessDataEdgesBetweenOutsideCompilations(
198     Graph* g, const string& outside_compilation_attr_name) {
199   // Gather all outside compilation to outside compilation nodes.
200   std::vector<Node*> placeholder_nodes;
201   for (Node* n : g->nodes()) {
202     if (n->type_string() == "Placeholder" &&
203         HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) {
204       placeholder_nodes.push_back(n);
205     }
206   }
207 
208   // Remove the placeholder nodes, and reconnect original edge.
209   auto node_name_index = g->BuildNodeNameIndex();
210   for (auto n : placeholder_nodes) {
211     string node_name;
212     int node_src_output;
213     TF_RETURN_IF_ERROR(GetNodeAttr(
214         n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name));
215     TF_RETURN_IF_ERROR(GetNodeAttr(
216         n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output));
217     auto iter = node_name_index.find(node_name);
218     if (iter == node_name_index.end()) {
219       return errors::Internal(
220           "Cannot find original node for oc -> host placeholder node ",
221           node_name);
222     }
223 
224     // Change all usage node to use the original node instead.
225     Node* original_node = iter->second;
226     std::vector<const Edge*> control_edges;
227     std::vector<OutEdgeInfo> data_edges;
228     for (auto e : n->out_edges()) {
229       if (e->IsControlEdge()) {
230         control_edges.push_back(e);
231       } else {
232         data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
233       }
234     }
235     for (const Edge* e : control_edges) {
236       g->AddControlEdge(original_node, e->dst());
237       g->RemoveEdge(e);
238     }
239     for (int i = 0, end = data_edges.size(); i < end; i++) {
240       Node* dst = data_edges[i].dst;
241       NodeDef new_def = dst->def();
242       int dst_input = data_edges[i].dst_input;
243       *new_def.mutable_input(dst_input) =
244           absl::StrCat(original_node->name(), ":", node_src_output);
245       TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
246 
247       const Edge* edge_to_replace = nullptr;
248       TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
249       g->RemoveEdge(edge_to_replace);
250       g->AddEdge(original_node, node_src_output, replace_node, dst_input);
251 
252       // Other edges might have `dst` as dst node. Update those edges with
253       // `replace_node`.
254       for (int j = i + 1, end = data_edges.size(); j < end; j++) {
255         if (data_edges[j].dst == dst) {
256           data_edges[j].dst = replace_node;
257         }
258       }
259 
260       // Other placeholder node might have `dst` as original node. Update
261       // `node_name_index` with `replace_node`.
262       node_name_index[replace_node->name()] = replace_node;
263     }
264 
265     // Remove placeholder node.
266     g->RemoveNode(n);
267   }
268   return OkStatus();
269 }
270 
271 // Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
272 // `PostprocessEdgesBetweenOutsideCompilations` for details.
PostprocessControlEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)273 Status PostprocessControlEdgesBetweenOutsideCompilations(
274     Graph* g, const string& outside_compilation_attr_name) {
275   auto node_name_index = g->BuildNodeNameIndex();
276 
277   // Reconnect outside compilation to outside compilation control edge.
278   for (Node* n : g->nodes()) {
279     std::vector<string> control_deps;
280     Status s =
281         GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName,
282                     &control_deps);
283     if (!s.ok()) {
284       if (s.code() != error::NOT_FOUND) {
285         return s;
286       } else {
287         continue;
288       }
289     } else {
290       n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName);
291       for (const string& control_input : control_deps) {
292         auto iter = node_name_index.find(control_input);
293         if (iter == node_name_index.end()) {
294           return errors::Internal("Cannot find original node for ",
295                                   control_input);
296         }
297         g->AddControlEdge(iter->second, n);
298       }
299     }
300   }
301   return OkStatus();
302 }
303 }  // namespace
304 
305 const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
306 
307 const char kXlaConnectedToXlaComputationAttrName[] =
308     "_xla_connected_to_xla_computation";
309 const char kXlaConnectedFromXlaComputationAttrName[] =
310     "_xla_connected_from_xla_computation";
311 const char kOutsideCompilationOriginalNodeAttrName[] =
312     "_xla_oc_to_oc_node_name";
313 const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output";
314 const char kXlaControlDependenciesWithinXlaClusterAttrName[] =
315     "_xla_control_dependencies_within_xla_cluster";
316 const char kXlaIsLiftedArgAttrName[] = "_xla_is_lifted_arg";
317 const char kXlaLiftedArgOutsideCompilationAttrName[] = "_xla_lifted_arg_oc";
318 const char kXlaOutsideCompilationInputsAttrName[] = "_xla_oc_inputs";
319 const char kXlaIsPlaceholderForArg[] = "_xla_is_placeholder_for_arg";
320 
PerformStaticShapeInferenceBeforeEncapsulation(Graph * g)321 Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) {
322   // Perform shape inference.
323   std::map<int, InferredShape> arg_shapes;
324   GraphShapeInfo shape_info;
325   TF_RETURN_IF_ERROR(
326       InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info));
327 
328   // Add attribute for output shapes.
329   auto node_name_index = g->BuildNodeNameIndex();
330   for (auto iter : shape_info) {
331     std::vector<PartialTensorShape> output_shapes;
332     std::transform(iter.second.begin(), iter.second.end(),
333                    std::back_inserter(output_shapes),
334                    [](const InferredShape& inferred_shape) {
335                      return inferred_shape.shape;
336                    });
337     Node* n = node_name_index[iter.first];
338     n->AddAttr(kXlaInferredShapesAttrName, output_shapes);
339   }
340 
341   return OkStatus();
342 }
343 
344 StatusOr<std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(const Graph * g,const string & outside_compilation_attr_name)345 OutsideCompilationClusterDependencies(
346     const Graph* g, const string& outside_compilation_attr_name) {
347   auto cluster_deps = std::make_unique<
348       absl::flat_hash_map<string, absl::flat_hash_set<string>>>();
349 
350   for (const Edge* e : g->edges()) {
351     auto src_outside_compilation =
352         GetStringAttr(*e->src(), outside_compilation_attr_name);
353     auto dst_outside_compilation =
354         GetStringAttr(*e->dst(), outside_compilation_attr_name);
355 
356     if (src_outside_compilation && dst_outside_compilation &&
357         *src_outside_compilation != *dst_outside_compilation) {
358       auto dst_deps_it = cluster_deps->find(*dst_outside_compilation);
359       if (dst_deps_it == cluster_deps->end()) {
360         cluster_deps->insert(std::make_pair(
361             *dst_outside_compilation,
362             absl::flat_hash_set<string>({*src_outside_compilation})));
363       } else {
364         dst_deps_it->second.insert(*src_outside_compilation);
365       }
366     }
367   }
368 
369   auto cluster_deps_ordered =
370       std::make_unique<absl::flat_hash_map<string, std::vector<string>>>();
371 
372   for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) {
373     std::vector<string> ordered_deps(it->second.begin(), it->second.end());
374     std::sort(ordered_deps.begin(), ordered_deps.end());
375     cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps));
376   }
377 
378   return std::move(cluster_deps_ordered);
379 }
380 
PreprocessEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)381 Status PreprocessEdgesBetweenOutsideCompilations(
382     Graph* g, const string& outside_compilation_attr_name) {
383   // Remove edges from source node to outside compilation nodes, and edges
384   // from outside compilation nodes to sink node.
385   std::vector<const Edge*> edges_to_remove;
386   for (const Edge* e : g->source_node()->out_edges()) {
387     if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
388       edges_to_remove.push_back(e);
389     }
390   }
391   for (const Edge* e : g->sink_node()->in_edges()) {
392     if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) {
393       edges_to_remove.push_back(e);
394     }
395   }
396   for (auto e : edges_to_remove) {
397     g->RemoveEdge(e);
398   }
399 
400   TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations(
401       g, outside_compilation_attr_name));
402   TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations(
403       g, outside_compilation_attr_name));
404   return OkStatus();
405 }
406 
PostprocessEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)407 Status PostprocessEdgesBetweenOutsideCompilations(
408     Graph* g, const string& outside_compilation_attr_name) {
409   TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations(
410       g, outside_compilation_attr_name));
411   TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations(
412       g, outside_compilation_attr_name));
413   return OkStatus();
414 }
415 
416 }  // namespace tensorflow
417