1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/split_utils.h"
17
18 #include <string>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/ascii.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
27 #include "tensorflow/core/grappler/utils.h"
28
29 namespace tensorflow {
30 namespace grappler {
31 namespace split_utils {
32
33 namespace {
34
ArgDefIsList(const OpDef::ArgDef & arg_def)35 bool ArgDefIsList(const OpDef::ArgDef& arg_def) {
36 return !arg_def.number_attr().empty() || !arg_def.type_list_attr().empty();
37 }
38
39 // Returns map from node name to NodeDef in a function.
NameToNode(const FunctionDef & function)40 absl::flat_hash_map<absl::string_view, const NodeDef*> NameToNode(
41 const FunctionDef& function) {
42 absl::flat_hash_map<absl::string_view, const NodeDef*> name_to_node;
43 for (const NodeDef& node : function.node_def()) {
44 name_to_node.insert({node.name(), &node});
45 }
46 return name_to_node;
47 }
48
49 // Returns true if the input string in a FunctionDef node refers to a function
50 // argument, as opposed to a node output.
IsFunctionArgument(absl::string_view input_str)51 bool IsFunctionArgument(absl::string_view input_str) {
52 // Arguments are in the form "fun_in" or "fun_in:number", where "fun_in" is
53 // the input arg name and "number" is the output index.
54 size_t pos = input_str.find(':');
55 return pos == absl::string_view::npos ||
56 absl::ascii_isdigit(input_str[pos + 1]);
57 }
58
FindArgDefIndex(const protobuf::RepeatedPtrField<OpDef::ArgDef> & arg_defs,absl::string_view name)59 size_t FindArgDefIndex(
60 const protobuf::RepeatedPtrField<OpDef::ArgDef>& arg_defs,
61 absl::string_view name) {
62 for (int i = 0; i < arg_defs.size(); i++) {
63 if (arg_defs[i].name() == name) {
64 return i;
65 }
66 }
67 return -1;
68 }
69
70 // Helper class to SplitFunction(). When adding nodes to `second`, some node
71 // inputs may refer to nodes in `first`. This class handles this case by adding
72 // an output argument to `first` and a corresponding input argument to `second`.
73 // The input of the node in `second` is rewritten to refer to the newly created
74 // input argument.
75 class InputRewriter {
76 public:
77 // Note `original_function` must not have any list arguments.
InputRewriter(const FunctionDef & original_function,const absl::flat_hash_set<absl::string_view> & nodes_in_first_func,int64_t num_captured_inputs,const FunctionLibraryDefinition & library,FunctionDef * first_function,FunctionDef * second_function,std::vector<DataType> * first_function_output_types)78 InputRewriter(
79 const FunctionDef& original_function,
80 const absl::flat_hash_set<absl::string_view>& nodes_in_first_func,
81 int64_t num_captured_inputs, const FunctionLibraryDefinition& library,
82 FunctionDef* first_function, FunctionDef* second_function,
83 std::vector<DataType>* first_function_output_types)
84 : original_function_(original_function),
85 nodes_in_first_func_(nodes_in_first_func),
86 num_captured_inputs_(num_captured_inputs),
87 library_(library),
88 name_to_node_(NameToNode(original_function)),
89 first_function_(first_function),
90 second_function_(second_function),
91 first_function_output_types_(first_function_output_types) {
92 for (const NodeDef& node_def : original_function_.node_def()) {
93 used_names_.insert(node_def.name());
94 }
95
96 for (const OpDef::ArgDef& input_arg :
97 original_function_.signature().input_arg()) {
98 used_names_.insert(input_arg.name());
99 }
100 }
101
102 // Rewrite an input of a node that is being moved to the second function.
103 // If the input is in the first function, an output argument will be added to
104 // the first function and a corresponding input argument will be added to the
105 // second function. In this case, the input argument's name will be returned.
106 // If the input is in the second function, the input will not be rewritten.
107 //
108 // *new_input_str will be set to the empty string if the input should be
109 // removed, which occurs if it is a control dependency for a node in the first
110 // function.
111 Status RewriteInput(absl::string_view input_str, string* new_input_str);
112
113 private:
IsInFirstFunction(absl::string_view node_name)114 bool IsInFirstFunction(absl::string_view node_name) {
115 return nodes_in_first_func_.contains(node_name);
116 }
117
118 // Rewrite a control input. input_str is in the form "^node_name"
119 Status RewriteControlInput(absl::string_view input_str,
120 string* new_input_str);
121
122 // Rewrite an input that is an argument to original_function_. input_str is in
123 // the form "fun_in" or "fun_in:number".
124 Status RewriteArgumentInput(absl::string_view input_str,
125 string* new_input_str);
126
127 // Rewrite an input that is the output of a node. input_str is in the form
128 // "node:out" or "node:out:number"
129 Status RewriteNodeInput(absl::string_view input_str, string* new_input_str);
130
131 // Rewrites an input, `input_str`, where the node producing `input_str` is in
132 // first_function_ and the node consuming `input_str` is in second_function_.
133 // This function adds an output argument to first_function_ and an input
134 // argument to second_function_. "input_arg_def" is the ArgDef corresponding
135 // to input_str, and must have the type() field set.
136 Status RewriteCrossFunctionInput(absl::string_view input_str,
137 const OpDef::ArgDef& input_arg_def,
138 string* new_input_str);
139
unique_name(const std::string & name)140 string unique_name(const std::string& name) {
141 if (used_names_.count(name) == 0) {
142 used_names_.insert(name);
143 return name;
144 }
145
146 for (int64_t suffix = 0; true; suffix++) {
147 string new_name = absl::StrCat(name, "_", suffix);
148 auto iter = used_names_.insert(new_name);
149 if (iter.second) {
150 return new_name;
151 }
152 }
153 }
154
155 const FunctionDef& original_function_;
156 const absl::flat_hash_set<absl::string_view>& nodes_in_first_func_;
157 const int64_t num_captured_inputs_;
158 const FunctionLibraryDefinition& library_;
159
160 // Map from node name to NodeDef in original_function_.node_def()
161 const absl::flat_hash_map<absl::string_view, const NodeDef*> name_to_node_;
162
163 FunctionDef* const first_function_;
164 FunctionDef* const second_function_;
165 std::vector<DataType>* const first_function_output_types_;
166
167 // Caches results of RewriteInput(), so that if the same input string is
168 // passed, it is rewritten to the same string.
169 absl::flat_hash_map<absl::string_view, string> input_map_;
170
171 // Node and argument names that are used in either function. Used to uniquify
172 // argument names.
173 std::unordered_set<string> used_names_;
174 };
175
RewriteInput(absl::string_view input_str,string * new_input_str)176 Status InputRewriter::RewriteInput(absl::string_view input_str,
177 string* new_input_str) {
178 auto iter = input_map_.find(input_str);
179 if (iter != input_map_.end()) {
180 *new_input_str = iter->second;
181 return OkStatus();
182 }
183
184 if (IsControlInput(input_str)) {
185 TF_RETURN_IF_ERROR(RewriteControlInput(input_str, new_input_str));
186 } else if (IsFunctionArgument(input_str)) {
187 TF_RETURN_IF_ERROR(RewriteArgumentInput(input_str, new_input_str));
188 } else {
189 TF_RETURN_IF_ERROR(RewriteNodeInput(input_str, new_input_str));
190 }
191 input_map_.insert({input_str, *new_input_str});
192 return OkStatus();
193 }
194
RewriteControlInput(absl::string_view input_str,string * new_input_str)195 Status InputRewriter::RewriteControlInput(absl::string_view input_str,
196 string* new_input_str) {
197 DCHECK_EQ(input_str.at(0), '^');
198 absl::string_view node_name = input_str.substr(1);
199 if (IsInFirstFunction(node_name)) {
200 *new_input_str = "";
201 } else {
202 *new_input_str = string{input_str};
203 }
204 return OkStatus();
205 }
206
RewriteArgumentInput(absl::string_view input_str,string * new_input_str)207 Status InputRewriter::RewriteArgumentInput(absl::string_view input_str,
208 string* new_input_str) {
209 std::vector<string> components = absl::StrSplit(input_str, ':');
210 if (components.size() != 1 && components.size() != 2) {
211 return errors::Internal("Found node with invalid argument input: ",
212 input_str);
213 }
214 string argument_name = components[0];
215 if (components.size() == 2 && components[1] != "0") {
216 // It is required that `original_function` must not have any list arguments.
217 return errors::Internal(
218 "Input string \"", input_str,
219 "\" has a last component which is not 0, but it is expected to be 0 "
220 "because corresponding argument is not a list");
221 }
222
223 int i = FindArgDefIndex(original_function_.signature().input_arg(),
224 argument_name);
225 if (i == -1) {
226 return errors::Internal(
227 "Input string \"", input_str,
228 "\" refers to an argument which does not exist. Argument \"",
229 argument_name, "\" does not appear in following FunctionDef: ",
230 original_function_.DebugString());
231 }
232 if (i >=
233 original_function_.signature().input_arg_size() - num_captured_inputs_) {
234 // Argument is a captured input. No need to modify argument string.
235 *new_input_str = string{input_str};
236 return OkStatus();
237 }
238 const OpDef::ArgDef* found_arg_def =
239 &original_function_.signature().input_arg(i);
240
241 if (ArgDefIsList(*found_arg_def)) {
242 return errors::Unimplemented(
243 "Splitting a function where an edge is a list of tensors is "
244 "unsupported. ArgDef representing edge: ",
245 found_arg_def->DebugString());
246 }
247 if (!found_arg_def->type_attr().empty()) {
248 return errors::Unimplemented(
249 "Splitting a function where an edge's ArgDef has a type attribute is "
250 "unsupported. ArgDef representing argument: ",
251 found_arg_def->DebugString());
252 }
253
254 return RewriteCrossFunctionInput(input_str, *found_arg_def, new_input_str);
255 }
256
RewriteNodeInput(absl::string_view input_str,string * new_input_str)257 Status InputRewriter::RewriteNodeInput(absl::string_view input_str,
258 string* new_input_str) {
259 std::vector<string> components = absl::StrSplit(input_str, ':');
260 if (components.size() != 2 && components.size() != 3) {
261 return errors::Internal("Found node with invalid node input: ", input_str);
262 }
263 const string& node_name = components[0];
264 const string& node_output_arg = components[1];
265 const string& list_output_index =
266 components.size() == 3 ? components[2] : "0";
267 if (!IsInFirstFunction(node_name)) {
268 *new_input_str = string{input_str};
269 return OkStatus();
270 }
271
272 auto index_iter = name_to_node_.find(node_name);
273 if (index_iter == name_to_node_.end()) {
274 return errors::Internal("Found input referring to nonexistent node: ",
275 node_name);
276 }
277 const NodeDef& node = *index_iter->second;
278
279 const OpRegistrationData* op_reg_data = nullptr;
280 TF_RETURN_IF_ERROR(library_.LookUp(node.op(), &op_reg_data));
281 int i = FindArgDefIndex(op_reg_data->op_def.output_arg(), node_output_arg);
282 if (i == -1) {
283 return errors::Internal("Could not found input \"", node_output_arg,
284 "\" for OpDef ", op_reg_data->op_def.name());
285 }
286 OpDef::ArgDef found_arg_def = op_reg_data->op_def.output_arg(i);
287
288 if (ArgDefIsList(found_arg_def)) {
289 return errors::Unimplemented(
290 "Splitting a function where an edge is a list of tensors is "
291 "unsupported. ArgDef representing edge: ",
292 found_arg_def.DebugString());
293 }
294 if (list_output_index != "0") {
295 return errors::Internal(
296 "Input string \"", input_str,
297 "\" has a last component which is not 0, but it is expected to be 0 "
298 "because corresponding output is not a list");
299 }
300
301 if (!found_arg_def.type_attr().empty()) {
302 const string& attr = found_arg_def.type_attr();
303 auto attr_iter = node.attr().find(attr);
304 if (attr_iter == node.attr().end()) {
305 return errors::Internal("Failed to find attr ", attr, " on node ",
306 node.name());
307 }
308 if (!attr_iter->second.placeholder().empty()) {
309 return errors::Unimplemented(
310 "Splitting a function where an edge between functions has an "
311 "AttrValue placeholder dtype is unsupported.");
312 }
313 DataType dtype = attr_iter->second.type();
314 if (dtype == DT_INVALID) {
315 return errors::Internal("Attr ", attr, " is not a dtype attr");
316 }
317 found_arg_def.mutable_type_attr()->clear();
318 found_arg_def.set_type(dtype);
319 }
320
321 return RewriteCrossFunctionInput(input_str, found_arg_def, new_input_str);
322 }
323
RewriteCrossFunctionInput(absl::string_view input_str,const OpDef::ArgDef & input_arg_def,string * new_input_str)324 Status InputRewriter::RewriteCrossFunctionInput(
325 absl::string_view input_str, const OpDef::ArgDef& input_arg_def,
326 string* new_input_str) {
327 DCHECK(input_arg_def.type() != DT_INVALID);
328 if (input_arg_def.is_ref() || IsRefType(input_arg_def.type())) {
329 // This case is untested and is not important to support, so an
330 // Unimplemented error is raised.
331 return errors::Unimplemented(
332 "Splitting a function where an edge between functions is a ref is "
333 "unsupported. Input ",
334 input_str, " is a ref type.");
335 }
336 OpDef::ArgDef* added_output_arg =
337 first_function_->mutable_signature()->add_output_arg();
338 *added_output_arg = input_arg_def;
339 size_t output_index = first_function_->signature().output_arg_size() - 1;
340 added_output_arg->set_name(absl::StrCat("output_", output_index));
341 added_output_arg->set_description(absl::StrCat(
342 "Output ", output_index, ", corresponding to input ", input_str));
343 first_function_->mutable_ret()->insert(
344 {added_output_arg->name(), string{input_str}});
345 first_function_output_types_->push_back(input_arg_def.type());
346
347 OpDef::ArgDef* added_input_arg =
348 second_function_->mutable_signature()->add_input_arg();
349 *added_input_arg = input_arg_def;
350 size_t input_index = second_function_->signature().input_arg_size() - 1;
351 added_input_arg->set_name(unique_name(absl::StrCat("input_", input_index)));
352 added_input_arg->set_description(absl::StrCat("Input ", input_index));
353
354 *new_input_str = added_input_arg->name();
355 return OkStatus();
356 }
357
InitializeSignatures(const FunctionDef & original_function_,FunctionDef * first_function_,FunctionDef * second_function_,const absl::flat_hash_set<absl::string_view> & nodes_in_first_function,const FunctionDefLibrary & func_def_lib_)358 void InitializeSignatures(
359 const FunctionDef& original_function_, FunctionDef* first_function_,
360 FunctionDef* second_function_,
361 const absl::flat_hash_set<absl::string_view>& nodes_in_first_function,
362 const FunctionDefLibrary& func_def_lib_) {
363 // Initialize first_function_->signature().
364 *first_function_->mutable_signature() = original_function_.signature();
365 graph_utils::SetUniqueGraphFunctionName(
366 original_function_.signature().name() + "_first_split", &func_def_lib_,
367 first_function_);
368 first_function_->mutable_signature()->clear_output_arg();
369 first_function_->mutable_signature()->clear_control_output();
370 first_function_->mutable_signature()->set_description(absl::StrCat(
371 "The function \"", original_function_.signature().name(),
372 "\" was split into two pieces in the make_deterministic Grappler pass. "
373 "This function is the first piece."));
374 first_function_->mutable_signature()->set_is_commutative(false);
375 first_function_->mutable_signature()->set_is_aggregate(false);
376
377 // Initialize second_function_->signature().
378 *second_function_->mutable_signature() = original_function_.signature();
379 graph_utils::SetUniqueGraphFunctionName(
380 original_function_.signature().name() + "_second_split", &func_def_lib_,
381 second_function_);
382 second_function_->mutable_signature()->clear_input_arg();
383 second_function_->mutable_signature()->clear_control_output();
384 second_function_->mutable_signature()->set_description(absl::StrCat(
385 "The function \"", original_function_.signature().name(),
386 "\" was split into two pieces in the make_deterministic Grappler pass. "
387 "This function is the second piece."));
388 second_function_->mutable_signature()->set_is_commutative(false);
389 second_function_->mutable_signature()->set_is_aggregate(false);
390
391 // Initialize the control_ret fields of the two signatures.
392 for (const auto& it : original_function_.control_ret()) {
393 if (nodes_in_first_function.contains(it.second)) {
394 first_function_->mutable_control_ret()->insert(it);
395 } else {
396 second_function_->mutable_control_ret()->insert(it);
397 }
398 }
399 }
400
401 } // namespace
402
SplitFunction(const FunctionDef & function,const absl::flat_hash_set<absl::string_view> & nodes_in_first_function,int64_t num_captured_inputs,const FunctionLibraryDefinition & library)403 StatusOr<SplitResults> SplitFunction(
404 const FunctionDef& function,
405 const absl::flat_hash_set<absl::string_view>& nodes_in_first_function,
406 int64_t num_captured_inputs, const FunctionLibraryDefinition& library) {
407 for (const auto& attr : function.attr()) {
408 if (attr.first != data::kTFDataFunction &&
409 attr.first != "_construction_context") {
410 return errors::Unimplemented(
411 "Cannot split function with unknown attribute key: ", attr.first);
412 }
413 }
414
415 for (int i = 0; i < function.signature().input_arg_size(); i++) {
416 // Processing list arguments is more complicated and not yet implemented.
417 if (ArgDefIsList(function.signature().input_arg(i))) {
418 return errors::Unimplemented(
419 "Cannot split function when an input argument is a list of tensors "
420 "instead of a single tensor.");
421 }
422 }
423
424 for (const NodeDef& node_def : function.node_def()) {
425 if (IsControlFlow(node_def)) {
426 return errors::Unimplemented(
427 "Cannot split function with control flow ops");
428 }
429 }
430
431 SplitResults results;
432 InitializeSignatures(function, &results.first_function,
433 &results.second_function, nodes_in_first_function,
434 library.ToProto());
435
436 // Insert _construction_context attribute into functions, if it exists on
437 // original_function_.
438 auto contruction_ctx_iter = function.attr().find("_construction_context");
439 if (contruction_ctx_iter != function.attr().end()) {
440 results.first_function.mutable_attr()->insert(
441 {contruction_ctx_iter->first, contruction_ctx_iter->second});
442 results.second_function.mutable_attr()->insert(
443 {contruction_ctx_iter->first, contruction_ctx_iter->second});
444 }
445
446 InputRewriter rewriter{function,
447 nodes_in_first_function,
448 num_captured_inputs,
449 library,
450 &results.first_function,
451 &results.second_function,
452 &results.first_function_output_types};
453
454 for (const NodeDef& orig_node_def : function.node_def()) {
455 if (!nodes_in_first_function.contains(orig_node_def.name())) {
456 // Add node to second function and rewrite its inputs.
457 NodeDef& new_node_def = *results.second_function.add_node_def();
458 new_node_def = orig_node_def;
459 new_node_def.clear_input();
460
461 for (const string& input_str : orig_node_def.input()) {
462 string* new_input_str = new_node_def.add_input();
463 TF_RETURN_IF_ERROR(rewriter.RewriteInput(input_str, new_input_str));
464 if (new_input_str->empty()) {
465 new_node_def.mutable_input()->RemoveLast();
466 VLOG(3) << "Removed input " << input_str << " from node "
467 << orig_node_def.name();
468 } else if (*new_input_str != input_str) {
469 VLOG(3) << "Rewrote input " << input_str << " to " << new_input_str
470 << " of node " << orig_node_def.name();
471 }
472 }
473 } else {
474 // Add node to first function, and check that all its inputs are also in
475 // the first function.
476 *results.first_function.add_node_def() = orig_node_def;
477 for (const string& input_str : orig_node_def.input()) {
478 std::vector<string> components = absl::StrSplit(input_str, ':');
479 if (!IsControlInput(input_str) && !IsFunctionArgument(input_str) &&
480 !nodes_in_first_function.contains(components[0])) {
481 return errors::Internal("Node ", orig_node_def.name(),
482 " is in first function but has input ",
483 input_str,
484 " which is not in first function.");
485 }
486 }
487 }
488 }
489
490 // Add return values to second_fuction.ret()
491 for (const OpDef::ArgDef& arg_def : function.signature().output_arg()) {
492 auto it = function.ret().find(arg_def.name());
493 if (it == function.ret().end()) {
494 return errors::Internal(
495 "Failed to find output_arg '", arg_def.name(),
496 "' in 'ret' section. FunctionDef: ", function.DebugString());
497 }
498 string& new_ret = (*results.second_function.mutable_ret())[arg_def.name()];
499 TF_RETURN_IF_ERROR(rewriter.RewriteInput(it->second, &new_ret));
500 DCHECK(!new_ret.empty());
501 }
502
503 // Add captured inputs to second_function.input_arg()
504 for (int i = function.signature().input_arg_size() - num_captured_inputs;
505 i < function.signature().input_arg_size(); i++) {
506 *results.second_function.mutable_signature()->add_input_arg() =
507 function.signature().input_arg(i);
508 }
509
510 return results;
511 }
512
513 } // namespace split_utils
514 } // namespace grappler
515 } // namespace tensorflow
516