xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/inline_function_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/inline_function_utils.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/function_utils.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/framework/collective.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/versions.pb.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/control_flow.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/graph/optimizer_cse.h"
39 #include "tensorflow/core/lib/core/threadpool.h"
40 #include "tensorflow/core/lib/gtl/map_util.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/profiler/lib/traceme.h"
43 #include "tensorflow/core/protobuf/config.pb.h"
44 #include "tensorflow/core/util/device_name_utils.h"
45 
46 namespace tensorflow {
47 
48 /*static*/ constexpr const char* const
49     LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
50 /*static*/ constexpr const char* const
51     LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
52 
53 namespace {
54 // A few string constant used throughout this module.
55 static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
56 static constexpr const char* const kDeviceArgOp =
57     FunctionLibraryDefinition::kDeviceArgOp;
58 static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
59 static constexpr const char* const kDeviceRetOp =
60     FunctionLibraryDefinition::kDeviceRetOp;
61 static constexpr const char* const kGradientOp =
62     FunctionLibraryDefinition::kGradientOp;
63 static constexpr const char* const kNodeLabel = "Func";
64 static constexpr const char* const kFuncAttr =
65     FunctionLibraryDefinition::kFuncAttr;
66 
67 // Represents the index-th output of a node.
68 struct Endpoint {
69   Node* node;
70   int index;
71 
72   // Returns the string name represents this endpoint.
nametensorflow::__anond182db660111::Endpoint73   string name() const {
74     if (index == 0) {
75       return node->name();
76     } else {
77       return strings::StrCat(node->name(), ":", index);
78     }
79   }
80 
dtypetensorflow::__anond182db660111::Endpoint81   DataType dtype() const { return node->output_type(index); }
82 };
83 
84 struct EndpointHash {
operator ()tensorflow::__anond182db660111::EndpointHash85   uint64 operator()(const Endpoint& x) const {
86     return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
87                   x.index);
88   }
89 };
90 
91 struct EndpointEq {
operator ()tensorflow::__anond182db660111::EndpointEq92   bool operator()(const Endpoint& x, const Endpoint& y) const {
93     return (x.node == y.node) && (x.index == y.index);
94   }
95 };
96 
97 // The following Add* routines are used to add a few graph nodes while
98 // functions are transformed.
AddNoOp(StringPiece name,Graph * g)99 static Node* AddNoOp(StringPiece name, Graph* g) {
100   NodeDef ndef;
101   ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
102   ndef.set_op("NoOp");
103   Status s;
104   Node* ret = g->AddNode(ndef, &s);
105   TF_CHECK_OK(s);
106   return ret;
107 }
108 
AddIdentity(StringPiece name,Graph * g,Endpoint input)109 static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
110   DCHECK_LT(0, input.dtype());
111   NodeDef ndef;
112   ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
113   ndef.set_op("Identity");
114   ndef.add_input(input.name());
115   AddNodeAttr("T", BaseType(input.dtype()), &ndef);
116   Status s;
117   Node* ret = g->AddNode(ndef, &s);
118   TF_CHECK_OK(s);
119   g->AddEdge(input.node, input.index, ret, 0);
120   return ret;
121 }
122 
InputDevices(const Node & caller)123 std::vector<string> InputDevices(const Node& caller) {
124   std::vector<string> input_devices(caller.in_edges().size());
125   std::vector<string> input_tensors(caller.in_edges().size());
126 
127   for (const Edge* edge : caller.in_edges()) {
128     if (edge->IsControlEdge()) continue;
129     const string& input_device = edge->src()->has_assigned_device_name()
130                                      ? edge->src()->assigned_device_name()
131                                      : edge->src()->requested_device();
132     input_devices[edge->dst_input()] = input_device;
133     input_tensors[edge->dst_input()] =
134         absl::StrCat(edge->src()->name(), ":", edge->src_output());
135   }
136 
137   if (VLOG_IS_ON(4)) {
138     VLOG(4) << "Function instantiation input devices:";
139     for (int i = 0; i < input_devices.size(); ++i) {
140       if (input_tensors[i].empty()) continue;  // skip control edges
141       VLOG(4) << "    [index " << i << "]"
142               << " device: " << input_devices[i]
143               << " (input: " << input_tensors[i] << ")";
144     }
145   }
146 
147   return input_devices;
148 }
149 
150 // Place input nodes on the same device as the corresponding caller input
151 // node. Do not specify any placement for all other nodes.
152 class DefaultFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
153  public:
DefaultFunctionBodyPlacer(const Node & caller)154   explicit DefaultFunctionBodyPlacer(const Node& caller)
155       : input_devices_(InputDevices(caller)) {}
156 
InputNodeDevice(int input_index) const157   absl::optional<string> InputNodeDevice(int input_index) const override {
158     return input_devices_[input_index];
159   }
OutputNodeDevice(int output_index) const160   absl::optional<string> OutputNodeDevice(int output_index) const override {
161     return absl::nullopt;
162   }
ColocateInputOutputIdentities() const163   bool ColocateInputOutputIdentities() const override { return false; }
ControlNodeDevice() const164   absl::optional<string> ControlNodeDevice() const override {
165     return absl::nullopt;
166   }
BodyNodeDevice(const NodeDef & ndef) const167   absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
168     return absl::nullopt;
169   }
170 
171  private:
172   const std::vector<string> input_devices_;
173 };
174 
175 // Place all nodes on the same device as caller node.
176 class SingleDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
177  public:
SingleDeviceFunctionBodyPlacer(const Node & caller)178   explicit SingleDeviceFunctionBodyPlacer(const Node& caller)
179       : caller_device_(caller.def().device()) {}
180 
InputNodeDevice(int input_index) const181   absl::optional<string> InputNodeDevice(int input_index) const override {
182     return caller_device_;
183   }
OutputNodeDevice(int output_index) const184   absl::optional<string> OutputNodeDevice(int output_index) const override {
185     return caller_device_;
186   }
ColocateInputOutputIdentities() const187   bool ColocateInputOutputIdentities() const override { return false; }
ControlNodeDevice() const188   absl::optional<string> ControlNodeDevice() const override {
189     return caller_device_;
190   }
BodyNodeDevice(const NodeDef & ndef) const191   absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
192     return caller_device_;
193   }
194 
195  private:
196   const string caller_device_;
197 };
198 
199 // Place input nodes on the same device as the corresponding caller input
200 // node. Do not place output node. Place control nodes on the same device as
201 // caller node. For all function body nodes overrides job, replica and task
202 // parts of the device assignment to match function caller node.
203 class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
204  public:
MultiDeviceFunctionBodyPlacer(const Node & caller)205   explicit MultiDeviceFunctionBodyPlacer(const Node& caller)
206       : caller_device_(caller.def().device()),
207         input_devices_(InputDevices(caller)) {
208     has_parsed_caller_device_ =
209         DeviceNameUtils::ParseFullName(caller_device_, &caller_parsed_device_);
210   }
211 
InputNodeDevice(int input_index) const212   absl::optional<string> InputNodeDevice(int input_index) const override {
213     return input_devices_[input_index];
214   }
OutputNodeDevice(int output_index) const215   absl::optional<string> OutputNodeDevice(int output_index) const override {
216     return absl::nullopt;
217   }
ColocateInputOutputIdentities() const218   bool ColocateInputOutputIdentities() const override { return true; }
ControlNodeDevice() const219   absl::optional<string> ControlNodeDevice() const override {
220     return caller_device_;
221   }
BodyNodeDevice(const NodeDef & ndef) const222   absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
223     // LINT.IfChange
224     // TODO(ezhulenev): If function would have been instantiated as a
225     // multi-device function and executed via FunctionLibraryRuntime, it could
226     // be potentially placed on any available device. However there are multiple
227     // tests relying on this assumption. Fix them, and remove this line.
228     if (ndef.device().empty()) return caller_device_;
229 
230     if (!has_parsed_caller_device_) return ndef.device();
231 
232     DeviceNameUtils::ParsedName ndef_parsed_device;
233     if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device))
234       return ndef.device();
235 
236     DeviceNameUtils::MergeUnsetDevNames(&ndef_parsed_device,
237                                         caller_parsed_device_);
238     return DeviceNameUtils::ParsedNameToString(ndef_parsed_device);
239     // LINT.ThenChange(../../compiler/mlir/tensorflow/ir/tf_ops.cc)
240   }
241 
242  private:
243   string caller_device_;
244   bool has_parsed_caller_device_;
245   DeviceNameUtils::ParsedName caller_parsed_device_;
246   std::vector<string> input_devices_;
247 };
248 
249 }  // namespace
250 
251 std::unique_ptr<InlinedFunctionBodyPlacer>
DefaultPlacer(const Graph & graph,const Node & caller)252 InlinedFunctionBodyPlacer::DefaultPlacer(const Graph& graph,
253                                          const Node& caller) {
254   VLOG(3) << "Create default placer for inlined function body.";
255   return std::make_unique<DefaultFunctionBodyPlacer>(caller);
256 }
257 
258 std::unique_ptr<InlinedFunctionBodyPlacer>
SingleDevicePlacer(const Graph & graph,const Node & caller)259 InlinedFunctionBodyPlacer::SingleDevicePlacer(const Graph& graph,
260                                               const Node& caller) {
261   VLOG(3) << "Create single device placer for inlined function body.";
262   return std::make_unique<SingleDeviceFunctionBodyPlacer>(caller);
263 }
264 
265 std::unique_ptr<InlinedFunctionBodyPlacer>
MultiDevicePlacer(const Graph & graph,const Node & caller)266 InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph,
267                                              const Node& caller) {
268   VLOG(3) << "Create multi device placer for inlined function body.";
269   return std::make_unique<MultiDeviceFunctionBodyPlacer>(caller);
270 }
271 
272 namespace {
273 
ValidateNoInline(const FunctionBody * fbody)274 Status ValidateNoInline(const FunctionBody* fbody) {
275   const auto attr = AttrSlice(&fbody->fdef.attr());
276   bool noinline = false;
277   if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) {
278     return errors::InvalidArgument(
279         "Can't inline function marked with '_noinline'");
280   }
281   return OkStatus();
282 }
283 
284 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
285 
286 // Propagate the debug info of `nodes` in function `func` to the `target` node.
287 // If the debug info of any node is missing, its node name and function name
288 // is used.
PropagateDebugInfoToNode(const string & func,const std::vector<const Node * > & nodes,NodeDef * target)289 void PropagateDebugInfoToNode(const string& func,
290                               const std::vector<const Node*>& nodes,
291                               NodeDef* target) {
292   if (nodes.empty() || target->has_experimental_debug_info()) {
293     return;
294   }
295   for (const Node* node : nodes) {
296     const auto& node_def = node->def();
297     if (node_def.has_experimental_debug_info()) {
298       target->mutable_experimental_debug_info()->MergeFrom(
299           node_def.experimental_debug_info());
300     } else {
301       target->mutable_experimental_debug_info()->add_original_node_names(
302           node_def.name());
303       target->mutable_experimental_debug_info()->add_original_func_names(func);
304     }
305   }
306 }
307 }  // namespace
308 
DebugString() const309 string InlineFunctionBodyOptions::DebugString() const {
310   const auto true_false = [](bool b) { return b ? "true" : "false"; };
311 
312   const auto keep_caller_node_str = [this]() -> string {
313     switch (keep_caller_node) {
314       case KeepCallerNode::kDoNotKeep:
315         return "DoNotKeep";
316       case KeepCallerNode::kFetchable:
317         return "Fetchable";
318       case KeepCallerNode::kTargetable:
319         return "Targetable";
320     }
321   };
322 
323   return absl::StrCat(
324       "disable_inlining=", true_false(disable_inlining),
325       ", ignore_noinline=", true_false(ignore_noinline),
326       ", inline_impl_selection_group_functions=",
327       true_false(inline_impl_selection_group_functions),
328       ", keep_caller_node=", keep_caller_node_str(), ", output_control_src=",
329       output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs"
330                                                            : "ControlOutputs",
331       ", inlined_function_body_placer=", inlined_function_body_placer.name,
332       ", uniquify_frame_names=", true_false(uniquify_frame_names));
333 }
334 
ValidateInlining(const Node * node,const FunctionBody * fbody,const InlineFunctionBodyOptions & options)335 Status ValidateInlining(const Node* node, const FunctionBody* fbody,
336                         const InlineFunctionBodyOptions& options) {
337   // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee
338   // that all side-effectful ops will be executed after inlining. See Grappler
339   // function_optimizer for details. Unify all function inlining mechanism.
340   // Do not inline if `!fbody->control_ret_nodes.empty()`.
341 
342   const auto num_node_inputs = static_cast<size_t>(node->num_inputs());
343   const auto num_node_outputs = static_cast<size_t>(node->num_outputs());
344 
345   if (num_node_inputs != fbody->arg_types.size() ||
346       num_node_inputs != fbody->arg_nodes.size()) {
347     return errors::InvalidArgument(
348         "Node inputs do not match function arguments: inputs=", num_node_inputs,
349         " arg_types=", fbody->arg_types.size(),
350         " arg_nodes=", fbody->arg_nodes.size());
351   }
352 
353   if (num_node_outputs != fbody->ret_types.size() ||
354       num_node_outputs != fbody->ret_nodes.size()) {
355     return errors::InvalidArgument(
356         "Node outputs do not match function returns: outputs=",
357         num_node_outputs, " ret_types=", fbody->ret_types.size(),
358         " ret_nodes=", fbody->ret_nodes.size());
359   }
360 
361   for (int i = 0; i < node->num_inputs(); ++i) {
362     if (node->input_type(i) != fbody->arg_types[i]) {
363       return errors::InvalidArgument(
364           "Node input type doesn't match function argument type: ",
365           node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i);
366     }
367   }
368   for (int i = 0; i < node->num_outputs(); ++i) {
369     if (node->output_type(i) != fbody->ret_types[i]) {
370       return errors::InvalidArgument(
371           "Node output type doesn't match function return type: ",
372           node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i);
373     }
374   }
375 
376   if (options.disable_inlining) {
377     return errors::InvalidArgument(
378         "Function inlining explicitly disabled by 'options.disable_inlining'");
379   }
380 
381   if (!options.inline_impl_selection_group_functions) {
382     bool is_impl_selection_group_function =
383         fbody->fdef.attr().find("api_implements") != fbody->fdef.attr().end();
384     if (is_impl_selection_group_function) {
385       return errors::InvalidArgument(
386           "Inlining of implementation selection group function ",
387           fbody->fdef.signature().name(),
388           " is disabled by options.inline_impl_selection_group_functions");
389     }
390   }
391 
392   if (!options.ignore_noinline) {
393     TF_RETURN_IF_ERROR(ValidateNoInline(fbody));
394   }
395 
396   return OkStatus();
397 }
398 
399 // Function inlining must preserve function execution semantics with regards to
400 // side-effects visibility. Tensorflow in Eager mode has an automatic control
401 // dependencies tracking mechanism, which enforces well-defined execution order
402 // of all side-effects. Any other frontend (e.g. Swift) must produce graphs
403 // following the same rules, to ensure that function inlining works correctly.
404 //
405 // IMPORTANT: Currently we do not have a true notion of "side-effectful" node,
406 // we assume that all stateful nodes might have side-effects, though it's not
407 // true in practice, e.g. `ReadVariableOp` doesn't have an observable
408 // side-effect.
409 //
410 // Automatic control dependency rules in Tensorflow 2.0 (python in eager mode):
411 //
412 // 1) When a function has a resource (DT_RESOURCE data type) input argument it
413 //   "captures" the mutable resource.  This is implemented by automatically
414 //    adding a incoming control edge from the previous side-effectful op
415 //    touching that resource, and an outgoing control edge to the next
416 //    side-effectful op using the same resource. This serializes the mutations
417 //    of the resource to make graph execution deterministic.
418 //
419 // 2) All stateful ops inside a function body are guaranteed to execute in
420 //    program order, this is achieved by adding control edges between stateful
421 //    ops at graph construction time. Stateful ops (or ops that must execute)
422 //    should be in the function control return set. Having a data edge to the
423 //    regular function output might be not enough, because after function
424 //    inlining it might happen that data output is unused.
425 //
426 // 3) Furthermore, all ops accepting the same resource as an input are
427 //    guaranteed to run in program order. This is also done by adding control
428 //    edges at graph construction time. The last op touching the resource
429 //    must be in a control return set, which will guarantee that all side
430 //    effects to the resource will happen before function completion.
431 //
432 // Function inlining must preserve side-effect visibility:
433 //
434 // 1) All side-effects to the captured resources, that happened before function
435 //    call must be visible to the function body nodes using that resources.
436 //
437 // 2) All side-effects to the captured resources, that happened inside function
438 //    body, must be visible to every op/function using that resource after the
439 //    function call completed.
440 //
441 // To guarantee that these properties are preserved after inlining we:
442 //
443 // 1) Create "input_control_node" NoOp. Function call node incoming control
444 //    edges will be forwarded *to* this node. Function inputs (Identity nodes)
445 //    will have a control edge *from* this node. If function body has nodes
446 //    without inputs, they will have a control edge *from* this node.
447 //
448 // 2) Create "output_control_node" NoOp. All nodes that have incoming control
449 //    edge *from* the function call node, will be forwarded to this node.
450 //
451 //    We have two options for choosing which nodes will have a control edge *to*
452 //    the "output control node":
453 //       a) control returns            (`control_ret` field in FunctionDef)
454 //       b) data returns               (`ret` field in FunctionDef)
455 //
456 //    We do a) for multi-device function calls in Tensorflow v2 and b)
457 //    for the rest for compatibility with Tensorflow v1.
458 //
459 //    Following the automatic control dependencies tracking rules, a node that
460 //    has an incoming control edge from the function call node is dependent on
461 //    the side-effects happening inside the function body. The output control
462 //    node will guarantee side-effects execution order.
463 //
464 //    If function call node doesn't have an outgoing control edge, it means that
465 //    no one is interested in observing side-effects that might have happened.
466 //
467 // Function inlining might leave the graph in partially-placed state. Function
468 // inlining caller must call Placer to guarantee that all nodes are placed.
469 //
470 // Function inlining with `options.override_device=true` will leave graph in
471 // fully placed state, by overriding all inlined nodes devices with the caller
472 // node device, but it will make functions always single-device. These functions
473 // after inlining will not be able to handle resources on multiple devices. This
474 // is currently acceptable for XLA use cases (XLA cluster is always executed on
475 // a single device).
476 //
477 // TODO(ezhulenev): Documentation above is ahead of implementation below.
InlineFunctionBody(const FunctionLibraryDefinition & flib_def,Graph * g,Node * caller,const FunctionBody * fbody,const InlineFunctionBodyOptions & options)478 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
479                           Node* caller, const FunctionBody* fbody,
480                           const InlineFunctionBodyOptions& options) {
481   VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " ["
482           << options.DebugString() << "]";
483   VLOG(4) << "Inlining function: " << fbody->fdef.DebugString();
484   VLOG(4) << "Current graphdef: " << g->ToGraphDefDebug().DebugString();
485   VLOG(4) << "Caller: " << caller->DebugString();
486 
487   Status validation = ValidateInlining(caller, fbody, options);
488   if (!validation.ok()) {
489     return errors::Internal("Inlining mismatch: ", validation.error_message());
490   }
491 
492   // Placer is responsible for assigning devices for all nodes that we will add
493   // to the graph.
494   const std::unique_ptr<InlinedFunctionBodyPlacer> placer =
495       options.inlined_function_body_placer.get(*g, *caller);
496 
497   // We can't possibly introduce a duplicate control edge during function
498   // inlining, so we skip this check in calls to the 'g->AddControlEdge(...)'.
499   static constexpr bool kDoNotCheckDuplicates = true;
500 
501   // ------------------------------------------------------------------------ //
502   // Helper functions to create `NoOp` and `Identity` nodes for auxiliary
503   // control nodes and inlined function inputs and outputs.
504 
505   // Add a NoOp node for function control inputs/outputs.
506   const auto no_op = [&](StringPiece name) -> Node* {
507     Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g);
508     const absl::optional<string> device = placer->ControlNodeDevice();
509     if (device.has_value()) node->set_requested_device(*device);
510     return node;
511   };
512 
513   // Add an Identity node for function input.
514   const auto input_identity = [&](StringPiece name, Endpoint input,
515                                   int index) -> Node* {
516     Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
517     const absl::optional<string> device = placer->InputNodeDevice(index);
518     if (device.has_value()) node->set_requested_device(*device);
519     bool colocate_identity = placer->ColocateInputOutputIdentities();
520     if (colocate_identity) {
521       node->AddAttr(kColocationAttrName,
522                     std::vector<string>{absl::StrCat(kColocationGroupPrefix,
523                                                      input.node->name())});
524     }
525     return node;
526   };
527 
528   // Add an Identity node for function output.
529   const auto output_identity = [&](StringPiece name, Endpoint input,
530                                    int index) -> Node* {
531     Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
532     const absl::optional<string> device = placer->OutputNodeDevice(index);
533     if (device.has_value()) node->set_requested_device(*device);
534     bool colocate_identity = placer->ColocateInputOutputIdentities();
535     if (colocate_identity) {
536       node->AddAttr(kColocationAttrName,
537                     std::vector<string>{absl::StrCat(kColocationGroupPrefix,
538                                                      input.node->name())});
539     }
540     return node;
541   };
542 
543   // ------------------------------------------------------------------------ //
544   // Helper function to get an input/output argument name by index. For
545   // functions instantiated from SymbolicGradien corresponding FunctionDef is
546   // empty, and argument name is unknown.
547 
548   auto arg_name = [&](auto& args, size_t i) -> absl::string_view {
549     if (i < args.size()) {
550       return args[i].name();
551     } else {
552       return "<unknown>";
553     }
554   };
555 
556   // ------------------------------------------------------------------------ //
557   // Input edges. For data edges coming into "caller", we first compute the
558   // <src>:<src_output> for the i-th input in "inputs".
559   // If "caller" has any input control dependencies, we add a NoOp
560   // node "input_control_node", which depends on "caller"'s control inputs.
561   std::vector<Endpoint> inputs(caller->num_inputs());
562   Node* input_control_node = nullptr;
563   for (const Edge* e : caller->in_edges()) {
564     if (e->IsControlEdge()) {
565       if (input_control_node == nullptr) {
566         input_control_node = no_op("input_control_node");
567       }
568       g->AddControlEdge(e->src(), input_control_node, kDoNotCheckDuplicates);
569     } else {
570       inputs[e->dst_input()] = {e->src(), e->src_output()};
571     }
572   }
573   if (input_control_node != nullptr) {
574     VLOG(3) << "Created input control node: " << input_control_node->name();
575   }
576 
577   // We create one Identity node for each input.
578   std::vector<Node*> input_nodes;
579   std::map<absl::string_view, absl::string_view> input_node_name_map;
580   for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
581     Node* n = input_identity("input", inputs[i], i);
582     input_node_name_map[arg_name(fbody->fdef.signature().input_arg(), i)] =
583         n->name();
584     input_nodes.push_back(n);
585   }
586 
587   // ------------------------------------------------------------------------ //
588   // Duplicate fbody->graph into 'g'.  First, we copy the nodes of
589   // fbody->graph into 'g' except the source and sink nodes.  We copy
590   // edges among nodes in 'fbody->graph'.
591   //
592   // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
593   // remember 'y' in node_map[x->id()].
594   std::unordered_set<string> fn_nodes;
595   for (Node* n : fbody->graph->op_nodes()) {
596     fn_nodes.insert(n->name());
597   }
598   std::vector<Node*> node_map(fbody->graph->num_node_ids());
599   for (Node* n : fbody->graph->op_nodes()) {
600     NodeDef ndef = n->def();
601 
602     // Maybe override requested node device assignment.
603     const absl::optional<string> device = placer->BodyNodeDevice(ndef);
604     if (device.has_value()) ndef.set_device(*device);
605 
606     // Add inlined function name to inlined node debug information.
607     PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef);
608 
609     // Add the function node name as a prefix:
610     //  1) to node name to avoid collisions
611     //  2) to frame name to avoid multiple LoopCond nodes in one frame
612     //  3) to colocation attribute
613     const string prefix = strings::StrCat(caller->name(), "/");
614     TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef,
615                                                 options.uniquify_frame_names));
616 
617     // If the colocation attribute is an input arg, we need to change it to the
618     // new input (Identity) node now.
619     TF_RETURN_IF_ERROR(
620         MaybeUpdateColocationConstraintsWithMap(input_node_name_map, &ndef));
621 
622     TF_RETURN_IF_ERROR(
623         MaybeAddPrefixToColocationConstraints(fn_nodes, prefix, &ndef));
624 
625     Status added_node;
626     Node* clone = g->AddNode(std::move(ndef), &added_node);
627     TF_CHECK_OK(added_node);
628     node_map[n->id()] = clone;
629     clone->SetStackTrace(n->GetStackTrace());
630 
631     // If there is an input control node, and one of:
632     // a) the node has no data or control inputs, or
633     // b) the node is a function call (including SymbolicGradient),
634     //    then add a control edge from the input control node to the clone (only
635     //    if it does not already have a control input).
636     //
637     // We must not execute any nodes if the original function call would not
638     // have executed. This is especially critical when the function call is
639     // inside a control-flow construct like tf.cond(). Case (a) ensures that
640     // such nodes do not run.
641     //
642     // The purpose of case (b) is to ensure that instances of case (a) created
643     // by further inlining steps also receive the control dependency.
644     //
645     // This edge is required to transfer execution frame down to all function
646     // body nodes of inlined nested function calls.
647     if (input_control_node) {
648       const auto is_input_edge = [](const Edge* e) -> bool {
649         return !e->src()->IsSource();
650       };
651       const auto is_control_edge = [](const Edge* e) -> bool {
652         return !e->src()->IsSource() && e->IsControlEdge();
653       };
654 
655       // Forward execution frame if:
656       //
657       // a) The node has no data or control inputs.
658       // b) OR the node is a function call without control inputs (control edge
659       //    will be used in nested function inlining to forward execution frame
660       //    to constants inside the function body).
661       //
662       // c) Do not forward control frame to function argument nodes, they will
663       //    be connected to the corresponding function input later.
664       const bool forward_execution_frame =
665           (absl::c_none_of(n->in_edges(), is_input_edge) ||       // (a)
666            (n->IsFunctionCall() &&                                // (b)
667             absl::c_none_of(n->in_edges(), is_control_edge))) &&  //
668           !n->IsArg();                                            // (c)
669 
670       if (forward_execution_frame) {
671         VLOG(4) << "Add control edge from input control node to: "
672                 << clone->name();
673         g->AddControlEdge(input_control_node, clone, kDoNotCheckDuplicates);
674       }
675     }
676   }
677   for (const Edge* e : fbody->graph->edges()) {
678     if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
679         e->dst()->IsSink()) {
680       continue;
681     }
682     Node* src_copy = node_map[e->src()->id()];
683     Node* dst_copy = node_map[e->dst()->id()];
684     g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
685   }
686 
687   // ------------------------------------------------------------------------ //
688   // Connect input edges.
689   //
690   // Then, we connect inputs[i] to the i-th identity node added. The nodes that
691   // previously connected to the j-th output of i-th arg node are reconnected
692   // to the i-th identity node.
693   //
694   // The added identity nodes depend on "input_control_node".
695   VLOG(4) << "Add input Identity nodes for each function argument:";
696 
697   for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
698     Node* arg = node_map[fbody->arg_nodes[i]->id()];
699     Node* n = input_nodes[i];
700     VLOG(4) << "    [index " << i << "] "
701             << arg_name(fbody->fdef.signature().input_arg(), i) << " as "
702             << n->name() << " (input: " << inputs[i].name()
703             << ", requested_device: " << n->requested_device() << ")";
704 
705     if (input_control_node) {
706       g->AddControlEdge(input_control_node, n, kDoNotCheckDuplicates);
707     }
708     for (const Edge* e : arg->out_edges()) {
709       if (e->IsControlEdge()) {
710         g->AddControlEdge(n, e->dst(), kDoNotCheckDuplicates);
711       } else {
712         g->AddEdge(n, 0, e->dst(), e->dst_input());
713       }
714     }
715     node_map[fbody->arg_nodes[i]->id()] = n;
716     g->RemoveNode(arg);  // 'arg' is disconnected.
717   }
718 
719   // ------------------------------------------------------------------------ //
720   // Connect output edges.
721   //
722   // For i-th return node in fbody->graph, we add in "g" an identity node
723   // (outputs[i-th]). We then reconnect every incoming edge into the i-th return
724   // node to the added identity node.
725   //
726   // For every data edge coming out of "callee"s i-th output, we reconnect it to
727   // the i-th identity added above.
728   //
729   // If "callee" is control-depended upon by any other nodes, we add a NoOp node
730   // "output_control_node". "output_control_node" depends on all identity nodes
731   // added above or on all control return nodes (controlled by
732   // `options.output_control_src` value). And nodes previously depend on
733   // "callee" is changed to depend on "output_control_node".
734   //
735   // If `keep_node_fetchable` is `true` we always add an output control node, to
736   // guarantee that executing a fetchable node will execute all side-effects.
737   VLOG(4) << "Add output Identity nodes for each function output argument:";
738 
739   std::vector<Node*> outputs(caller->num_outputs());
740   for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
741     Node* ret = node_map[fbody->ret_nodes[i]->id()];
742     Endpoint data;  // Data input for the ret node.
743     for (const Edge* e : ret->in_edges()) {
744       if (!e->IsControlEdge()) {
745         data = {e->src(), e->src_output()};
746         break;
747       }
748     }
749     CHECK(data.node != nullptr);
750     Node* n = output_identity("output", data, i);
751     outputs[i] = n;
752     VLOG(4) << "    [index " << i << "] "
753             << arg_name(fbody->fdef.signature().output_arg(), i) << " as "
754             << n->name() << " (ret: " << data.node->name() << ":" << data.index
755             << ", requested_device: " << n->requested_device() << ")";
756     for (const Edge* e : ret->in_edges()) {
757       if (e->IsControlEdge()) {
758         g->AddControlEdge(e->src(), n, kDoNotCheckDuplicates);
759       }
760     }
761     g->RemoveNode(ret);  // 'ret' is disconnected.
762   }
763 
764   Node* output_control_node = nullptr;
765   const bool has_control_outputs = absl::c_any_of(
766       caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); });
767 
768   using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
769   const bool keep_caller_node =
770       options.keep_caller_node == KeepCallerNode::kFetchable ||
771       options.keep_caller_node == KeepCallerNode::kTargetable;
772 
773   if (has_control_outputs || keep_caller_node) {
774     output_control_node = no_op("output_control_node");
775     VLOG(4) << "Add output control node: " << output_control_node->name();
776     if (options.output_control_src == OutputControlSrc::kDataOutputs) {
777       for (Node* n : outputs) {
778         VLOG(4) << "    [data output] add control edge from: " << n->name();
779         g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
780       }
781     } else {
782       for (Node* fbody_node : fbody->control_ret_nodes) {
783         Node* n = node_map[fbody_node->id()];
784         VLOG(4) << "    [control output] add control edge from: " << n->name();
785         g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
786       }
787     }
788   }
789 
790   // We can't leave output control node without incoming control edges, because
791   // in this case outgoing control edge will loose execution frame information.
792   // We connect input_control_node and output_control_node with a control edge
793   // to forward execution frame to the controlled nodes. Above we add a control
794   // edge to all function calls inside function body, to guarantee that we will
795   // always have input_control_node when we need it.
796   if (output_control_node && output_control_node->in_edges().empty()) {
797     if (input_control_node) {
798       VLOG(4) << "Add a control edge between input and output control nodes: "
799               << input_control_node->name() << " to "
800               << output_control_node->name();
801       g->AddControlEdge(input_control_node, output_control_node,
802                         kDoNotCheckDuplicates);
803     } else {
804       VLOG(4) << "Function inlining potentially dropped execution frame "
805                  "information from outgoing control edges.";
806     }
807   }
808 
809   for (const Edge* e : caller->out_edges()) {
810     if (e->IsControlEdge()) {
811       g->AddControlEdge(output_control_node, e->dst(), kDoNotCheckDuplicates);
812     } else {
813       g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
814     }
815   }
816 
817   // ------------------------------------------------------------------------ //
818   // Add an IdentityN or NoOp node in-place of caller node to keep `caller`
819   // fetchable or targetable.
820 
821   if (keep_caller_node) {
822     std::vector<NodeBuilder::NodeOut> output_tensors;
823     absl::c_transform(outputs, std::back_inserter(output_tensors),
824                       [](Node* n) { return NodeBuilder::NodeOut(n, 0); });
825 
826     Node* caller_substitute_node;
827     if (options.keep_caller_node == KeepCallerNode::kTargetable ||
828         output_tensors.empty()) {
829       // IdentityN node must have at least one data input. If function has no
830       // data outputs, we can't keep it fetchable.
831       TF_CHECK_OK(NodeBuilder(caller->name(), "NoOp")
832                       .Device(caller->requested_device())
833                       .ControlInput(output_control_node)
834                       .Finalize(g, &caller_substitute_node));
835 
836     } else if (options.keep_caller_node == KeepCallerNode::kFetchable) {
837       TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN")
838                       .Device(caller->requested_device())
839                       .Input(output_tensors)
840                       .ControlInput(output_control_node)
841                       .Finalize(g, &caller_substitute_node));
842     }
843   }
844 
845   // ------------------------------------------------------------------------ //
846   // 'caller' is replaced with inlined function body nodes and maybe IdentityN
847   // to keep it fetchable.
848   VLOG(3) << "Successfully inlined function call node: " << caller->name();
849   g->RemoveNode(caller);
850 
851   VLOG(4) << "Final graph: " << g->ToGraphDefDebug().DebugString();
852 
853   return OkStatus();
854 }
855 
ExpandInlineFunctions(FunctionLibraryRuntime * lib,Graph * graph,const ExpandInlineFunctionsOptions & options)856 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
857                            const ExpandInlineFunctionsOptions& options) {
858   std::vector<std::pair<Node*, const FunctionBody*>> candidates;
859 
860   const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
861 
862   for (Node* node : graph->nodes()) {
863     // Skip nodes that are not function calls or SymbolicGradient calls.
864     if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) {
865       continue;
866     }
867     // Skip function calls that marked noinline.
868     bool noinline;
869     if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
870       VLOG(3) << "noinline: " << SummarizeNode(*node);
871       continue;
872     }
873     FunctionLibraryRuntime::Handle handle;
874     Status s = InstantiateFunctionCall(node->def(), lib, &handle);
875     if (!s.ok()) {
876       LOG(ERROR) << "Failed to instantiate a function:  " << s.error_message();
877       continue;
878     }
879     const FunctionBody* fbody = lib->GetFunctionBody(handle);
880     CHECK_NOTNULL(fbody);
881     candidates.emplace_back(node, fbody);
882   }
883 
884   bool inlined_any = false;
885   for (const auto& p : candidates) {
886     Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second,
887                                         p.first->IsPartitionedCall()
888                                             ? options.multi_device_options
889                                             : options.native_options);
890     if (inlined.ok()) {
891       inlined_any = true;
892     } else {
893       VLOG(1) << "Failed to inline function call: node=" << p.first->name()
894               << " error=" << inlined.error_message();
895     }
896   }
897 
898   // TODO(ezhulenev): Release handles for inlined function calls.
899 
900   return inlined_any;
901 }
902 
903 }  // end namespace tensorflow
904