xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/data/split_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/split_utils.h"
17 
18 #include <string>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/ascii.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
27 #include "tensorflow/core/grappler/utils.h"
28 
29 namespace tensorflow {
30 namespace grappler {
31 namespace split_utils {
32 
33 namespace {
34 
ArgDefIsList(const OpDef::ArgDef & arg_def)35 bool ArgDefIsList(const OpDef::ArgDef& arg_def) {
36   return !arg_def.number_attr().empty() || !arg_def.type_list_attr().empty();
37 }
38 
39 // Returns map from node name to NodeDef in a function.
NameToNode(const FunctionDef & function)40 absl::flat_hash_map<absl::string_view, const NodeDef*> NameToNode(
41     const FunctionDef& function) {
42   absl::flat_hash_map<absl::string_view, const NodeDef*> name_to_node;
43   for (const NodeDef& node : function.node_def()) {
44     name_to_node.insert({node.name(), &node});
45   }
46   return name_to_node;
47 }
48 
49 // Returns true if the input string in a FunctionDef node refers to a function
50 // argument, as opposed to a node output.
IsFunctionArgument(absl::string_view input_str)51 bool IsFunctionArgument(absl::string_view input_str) {
52   // Arguments are in the form "fun_in" or "fun_in:number", where "fun_in" is
53   // the input arg name and "number" is the output index.
54   size_t pos = input_str.find(':');
55   return pos == absl::string_view::npos ||
56          absl::ascii_isdigit(input_str[pos + 1]);
57 }
58 
FindArgDefIndex(const protobuf::RepeatedPtrField<OpDef::ArgDef> & arg_defs,absl::string_view name)59 size_t FindArgDefIndex(
60     const protobuf::RepeatedPtrField<OpDef::ArgDef>& arg_defs,
61     absl::string_view name) {
62   for (int i = 0; i < arg_defs.size(); i++) {
63     if (arg_defs[i].name() == name) {
64       return i;
65     }
66   }
67   return -1;
68 }
69 
70 // Helper class to SplitFunction(). When adding nodes to `second`, some node
71 // inputs may refer to nodes in `first`. This class handles this case by adding
72 // an output argument to `first` and a corresponding input argument to `second`.
73 // The input of the node in `second` is rewritten to refer to the newly created
74 // input argument.
75 class InputRewriter {
76  public:
77   // Note `original_function` must not have any list arguments.
InputRewriter(const FunctionDef & original_function,const absl::flat_hash_set<absl::string_view> & nodes_in_first_func,int64_t num_captured_inputs,const FunctionLibraryDefinition & library,FunctionDef * first_function,FunctionDef * second_function,std::vector<DataType> * first_function_output_types)78   InputRewriter(
79       const FunctionDef& original_function,
80       const absl::flat_hash_set<absl::string_view>& nodes_in_first_func,
81       int64_t num_captured_inputs, const FunctionLibraryDefinition& library,
82       FunctionDef* first_function, FunctionDef* second_function,
83       std::vector<DataType>* first_function_output_types)
84       : original_function_(original_function),
85         nodes_in_first_func_(nodes_in_first_func),
86         num_captured_inputs_(num_captured_inputs),
87         library_(library),
88         name_to_node_(NameToNode(original_function)),
89         first_function_(first_function),
90         second_function_(second_function),
91         first_function_output_types_(first_function_output_types) {
92     for (const NodeDef& node_def : original_function_.node_def()) {
93       used_names_.insert(node_def.name());
94     }
95 
96     for (const OpDef::ArgDef& input_arg :
97          original_function_.signature().input_arg()) {
98       used_names_.insert(input_arg.name());
99     }
100   }
101 
102   // Rewrite an input of a node that is being moved to the second function.
103   // If the input is in the first function, an output argument will be added to
104   // the first function and a corresponding input argument will be added to the
105   // second function. In this case, the input argument's name will be returned.
106   // If the input is in the second function, the input will not be rewritten.
107   //
108   // *new_input_str will be set to the empty string if the input should be
109   // removed, which occurs if it is a control dependency for a node in the first
110   // function.
111   Status RewriteInput(absl::string_view input_str, string* new_input_str);
112 
113  private:
IsInFirstFunction(absl::string_view node_name)114   bool IsInFirstFunction(absl::string_view node_name) {
115     return nodes_in_first_func_.contains(node_name);
116   }
117 
118   // Rewrite a control input. input_str is in the form "^node_name"
119   Status RewriteControlInput(absl::string_view input_str,
120                              string* new_input_str);
121 
122   // Rewrite an input that is an argument to original_function_. input_str is in
123   // the form "fun_in" or "fun_in:number".
124   Status RewriteArgumentInput(absl::string_view input_str,
125                               string* new_input_str);
126 
127   // Rewrite an input that is the output of a node. input_str is in the form
128   // "node:out" or "node:out:number"
129   Status RewriteNodeInput(absl::string_view input_str, string* new_input_str);
130 
131   // Rewrites an input, `input_str`, where the node producing `input_str` is in
132   // first_function_ and the node consuming `input_str` is in second_function_.
133   // This function adds an output argument to first_function_ and an input
134   // argument to second_function_. "input_arg_def" is the ArgDef corresponding
135   // to input_str, and must have the type() field set.
136   Status RewriteCrossFunctionInput(absl::string_view input_str,
137                                    const OpDef::ArgDef& input_arg_def,
138                                    string* new_input_str);
139 
unique_name(const std::string & name)140   string unique_name(const std::string& name) {
141     if (used_names_.count(name) == 0) {
142       used_names_.insert(name);
143       return name;
144     }
145 
146     for (int64_t suffix = 0; true; suffix++) {
147       string new_name = absl::StrCat(name, "_", suffix);
148       auto iter = used_names_.insert(new_name);
149       if (iter.second) {
150         return new_name;
151       }
152     }
153   }
154 
155   const FunctionDef& original_function_;
156   const absl::flat_hash_set<absl::string_view>& nodes_in_first_func_;
157   const int64_t num_captured_inputs_;
158   const FunctionLibraryDefinition& library_;
159 
160   // Map from node name to NodeDef in original_function_.node_def()
161   const absl::flat_hash_map<absl::string_view, const NodeDef*> name_to_node_;
162 
163   FunctionDef* const first_function_;
164   FunctionDef* const second_function_;
165   std::vector<DataType>* const first_function_output_types_;
166 
167   // Caches results of RewriteInput(), so that if the same input string is
168   // passed, it is rewritten to the same string.
169   absl::flat_hash_map<absl::string_view, string> input_map_;
170 
171   // Node and argument names that are used in either function. Used to uniquify
172   // argument names.
173   std::unordered_set<string> used_names_;
174 };
175 
RewriteInput(absl::string_view input_str,string * new_input_str)176 Status InputRewriter::RewriteInput(absl::string_view input_str,
177                                    string* new_input_str) {
178   auto iter = input_map_.find(input_str);
179   if (iter != input_map_.end()) {
180     *new_input_str = iter->second;
181     return OkStatus();
182   }
183 
184   if (IsControlInput(input_str)) {
185     TF_RETURN_IF_ERROR(RewriteControlInput(input_str, new_input_str));
186   } else if (IsFunctionArgument(input_str)) {
187     TF_RETURN_IF_ERROR(RewriteArgumentInput(input_str, new_input_str));
188   } else {
189     TF_RETURN_IF_ERROR(RewriteNodeInput(input_str, new_input_str));
190   }
191   input_map_.insert({input_str, *new_input_str});
192   return OkStatus();
193 }
194 
RewriteControlInput(absl::string_view input_str,string * new_input_str)195 Status InputRewriter::RewriteControlInput(absl::string_view input_str,
196                                           string* new_input_str) {
197   DCHECK_EQ(input_str.at(0), '^');
198   absl::string_view node_name = input_str.substr(1);
199   if (IsInFirstFunction(node_name)) {
200     *new_input_str = "";
201   } else {
202     *new_input_str = string{input_str};
203   }
204   return OkStatus();
205 }
206 
RewriteArgumentInput(absl::string_view input_str,string * new_input_str)207 Status InputRewriter::RewriteArgumentInput(absl::string_view input_str,
208                                            string* new_input_str) {
209   std::vector<string> components = absl::StrSplit(input_str, ':');
210   if (components.size() != 1 && components.size() != 2) {
211     return errors::Internal("Found node with invalid argument input: ",
212                             input_str);
213   }
214   string argument_name = components[0];
215   if (components.size() == 2 && components[1] != "0") {
216     // It is required that `original_function` must not have any list arguments.
217     return errors::Internal(
218         "Input string \"", input_str,
219         "\" has a last component which is not 0, but it is expected to be 0 "
220         "because corresponding argument is not a list");
221   }
222 
223   int i = FindArgDefIndex(original_function_.signature().input_arg(),
224                           argument_name);
225   if (i == -1) {
226     return errors::Internal(
227         "Input string \"", input_str,
228         "\" refers to an argument which does not exist. Argument \"",
229         argument_name, "\" does not appear in following FunctionDef: ",
230         original_function_.DebugString());
231   }
232   if (i >=
233       original_function_.signature().input_arg_size() - num_captured_inputs_) {
234     // Argument is a captured input. No need to modify argument string.
235     *new_input_str = string{input_str};
236     return OkStatus();
237   }
238   const OpDef::ArgDef* found_arg_def =
239       &original_function_.signature().input_arg(i);
240 
241   if (ArgDefIsList(*found_arg_def)) {
242     return errors::Unimplemented(
243         "Splitting a function where an edge is a list of tensors is "
244         "unsupported. ArgDef representing edge: ",
245         found_arg_def->DebugString());
246   }
247   if (!found_arg_def->type_attr().empty()) {
248     return errors::Unimplemented(
249         "Splitting a function where an edge's ArgDef has a type attribute is "
250         "unsupported. ArgDef representing argument: ",
251         found_arg_def->DebugString());
252   }
253 
254   return RewriteCrossFunctionInput(input_str, *found_arg_def, new_input_str);
255 }
256 
RewriteNodeInput(absl::string_view input_str,string * new_input_str)257 Status InputRewriter::RewriteNodeInput(absl::string_view input_str,
258                                        string* new_input_str) {
259   std::vector<string> components = absl::StrSplit(input_str, ':');
260   if (components.size() != 2 && components.size() != 3) {
261     return errors::Internal("Found node with invalid node input: ", input_str);
262   }
263   const string& node_name = components[0];
264   const string& node_output_arg = components[1];
265   const string& list_output_index =
266       components.size() == 3 ? components[2] : "0";
267   if (!IsInFirstFunction(node_name)) {
268     *new_input_str = string{input_str};
269     return OkStatus();
270   }
271 
272   auto index_iter = name_to_node_.find(node_name);
273   if (index_iter == name_to_node_.end()) {
274     return errors::Internal("Found input referring to nonexistent node: ",
275                             node_name);
276   }
277   const NodeDef& node = *index_iter->second;
278 
279   const OpRegistrationData* op_reg_data = nullptr;
280   TF_RETURN_IF_ERROR(library_.LookUp(node.op(), &op_reg_data));
281   int i = FindArgDefIndex(op_reg_data->op_def.output_arg(), node_output_arg);
282   if (i == -1) {
283     return errors::Internal("Could not found input \"", node_output_arg,
284                             "\" for OpDef ", op_reg_data->op_def.name());
285   }
286   OpDef::ArgDef found_arg_def = op_reg_data->op_def.output_arg(i);
287 
288   if (ArgDefIsList(found_arg_def)) {
289     return errors::Unimplemented(
290         "Splitting a function where an edge is a list of tensors is "
291         "unsupported. ArgDef representing edge: ",
292         found_arg_def.DebugString());
293   }
294   if (list_output_index != "0") {
295     return errors::Internal(
296         "Input string \"", input_str,
297         "\" has a last component which is not 0, but it is expected to be 0 "
298         "because corresponding output is not a list");
299   }
300 
301   if (!found_arg_def.type_attr().empty()) {
302     const string& attr = found_arg_def.type_attr();
303     auto attr_iter = node.attr().find(attr);
304     if (attr_iter == node.attr().end()) {
305       return errors::Internal("Failed to find attr ", attr, " on node ",
306                               node.name());
307     }
308     if (!attr_iter->second.placeholder().empty()) {
309       return errors::Unimplemented(
310           "Splitting a function where an edge between functions has an "
311           "AttrValue placeholder dtype is unsupported.");
312     }
313     DataType dtype = attr_iter->second.type();
314     if (dtype == DT_INVALID) {
315       return errors::Internal("Attr ", attr, " is not a dtype attr");
316     }
317     found_arg_def.mutable_type_attr()->clear();
318     found_arg_def.set_type(dtype);
319   }
320 
321   return RewriteCrossFunctionInput(input_str, found_arg_def, new_input_str);
322 }
323 
RewriteCrossFunctionInput(absl::string_view input_str,const OpDef::ArgDef & input_arg_def,string * new_input_str)324 Status InputRewriter::RewriteCrossFunctionInput(
325     absl::string_view input_str, const OpDef::ArgDef& input_arg_def,
326     string* new_input_str) {
327   DCHECK(input_arg_def.type() != DT_INVALID);
328   if (input_arg_def.is_ref() || IsRefType(input_arg_def.type())) {
329     // This case is untested and is not important to support, so an
330     // Unimplemented error is raised.
331     return errors::Unimplemented(
332         "Splitting a function where an edge between functions is a ref is "
333         "unsupported. Input ",
334         input_str, " is a ref type.");
335   }
336   OpDef::ArgDef* added_output_arg =
337       first_function_->mutable_signature()->add_output_arg();
338   *added_output_arg = input_arg_def;
339   size_t output_index = first_function_->signature().output_arg_size() - 1;
340   added_output_arg->set_name(absl::StrCat("output_", output_index));
341   added_output_arg->set_description(absl::StrCat(
342       "Output ", output_index, ", corresponding to input ", input_str));
343   first_function_->mutable_ret()->insert(
344       {added_output_arg->name(), string{input_str}});
345   first_function_output_types_->push_back(input_arg_def.type());
346 
347   OpDef::ArgDef* added_input_arg =
348       second_function_->mutable_signature()->add_input_arg();
349   *added_input_arg = input_arg_def;
350   size_t input_index = second_function_->signature().input_arg_size() - 1;
351   added_input_arg->set_name(unique_name(absl::StrCat("input_", input_index)));
352   added_input_arg->set_description(absl::StrCat("Input ", input_index));
353 
354   *new_input_str = added_input_arg->name();
355   return OkStatus();
356 }
357 
InitializeSignatures(const FunctionDef & original_function_,FunctionDef * first_function_,FunctionDef * second_function_,const absl::flat_hash_set<absl::string_view> & nodes_in_first_function,const FunctionDefLibrary & func_def_lib_)358 void InitializeSignatures(
359     const FunctionDef& original_function_, FunctionDef* first_function_,
360     FunctionDef* second_function_,
361     const absl::flat_hash_set<absl::string_view>& nodes_in_first_function,
362     const FunctionDefLibrary& func_def_lib_) {
363   // Initialize first_function_->signature().
364   *first_function_->mutable_signature() = original_function_.signature();
365   graph_utils::SetUniqueGraphFunctionName(
366       original_function_.signature().name() + "_first_split", &func_def_lib_,
367       first_function_);
368   first_function_->mutable_signature()->clear_output_arg();
369   first_function_->mutable_signature()->clear_control_output();
370   first_function_->mutable_signature()->set_description(absl::StrCat(
371       "The function \"", original_function_.signature().name(),
372       "\" was split into two pieces in the make_deterministic Grappler pass. "
373       "This function is the first piece."));
374   first_function_->mutable_signature()->set_is_commutative(false);
375   first_function_->mutable_signature()->set_is_aggregate(false);
376 
377   // Initialize second_function_->signature().
378   *second_function_->mutable_signature() = original_function_.signature();
379   graph_utils::SetUniqueGraphFunctionName(
380       original_function_.signature().name() + "_second_split", &func_def_lib_,
381       second_function_);
382   second_function_->mutable_signature()->clear_input_arg();
383   second_function_->mutable_signature()->clear_control_output();
384   second_function_->mutable_signature()->set_description(absl::StrCat(
385       "The function \"", original_function_.signature().name(),
386       "\" was split into two pieces in the make_deterministic Grappler pass. "
387       "This function is the second piece."));
388   second_function_->mutable_signature()->set_is_commutative(false);
389   second_function_->mutable_signature()->set_is_aggregate(false);
390 
391   // Initialize the control_ret fields of the two signatures.
392   for (const auto& it : original_function_.control_ret()) {
393     if (nodes_in_first_function.contains(it.second)) {
394       first_function_->mutable_control_ret()->insert(it);
395     } else {
396       second_function_->mutable_control_ret()->insert(it);
397     }
398   }
399 }
400 
401 }  // namespace
402 
SplitFunction(const FunctionDef & function,const absl::flat_hash_set<absl::string_view> & nodes_in_first_function,int64_t num_captured_inputs,const FunctionLibraryDefinition & library)403 StatusOr<SplitResults> SplitFunction(
404     const FunctionDef& function,
405     const absl::flat_hash_set<absl::string_view>& nodes_in_first_function,
406     int64_t num_captured_inputs, const FunctionLibraryDefinition& library) {
407   for (const auto& attr : function.attr()) {
408     if (attr.first != data::kTFDataFunction &&
409         attr.first != "_construction_context") {
410       return errors::Unimplemented(
411           "Cannot split function with unknown attribute key: ", attr.first);
412     }
413   }
414 
415   for (int i = 0; i < function.signature().input_arg_size(); i++) {
416     // Processing list arguments is more complicated and not yet implemented.
417     if (ArgDefIsList(function.signature().input_arg(i))) {
418       return errors::Unimplemented(
419           "Cannot split function when an input argument is a list of tensors "
420           "instead of a single tensor.");
421     }
422   }
423 
424   for (const NodeDef& node_def : function.node_def()) {
425     if (IsControlFlow(node_def)) {
426       return errors::Unimplemented(
427           "Cannot split function with control flow ops");
428     }
429   }
430 
431   SplitResults results;
432   InitializeSignatures(function, &results.first_function,
433                        &results.second_function, nodes_in_first_function,
434                        library.ToProto());
435 
436   // Insert _construction_context attribute into functions, if it exists on
437   // original_function_.
438   auto contruction_ctx_iter = function.attr().find("_construction_context");
439   if (contruction_ctx_iter != function.attr().end()) {
440     results.first_function.mutable_attr()->insert(
441         {contruction_ctx_iter->first, contruction_ctx_iter->second});
442     results.second_function.mutable_attr()->insert(
443         {contruction_ctx_iter->first, contruction_ctx_iter->second});
444   }
445 
446   InputRewriter rewriter{function,
447                          nodes_in_first_function,
448                          num_captured_inputs,
449                          library,
450                          &results.first_function,
451                          &results.second_function,
452                          &results.first_function_output_types};
453 
454   for (const NodeDef& orig_node_def : function.node_def()) {
455     if (!nodes_in_first_function.contains(orig_node_def.name())) {
456       // Add node to second function and rewrite its inputs.
457       NodeDef& new_node_def = *results.second_function.add_node_def();
458       new_node_def = orig_node_def;
459       new_node_def.clear_input();
460 
461       for (const string& input_str : orig_node_def.input()) {
462         string* new_input_str = new_node_def.add_input();
463         TF_RETURN_IF_ERROR(rewriter.RewriteInput(input_str, new_input_str));
464         if (new_input_str->empty()) {
465           new_node_def.mutable_input()->RemoveLast();
466           VLOG(3) << "Removed input " << input_str << " from node "
467                   << orig_node_def.name();
468         } else if (*new_input_str != input_str) {
469           VLOG(3) << "Rewrote input " << input_str << " to " << new_input_str
470                   << " of node " << orig_node_def.name();
471         }
472       }
473     } else {
474       // Add node to first function, and check that all its inputs are also in
475       // the first function.
476       *results.first_function.add_node_def() = orig_node_def;
477       for (const string& input_str : orig_node_def.input()) {
478         std::vector<string> components = absl::StrSplit(input_str, ':');
479         if (!IsControlInput(input_str) && !IsFunctionArgument(input_str) &&
480             !nodes_in_first_function.contains(components[0])) {
481           return errors::Internal("Node ", orig_node_def.name(),
482                                   " is in first function but has input ",
483                                   input_str,
484                                   " which is not in first function.");
485         }
486       }
487     }
488   }
489 
490   // Add return values to second_fuction.ret()
491   for (const OpDef::ArgDef& arg_def : function.signature().output_arg()) {
492     auto it = function.ret().find(arg_def.name());
493     if (it == function.ret().end()) {
494       return errors::Internal(
495           "Failed to find output_arg '", arg_def.name(),
496           "' in 'ret' section. FunctionDef: ", function.DebugString());
497     }
498     string& new_ret = (*results.second_function.mutable_ret())[arg_def.name()];
499     TF_RETURN_IF_ERROR(rewriter.RewriteInput(it->second, &new_ret));
500     DCHECK(!new_ret.empty());
501   }
502 
503   // Add captured inputs to second_function.input_arg()
504   for (int i = function.signature().input_arg_size() - num_captured_inputs;
505        i < function.signature().input_arg_size(); i++) {
506     *results.second_function.mutable_signature()->add_input_arg() =
507         function.signature().input_arg(i);
508   }
509 
510   return results;
511 }
512 
513 }  // namespace split_utils
514 }  // namespace grappler
515 }  // namespace tensorflow
516