1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/utils/functions.h"
16
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_replace.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/graph_def_util.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/utils.h"
34 #include "tensorflow/core/lib/strings/scanner.h"
35
36 namespace tensorflow {
37 namespace grappler {
38
GrapplerFunctionItem(string func_name,string description,AttrSlice func_attr,std::vector<const FunctionDef::ArgAttrs * > arg_attr,std::vector<InputArgInstantiation> input_args,std::vector<OutputArgInstantiation> output_args,std::vector<ControlOutput> control_outputs,const int graph_def_version,const bool is_stateful,GraphDef && function_body)39 GrapplerFunctionItem::GrapplerFunctionItem(
40 string func_name, string description, AttrSlice func_attr,
41 std::vector<const FunctionDef::ArgAttrs*> arg_attr,
42 std::vector<InputArgInstantiation> input_args,
43 std::vector<OutputArgInstantiation> output_args,
44 std::vector<ControlOutput> control_outputs, const int graph_def_version,
45 const bool is_stateful, GraphDef&& function_body)
46 : description_(std::move(description)),
47 func_attr_(func_attr),
48 arg_attr_(std::move(arg_attr)),
49 input_args_(std::move(input_args)),
50 output_args_(std::move(output_args)),
51 control_outputs_(std::move(control_outputs)),
52 is_stateful_(is_stateful) {
53 id = std::move(func_name);
54 graph = std::move(function_body);
55 graph.mutable_versions()->set_producer(graph_def_version);
56
57 // Fill the feed nodes with function input arguments.
58 for (const InputArgInstantiation& input_arg : input_args_) {
59 feed.push_back({input_arg.node_name, Tensor()});
60 }
61 // Fill the fetch nodes with outputs.
62 for (const OutputArgInstantiation& output_arg : output_args_) {
63 fetch.push_back(output_arg.node_name);
64 }
65 // We must keep all control output nodes.
66 for (const ControlOutput& control_output : control_outputs_) {
67 keep_ops.push_back(control_output.node_name);
68 }
69
70 // Tensorflow functions execution semantics is different from the main graph,
71 // and we need to preserve it when we do graph optimizations.
72 optimization_options().allow_pruning_stateful_and_dataset_ops = false;
73 }
74
description() const75 const string& GrapplerFunctionItem::description() const { return description_; }
76
inputs() const77 const std::vector<InputArgInstantiation>& GrapplerFunctionItem::inputs() const {
78 return input_args_;
79 }
80
input(int i) const81 const InputArgInstantiation& GrapplerFunctionItem::input(int i) const {
82 return input_args_[i];
83 }
84
input_size() const85 const std::size_t GrapplerFunctionItem::input_size() const {
86 return input_args_.size();
87 }
88
outputs() const89 const std::vector<OutputArgInstantiation>& GrapplerFunctionItem::outputs()
90 const {
91 return output_args_;
92 }
93
output(int i) const94 const OutputArgInstantiation& GrapplerFunctionItem::output(int i) const {
95 return output_args_[i];
96 }
97
output_size() const98 const std::size_t GrapplerFunctionItem::output_size() const {
99 return output_args_.size();
100 }
101
control_outputs() const102 const std::vector<ControlOutput>& GrapplerFunctionItem::control_outputs()
103 const {
104 return control_outputs_;
105 }
106
control_output_size() const107 const std::size_t GrapplerFunctionItem::control_output_size() const {
108 return control_outputs_.size();
109 }
110
func_attr() const111 const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; }
112
113 const std::vector<const FunctionDef::ArgAttrs*>&
arg_attr() const114 GrapplerFunctionItem::arg_attr() const {
115 return arg_attr_;
116 }
117
function_body() const118 const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
119
mutable_function_body()120 GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
121
is_stateful() const122 bool GrapplerFunctionItem::is_stateful() const { return is_stateful_; }
123
SwapFunctionBody(GraphDef && other)124 GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) {
125 graph = std::move(other);
126 return *this;
127 }
128
HasParametrizedType(const FunctionDef & func)129 bool HasParametrizedType(const FunctionDef& func) {
130 const auto is_type_parametrized = [](const OpDef::ArgDef& arg) {
131 return !arg.type_attr().empty() || !arg.number_attr().empty() ||
132 !arg.type_list_attr().empty();
133 };
134
135 const auto& input = func.signature().input_arg();
136 const auto& output = func.signature().output_arg();
137 return std::any_of(input.begin(), input.end(), is_type_parametrized) ||
138 std::any_of(output.begin(), output.end(), is_type_parametrized);
139 }
140
HasParametrizedBody(const FunctionDef & func)141 bool HasParametrizedBody(const FunctionDef& func) {
142 const auto is_parametrized = [&](const NodeDef& node) {
143 for (const auto& attr : node.attr()) {
144 if (!attr.second.placeholder().empty()) return true;
145 }
146 return false;
147 };
148 return std::any_of(func.node_def().begin(), func.node_def().end(),
149 is_parametrized);
150 }
151
IsParametrized(const FunctionDef & func)152 bool IsParametrized(const FunctionDef& func) {
153 return HasParametrizedType(func) || HasParametrizedBody(func);
154 }
155
InstantiationTypeParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,DataType> * type_parameters)156 Status InstantiationTypeParameters(
157 const FunctionDef& func, const AttrSlice& func_instantiation_attr,
158 absl::flat_hash_map<string, DataType>* type_parameters) {
159 if (!type_parameters->empty()) {
160 return errors::InvalidArgument("Type parameters output map must be empty");
161 }
162
163 const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) -> Status {
164 if (!arg.type_attr().empty()) {
165 DataType dtype;
166 TF_RETURN_IF_ERROR(
167 GetNodeAttr(func_instantiation_attr, arg.type_attr(), &dtype));
168 type_parameters->emplace(arg.type_attr(), dtype);
169
170 } else if (!arg.type_list_attr().empty()) {
171 std::vector<DataType> dtypes;
172 TF_RETURN_IF_ERROR(
173 GetNodeAttr(func_instantiation_attr, arg.type_list_attr(), &dtypes));
174 int index = 0;
175 for (const DataType& dtype : dtypes) {
176 type_parameters->emplace(absl::StrCat(arg.type_list_attr(), ":", index),
177 dtype);
178 ++index;
179 }
180 }
181 return OkStatus();
182 };
183
184 for (const auto& input : func.signature().input_arg())
185 TF_RETURN_IF_ERROR(resolve_type_attr(input));
186 for (const auto& output : func.signature().output_arg())
187 TF_RETURN_IF_ERROR(resolve_type_attr(output));
188
189 return OkStatus();
190 }
191
InstantiationBodyParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,AttrValue> * body_parameters)192 Status InstantiationBodyParameters(
193 const FunctionDef& func, const AttrSlice& func_instantiation_attr,
194 absl::flat_hash_map<string, AttrValue>* body_parameters) {
195 if (!body_parameters->empty()) {
196 return errors::InvalidArgument("Body parameters output map must be empty");
197 }
198
199 for (const NodeDef& func_body_node : func.node_def()) {
200 for (auto& attr : func_body_node.attr()) {
201 const string& placeholder = attr.second.placeholder();
202
203 if (placeholder.empty() || body_parameters->contains(placeholder)) {
204 continue;
205 }
206
207 const AttrValue* placeholder_value =
208 func_instantiation_attr.Find(placeholder);
209 if (placeholder_value) {
210 body_parameters->insert({placeholder, *placeholder_value});
211 } else {
212 return errors::InvalidArgument("Can't resolve placeholder: ",
213 placeholder);
214 }
215 }
216 }
217
218 return OkStatus();
219 }
220
MakeGrapplerFunctionItem(const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)221 Status MakeGrapplerFunctionItem(const FunctionDef& func,
222 const AttrSlice& func_instantiation_attr,
223 const FunctionLibraryDefinition& flib,
224 const int graph_def_version,
225 GrapplerFunctionItem* item) {
226 const OpDef& signature = func.signature();
227
228 if (signature.name().empty()) {
229 return errors::InvalidArgument("Function name must be specified");
230 }
231
232 // Function types will be resolved from function instantiation attributes. All
233 // other attributes will be lost during conversion to FunctionDef.
234 for (const OpDef::AttrDef& attr : signature.attr()) {
235 if (attr.type() != "type") {
236 return errors::InvalidArgument(
237 "Function signature must have only type attributes");
238 }
239 }
240
241 // Instantiate function into a statically defined FunctionBody Graph.
242 std::unique_ptr<FunctionBody> fbody;
243 TF_RETURN_IF_ERROR(
244 FunctionDefToBodyHelper(func, func_instantiation_attr, &flib, &fbody));
245
246 GraphDef function_body;
247 fbody->graph->ToGraphDef(&function_body);
248
249 // Function body shares the library with the graph that instantiated it. We do
250 // not need a full copy of the function library, just the reachable subset.
251 *function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto();
252
253 VLOG(3) << absl::Substitute(
254 "Deleted $0 unreachable functions from the Grappler function item "
255 "instantiation of $1 (library size = $2)",
256 flib.num_functions() - function_body.library().function_size(),
257 signature.name(), function_body.library().function_size());
258
259 const int num_instantiated_inputs = fbody->arg_types.size();
260 const int num_instantiated_outputs = fbody->ret_types.size();
261
262 std::vector<InputArgInstantiation> inputs;
263 inputs.reserve(num_instantiated_inputs);
264
265 for (int in_id = 0; in_id < num_instantiated_inputs; ++in_id) {
266 const Node* node = fbody->arg_nodes[in_id];
267 const DataType& dtype = fbody->arg_types[in_id];
268 inputs.emplace_back(node->name(), dtype);
269 }
270
271 std::vector<OutputArgInstantiation> outputs;
272 outputs.reserve(num_instantiated_outputs);
273
274 for (int out_id = 0; out_id < num_instantiated_outputs; ++out_id) {
275 const Node* node = fbody->ret_nodes[out_id];
276 const DataType& dtype = fbody->ret_types[out_id];
277 outputs.emplace_back(node->name(), dtype);
278 }
279
280 // Control outputs ensure that all side-effectful nodes in the function body
281 // will execute, even if they are not required to compute regular output args.
282 std::vector<ControlOutput> control_outputs;
283 control_outputs.reserve(func.control_ret_size());
284 for (const auto& control_ret : func.control_ret()) {
285 control_outputs.push_back({control_ret.first, control_ret.second});
286 }
287 // Sort control outputs to keep FunctionDef output stable. The sort order of
288 // map entries in func.control_ret() are not stable.
289 // See b/174715578 for context on why stability is desired.
290 std::sort(control_outputs.begin(), control_outputs.end());
291
292 std::vector<const FunctionDef::ArgAttrs*> arg_attr(inputs.size(), nullptr);
293 for (const auto& attr : func.arg_attr()) {
294 arg_attr.at(attr.first) = &attr.second;
295 }
296
297 *item = GrapplerFunctionItem(
298 /*func_name=*/signature.name(),
299 /*description=*/signature.description(),
300 /*func_attr=*/AttrSlice(&func.attr()), std::move(arg_attr),
301 std::move(inputs), std::move(outputs), std::move(control_outputs),
302 graph_def_version, signature.is_stateful(), std::move(function_body));
303 return OkStatus();
304 }
305
MakeGrapplerFunctionItem(const FunctionDef & func,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)306 Status MakeGrapplerFunctionItem(const FunctionDef& func,
307 const FunctionLibraryDefinition& flib,
308 const int graph_def_version,
309 GrapplerFunctionItem* item) {
310 return MakeGrapplerFunctionItem(func, AttrSlice(), flib, graph_def_version,
311 item);
312 }
313
ReplaceInputWithConst(const NodeDef & input_const,int input_index,GrapplerFunctionItem * item)314 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
315 GrapplerFunctionItem* item) {
316 if (!IsConstant(input_const)) {
317 return errors::InvalidArgument("Input node is not a constant: ",
318 SummarizeNodeDef(input_const));
319 }
320 const int item_input_size = item->input_size();
321 if (input_index < 0 || input_index >= item_input_size) {
322 return errors::InvalidArgument(
323 "Function input index is out of bound: index=", input_index,
324 " input_size=", item->input_size());
325 }
326
327 const InputArgInstantiation& input_arg = item->input(input_index);
328
329 for (NodeDef& node : *item->graph.mutable_node()) {
330 // Replace '_Arg' node in the function body with a 'Const' node.
331 if (node.name() == input_arg.node_name) {
332 node = input_const;
333 node.set_name(input_arg.node_name);
334 node.clear_input();
335 node.clear_device(); // device placement is defined by instantiating node
336 }
337
338 // Update index in all inputs after the removed const input.
339 if (IsArg(node)) {
340 auto attrs = AttrSlice(node);
341 int index;
342 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
343 if (index >= input_index) {
344 (*node.mutable_attr())["index"].set_i(index - 1);
345 }
346 }
347 }
348
349 item->input_args_.erase(item->input_args_.begin() + input_index);
350 item->arg_attr_.erase(item->arg_attr_.begin() + input_index);
351
352 return OkStatus();
353 }
354
RemoveFunctionOutputs(const absl::flat_hash_set<int> & remove_outputs,GrapplerFunctionItem * item,std::vector<std::pair<int,int>> * output_mapping)355 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
356 GrapplerFunctionItem* item,
357 std::vector<std::pair<int, int>>* output_mapping) {
358 DCHECK(output_mapping->empty());
359
360 // Do some sanity checking of the removed outputs positions.
361 for (int remove_output : remove_outputs) {
362 const int item_output_size = item->output_size();
363 if (remove_output < 0 || remove_output >= item_output_size) {
364 return errors::InvalidArgument(
365 "Function output index is out of bound: index=", remove_output,
366 " output_size=", item->output_size());
367 }
368 }
369
370 absl::flat_hash_set<const OutputArgInstantiation*> remove_output_args;
371 const auto is_remove_output_arg = [&](const OutputArgInstantiation& output) {
372 return remove_output_args.find(&output) != remove_output_args.end();
373 };
374
375 for (int i = 0, end = item->output_size(); i < end; ++i) {
376 const OutputArgInstantiation& output = item->output(i);
377 if (remove_outputs.contains(i)) {
378 VLOG(3) << "Remove functions output: name=" << output.node_name
379 << "(index = " << i << ")";
380 remove_output_args.insert(&output);
381 } else if (!remove_output_args.empty()) {
382 // Add output mapping only if output position changed.
383 output_mapping->push_back({i, i - remove_output_args.size()});
384 }
385 }
386
387 // Update 'index' attribute in all '_Retval' nodes that are in output mapping.
388 for (NodeDef& node : *item->graph.mutable_node()) {
389 if (IsRetval(node)) {
390 auto attrs = AttrSlice(node);
391 int index;
392 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
393
394 for (const auto& mapping : *output_mapping) {
395 const int from = mapping.first;
396 const int to = mapping.second;
397 if (index == from) {
398 (*node.mutable_attr())["index"].set_i(to);
399 }
400 }
401 }
402 }
403
404 auto& o = item->output_args_;
405 o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end());
406
407 return OkStatus();
408 }
409
410 namespace {
411
412 // FunctionDef uses different connectivity encoding for the function body nodes,
413 // than a GraphDef (see function.proto for details). This is a helper class that
414 // converts inputs in GraphDef format (node[:position]) to the FunctionDef
415 // format (node:output[:position]).
416 class MakeFunctionDefHelper {
417 public:
418 MakeFunctionDefHelper() = default;
419
420 Status Initialize(const GrapplerFunctionItem& item,
421 const FunctionLibraryDefinition& flib);
422
423 // Converts input name from GraphDef format (name[:position]) to the
424 // FunctionDef input format (name[:output][:position]) using registered input
425 // arg instantiations and function body outputs.
426 Status AsFunctionDefInput(const string& graph_def_input,
427 string* func_def_input) const;
428
429 // Updates Node inputs from GraphDef to FunctionDef format.
430 Status AsFunctionDefNode(NodeDef* function_body_node) const;
431
IsInputNode(const NodeDef & node) const432 bool IsInputNode(const NodeDef& node) const {
433 return input_nodes_.contains(node.name());
434 }
435
IsOutputNode(const NodeDef & node) const436 bool IsOutputNode(const NodeDef& node) const {
437 return output_nodes_.contains(node.name());
438 }
439
440 private:
441 absl::flat_hash_set<absl::string_view> input_nodes_;
442 absl::flat_hash_set<absl::string_view> output_nodes_;
443 // Mapping from function body node name to output names range map.
444 absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
445 };
446
Initialize(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib)447 Status MakeFunctionDefHelper::Initialize(
448 const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib) {
449 for (const InputArgInstantiation& input_arg : item.inputs()) {
450 input_nodes_.insert(input_arg.node_name);
451 }
452 for (const OutputArgInstantiation& output_arg : item.outputs()) {
453 output_nodes_.insert(output_arg.node_name);
454 }
455
456 for (const NodeDef& node : item.function_body().node()) {
457 const OpRegistrationData* registration;
458 TF_RETURN_IF_ERROR(flib.LookUp(node.op(), ®istration));
459
460 tensorflow::NameRangeMap outputs_range_map;
461 TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
462 node, registration->op_def, nullptr, &outputs_range_map));
463
464 function_body_outputs_.emplace(node.name(), std::move(outputs_range_map));
465 }
466
467 return OkStatus();
468 }
469
AsFunctionDefInput(const string & graph_def_input,string * func_def_input) const470 Status MakeFunctionDefHelper::AsFunctionDefInput(const string& graph_def_input,
471 string* func_def_input) const {
472 if (IsControlInput(graph_def_input)) {
473 *func_def_input = graph_def_input;
474 return OkStatus();
475 }
476
477 const SafeTensorId tensor = ParseTensorName(graph_def_input);
478 DCHECK_GE(tensor.index(), 0);
479
480 // Graph def input corresponds to one of the function inputs.
481 const auto is_input = input_nodes_.find(tensor.node());
482 if (is_input != input_nodes_.end()) {
483 DCHECK_EQ(tensor.index(), 0);
484 *func_def_input = tensor.node();
485 return OkStatus();
486 }
487
488 // Or it must be output from one of the function body nodes
489 const auto is_body_output = function_body_outputs_.find(tensor.node());
490 if (is_body_output != function_body_outputs_.end()) {
491 const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
492
493 for (const auto& el : outputs_range_map) {
494 const auto& output_name = el.first;
495 const auto& output_range = el.second;
496 if (tensor.index() >= output_range.first &&
497 tensor.index() < output_range.second) {
498 *func_def_input = absl::StrCat(tensor.node(), ":", output_name, ":",
499 tensor.index() - output_range.first);
500 return OkStatus();
501 }
502 }
503 }
504
505 return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
506 }
507
AsFunctionDefNode(NodeDef * function_body_node) const508 Status MakeFunctionDefHelper::AsFunctionDefNode(
509 NodeDef* function_body_node) const {
510 string func_def_input;
511
512 for (int i = 0; i < function_body_node->input_size(); ++i) {
513 TF_RETURN_IF_ERROR(
514 AsFunctionDefInput(function_body_node->input(i), &func_def_input));
515 function_body_node->set_input(i, func_def_input);
516 }
517
518 return OkStatus();
519 }
520
521 } // namespace
522
MakeFunctionDef(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib,FunctionDef * func)523 Status MakeFunctionDef(const GrapplerFunctionItem& item,
524 const FunctionLibraryDefinition& flib,
525 FunctionDef* func) {
526 func->mutable_signature()->set_name(item.id);
527 func->mutable_signature()->set_description(item.description());
528 func->mutable_signature()->set_is_stateful(item.is_stateful());
529
530 MakeFunctionDefHelper helper;
531 TF_RETURN_IF_ERROR(helper.Initialize(item, flib));
532
533 // Mapping from the '_Retval' node name to the output tensor.
534 absl::flat_hash_map<absl::string_view, string> output_tensors;
535 for (const NodeDef& func_body_node : item.function_body().node()) {
536 if (!helper.IsOutputNode(func_body_node)) continue;
537 if (func_body_node.input_size() != 1) {
538 return errors::Internal("_Retval node must have single input: ",
539 SummarizeNodeDef(func_body_node));
540 }
541 output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
542 }
543
544 for (const InputArgInstantiation& input_arg : item.inputs()) {
545 OpDef::ArgDef arg_def;
546 arg_def.set_name(input_arg.node_name);
547 arg_def.set_type(input_arg.data_type);
548 arg_def.set_is_ref(IsRefType(input_arg.data_type));
549 *func->mutable_signature()->add_input_arg() = arg_def;
550 }
551
552 // Add function output arguments.
553 for (const OutputArgInstantiation& output_arg : item.outputs()) {
554 const string output_name =
555 absl::StrReplaceAll(output_arg.node_name, {{"_RetVal", ""}});
556
557 OpDef::ArgDef arg_def;
558 arg_def.set_name(output_name);
559 arg_def.set_type(output_arg.data_type);
560 arg_def.set_is_ref(IsRefType(output_arg.data_type));
561 *func->mutable_signature()->add_output_arg() = arg_def;
562
563 auto it = output_tensors.find(output_arg.node_name);
564 if (it == output_tensors.end()) {
565 return errors::Internal(
566 "Can't find an output tensor for the output node: ",
567 output_arg.node_name);
568 }
569
570 TF_RETURN_IF_ERROR(helper.AsFunctionDefInput(
571 it->second, &(*func->mutable_ret())[output_name]));
572 }
573
574 // Add function control outputs.
575 for (const ControlOutput& control_out : item.control_outputs()) {
576 func->mutable_control_ret()->insert(
577 {control_out.output_name, control_out.node_name});
578 *func->mutable_signature()->add_control_output() = control_out.output_name;
579 }
580
581 // Copy function definition specific attributes.
582 for (const auto& attr : item.func_attr()) {
583 const auto& attr_name = attr.first;
584 const auto& attr_value = attr.second;
585 (*func->mutable_attr())[attr_name] = attr_value;
586 }
587
588 // Copy function arg attributes.
589 for (int i = 0, end = item.arg_attr().size(); i < end; ++i) {
590 const auto* attr = item.arg_attr().at(i);
591 if (attr != nullptr) {
592 (*func->mutable_arg_attr())[i] = *attr;
593 }
594 }
595
596 // Copy function body nodes to the FunctionDef and update input format
597 for (const NodeDef& func_node : item.function_body().node()) {
598 // Skip original `_Arg` and `_Retval` nodes. If node was converted to some
599 // other type (e.g. inputs converted to placeholders), we need to check that
600 // it's not registered as function input or output node.
601 if (IsArg(func_node) || IsRetval(func_node) ||
602 helper.IsInputNode(func_node) || helper.IsOutputNode(func_node))
603 continue;
604
605 NodeDef* func_def_node = func->add_node_def();
606 *func_def_node = func_node;
607 TF_RETURN_IF_ERROR(helper.AsFunctionDefNode(func_def_node));
608 }
609
610 return OkStatus();
611 }
612
613 } // end namespace grappler
614 } // end namespace tensorflow
615