xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/utils/functions.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/utils/functions.h"
16 
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_replace.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/graph_def_util.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/utils.h"
34 #include "tensorflow/core/lib/strings/scanner.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 
GrapplerFunctionItem(string func_name,string description,AttrSlice func_attr,std::vector<const FunctionDef::ArgAttrs * > arg_attr,std::vector<InputArgInstantiation> input_args,std::vector<OutputArgInstantiation> output_args,std::vector<ControlOutput> control_outputs,const int graph_def_version,const bool is_stateful,GraphDef && function_body)39 GrapplerFunctionItem::GrapplerFunctionItem(
40     string func_name, string description, AttrSlice func_attr,
41     std::vector<const FunctionDef::ArgAttrs*> arg_attr,
42     std::vector<InputArgInstantiation> input_args,
43     std::vector<OutputArgInstantiation> output_args,
44     std::vector<ControlOutput> control_outputs, const int graph_def_version,
45     const bool is_stateful, GraphDef&& function_body)
46     : description_(std::move(description)),
47       func_attr_(func_attr),
48       arg_attr_(std::move(arg_attr)),
49       input_args_(std::move(input_args)),
50       output_args_(std::move(output_args)),
51       control_outputs_(std::move(control_outputs)),
52       is_stateful_(is_stateful) {
53   id = std::move(func_name);
54   graph = std::move(function_body);
55   graph.mutable_versions()->set_producer(graph_def_version);
56 
57   // Fill the feed nodes with function input arguments.
58   for (const InputArgInstantiation& input_arg : input_args_) {
59     feed.push_back({input_arg.node_name, Tensor()});
60   }
61   // Fill the fetch nodes with outputs.
62   for (const OutputArgInstantiation& output_arg : output_args_) {
63     fetch.push_back(output_arg.node_name);
64   }
65   // We must keep all control output nodes.
66   for (const ControlOutput& control_output : control_outputs_) {
67     keep_ops.push_back(control_output.node_name);
68   }
69 
70   // Tensorflow functions execution semantics is different from the main graph,
71   // and we need to preserve it when we do graph optimizations.
72   optimization_options().allow_pruning_stateful_and_dataset_ops = false;
73 }
74 
description() const75 const string& GrapplerFunctionItem::description() const { return description_; }
76 
inputs() const77 const std::vector<InputArgInstantiation>& GrapplerFunctionItem::inputs() const {
78   return input_args_;
79 }
80 
input(int i) const81 const InputArgInstantiation& GrapplerFunctionItem::input(int i) const {
82   return input_args_[i];
83 }
84 
input_size() const85 const std::size_t GrapplerFunctionItem::input_size() const {
86   return input_args_.size();
87 }
88 
outputs() const89 const std::vector<OutputArgInstantiation>& GrapplerFunctionItem::outputs()
90     const {
91   return output_args_;
92 }
93 
output(int i) const94 const OutputArgInstantiation& GrapplerFunctionItem::output(int i) const {
95   return output_args_[i];
96 }
97 
output_size() const98 const std::size_t GrapplerFunctionItem::output_size() const {
99   return output_args_.size();
100 }
101 
control_outputs() const102 const std::vector<ControlOutput>& GrapplerFunctionItem::control_outputs()
103     const {
104   return control_outputs_;
105 }
106 
control_output_size() const107 const std::size_t GrapplerFunctionItem::control_output_size() const {
108   return control_outputs_.size();
109 }
110 
func_attr() const111 const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; }
112 
113 const std::vector<const FunctionDef::ArgAttrs*>&
arg_attr() const114 GrapplerFunctionItem::arg_attr() const {
115   return arg_attr_;
116 }
117 
function_body() const118 const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
119 
mutable_function_body()120 GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
121 
is_stateful() const122 bool GrapplerFunctionItem::is_stateful() const { return is_stateful_; }
123 
SwapFunctionBody(GraphDef && other)124 GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) {
125   graph = std::move(other);
126   return *this;
127 }
128 
HasParametrizedType(const FunctionDef & func)129 bool HasParametrizedType(const FunctionDef& func) {
130   const auto is_type_parametrized = [](const OpDef::ArgDef& arg) {
131     return !arg.type_attr().empty() || !arg.number_attr().empty() ||
132            !arg.type_list_attr().empty();
133   };
134 
135   const auto& input = func.signature().input_arg();
136   const auto& output = func.signature().output_arg();
137   return std::any_of(input.begin(), input.end(), is_type_parametrized) ||
138          std::any_of(output.begin(), output.end(), is_type_parametrized);
139 }
140 
HasParametrizedBody(const FunctionDef & func)141 bool HasParametrizedBody(const FunctionDef& func) {
142   const auto is_parametrized = [&](const NodeDef& node) {
143     for (const auto& attr : node.attr()) {
144       if (!attr.second.placeholder().empty()) return true;
145     }
146     return false;
147   };
148   return std::any_of(func.node_def().begin(), func.node_def().end(),
149                      is_parametrized);
150 }
151 
IsParametrized(const FunctionDef & func)152 bool IsParametrized(const FunctionDef& func) {
153   return HasParametrizedType(func) || HasParametrizedBody(func);
154 }
155 
InstantiationTypeParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,DataType> * type_parameters)156 Status InstantiationTypeParameters(
157     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
158     absl::flat_hash_map<string, DataType>* type_parameters) {
159   if (!type_parameters->empty()) {
160     return errors::InvalidArgument("Type parameters output map must be empty");
161   }
162 
163   const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) -> Status {
164     if (!arg.type_attr().empty()) {
165       DataType dtype;
166       TF_RETURN_IF_ERROR(
167           GetNodeAttr(func_instantiation_attr, arg.type_attr(), &dtype));
168       type_parameters->emplace(arg.type_attr(), dtype);
169 
170     } else if (!arg.type_list_attr().empty()) {
171       std::vector<DataType> dtypes;
172       TF_RETURN_IF_ERROR(
173           GetNodeAttr(func_instantiation_attr, arg.type_list_attr(), &dtypes));
174       int index = 0;
175       for (const DataType& dtype : dtypes) {
176         type_parameters->emplace(absl::StrCat(arg.type_list_attr(), ":", index),
177                                  dtype);
178         ++index;
179       }
180     }
181     return OkStatus();
182   };
183 
184   for (const auto& input : func.signature().input_arg())
185     TF_RETURN_IF_ERROR(resolve_type_attr(input));
186   for (const auto& output : func.signature().output_arg())
187     TF_RETURN_IF_ERROR(resolve_type_attr(output));
188 
189   return OkStatus();
190 }
191 
InstantiationBodyParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,AttrValue> * body_parameters)192 Status InstantiationBodyParameters(
193     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
194     absl::flat_hash_map<string, AttrValue>* body_parameters) {
195   if (!body_parameters->empty()) {
196     return errors::InvalidArgument("Body parameters output map must be empty");
197   }
198 
199   for (const NodeDef& func_body_node : func.node_def()) {
200     for (auto& attr : func_body_node.attr()) {
201       const string& placeholder = attr.second.placeholder();
202 
203       if (placeholder.empty() || body_parameters->contains(placeholder)) {
204         continue;
205       }
206 
207       const AttrValue* placeholder_value =
208           func_instantiation_attr.Find(placeholder);
209       if (placeholder_value) {
210         body_parameters->insert({placeholder, *placeholder_value});
211       } else {
212         return errors::InvalidArgument("Can't resolve placeholder: ",
213                                        placeholder);
214       }
215     }
216   }
217 
218   return OkStatus();
219 }
220 
MakeGrapplerFunctionItem(const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)221 Status MakeGrapplerFunctionItem(const FunctionDef& func,
222                                 const AttrSlice& func_instantiation_attr,
223                                 const FunctionLibraryDefinition& flib,
224                                 const int graph_def_version,
225                                 GrapplerFunctionItem* item) {
226   const OpDef& signature = func.signature();
227 
228   if (signature.name().empty()) {
229     return errors::InvalidArgument("Function name must be specified");
230   }
231 
232   // Function types will be resolved from function instantiation attributes. All
233   // other attributes will be lost during conversion to FunctionDef.
234   for (const OpDef::AttrDef& attr : signature.attr()) {
235     if (attr.type() != "type") {
236       return errors::InvalidArgument(
237           "Function signature must have only type attributes");
238     }
239   }
240 
241   // Instantiate function into a statically defined FunctionBody Graph.
242   std::unique_ptr<FunctionBody> fbody;
243   TF_RETURN_IF_ERROR(
244       FunctionDefToBodyHelper(func, func_instantiation_attr, &flib, &fbody));
245 
246   GraphDef function_body;
247   fbody->graph->ToGraphDef(&function_body);
248 
249   // Function body shares the library with the graph that instantiated it. We do
250   // not need a full copy of the function library, just the reachable subset.
251   *function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto();
252 
253   VLOG(3) << absl::Substitute(
254       "Deleted $0 unreachable functions from the Grappler function item "
255       "instantiation of $1 (library size = $2)",
256       flib.num_functions() - function_body.library().function_size(),
257       signature.name(), function_body.library().function_size());
258 
259   const int num_instantiated_inputs = fbody->arg_types.size();
260   const int num_instantiated_outputs = fbody->ret_types.size();
261 
262   std::vector<InputArgInstantiation> inputs;
263   inputs.reserve(num_instantiated_inputs);
264 
265   for (int in_id = 0; in_id < num_instantiated_inputs; ++in_id) {
266     const Node* node = fbody->arg_nodes[in_id];
267     const DataType& dtype = fbody->arg_types[in_id];
268     inputs.emplace_back(node->name(), dtype);
269   }
270 
271   std::vector<OutputArgInstantiation> outputs;
272   outputs.reserve(num_instantiated_outputs);
273 
274   for (int out_id = 0; out_id < num_instantiated_outputs; ++out_id) {
275     const Node* node = fbody->ret_nodes[out_id];
276     const DataType& dtype = fbody->ret_types[out_id];
277     outputs.emplace_back(node->name(), dtype);
278   }
279 
280   // Control outputs ensure that all side-effectful nodes in the function body
281   // will execute, even if they are not required to compute regular output args.
282   std::vector<ControlOutput> control_outputs;
283   control_outputs.reserve(func.control_ret_size());
284   for (const auto& control_ret : func.control_ret()) {
285     control_outputs.push_back({control_ret.first, control_ret.second});
286   }
287   // Sort control outputs to keep FunctionDef output stable. The sort order of
288   // map entries in func.control_ret() are not stable.
289   // See b/174715578 for context on why stability is desired.
290   std::sort(control_outputs.begin(), control_outputs.end());
291 
292   std::vector<const FunctionDef::ArgAttrs*> arg_attr(inputs.size(), nullptr);
293   for (const auto& attr : func.arg_attr()) {
294     arg_attr.at(attr.first) = &attr.second;
295   }
296 
297   *item = GrapplerFunctionItem(
298       /*func_name=*/signature.name(),
299       /*description=*/signature.description(),
300       /*func_attr=*/AttrSlice(&func.attr()), std::move(arg_attr),
301       std::move(inputs), std::move(outputs), std::move(control_outputs),
302       graph_def_version, signature.is_stateful(), std::move(function_body));
303   return OkStatus();
304 }
305 
MakeGrapplerFunctionItem(const FunctionDef & func,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)306 Status MakeGrapplerFunctionItem(const FunctionDef& func,
307                                 const FunctionLibraryDefinition& flib,
308                                 const int graph_def_version,
309                                 GrapplerFunctionItem* item) {
310   return MakeGrapplerFunctionItem(func, AttrSlice(), flib, graph_def_version,
311                                   item);
312 }
313 
ReplaceInputWithConst(const NodeDef & input_const,int input_index,GrapplerFunctionItem * item)314 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
315                              GrapplerFunctionItem* item) {
316   if (!IsConstant(input_const)) {
317     return errors::InvalidArgument("Input node is not a constant: ",
318                                    SummarizeNodeDef(input_const));
319   }
320   const int item_input_size = item->input_size();
321   if (input_index < 0 || input_index >= item_input_size) {
322     return errors::InvalidArgument(
323         "Function input index is out of bound: index=", input_index,
324         " input_size=", item->input_size());
325   }
326 
327   const InputArgInstantiation& input_arg = item->input(input_index);
328 
329   for (NodeDef& node : *item->graph.mutable_node()) {
330     // Replace '_Arg' node in the function body with a 'Const' node.
331     if (node.name() == input_arg.node_name) {
332       node = input_const;
333       node.set_name(input_arg.node_name);
334       node.clear_input();
335       node.clear_device();  // device placement is defined by instantiating node
336     }
337 
338     // Update index in all inputs after the removed const input.
339     if (IsArg(node)) {
340       auto attrs = AttrSlice(node);
341       int index;
342       TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
343       if (index >= input_index) {
344         (*node.mutable_attr())["index"].set_i(index - 1);
345       }
346     }
347   }
348 
349   item->input_args_.erase(item->input_args_.begin() + input_index);
350   item->arg_attr_.erase(item->arg_attr_.begin() + input_index);
351 
352   return OkStatus();
353 }
354 
RemoveFunctionOutputs(const absl::flat_hash_set<int> & remove_outputs,GrapplerFunctionItem * item,std::vector<std::pair<int,int>> * output_mapping)355 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
356                              GrapplerFunctionItem* item,
357                              std::vector<std::pair<int, int>>* output_mapping) {
358   DCHECK(output_mapping->empty());
359 
360   // Do some sanity checking of the removed outputs positions.
361   for (int remove_output : remove_outputs) {
362     const int item_output_size = item->output_size();
363     if (remove_output < 0 || remove_output >= item_output_size) {
364       return errors::InvalidArgument(
365           "Function output index is out of bound: index=", remove_output,
366           " output_size=", item->output_size());
367     }
368   }
369 
370   absl::flat_hash_set<const OutputArgInstantiation*> remove_output_args;
371   const auto is_remove_output_arg = [&](const OutputArgInstantiation& output) {
372     return remove_output_args.find(&output) != remove_output_args.end();
373   };
374 
375   for (int i = 0, end = item->output_size(); i < end; ++i) {
376     const OutputArgInstantiation& output = item->output(i);
377     if (remove_outputs.contains(i)) {
378       VLOG(3) << "Remove functions output: name=" << output.node_name
379               << "(index = " << i << ")";
380       remove_output_args.insert(&output);
381     } else if (!remove_output_args.empty()) {
382       // Add output mapping only if output position changed.
383       output_mapping->push_back({i, i - remove_output_args.size()});
384     }
385   }
386 
387   // Update 'index' attribute in all '_Retval' nodes that are in output mapping.
388   for (NodeDef& node : *item->graph.mutable_node()) {
389     if (IsRetval(node)) {
390       auto attrs = AttrSlice(node);
391       int index;
392       TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
393 
394       for (const auto& mapping : *output_mapping) {
395         const int from = mapping.first;
396         const int to = mapping.second;
397         if (index == from) {
398           (*node.mutable_attr())["index"].set_i(to);
399         }
400       }
401     }
402   }
403 
404   auto& o = item->output_args_;
405   o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end());
406 
407   return OkStatus();
408 }
409 
410 namespace {
411 
412 // FunctionDef uses different connectivity encoding for the function body nodes,
413 // than a GraphDef (see function.proto for details). This is a helper class that
414 // converts inputs in GraphDef format (node[:position]) to the FunctionDef
415 // format (node:output[:position]).
416 class MakeFunctionDefHelper {
417  public:
418   MakeFunctionDefHelper() = default;
419 
420   Status Initialize(const GrapplerFunctionItem& item,
421                     const FunctionLibraryDefinition& flib);
422 
423   // Converts input name from GraphDef format (name[:position]) to the
424   // FunctionDef input format (name[:output][:position]) using registered input
425   // arg instantiations and function body outputs.
426   Status AsFunctionDefInput(const string& graph_def_input,
427                             string* func_def_input) const;
428 
429   // Updates Node inputs from GraphDef to FunctionDef format.
430   Status AsFunctionDefNode(NodeDef* function_body_node) const;
431 
IsInputNode(const NodeDef & node) const432   bool IsInputNode(const NodeDef& node) const {
433     return input_nodes_.contains(node.name());
434   }
435 
IsOutputNode(const NodeDef & node) const436   bool IsOutputNode(const NodeDef& node) const {
437     return output_nodes_.contains(node.name());
438   }
439 
440  private:
441   absl::flat_hash_set<absl::string_view> input_nodes_;
442   absl::flat_hash_set<absl::string_view> output_nodes_;
443   // Mapping from function body node name to output names range map.
444   absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
445 };
446 
Initialize(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib)447 Status MakeFunctionDefHelper::Initialize(
448     const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib) {
449   for (const InputArgInstantiation& input_arg : item.inputs()) {
450     input_nodes_.insert(input_arg.node_name);
451   }
452   for (const OutputArgInstantiation& output_arg : item.outputs()) {
453     output_nodes_.insert(output_arg.node_name);
454   }
455 
456   for (const NodeDef& node : item.function_body().node()) {
457     const OpRegistrationData* registration;
458     TF_RETURN_IF_ERROR(flib.LookUp(node.op(), &registration));
459 
460     tensorflow::NameRangeMap outputs_range_map;
461     TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
462         node, registration->op_def, nullptr, &outputs_range_map));
463 
464     function_body_outputs_.emplace(node.name(), std::move(outputs_range_map));
465   }
466 
467   return OkStatus();
468 }
469 
AsFunctionDefInput(const string & graph_def_input,string * func_def_input) const470 Status MakeFunctionDefHelper::AsFunctionDefInput(const string& graph_def_input,
471                                                  string* func_def_input) const {
472   if (IsControlInput(graph_def_input)) {
473     *func_def_input = graph_def_input;
474     return OkStatus();
475   }
476 
477   const SafeTensorId tensor = ParseTensorName(graph_def_input);
478   DCHECK_GE(tensor.index(), 0);
479 
480   // Graph def input corresponds to one of the function inputs.
481   const auto is_input = input_nodes_.find(tensor.node());
482   if (is_input != input_nodes_.end()) {
483     DCHECK_EQ(tensor.index(), 0);
484     *func_def_input = tensor.node();
485     return OkStatus();
486   }
487 
488   // Or it must be output from one of the function body nodes
489   const auto is_body_output = function_body_outputs_.find(tensor.node());
490   if (is_body_output != function_body_outputs_.end()) {
491     const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
492 
493     for (const auto& el : outputs_range_map) {
494       const auto& output_name = el.first;
495       const auto& output_range = el.second;
496       if (tensor.index() >= output_range.first &&
497           tensor.index() < output_range.second) {
498         *func_def_input = absl::StrCat(tensor.node(), ":", output_name, ":",
499                                        tensor.index() - output_range.first);
500         return OkStatus();
501       }
502     }
503   }
504 
505   return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
506 }
507 
AsFunctionDefNode(NodeDef * function_body_node) const508 Status MakeFunctionDefHelper::AsFunctionDefNode(
509     NodeDef* function_body_node) const {
510   string func_def_input;
511 
512   for (int i = 0; i < function_body_node->input_size(); ++i) {
513     TF_RETURN_IF_ERROR(
514         AsFunctionDefInput(function_body_node->input(i), &func_def_input));
515     function_body_node->set_input(i, func_def_input);
516   }
517 
518   return OkStatus();
519 }
520 
521 }  // namespace
522 
MakeFunctionDef(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib,FunctionDef * func)523 Status MakeFunctionDef(const GrapplerFunctionItem& item,
524                        const FunctionLibraryDefinition& flib,
525                        FunctionDef* func) {
526   func->mutable_signature()->set_name(item.id);
527   func->mutable_signature()->set_description(item.description());
528   func->mutable_signature()->set_is_stateful(item.is_stateful());
529 
530   MakeFunctionDefHelper helper;
531   TF_RETURN_IF_ERROR(helper.Initialize(item, flib));
532 
533   // Mapping from the '_Retval' node name to the output tensor.
534   absl::flat_hash_map<absl::string_view, string> output_tensors;
535   for (const NodeDef& func_body_node : item.function_body().node()) {
536     if (!helper.IsOutputNode(func_body_node)) continue;
537     if (func_body_node.input_size() != 1) {
538       return errors::Internal("_Retval node must have single input: ",
539                               SummarizeNodeDef(func_body_node));
540     }
541     output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
542   }
543 
544   for (const InputArgInstantiation& input_arg : item.inputs()) {
545     OpDef::ArgDef arg_def;
546     arg_def.set_name(input_arg.node_name);
547     arg_def.set_type(input_arg.data_type);
548     arg_def.set_is_ref(IsRefType(input_arg.data_type));
549     *func->mutable_signature()->add_input_arg() = arg_def;
550   }
551 
552   // Add function output arguments.
553   for (const OutputArgInstantiation& output_arg : item.outputs()) {
554     const string output_name =
555         absl::StrReplaceAll(output_arg.node_name, {{"_RetVal", ""}});
556 
557     OpDef::ArgDef arg_def;
558     arg_def.set_name(output_name);
559     arg_def.set_type(output_arg.data_type);
560     arg_def.set_is_ref(IsRefType(output_arg.data_type));
561     *func->mutable_signature()->add_output_arg() = arg_def;
562 
563     auto it = output_tensors.find(output_arg.node_name);
564     if (it == output_tensors.end()) {
565       return errors::Internal(
566           "Can't find an output tensor for the output node: ",
567           output_arg.node_name);
568     }
569 
570     TF_RETURN_IF_ERROR(helper.AsFunctionDefInput(
571         it->second, &(*func->mutable_ret())[output_name]));
572   }
573 
574   // Add function control outputs.
575   for (const ControlOutput& control_out : item.control_outputs()) {
576     func->mutable_control_ret()->insert(
577         {control_out.output_name, control_out.node_name});
578     *func->mutable_signature()->add_control_output() = control_out.output_name;
579   }
580 
581   // Copy function definition specific attributes.
582   for (const auto& attr : item.func_attr()) {
583     const auto& attr_name = attr.first;
584     const auto& attr_value = attr.second;
585     (*func->mutable_attr())[attr_name] = attr_value;
586   }
587 
588   // Copy function arg attributes.
589   for (int i = 0, end = item.arg_attr().size(); i < end; ++i) {
590     const auto* attr = item.arg_attr().at(i);
591     if (attr != nullptr) {
592       (*func->mutable_arg_attr())[i] = *attr;
593     }
594   }
595 
596   // Copy function body nodes to the FunctionDef and update input format
597   for (const NodeDef& func_node : item.function_body().node()) {
598     // Skip original `_Arg` and `_Retval` nodes. If node was converted to some
599     // other type (e.g. inputs converted to placeholders), we need to check that
600     // it's not registered as function input or output node.
601     if (IsArg(func_node) || IsRetval(func_node) ||
602         helper.IsInputNode(func_node) || helper.IsOutputNode(func_node))
603       continue;
604 
605     NodeDef* func_def_node = func->add_node_def();
606     *func_def_node = func_node;
607     TF_RETURN_IF_ERROR(helper.AsFunctionDefNode(func_def_node));
608   }
609 
610   return OkStatus();
611 }
612 
613 }  // end namespace grappler
614 }  // end namespace tensorflow
615