xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/eager/execute.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/execute.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <optional>
21 #include <vector>
22 
23 // clang-format off
24 // Required for IS_MOBILE_PLATFORM
25 #include "absl/container/btree_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_replace.h"
28 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/framework/function.pb.h"
31 #include "tensorflow/core/framework/kernel_def.pb.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/lib/core/refcount.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/platform.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 
41 // clang-format on
42 
43 #include "absl/container/inlined_vector.h"
44 #include "absl/strings/match.h"
45 #include "absl/strings/str_cat.h"
46 #include "absl/types/optional.h"
47 #include "tensorflow/c/tf_tensor_internal.h"
48 #include "tensorflow/compiler/jit/defs.h"
49 #include "tensorflow/core/common_runtime/colocation_graph.h"
50 #include "tensorflow/core/common_runtime/device.h"
51 #include "tensorflow/core/common_runtime/device_set.h"
52 #include "tensorflow/core/common_runtime/eager/context.h"
53 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
54 #include "tensorflow/core/common_runtime/eager/execute_node.h"
55 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
56 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
57 #include "tensorflow/core/framework/dataset.h"
58 #include "tensorflow/core/framework/function.h"
59 #include "tensorflow/core/framework/logging.h"
60 #include "tensorflow/core/framework/node_def_util.h"
61 #include "tensorflow/core/framework/tensor_reference.h"
62 #include "tensorflow/core/framework/types.pb.h"
63 #include "tensorflow/core/lib/core/errors.h"
64 #include "tensorflow/core/platform/statusor.h"
65 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
66 #include "tensorflow/core/profiler/lib/traceme.h"
67 #include "tensorflow/core/protobuf/error_codes.pb.h"
68 #include "tensorflow/core/util/device_name_utils.h"
69 #if !defined(IS_MOBILE_PLATFORM)
70 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
71 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
72 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
73 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
74 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
75 #endif  // IS_MOBILE_PLATFORM
76 #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
77 #include "tensorflow/core/framework/step_stats.pb.h"
78 #include "tensorflow/core/framework/tensor.h"
79 #include "tensorflow/core/framework/types.h"
80 #include "tensorflow/core/lib/core/status.h"
81 #include "tensorflow/core/lib/gtl/cleanup.h"
82 #include "tensorflow/core/lib/gtl/flatset.h"
83 #include "tensorflow/core/lib/random/random.h"
84 #include "tensorflow/core/platform/env.h"
85 #include "tensorflow/core/platform/mutex.h"
86 #include "tensorflow/core/util/ptr_util.h"
87 #include "tensorflow/core/util/util.h"
88 
89 #ifdef INTEL_MKL
90 #include "tensorflow/core/graph/mkl_graph_util.h"
91 #endif
92 
93 namespace tensorflow {
94 
95 namespace {
96 
DeviceNameOrUnspecified(Device * device)97 const string& DeviceNameOrUnspecified(Device* device) {
98   static string* unspecified_string = new string("<unspecified>");
99   return (device == nullptr) ? *unspecified_string : device->name();
100 }
101 
102 // Returns whether a kernel should be cached.
KernelCacheEnabled(const OpDef & op_def)103 bool KernelCacheEnabled(const OpDef& op_def) {
104   if (data::DatasetOpKernel::IsDatasetOp(op_def)) {
105     return false;
106   }
107   // TODO(b/162540360): Revisit a way to mark kernels as uncachable once we have
108   // 5+ kernels to exclude.
109   return true;
110 }
111 
112 // This function expects *handle to point to an existing tensor handle that is
113 // currently on "handle_device", but where the operation expects that input to
114 // reside on "expected_input_device".  The function will arrange for this
115 // transfer to happen and will return OK on success and will storage a new
116 // handle to the equivalent tensor on the correct device in "*result".  Or if an
117 // error is encountered, it will return a non-OK status and set "*result" to
118 // nullptr.
119 //
120 // `op_device` is passed in explicitly because `op->device()` might be
121 // unset and we might have selected some specific device to run this op on.
CopyInputToExpectedDevice(EagerContext * ctx,EagerOperation * op,Device * op_device,TensorHandle * handle,int i,Device * handle_device,Device * expected_input_device,TensorHandle ** result)122 Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op,
123                                  Device* op_device,
124                                  TensorHandle* handle,  // op->Inputs()[i]
125                                  int i, Device* handle_device,
126                                  Device* expected_input_device,
127                                  TensorHandle** result) {
128   VLOG(6) << "Expected input device: " << expected_input_device->name()
129           << "; handle_device: " << handle_device->name();
130   // Should only be called when these don't match
131   DCHECK(expected_input_device != handle_device);
132   *result = nullptr;
133   const string& op_device_name = DeviceNameOrUnspecified(op_device);
134 
135   switch (ctx->GetDevicePlacementPolicy()) {
136     case DEVICE_PLACEMENT_SILENT_FOR_INT32:
137       // TODO(xpan): See if we could bubble python related error up
138       // to python level.
139       if (handle->dtype == DT_INT32) {
140         // Note: enabling silent copies of int32 tensors to match behavior
141         // of graph mode.
142         break;
143       }
144       VLOG(6) << "DevicePlacementPolicy: DEVICE_PLACEMENT_SILENT_FOR_INT32 but "
145                  "input type is not INT32.";
146       TF_FALLTHROUGH_INTENDED;
147     case DEVICE_PLACEMENT_EXPLICIT:
148       // tf.identity is allowed to copy, as indicated in the error message
149       // below.
150       if (op->Name() == "Identity" ||
151           op->Name() == "IdentityN"
152           // Constants start on CPU:0 and are copied via EagerConst to the
153           // current device.
154           || op->Name() == "_EagerConst") {
155         break;
156       }
157       return errors::InvalidArgument(
158           "Tensors on conflicting devices:"
159           " cannot compute ",
160           op->Name(), " as input #", i, " was expected to be on ",
161           expected_input_device->name(), " but is actually on ",
162           handle_device->name(), " (operation running on ", op_device_name, ")",
163           " Tensors can be copied explicitly using:"
164           " `with tf.device(device_name): x = tf.identity(x)`"
165           " or transparently copied by using"
166           " tf.config.experimental.set_device_policy('silent')."
167           " Copying tensors between devices may slow down your model");
168     case DEVICE_PLACEMENT_WARN:
169       LOG(WARNING) << "before computing " << op->Name() << " input #" << i
170                    << " was expected to be on " << expected_input_device->name()
171                    << " but is actually on " << handle_device->name()
172                    << " (operation running on " << op_device_name
173                    << "). This triggers a copy which can be a performance "
174                       "bottleneck.";
175       break;
176     case DEVICE_PLACEMENT_SILENT:  // Do nothing.
177       break;
178   }
179   // We are only here if the policy is warn or silent copies, so we should
180   // trigger a copy.
181   TensorHandle* result_handle = nullptr;
182   profiler::TraceMe activity(
183       [&] {
184         return absl::StrCat("_Send input ", i, " from ", handle_device->name(),
185                             " to ", expected_input_device->name());
186       },
187       profiler::TraceMeLevel::kInfo);
188   Status status =
189       EagerCopyToDevice(handle, ctx, &op->Executor(), expected_input_device,
190                         /* mirror= */ true, &result_handle);
191   activity.Stop();
192   if (!status.ok()) {
193     return Status(
194         status.code(),
195         absl::StrCat("Failed copying input tensor from ", handle_device->name(),
196                      " to ", expected_input_device->name(), " in order to run ",
197                      op->Name(), ": ", status.error_message()));
198   }
199 
200   *result = result_handle;
201 
202   return OkStatus();
203 }
204 
205 // `op_device_name` the name of the device on which the op will run, if any.
206 // For functions running using function library runtime, the device can be
207 // unspecified.
ValidateInputTypeAndPlacement(EagerContext * ctx,EagerOperation * op,const core::RefCountPtr<KernelAndDevice> & kernel)208 Status ValidateInputTypeAndPlacement(
209     EagerContext* ctx, EagerOperation* op,
210     const core::RefCountPtr<KernelAndDevice>& kernel) {
211   profiler::TraceMe activity("ValidateInputTypeAndPlacement",
212                              profiler::TraceMeLevel::kInfo);
213   const int n_inputs = op->Inputs().size();
214   if (kernel->num_inputs() != n_inputs) {
215     return errors::InvalidArgument("expected ", kernel->num_inputs(),
216                                    " inputs, got ", n_inputs);
217   }
218   const bool is_function = kernel->IsFunction();
219   if (n_inputs > 0) {
220     const DataType* input_types = &kernel->input_dtypes()[0];
221     const absl::InlinedVector<TensorHandle*, 4>* handles;
222     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&handles));
223     for (int i = 0; i < n_inputs; ++i) {
224       TensorHandle* handle = (*handles)[i];
225       Device* expected_device = kernel->InputDevice(i);
226       if (!kernel->IsFunction() && handle->Type() == TensorHandle::PACKED) {
227         // Extract a handle on the op device from a packed input.
228         // This happens when a function is marked for XLA compilation.
229         // MaybePackInputTensor guarantees that a primitive op has no packed
230         // input at this point.
231         for (int j = 0; j < handle->NumPackedHandles(); ++j) {
232           TensorHandle* h = nullptr;
233           TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(j, &h));
234           if ((h->op_device() != nullptr) &&
235               (h->op_device()->name() == op->DeviceName())) {
236             op->UpdateInput(i, h);
237             handle = h;
238             break;
239           }
240         }
241       }
242       Device* handle_device = handle->DeviceOrHostCPU(*ctx);
243       const bool maybe_copy =
244           !is_function || handle->Type() != TensorHandle::REMOTE;
245       VLOG(6) << "!is_function: " << !is_function;
246       VLOG(6) << "handle->Type(): " << handle->Type();
247       // If the input is already on the right device, then nothing to do.
248       if (expected_device != handle_device && maybe_copy) {
249         TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
250                                                      handle, i, handle_device,
251                                                      expected_device, &handle));
252         op->UpdateInput(i, handle);
253         // Unref handle since it has a ref as an input now
254         handle->Unref();
255       }
256       if (handle->dtype != input_types[i]) {
257         return errors::InvalidArgument(
258             "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
259             " was expected to be a ", DataTypeString(input_types[i]),
260             " tensor but is a ", DataTypeString(handle->dtype), " tensor");
261       }
262     }
263   }
264   return OkStatus();
265 }
266 
GetOutputDTypes(EagerOperation * op,DataTypeVector * output_dtypes)267 Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
268   const auto& node_def = op->MutableAttrs()->BuildNodeDef();
269   const OpDef* op_def = nullptr;
270 
271   const FunctionDef* function_def =
272       op->EagerContext().FuncLibDef()->Find(op->Name());
273   if (function_def != nullptr) {
274     op_def = &(function_def->signature());
275   } else {
276     TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
277   }
278 
279   TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));
280 
281   return OkStatus();
282 }
283 
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)284 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
285                                                const tensorflow::Fprint128& b) {
286   return {tensorflow::FingerprintCat64(a.low64, b.low64),
287           tensorflow::FingerprintCat64(a.high64, b.high64)};
288 }
289 
FingerprintCat128(const tensorflow::Fprint128 & a,const int64_t b)290 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
291                                                const int64_t b) {
292   auto x = tensorflow::FingerprintCat64(a.low64, b);
293   return {x, tensorflow::FingerprintCat64(a.high64, x)};
294 }
295 
GetKernelDef(const EagerOperation & op,const NodeDef * node_def,const Device * op_device)296 const KernelDef* GetKernelDef(const EagerOperation& op, const NodeDef* node_def,
297                               const Device* op_device) {
298   if (node_def == nullptr || op_device == nullptr) return nullptr;
299   const KernelDef* kernel_def = nullptr;
300   Status s = FindKernelDef(DeviceType(op_device->device_type()), *node_def,
301                            &kernel_def,
302                            /*kernel_class_name=*/nullptr);
303   if (!s.ok()) return nullptr;
304   return kernel_def;
305 }
306 
IsHostMemoryArg(const EagerOperation & op,const NodeDef * node_def,const Device * op_device,const KernelDef * kernel_def,const int port_id)307 bool IsHostMemoryArg(const EagerOperation& op, const NodeDef* node_def,
308                      const Device* op_device, const KernelDef* kernel_def,
309                      const int port_id) {
310   if (op.is_function()) return false;
311   if (node_def == nullptr) return false;
312   if (kernel_def == nullptr || op_device == nullptr) return false;
313   const auto& host_memory_args = kernel_def->host_memory_arg();
314   const OpDef& op_def = OpRegistry::Global()->LookUp(op.Name())->op_def;
315   const int arg_id = OpPortIdToArgId(*node_def, op_def.input_arg(), port_id);
316   return std::find(host_memory_args.begin(), host_memory_args.end(),
317                    op_def.input_arg(arg_id).name()) != host_memory_args.end();
318 }
319 
GetDeviceForInput(const EagerOperation & op,const EagerContext & ctx,const bool is_host_memory_arg,TensorHandle * tensor_handle,Device ** result)320 Status GetDeviceForInput(const EagerOperation& op, const EagerContext& ctx,
321                          const bool is_host_memory_arg,
322                          TensorHandle* tensor_handle, Device** result) {
323   Device* cpu_device = ctx.HostCPU();
324   string device_name;
325   if (tensor_handle->Type() != TensorHandle::LOCAL) {
326     Device* device = tensor_handle->device();
327     device_name = device != nullptr ? device->name() : cpu_device->name();
328     *result = (device == nullptr ? cpu_device : device);
329   } else if (tensor_handle->dtype == DT_RESOURCE) {
330     // Use the resource's actual device because it is the device that will
331     // influence partitioning the multi-device function.
332     const Tensor* tensor;
333     // TODO(fishx): Avoid blocking here.
334     TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
335     if (tensor->NumElements() == 0) {
336       return errors::InvalidArgument("Empty resource handle");
337     }
338     const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
339     device_name = handle.device();
340 
341     Device* input_device;
342     TF_RETURN_IF_ERROR(
343         ctx.FindDeviceFromName(device_name.c_str(), &input_device));
344     *result = input_device;
345   } else {
346     Device* device = tensor_handle->device();
347     const bool is_tpu = device != nullptr && device->device_type() == "TPU";
348     // int32 return values can be placed on TPUs.
349     // int32 retrun values can be placed on device for eager operations.
350     const bool use_host_memory =
351         is_tpu || (!op.is_function() && device != cpu_device &&
352                    !is_host_memory_arg)
353             ? MTypeFromDTypeIntsOnDevice(tensor_handle->dtype)
354             : MTypeFromDType(tensor_handle->dtype);
355     if (use_host_memory) {
356       *result = cpu_device;
357     } else {
358       // Eager ops executing as functions should have their preferred inputs set
359       // to the op's device. This allows us to avoid expensive D2H copies if a
360       // mirror of the tensor already exists on the op's device.
361       if (!op.is_function() && device != cpu_device && !is_host_memory_arg) {
362         device = absl::get<Device*>(op.Device());
363       }
364       *result = (device == nullptr ? cpu_device : device);
365     }
366   }
367   return OkStatus();
368 }
369 
370 // Appends a TensorShape object to Fprint128 hash.
371 // For best performance, we would like to avoid dynamic memory allocation in
372 // this function.
373 // If "shape" has unknown rank, we attach "?" to hashed content; otherwise we
374 // attach every dim size to hashed content.
AppendTensorShapeToFingerprint(const PartialTensorShape & shape,Fprint128 * fingerprint)375 void AppendTensorShapeToFingerprint(const PartialTensorShape& shape,
376                                     Fprint128* fingerprint) {
377   if (shape.unknown_rank()) {
378     char c = '?';
379     *fingerprint = FingerprintCat128(*fingerprint, c);
380   } else {
381     for (int i = 0; i < shape.dims(); i++) {
382       int64_t dim = shape.dim_size(i);
383       *fingerprint = FingerprintCat128(*fingerprint, dim);
384     }
385   }
386 }
387 
GetFuncAttr(const EagerOperation * op,const EagerContext & ctx,const char * attr_name,bool * value)388 Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx,
389                    const char* attr_name, bool* value) {
390   Status status = op->Attrs().Get(attr_name, value);
391   if (status.ok()) {
392     VLOG(2) << "Caller explicitly specifies "
393             << (attr_name ? "=true " : "=false, ") << op->DebugString();
394     return OkStatus();
395   }
396 
397   const FunctionDef* function_def =
398       ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name());
399   if (function_def == nullptr) {
400     return errors::NotFound("Failed to find function '", op->Name(), "'");
401   }
402 
403   status = GetNodeAttr(AttrSlice(&function_def->attr()), attr_name, value);
404   if (status.ok()) {
405     VLOG(2) << "Function definition explicitly specifies "
406             << (attr_name ? "=true" : "=false");
407     return OkStatus();
408   }
409   return status;
410 }
411 
MustCompileWithXLA(const EagerOperation * op,const EagerContext & ctx,bool * compile_with_xla)412 Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx,
413                           bool* compile_with_xla) {
414   if (!op->is_function()) {
415     *compile_with_xla = false;
416     return OkStatus();
417   }
418 
419   if (op->eager_func_params().has_value() &&
420       op->eager_func_params().value().is_component_function) {
421     // If the op is a component of a multi-device function, don't compile it
422     // with XLA.
423     *compile_with_xla = false;
424     return OkStatus();
425   }
426 
427   Status status = GetFuncAttr(op, ctx, kXlaMustCompileAttr, compile_with_xla);
428   if (status.ok()) {
429     return OkStatus();
430   }
431 
432   // No explicit requests. Compile for XLA devices by default.
433   if (op->GetDeviceParsedName().type == "TPU" ||
434       op->GetDeviceParsedName().type == "XLA_GPU" ||
435       op->GetDeviceParsedName().type == "XLA_CPU") {
436     VLOG(2) << "Compiling " << op->Name()
437             << " with XLA because it is running on an XLA device "
438             << op->GetDeviceParsedName().type;
439     *compile_with_xla = true;
440   } else {
441     *compile_with_xla = false;
442   }
443 
444   return OkStatus();
445 }
446 
VerifyWrappableInCallOp(const OpDef & opdef,EagerOperation * op)447 Status VerifyWrappableInCallOp(const OpDef& opdef, EagerOperation* op) {
448   absl::flat_hash_set<string> opdef_attrs;
449   for (const auto& attr : opdef.attr()) {
450     opdef_attrs.insert(attr.name());
451   }
452   const auto& node_def = op->MutableAttrs()->BuildNodeDef();
453   for (const auto& attr : node_def.attr()) {
454     if (opdef_attrs.find(attr.first) == opdef_attrs.end()) {
455       return errors::Unimplemented("EagerOperation: ", op->Name(),
456                                    " has a private attr '", attr.first, "'.");
457     }
458   }
459   return OkStatus();
460 }
461 
462 using ProtoArgListType = protobuf::RepeatedPtrField<OpDef_ArgDef>;
463 
EscapeOrigName(const string & orig_name)464 string EscapeOrigName(const string& orig_name) {
465   // Replace _ with __ in the original name to avoid name conflicts.
466   return absl::StrReplaceAll(orig_name, {{"_", "__"}});
467 }
468 
469 // Variadic args are flattened during wrapping. This utility returns the name
470 // of a flattened arg/attr.
GetFlatName(const string orig_name,int index)471 string GetFlatName(const string orig_name, int index) {
472   return absl::StrCat(EscapeOrigName(orig_name), "_", index);
473 }
474 
475 // Builds the name of the wrapping FunctionDef for an eager op.
476 //
477 // For ops without variadic inputs/outputs, the name is simply __wrapped_OpType.
478 //
479 // For ops with variadic inputs/outputs, the arity of each variadic attr is
480 // encoded in the name. For example:
481 //
482 // IdentityN[T:[DT_FLOAT, DT_INT64]] -> __wrapped__IdentityN_T_2
483 // Concat[N:2, T:DT_FLOAT] -> __wrapped__Concat_N_2
BuildWrappedOpName(EagerOperation * op,const OpDef & opdef,const AbstractOpAttrs * op_attrs,string * name)484 Status BuildWrappedOpName(EagerOperation* op, const OpDef& opdef,
485                           const AbstractOpAttrs* op_attrs, string* name) {
486   string fname = absl::StrCat("__wrapped__", EscapeOrigName(op->Name()));
487   // For every variadic arg in `args`, populates `attr_to_len` with
488   // (attr_name, len(arg)).
489   auto FillAttrToLen = [op_attrs, op](
490                            const ProtoArgListType& args,
491                            absl::btree_map<string, int>* attr_to_len) {
492     for (const auto& arg : args) {
493       if (!arg.type_list_attr().empty()) {
494         gtl::InlinedVector<DataType, 4> type_list;
495         TF_RETURN_IF_ERROR(
496             op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
497         (*attr_to_len)[arg.type_list_attr()] = type_list.size();
498       } else if (!arg.number_attr().empty()) {
499         int64_t number_attr;
500         if (!op_attrs->GetInt(arg.number_attr(), &number_attr)) {
501           return errors::Internal("Unable to read attr ", arg.number_attr(),
502                                   " for op ", op->Name());
503         }
504         (*attr_to_len)[arg.number_attr()] = number_attr;
505       }
506     }
507     return OkStatus();
508   };
509   absl::btree_map<string, int> attr_to_len;
510   TF_RETURN_IF_ERROR(FillAttrToLen(opdef.input_arg(), &attr_to_len));
511   TF_RETURN_IF_ERROR(FillAttrToLen(opdef.output_arg(), &attr_to_len));
512   for (auto& name_len : attr_to_len) {
513     absl::StrAppend(&fname, "_", name_len.first, "_", name_len.second);
514   }
515   // The NodeDef in the FunctionDef gets placed on `op-DeviceName()` to ensure
516   // placement consistency with eager mode.
517   // TODO(b/200153278): Ideally we would just forward the call op's device at
518   // runtime but currently there is no way to do it so we incur the cost of
519   // creating extra FunctionDefs.
520   absl::StrAppend(&fname, "_device_", op->DeviceName());
521   *name = fname;
522   return OkStatus();
523 }
524 
525 // Validates the node def. This is required when running in eager op as function
526 // mode because this code path does not go through the _apply_op_helper's
527 // validation (which is reached when executing in graph mode)
528 // or the eager execution's validation (which is reached via the CreateOpKernel
529 // call).
ValidateOp(EagerOperation * op)530 Status ValidateOp(EagerOperation* op) {
531   const NodeDef& node_def = op->MutableAttrs()->BuildNodeDef();
532   const OpDef* op_def;
533   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
534   return ValidateNodeDef(node_def, *op_def);
535 }
536 
537 // Builds the signature of the wrapping FunctionDef for an eager op.
538 //
539 // For ops without variadic inputs/outputs, the signature is the same as the
540 // OpDef of the original op.
541 //
542 // Variadic inputs/outputs get flattened since we do not support executing
543 // functions with variadic signatures.
544 //
545 // TODO(srbs): These examples should be tests.
546 //
547 // Examples:
548 //
549 // Mixed type list:
550 //
551 // op {
552 //   name: "IdentityN"
553 //   input_arg {
554 //     name: "input"
555 //     type_list_attr: "T"
556 //   }
557 //   output_arg {
558 //     name: "output"
559 //     type_list_attr: "T"
560 //   }
561 //   attr {
562 //     name: "T"
563 //     type: "list(type)"
564 //     has_minimum: true
565 //     minimum: 1
566 //   }
567 // }
568 //
569 // With two inputs T=[DT_FLOAT, DT_INT64] would convert to
570 //
571 // op {
572 //   name: "__wrapped__IdentityN_T_2"
573 //   input_arg {
574 //     name: "input_0"
575 //     type_attr: "T_0"
576 //   }
577 //   input_arg {
578 //     name: "input_1"
579 //     type_attr: "T_1"
580 //   }
581 //   output_arg {
582 //     name: "output_0"
583 //     type_attr: "T_0"
584 //   }
585 //   output_arg {
586 //     name: "output_1"
587 //     type_attr: "T_1"
588 //   }
589 //   attr {
590 //     name: "T_0"
591 //     type: "type"
592 //   }
593 //   attr {
594 //     name: "T_1"
595 //     type: "type"
596 //   }
597 //   attr {
598 //     name: "T"
599 //     type: "list(type)"
600 //     has_minimum: true
601 //     minimum: 1
602 //   }
603 // }
604 //
605 // Note that the list(type) attr is preserved so that it can get copied to the
606 // inner op via a placeholder. This allows additional verification.
607 //
608 // Single type list:
609 //
610 // op {
611 //   name: "ConcatV2"
612 //   input_arg {
613 //     name: "values"
614 //     type_attr: "T"
615 //     number_attr: "N"
616 //   }
617 //   attr {
618 //     name: "N"
619 //     type: "int"
620 //     has_minimum: true
621 //     minimum: 2
622 //   }
623 //   attr {
624 //     name: "T"
625 //     type: "type"
626 //   }
627 //   [axis, output, Tidx are simply copied]
628 // }
629 //
630 // With two inputs N=2 would convert to:
631 //
632 // op {
633 //   name: "__wrapped__ConcatV2_N_2"
634 //   input_arg {
635 //     name: "values_0"
636 //     type_attr: "T"
637 //   }
638 //   input_arg {
639 //     name: "values_1"
640 //     type_attr: "T"
641 //   }
642 //   attr {
643 //     name: "N"
644 //     type: "int"
645 //     has_minimum: true
646 //     minimum: 2
647 //   }
648 //   attr {
649 //     name: "T"
650 //     type: "type"
651 //   }
652 //   [axis, output, Tidx are simply copied]
653 // }
654 //
655 // Note that the N attr is preserved so that it can get copied to the
656 // inner op via a placeholder. This allows additional verification.
BuildWrappedOpSignature(EagerOperation * op,const OpDef & opdef,const string & fname,OpDef & signature)657 Status BuildWrappedOpSignature(EagerOperation* op, const OpDef& opdef,
658                                const string& fname, OpDef& signature) {
659   signature = opdef;
660   signature.clear_input_arg();
661   signature.clear_output_arg();
662   signature.set_name(fname);
663   auto op_attrs = op->GetOpAttrs();
664   auto FillSignatureArgs = [op_attrs, op](
665                                const ProtoArgListType& opdef_args,
666                                ProtoArgListType* sig_args,
667                                absl::flat_hash_set<string>& new_attrs) {
668     for (const auto& arg : opdef_args) {
669       if (!arg.type_list_attr().empty()) {
670         gtl::InlinedVector<DataType, 4> type_list;
671         TF_RETURN_IF_ERROR(
672             op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
673         for (size_t i = 0; i < type_list.size(); i++) {
674           auto arg_def = sig_args->Add();
675           arg_def->set_name(GetFlatName(arg.name(), i));
676           auto attr_name = GetFlatName(arg.type_list_attr(), i);
677           new_attrs.insert(attr_name);
678           arg_def->set_type_attr(std::move(attr_name));
679         }
680       } else if (!arg.number_attr().empty()) {
681         int64_t number_attr;
682         if (!op_attrs->GetInt(arg.number_attr(), &number_attr)) {
683           return errors::Internal("Unable to read attr ", arg.number_attr(),
684                                   " for op ", op->Name());
685         }
686         for (int64_t i = 0; i < number_attr; i++) {
687           auto arg_def = sig_args->Add();
688           arg_def->set_name(GetFlatName(arg.name(), i));
689           if (!arg.type_attr().empty()) {
690             arg_def->set_type_attr(arg.type_attr());
691           } else {
692             arg_def->set_type(arg.type());
693           }
694         }
695       } else {
696         auto arg_def = sig_args->Add();
697         *arg_def = arg;
698         arg_def->set_name(EscapeOrigName(arg.name()));
699         if (!arg.type_attr().empty()) {
700           // Don't escape: type attrs are still referenced by the original name.
701           arg_def->set_type_attr(arg.type_attr());
702         }
703       }
704     }
705     return OkStatus();
706   };
707   absl::flat_hash_set<string> new_attrs;
708   TF_RETURN_IF_ERROR(FillSignatureArgs(
709       opdef.input_arg(), signature.mutable_input_arg(), new_attrs));
710   TF_RETURN_IF_ERROR(FillSignatureArgs(
711       opdef.output_arg(), signature.mutable_output_arg(), new_attrs));
712   for (auto& attr_name : new_attrs) {
713     auto attr_def = signature.mutable_attr()->Add();
714     attr_def->set_name(attr_name);
715     attr_def->set_type("type");
716   }
717   return OkStatus();
718 }
719 
720 // For mixed type inputs "list(type)" we create new attributes in the signature
721 // for each element tensor (See examples in BuildWrappedOpSignature). Here
722 // we construct the values for those attributes and set them on the wrapped op.
AddMixedTypeListAttrs(EagerOperation * wrapped_op,const AbstractOpAttrs * op_attrs,const OpDef & opdef)723 Status AddMixedTypeListAttrs(EagerOperation* wrapped_op,
724                              const AbstractOpAttrs* op_attrs,
725                              const OpDef& opdef) {
726   auto FillAttrsToAdd =
727       [op_attrs](const ProtoArgListType& opdef_args,
728                  absl::flat_hash_map<string, DataType>* attrs_to_add) {
729         for (const auto& arg : opdef_args) {
730           if (!arg.type_list_attr().empty()) {
731             gtl::InlinedVector<DataType, 4> type_list;
732             TF_RETURN_IF_ERROR(
733                 op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
734             for (size_t i = 0; i < type_list.size(); i++) {
735               auto attr_name = GetFlatName(arg.type_list_attr(), i);
736               (*attrs_to_add)[attr_name] = type_list[i];
737             }
738           }
739         }
740         return OkStatus();
741       };
742   absl::flat_hash_map<string, DataType> attrs_to_add;
743   TF_RETURN_IF_ERROR(FillAttrsToAdd(opdef.input_arg(), &attrs_to_add));
744   TF_RETURN_IF_ERROR(FillAttrsToAdd(opdef.output_arg(), &attrs_to_add));
745   for (auto& name_type : attrs_to_add) {
746     TF_RETURN_IF_ERROR(
747         wrapped_op->SetAttrType(name_type.first.data(), name_type.second));
748   }
749   // TODO(srbs): Rename all original attributes using EscapeOrigName.
750   return OkStatus();
751 }
752 
753 // Maps the op's outputs to the function outputs. Mainly useful for variadic
754 // outputs which need to be flattened.
PopulateRetMap(FunctionDef * fdef,const AbstractOpAttrs * op_attrs,const EagerOperation * op,const OpDef & opdef,const OpDef & signature,const string & node_name)755 Status PopulateRetMap(FunctionDef* fdef, const AbstractOpAttrs* op_attrs,
756                       const EagerOperation* op, const OpDef& opdef,
757                       const OpDef& signature, const string& node_name) {
758   int next_sig_output = 0;
759   for (size_t i = 0; i < opdef.output_arg_size(); i++) {
760     const auto& output_arg = opdef.output_arg(i);
761     if (!output_arg.type_list_attr().empty()) {
762       gtl::InlinedVector<DataType, 4> type_list;
763       TF_RETURN_IF_ERROR(
764           op_attrs->GetTypeList(output_arg.type_list_attr(), &type_list));
765       for (int j = 0; j < type_list.size(); j++) {
766         (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
767             absl::StrCat(node_name, ":", output_arg.name(), ":", j);
768       }
769     } else if (!output_arg.number_attr().empty()) {
770       int64_t number_attr;
771       if (!op_attrs->GetInt(output_arg.number_attr(), &number_attr)) {
772         return errors::Internal("Unable to read attr ",
773                                 output_arg.number_attr(), " for op ",
774                                 op->Name());
775       }
776       for (int j = 0; j < number_attr; j++) {
777         (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
778             absl::StrCat(node_name, ":", output_arg.name(), ":", j);
779       }
780     } else {
781       (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
782           absl::StrCat(node_name, ":", output_arg.name(), ":0");
783     }
784   }
785   return OkStatus();
786 }
787 
788 #ifdef INTEL_MKL
GetMKLNodeDef(NodeDef * ndef)789 inline void GetMKLNodeDef(NodeDef* ndef) {
790   // All MKL eager ops have `_kernel` private attribute that needs to be set
791   // to a fixed label.
792   AttrValue attr_kernel;
793   attr_kernel.set_s(mkl_op_registry::kMklNameChangeOpLabel);
794   (*ndef->mutable_attr()).insert({"_kernel", attr_kernel});
795 }
796 #endif  // INTEL_MKL
797 
WrapInCallOp(EagerOperation * op,EagerOperation ** wrapped_op)798 Status WrapInCallOp(EagerOperation* op, EagerOperation** wrapped_op) {
799   DCHECK(!op->is_function());
800   const OpDef& opdef = OpRegistry::Global()->LookUp(op->Name())->op_def;
801   // Raise an error for ops which don't support wrapping yet. This includes
802   // ops with list inputs/outputs and ops with private attrs.
803   // TODO(srbs): Support list inputs/outputs.
804   TF_RETURN_IF_ERROR(VerifyWrappableInCallOp(opdef, op));
805 
806   // Build a FunctionDef containing op as a node and register with context.
807   // TODO(srbs): Here we are unable to distinguish between a FunctionDef for
808   // a wrapped eager op and an existing user defined function registered with
809   // the context e.g. with something like
810   // @tf.function
811   // def __wrapped__Add(x, y):
812   //   ...
813   // This can be avoided by introducing a dict in EagerContext that stores a
814   // mapping from the eager op's name to its unique FunctionDef name.
815   auto op_attrs = op->GetOpAttrs();
816   string fname;
817   TF_RETURN_IF_ERROR(BuildWrappedOpName(op, opdef, op_attrs, &fname));
818   if (!op->EagerContext().GetFunctionDef(fname)) {
819     FunctionDef fdef;
820     // Set signature.
821     TF_RETURN_IF_ERROR(
822         BuildWrappedOpSignature(op, opdef, fname, *fdef.mutable_signature()));
823     // Add node.
824     NodeDef* ndef = fdef.add_node_def();
825     ndef->set_op(op->Name());
826     ndef->set_name(op->Name());  // This could be anything.
827     const auto& signature = fdef.signature();
828     for (size_t i = 0; i < signature.input_arg_size(); i++) {
829       ndef->add_input(absl::StrCat(fdef.signature().input_arg(i).name(), ":0"));
830     }
831     // TODO(srbs): Private attrs on the op are dropped here and applied to
832     // the call op instead. If this causes problems we might have to copy those
833     // attrs to this ndef. That would require updating fname to contain a hash
834     // of such attributes.
835     for (const auto& attr : opdef.attr()) {
836       (*ndef->mutable_attr())[attr.name()].set_placeholder(attr.name());
837     }
838     // Set the device of this node to be the exact same one that eager mode
839     // would have used.
840     // TODO(b/200153278): Ideally we would just forward the call op's device at
841     // runtime but currently there is no way to do it.
842     ndef->set_device(op->DeviceName());
843 
844 #ifdef INTEL_MKL
845     if (IsMKLEnabled() &&
846         absl::StartsWith(op->Name(), mkl_op_registry::kMklOpPrefix)) {
847       GetMKLNodeDef(ndef);
848     }
849 #endif  // INTEL_MKL
850 
851     // Set `ret` map.
852     TF_RETURN_IF_ERROR(
853         PopulateRetMap(&fdef, op_attrs, op, opdef, signature, ndef->name()));
854     VLOG(1) << fdef.DebugString();
855     TF_RETURN_IF_ERROR(op->EagerContext().AddFunctionDef(std::move(fdef)));
856   }
857   // Build the call op.
858   auto& ctx = op->EagerContext();
859   AbstractOperationPtr call_op(ctx.CreateOperation());
860   TF_RETURN_IF_ERROR(call_op->Reset(fname.c_str(), op->DeviceName().c_str()));
861   for (auto t : op->Inputs()) {
862     TF_RETURN_IF_ERROR(call_op->AddInput(t));
863   }
864   *wrapped_op = down_cast<EagerOperation*>(call_op.release());
865   // Attributes on the elementary eager operation are applied to the call op and
866   // to the NodeDef inside the FunctionDef. This allows us to have a single
867   // FunctionDef for different attribute values. When the function is
868   // instantiated, these attributes get forwarded to the NodeDef. This is done
869   // by setting the AttrValue.placeholder field for the NodeDef attrs.
870   (*wrapped_op)->AddAttrs(op_attrs);
871   return AddMixedTypeListAttrs(*wrapped_op, op_attrs, opdef);
872 }
873 
874 // Necessary condition to place int args/retvals on device but not sufficient.
875 // For eager operations return values can be placed on the device for use
876 // by subsequent eager ops. E.g.
877 // with tf.device("/GPU:0"):
878 //   x = tf.random_uniform(shape=(2, 2), maxval=5, dtype=tf.int32)
879 //   y = tf.random_uniform(shape=(2, 2), maxval=5, dtype=tf.int32)
880 //   z = tf.bitwise.bitwise_and(x, y)
881 // In the above example `z` can use the outputs of `x` and `y` without needing
882 // an H2D copy if x and y are left on-device.
IntArgsAndRetvalsOnDevice(EagerOperation * op,const KernelDef * kernel_def)883 bool IntArgsAndRetvalsOnDevice(EagerOperation* op,
884                                const KernelDef* kernel_def) {
885   // We choose to leave `EagerConsts`
886   // on HOST to avoid `shape` and other arguments that are traditionally pinned
887   // to HostMemory from being placed on-device and then being copied to host via
888   // an expensive D2H transfer.
889   if (op->Name() == "_EagerConst") return false;
890 
891   // Check if any of the Op's output_arg(s) are pinned to Host.
892   if (kernel_def == nullptr) return false;
893   const OpDef& op_def = OpRegistry::Global()->LookUp(op->Name())->op_def;
894   for (const string& host_memory_arg : kernel_def->host_memory_arg()) {
895     for (const auto& output_arg : op_def.output_arg()) {
896       if (output_arg.name() == host_memory_arg) {
897         return false;
898       }
899     }
900   }
901 
902   return true;
903 }
904 
GetKernelCacheKey(const EagerOperation & op,const Fprint128 & op_cache_key,const std::vector<Device * > & input_device_ptrs,const std::unordered_map<int,DtypeAndPartialTensorShape> & input_resource_variable_dtypes_and_shapes)905 StatusOr<Fprint128> GetKernelCacheKey(
906     const EagerOperation& op, const Fprint128& op_cache_key,
907     const std::vector<Device*>& input_device_ptrs,
908     const std::unordered_map<int, DtypeAndPartialTensorShape>&
909         input_resource_variable_dtypes_and_shapes) {
910   EagerContext& ctx = op.EagerContext();
911 
912   Fprint128 cache_key = op_cache_key;
913   /// Include soft placement policy in cache key since the placement strategy
914   // can change and thus affect which kernel is picked.
915   cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement());
916 
917   // Include run_eager_op_as_function policy in cache key since the execution
918   // strategy can change and affect which kernel is picked.
919   VLOG(3) << "ctx.RunEagerOpAsFunction(): " << ctx.RunEagerOpAsFunction();
920   cache_key = FingerprintCat128(cache_key, ctx.RunEagerOpAsFunction());
921 
922   // When running in eager_op_as_function mode Send/Recv ops need to be
923   // placed on the same rendezvous to match the behaviour of eager mode.
924   bool reuse_rendezvous_for_functions =
925       (ctx.RunEagerOpAsFunction() && !op.is_function()) ||
926       ctx.GetReuseRendezvousForFunctions();
927   // The launch-time rendezvous reuse setting is bundled with the kernel, so we
928   // need to include it in the cache key.
929   cache_key = FingerprintCat128(cache_key, reuse_rendezvous_for_functions);
930 
931   for (int i = 0, end = input_device_ptrs.size(); i < end; ++i) {
932     cache_key = FingerprintCat128(cache_key,
933                                   Fingerprint128(input_device_ptrs[i]->name()));
934 
935     auto input_resource = input_resource_variable_dtypes_and_shapes.find(i);
936     if (input_resource != input_resource_variable_dtypes_and_shapes.end()) {
937       // const DtypeAndPartialTensorShape& dtype_and_shape
938       const DtypeAndPartialTensorShape& dtype_and_shape =
939           input_resource->second;
940       // Add _Arg index, dtype and shape to "cache_key".
941       cache_key = FingerprintCat128(cache_key, i);
942       cache_key = FingerprintCat128(cache_key, dtype_and_shape.dtype);
943       AppendTensorShapeToFingerprint(dtype_and_shape.shape, &cache_key);
944     }
945   }
946 
947   return cache_key;
948 }
949 
950 // Extracts function input info for `op` with `kernel_def`.
951 // The following are extracted:
952 //   `input_device_ptrs` - The input devices of `op`.
953 //   `composite_devices` - Maps from a CompositeDevice name to a list of
954 //     physical device names.
955 //   `input_resource_variable_dtypes_shape` - A map from input index
956 //     to dtype and shapes for resource inputs.
ExtractFunctionInputInfo(EagerOperation * op,const KernelDef * kernel_def,std::vector<Device * > & input_device_ptrs,absl::flat_hash_map<string,const std::vector<string> * > & composite_devices,std::unordered_map<int,DtypeAndPartialTensorShape> & input_resource_variable_dtypes_and_shapes)957 Status ExtractFunctionInputInfo(
958     EagerOperation* op, const KernelDef* kernel_def,
959     std::vector<Device*>& input_device_ptrs,
960     absl::flat_hash_map<string, const std::vector<string>*>& composite_devices,
961     std::unordered_map<int, DtypeAndPartialTensorShape>&
962         input_resource_variable_dtypes_and_shapes) {
963   profiler::TraceMe activity("EagerCopyToDevice",
964                              profiler::TraceMeLevel::kInfo);
965   EagerContext& ctx = op->EagerContext();
966   input_device_ptrs.reserve(op->Inputs().size());
967   const absl::InlinedVector<TensorHandle*, 4>* inputs;
968   TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
969   Device* op_device = nullptr;
970   const NodeDef* node_def = nullptr;
971   if (!op->is_function()) {
972     op_device = absl::get<Device*>(op->Device());
973     node_def = &op->MutableAttrs()->BuildNodeDef();
974   }
975   for (int i = 0, end = inputs->size(); i < end; ++i) {
976     TensorHandle* input = (*inputs)[i];
977 
978     Device* input_device;
979     bool is_host_memory_arg =
980         IsHostMemoryArg(*op, node_def, op_device, kernel_def, i);
981     TF_RETURN_IF_ERROR(
982         GetDeviceForInput(*op, ctx, is_host_memory_arg, input, &input_device));
983     VLOG(1) << op->Name() << ":input:" << i << " " << input_device->name();
984     input_device_ptrs.push_back(input_device);
985     CompositeDevice* composite_device = nullptr;
986     if (ctx.FindCompositeDeviceFromName(input_device->name(), &composite_device)
987             .ok()) {
988       composite_devices[input_device->name()] =
989           composite_device->underlying_devices();
990     }
991     if (input->dtype == DT_RESOURCE) {
992       // We only care about data type and shape for resource variable inputs.
993       // But we have no way to tell if input is resource variable (other than
994       // looking it up in ResourceMgr, which is slow). So we just get
995       // resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
996       // resource_dtypes_and_shapes is not empty, take the first element.
997       std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
998       TF_RETURN_IF_ERROR(
999           input->GetResourceHandleDtypesAndShapes(&resource_dtypes_and_shapes));
1000       if (!resource_dtypes_and_shapes.empty()) {
1001         const DtypeAndPartialTensorShape& dtype_and_shape =
1002             resource_dtypes_and_shapes.at(0);
1003         input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
1004       }
1005     }
1006   }
1007   return OkStatus();
1008 }
1009 
SetOpDevice(EagerContext & ctx,EagerOperation * op,Device ** device)1010 Status SetOpDevice(EagerContext& ctx, EagerOperation* op, Device** device) {
1011   // Here in local execute, set preferred device to be on the local task to
1012   // avoid placing op on a remote device with higher priority.
1013   const DeviceNameUtils::ParsedName& preferred_device =
1014       DeviceNameUtils::HasSomeDetails(op->GetDeviceParsedName())
1015           ? op->GetDeviceParsedName()
1016           : DeviceNameUtils::AddressSpace(ctx.HostCPUParsedName());
1017   // Note: We use the unwrapped op for inferring the device.
1018   // Without this, when wrapping CPU-only ops like RangeDataset we would
1019   // place the wrapped op on a GPU (if one is available) which leads to
1020   // errors because placer pins the function output nodes to GPU thereby
1021   // forcing a H2D copy of the dataset variant which is not supported.
1022   auto ndef = op->MutableAttrs()->BuildNodeDef();
1023 #ifdef INTEL_MKL
1024   if (IsMKLEnabled() &&
1025       absl::StartsWith(op->Name(), mkl_op_registry::kMklOpPrefix)) {
1026     GetMKLNodeDef(&ndef);
1027   }
1028 #endif  // INTEL_MKL
1029 
1030   TF_RETURN_IF_ERROR(ctx.SelectDevice(preferred_device, ndef, device));
1031 
1032   VLOG(1) << "PreferredDevice " << op->Name() << ": " << preferred_device;
1033   VLOG(1) << "Placer place op [" << op->Name()
1034           << "] on device: " << (*device)->name();
1035   VLOG(4) << "Available kernels for " << op->Name() << " are"
1036           << KernelsRegisteredForOp(op->Name());
1037   op->SetDevice(*device);
1038   return OkStatus();
1039 }
1040 
GetDeviceCacheKey(EagerOperation * op,const EagerContext & ctx)1041 Fprint128 GetDeviceCacheKey(EagerOperation* op, const EagerContext& ctx) {
1042   Fprint128 device_cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
1043   device_cache_key =
1044       FingerprintCat128(device_cache_key, ctx.AllowSoftPlacement());
1045   return device_cache_key;
1046 }
1047 
GetOrCreateKernelAndDevice(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,core::RefCountPtr<KernelAndDevice> * out_kernel)1048 Status GetOrCreateKernelAndDevice(
1049     EagerOperation* op, TensorHandle** retvals, int* num_retvals,
1050     core::RefCountPtr<KernelAndDevice>* out_kernel) {
1051   EagerContext& ctx = op->EagerContext();
1052   Device* device = absl::get<Device*>(op->Device());
1053 
1054   // Set the EagerOperation's device prior to extracting the input_device_ptrs
1055   // to avoid any redundant H2D/D2H copies.
1056   if (device == nullptr && !op->is_function()) {
1057     Fprint128 device_cache_key = GetDeviceCacheKey(op, ctx);
1058     device = ctx.GetCachedDevice(device_cache_key);
1059     if (device == nullptr) {
1060       TF_RETURN_IF_ERROR(SetOpDevice(ctx, op, &device));
1061       ctx.AddDeviceToCache(device_cache_key, device);
1062     } else {
1063       op->SetDevice(device);
1064     }
1065   }
1066 
1067   // Save the original value of reuse_rendezvous_for_functions from the context.
1068   bool reuse_rendezvous_for_functions_original_value =
1069       ctx.GetReuseRendezvousForFunctions();
1070   // When running in eager_op_as_function mode Send/Recv ops need to be
1071   // placed on the same rendezvous to match the behaviour of eager mode.
1072   bool reuse_rendezvous_for_functions =
1073       (ctx.RunEagerOpAsFunction() && !op->is_function()) ||
1074       reuse_rendezvous_for_functions_original_value;
1075 
1076   std::vector<Device*> input_device_ptrs;
1077   absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
1078   std::unordered_map<int, DtypeAndPartialTensorShape>
1079       input_resource_variable_dtypes_and_shapes;
1080   const KernelDef* kernel_def = nullptr;
1081   if (!op->is_function()) {
1082     const NodeDef* node_def = &op->MutableAttrs()->BuildNodeDef();
1083     kernel_def = GetKernelDef(*op, node_def, device);
1084   }
1085   if (op->is_function() || ctx.RunEagerOpAsFunction()) {
1086     TF_RETURN_IF_ERROR(ExtractFunctionInputInfo(
1087         op, kernel_def, input_device_ptrs, composite_devices,
1088         input_resource_variable_dtypes_and_shapes));
1089   }
1090 
1091   TF_ASSIGN_OR_RETURN(
1092       Fprint128 cache_key,
1093       GetKernelCacheKey(*op, op->MutableAttrs()->CacheKey(op->DeviceName()),
1094                         input_device_ptrs,
1095                         input_resource_variable_dtypes_and_shapes));
1096   core::RefCountPtr<KernelAndDevice> kernel = ctx.GetCachedKernel(cache_key);
1097   AbstractOperationPtr wrapped_op_releaser;
1098   // We can eliminate some overhead by running simple functions using regular
1099   // CallOp kernel. However, it is tricky to figure out which functions should
1100   // be run using CallOp. Also, currently CallOp runs neither optimization
1101   // passes (needed for TPU/XLA) nor grappler.
1102   // Here are some cases where a function should be run in multi-device mode:
1103   //  - Function takes at least two resources on different devices.
1104   //  - Function takes a resource on deviceA and a body op explicitly placed
1105   //  on deviceB.
1106   //  - Function has a colocation constraint.
1107   //  - Function has an explicit device annotation (which might not be using
1108   //    full canonical device name) different from op_device. Note that false
1109   //    positives are ok.
1110   //  - Function has a node or a (node) attribute that can potentially make
1111   //    the function multi-device after a rewrite pass (e.g. various XLA/TPU
1112   //    special nodes and attributes)
1113   if (kernel == nullptr) {
1114     VLOG(2) << "Creating new kernel for " << op->Name() << " on device "
1115             << DeviceNameOrUnspecified(absl::get<Device*>(op->Device()));
1116     bool run_function_with_flr = false;
1117     bool function_outputs_on_op_device = false;
1118     absl::optional<string> xla_compile_device_type;
1119     if (op->is_function()) {
1120       bool compile_with_xla;
1121       TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla));
1122       if (compile_with_xla) {
1123         if (ctx.JitCompileRewrite()) {
1124           xla_compile_device_type = op->GetDeviceParsedName().type;
1125           run_function_with_flr = true;
1126         } else {
1127           // Note that it is not ideal, but currently correct, to set this
1128           // attribute after computing the kernel cache key above.
1129           // Note: If the attribute is already set to true, this is a noop.
1130           op->MutableAttrs()->Set(kXlaMustCompileAttr, true);
1131         }
1132       } else {
1133         run_function_with_flr = true;
1134       }
1135       GetFuncAttr(op, ctx, kOutputsOnOpDevice, &function_outputs_on_op_device)
1136           .IgnoreError();
1137     }
1138 
1139     VLOG(2) << op->Name() << " function_outputs_on_op_device: "
1140             << function_outputs_on_op_device;
1141     if (device == nullptr) {
1142       TF_RETURN_IF_ERROR(SetOpDevice(ctx, op, &device));
1143     } else {
1144       VLOG(1) << "Device for [" << op->Name()
1145               << "] already set to: " << device->name();
1146     }
1147 
1148     // Note: We wrap the eager op AFTER the device has been inferred to ensure
1149     // that placement of the NodeDef in the function is exactly the same as in
1150     // eager mode. This is specially important for cases where the
1151     // preferred device is not the actual device on which the op is run.
1152     // E.g. the preferred device for a `RangeDataset` op could be set to `GPU`
1153     // but `ctx->SelectDevice` would still place it on CPU. Placer on the other
1154     // hand would throw an error.
1155     //
1156     // Note: The wrapped function is never jit compiled but rather run via the
1157     // FLR. This is needed because certain ops e.g. `VarHandleOp` can not be
1158     // jit compiled. Ideally we would run this via the jit compiled path and
1159     // expect unsupported ops to be outside compiled but that is not supported
1160     // on GPUs right now.
1161     bool allow_small_function_optimizations = false;
1162     bool int_args_and_retvals_on_device = false;
1163     bool allow_control_flow_sync_execution = false;
1164     // TODO(b/176491312): Remove this if shape inference on import flag is
1165     // removed.
1166     bool shape_inference_on_tfe_dialect_import = true;
1167     if (ctx.RunEagerOpAsFunction() && !op->is_function()) {
1168       EagerOperation* wrapped_op = nullptr;
1169       TF_RETURN_IF_ERROR(ValidateOp(op));
1170       TF_RETURN_IF_ERROR(WrapInCallOp(op, &wrapped_op));
1171       DCHECK(wrapped_op);
1172       DCHECK(wrapped_op->is_function());
1173       wrapped_op_releaser.reset(wrapped_op);
1174       run_function_with_flr = true;
1175       allow_small_function_optimizations = true;
1176       allow_control_flow_sync_execution = true;
1177       shape_inference_on_tfe_dialect_import = false;
1178       int_args_and_retvals_on_device =
1179           IntArgsAndRetvalsOnDevice(op, kernel_def);
1180       op = wrapped_op;
1181       if (int_args_and_retvals_on_device) {
1182         op->MutableAttrs()->Set(FunctionLibraryDefinition::kIntsOnDeviceAttr,
1183                                 true);
1184       }
1185     }
1186     const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
1187 
1188     FunctionLibraryRuntime* flr =
1189         device == nullptr ? nullptr : ctx.func_lib(device);
1190     if (device != nullptr && flr == nullptr) {
1191       return errors::NotFound(
1192           "Unable to find a FunctionLibraryRuntime corresponding to device ",
1193           device->name());
1194     }
1195     auto runner = (flr != nullptr && flr->runner() != nullptr) ? flr->runner()
1196                                                                : ctx.runner();
1197     GraphCollector* graph_collector = nullptr;
1198     if (ctx.ShouldStoreGraphs()) {
1199       graph_collector = ctx.GetGraphCollector();
1200     }
1201     // Treat the function as multi_device only when we are not compiling
1202     // it wholly with XLA. When compiling wholly with XLA, flr->CreateKernel
1203     // will create an XlaLaunchOp kernel to compile and run the function.
1204     if (run_function_with_flr) {
1205       // Multi-device functions don't use the rendezvous from eager context.
1206       // If we use that rendezvous, multiple concurrent calls to the same
1207       // function will likely result in collisions. However, this also means
1208       // that we don't support legitimate sending/receiving across function
1209       // boundary.
1210       VLOG(2) << "Running " << ndef.op() << " using multi-device function. "
1211               << "Full node_def=" << ndef.DebugString();
1212       std::function<int64_t()> get_op_id = nullptr;
1213 #if !defined(IS_MOBILE_PLATFORM)
1214       get_op_id = [&ctx]() { return ctx.RemoteMgr()->NextOpId(); };
1215 #endif  // IS_MOBILE_PLATFORM
1216 
1217       ctx.reuse_rendezvous_for_functions_mu()->lock();
1218       ctx.SetReuseRendezvousForFunctions(reuse_rendezvous_for_functions);
1219       auto rendezvous_creator = ctx.RendezvousCreator();
1220       ctx.SetReuseRendezvousForFunctions(
1221           reuse_rendezvous_for_functions_original_value);
1222       ctx.reuse_rendezvous_for_functions_mu()->unlock();
1223       kernel.reset(new KernelAndDeviceFunc(
1224           flr, ctx.pflr(), std::move(input_device_ptrs),
1225           std::move(composite_devices),
1226           std::move(input_resource_variable_dtypes_and_shapes), runner,
1227           ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
1228           function_outputs_on_op_device, allow_small_function_optimizations,
1229           allow_control_flow_sync_execution,
1230           shape_inference_on_tfe_dialect_import, int_args_and_retvals_on_device,
1231           xla_compile_device_type, std::move(rendezvous_creator), get_op_id));
1232     } else {
1233       VLOG(2) << "Running " << ndef.op() << " using op kernel. "
1234               << ". Full node_def=" << ndef.DebugString();
1235       kernel.reset(new KernelAndDeviceOp(
1236           ctx.GetRendezvous(), ctx.LogMemory(), flr, runner,
1237           ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
1238     }
1239 
1240     TF_RETURN_IF_ERROR(
1241         kernel->Init(ctx.LogDevicePlacement(), ndef, graph_collector));
1242 
1243     if (op->is_function()) {
1244       ctx.AddKernelToCache(cache_key, kernel.get());
1245     } else {
1246       // Exclude tf.data op kernels from being cached. The reason for this is
1247       // that tf.data op kernels that accept a user-defined function will have a
1248       // unique cache key every time they are executed (because the user-defined
1249       // function is traced every time). Caching such kernels provides no
1250       // benefit and in some cases results in linear memory growth of use
1251       // programs that build input pipeline graphs in a loop.
1252       const OpDef* op_def;
1253       TF_RETURN_IF_ERROR(OpDefForOp(op->Name().data(), &op_def));
1254       if (KernelCacheEnabled(*op_def)) {
1255         ctx.AddKernelToCache(cache_key, kernel.get());
1256       }
1257     }
1258   }
1259 
1260   int num_outputs = kernel->num_outputs();
1261   if (num_outputs > *num_retvals) {
1262     return errors::InvalidArgument("Expecting ", num_outputs,
1263                                    " outputs, but *num_retvals is ",
1264                                    *num_retvals);
1265   }
1266   *num_retvals = num_outputs;
1267 
1268   kernel->Ref();  // Ownership of reference is passed to out_kernel.
1269   out_kernel->reset(kernel.get());
1270   return OkStatus();
1271 }
1272 
CreateUnshapedOutput(const KernelAndDevice & kernel,const int output_num,Device * output_device,const DataType & output_dtype,const absl::optional<EagerFunctionParams> & eager_func_params,EagerContext * ctx,TensorHandle ** output)1273 Status CreateUnshapedOutput(
1274     const KernelAndDevice& kernel, const int output_num, Device* output_device,
1275     const DataType& output_dtype,
1276     const absl::optional<EagerFunctionParams>& eager_func_params,
1277     EagerContext* ctx, TensorHandle** output) {
1278 #if defined(IS_MOBILE_PLATFORM)
1279   return errors::Unimplemented(
1280       "Remote outputs are not available on mobile devices.");
1281 #else  // !IS_MOBILE_PLATFORM
1282   int64_t op_id;
1283   if (eager_func_params.has_value()) {
1284     op_id = eager_func_params.value().op_id;
1285   } else {
1286     return errors::InvalidArgument(
1287         "Unable to find a remote op id for a remote output of ", kernel.name());
1288   }
1289   string remote_task;
1290   if (!DeviceNameUtils::GetTaskName(output_device->parsed_name(),
1291                                     &remote_task)) {
1292     return errors::InvalidArgument(
1293         "Unable to find remote task corresponding to device ",
1294         output_device->name());
1295   }
1296   if (ctx->RemoteMgr()->IsMaster()) {
1297     *output = TensorHandle::CreateUnshapedRemoteHandle(
1298         op_id, output_num, remote_task, output_dtype, output_device, ctx);
1299   } else {
1300     *output = TensorHandle::CreateLazyRemoteHandle(op_id, output_num,
1301                                                    output_dtype, output_device,
1302                                                    /*is_ready=*/false, ctx);
1303   }
1304   return OkStatus();
1305 #endif  // !IS_MOBILE_PLATFORM
1306 }
1307 
AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,EagerOperation * op,TensorHandle ** retvals)1308 Status AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,
1309                         EagerOperation* op, TensorHandle** retvals) {
1310   EagerExecutor& executor = op->Executor();
1311   EagerContext& ctx = op->EagerContext();
1312   GraphCollector* graph_collector = nullptr;
1313   if (ctx.ShouldStoreGraphs()) {
1314     graph_collector = ctx.GetGraphCollector();
1315   }
1316   const int num_outputs = kernel->num_outputs();
1317   absl::optional<EagerFunctionParams> eager_func_params =
1318       op->eager_func_params();
1319   if (kernel->IsCrossProcess() && !eager_func_params.has_value()) {
1320     // Create an eager op id for a cross-process function if not exist.
1321 #if defined(IS_MOBILE_PLATFORM)
1322     return errors::Unimplemented(
1323         "Cross-process functions are not supported on mobile devices.");
1324 #else  // !IS_MOBILE_PLATFORM
1325     const int64_t op_id = ctx.RemoteMgr()->NextOpId();
1326     eager_func_params = EagerFunctionParams{
1327         op_id, /* is_component_function= */ false, /* step_id= */ std::nullopt};
1328 #endif  // !IS_MOBILE_PLATFORM
1329   }
1330   if (executor.Async()) {
1331     const DataTypeVector& output_dtypes = kernel->output_dtypes();
1332     for (int i = 0, end = num_outputs; i < end; ++i) {
1333       Device* output_device = ctx.CanonicalDevice(kernel->OutputDevice(i));
1334       if (output_device == nullptr || output_device->IsLocal()) {
1335         retvals[i] = TensorHandle::CreateEmptyLocalHandle(
1336             /* d= */ output_device, /* op_device= */ kernel->device(),
1337             /* resource_device= */ kernel->OutputResourceDevice(i),
1338             output_dtypes[i], &ctx);
1339       } else {
1340         TF_RETURN_IF_ERROR(
1341             CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
1342                                  eager_func_params, &ctx, &retvals[i]));
1343       }
1344     }
1345     const absl::InlinedVector<TensorHandle*, 4>* inputs;
1346     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1347     auto node = std::make_unique<AsyncExecuteNode>(
1348         &ctx, *inputs, eager_func_params, std::move(kernel), graph_collector,
1349         op->GetCancellationManager(),
1350         absl::Span<TensorHandle*>(retvals, num_outputs), op->GetStackTrace());
1351     // Release the inputs from the eager operation since the AsyncExecuteNode
1352     // would have taken ownership. This allows the inputs to be forwarded if
1353     // possible.
1354     op->Clear();
1355     // For async mode, execution order will make sure that all
1356     // input handles are ready before executing them.
1357     // TODO(b/137118203): Consider executing "cheap" kernels inline for
1358     // performance.
1359     return executor.AddOrExecute(std::move(node));
1360   } else {
1361     for (int i = 0, end = num_outputs; i < end; ++i) {
1362       retvals[i] = nullptr;
1363     }
1364     const absl::InlinedVector<TensorHandle*, 4>* inputs;
1365     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1366     ExecuteNode node(&ctx, *inputs, eager_func_params, kernel, graph_collector,
1367                      op->GetCancellationManager(),
1368                      {retvals, static_cast<size_t>(num_outputs)},
1369                      op->GetStackTrace());
1370     Status s = executor.SyncExecute(&node);
1371     // We release the inputs AFTER executing the operation in sync mode since
1372     // ExecuteNode does not increment the reference count and thus does not have
1373     // ownership of the inputs while executing.
1374     op->Clear();
1375     return s;
1376   }
1377 }
1378 
1379 // There are a lot of references to devices in this function and around.
1380 // Here is what they mean:
1381 //  EagerOperation::Device(): The device on which the user requested the op
1382 //    be executed, except if we had to change the device due to resource inputs
1383 //    or CPU pinning. If the user did not request a device, the op does not
1384 //    take resources, and we did not pin it to CPU, the device can be nullptr.
1385 //  KernelAndDevice::Device(): The first time we see an op (combined with
1386 //    its attributes), we need to create a KernelAndDevice object for it.
1387 //    If op->Device() is a nullptr, we select a device for the op when
1388 //    creating the KernelAndDevice. A concrete device will always be selected
1389 //    here except when `op` is a function to be executed using function library
1390 //    runtime. In this case, we don't select a device because running
1391 //    a function with explicitly requested device has different behavior than
1392 //    running without an explicitly requested device.
EagerLocalExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1393 Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
1394                          int* num_retvals) {
1395   profiler::ScopedMemoryDebugAnnotation op_annotation(
1396       op->op_name(), op->eager_func_params().has_value()
1397                          ? op->eager_func_params().value().step_id.value_or(0)
1398                          : 0);
1399   profiler::TraceMe activity(
1400       [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
1401       profiler::TraceMeLevel::kInfo);
1402   EagerContext& ctx = op->EagerContext();
1403   auto& executor = op->Executor();
1404   TF_RETURN_IF_ERROR(executor.status());
1405 
1406   core::RefCountPtr<KernelAndDevice> kernel;
1407   auto status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1408 
1409 #ifdef INTEL_MKL
1410   if (IsMKLEnabled() && kernel != nullptr &&
1411       op->Device() == kVariantDeviceNull) {
1412     // oneDNN optimization pass relies on the op's assigned device to determine
1413     // whether it can be rewritten.
1414     op->SetDevice(kernel->device());
1415   }
1416 #endif  // INTEL_MKL
1417 
1418   // Run all the registered rewrite pass after the placement, regardless whether
1419   // the placement is successful or not. The passes can either create new ops
1420   // (without placement) or update some fields of the input op.
1421   std::unique_ptr<tensorflow::EagerOperation> out_op;
1422   TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
1423       EagerOpRewriteRegistry::POST_PLACEMENT, op, &out_op));
1424   if (out_op) {
1425     op = out_op.get();
1426     // If the out op doesn't have device, either because it is a new op or
1427     // the op wasn't placed successfully, then we do the placement again.
1428     if (op->Device() == kVariantDeviceNull) {
1429       status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1430     }
1431   }
1432   if (!status.ok()) return status;
1433 
1434   int num_outputs = kernel->num_outputs();
1435   TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
1436 
1437   if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
1438     string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
1439                                  kernel->device()->name());
1440     if (!logging::LogToListeners(msg)) {
1441       LOG(INFO) << msg;
1442     }
1443   }
1444 
1445   Status s = AddOrExecuteNode(std::move(kernel), op, retvals);
1446   // Since the operation failed, we need to Unref any outputs if they were
1447   // allocated.
1448   if (!s.ok()) {
1449     for (int i = 0, end = num_outputs; i < end; ++i) {
1450       if (retvals[i] != nullptr) {
1451         retvals[i]->Unref();
1452         retvals[i] = nullptr;
1453       }
1454     }
1455   }
1456 
1457   return s;
1458 }
1459 
1460 // Run a Pack op to pack the tensors pointed by a packed input TensorHandle if
1461 // the op is a primitive op.
MaybePackInputTensor(EagerOperation * op)1462 Status MaybePackInputTensor(EagerOperation* op) {
1463   if (op->is_function() || op->EagerContext().RunEagerOpAsFunction()) {
1464     // Functions could take packed TensorHandles as inputs.
1465     return OkStatus();
1466   }
1467   EagerContext& ctx = op->EagerContext();
1468   const absl::InlinedVector<TensorHandle*, 4>* inputs;
1469   TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1470   for (int i = 0; i < inputs->size(); ++i) {
1471     TensorHandle* handle = (*inputs)[i];
1472     if (handle->Type() == TensorHandle::PACKED) {
1473       EagerOperation pack_op(&ctx);
1474       TF_RETURN_IF_ERROR(pack_op.Reset("Pack", /*device_name=*/nullptr,
1475                                        /*remote=*/false, /*executor=*/nullptr));
1476       pack_op.MutableAttrs()->Set("N", handle->NumPackedHandles());
1477       pack_op.MutableAttrs()->Set("T", handle->dtype);
1478       for (int i = 0; i < handle->NumPackedHandles(); ++i) {
1479         tensorflow::TensorHandle* h = nullptr;
1480         TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(i, &h));
1481         TF_RETURN_IF_ERROR(pack_op.AddInput(h));
1482       }
1483       int num_retvals = 1;
1484       absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
1485       TF_RETURN_IF_ERROR(
1486           EagerLocalExecute(&pack_op, retvals.data(), &num_retvals));
1487       tensorflow::TensorHandle* ret = retvals.at(0);
1488       op->UpdateInput(i, ret);
1489       ret->Unref();
1490     }
1491   }
1492   return OkStatus();
1493 }
1494 
1495 #if !defined(IS_MOBILE_PLATFORM)
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)1496 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
1497   EagerContext& ctx = op->EagerContext();
1498 
1499   remote_op->set_id(ctx.RemoteMgr()->NextOpId());
1500   remote_op->set_name(op->Name());
1501 
1502   op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs());
1503   remote_op->set_device(absl::get<Device*>(op->Device())->name());
1504   remote_op->set_is_function(op->is_function());
1505 }
1506 
StoreResourceDtypesAndShapes(const eager::Operation & remote_op,const DataTypeVector & output_dtypes,TensorHandle ** retvals)1507 Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
1508                                     const DataTypeVector& output_dtypes,
1509                                     TensorHandle** retvals) {
1510   if (remote_op.name() == "VarHandleOp") {
1511     if (output_dtypes.size() != 1) {
1512       return errors::Internal("VarHandleOp should only have one output.");
1513     }
1514     if (output_dtypes[0] != DT_RESOURCE) {
1515       return errors::Internal(
1516           "The output of VarHandleOp should be a DT_RESOURCE.");
1517     }
1518     AttrSlice attr_slice = AttrSlice(&remote_op.attrs());
1519     const AttrValue* dtype;
1520     TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
1521     const AttrValue* shape;
1522     TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
1523     retvals[0]->SetResourceHandleDtypeAndShape(
1524         {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
1525   }
1526   return OkStatus();
1527 }
1528 
EagerRemoteExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1529 Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
1530                           int* num_retvals) {
1531   EagerContext& ctx = op->EagerContext();
1532 
1533   // TODO(fishx): Remove following code when lazy tensor copy is ready.
1534   if (op->Device() == kVariantDeviceNull) {
1535     tensorflow::Device* device = nullptr;
1536     string device_name = op->DeviceName();
1537     TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
1538     op->SetDevice(device);
1539   }
1540 
1541   core::RefCountPtr<eager::EagerClient> eager_client;
1542   uint64 context_id = ctx.GetContextId();
1543   TF_RETURN_IF_ERROR(ctx.GetClient(op->GetDeviceParsedName(), &eager_client));
1544   string remote_task;
1545   if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) {
1546     return errors::InvalidArgument(
1547         "Unable to find remote task corresponding to device ",
1548         op->DeviceName());
1549   }
1550 
1551   std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
1552   request->set_context_id(context_id);
1553 
1554   eager::Operation* remote_op = request->add_queue()->mutable_operation();
1555 
1556   tensorflow::Device* op_device = absl::get<Device*>(op->Device());
1557   {
1558     profiler::TraceMe activity("CopyInputToExpectedDevice",
1559                                profiler::TraceMeLevel::kInfo);
1560     const bool is_function = op->is_function();
1561     const absl::InlinedVector<TensorHandle*, 4>* inputs;
1562     TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1563     for (int i = 0, end = inputs->size(); i < end; i++) {
1564       tensorflow::TensorHandle* input = (*inputs)[i];
1565       tensorflow::Device* input_device = input->device();
1566       tensorflow::Device* input_device_or_cpu = input->DeviceOrHostCPU(ctx);
1567       const string* input_device_name = &input_device_or_cpu->name();
1568       bool serialize_resource_dtype_and_shape = false;
1569       if (op_device != input_device &&
1570           // If the expected and actual devices are on the same task, don't
1571           // explicitly copy, and instead depend on the copy to happen locally
1572           // when the op is executed on the device.
1573           !ctx.OnSameTask(op_device, input_device)) {
1574         if (!is_function || input_device_or_cpu->IsLocal()) {
1575           tensorflow::Device* remote_cpu_device;
1576           TF_RETURN_IF_ERROR(
1577               ctx.CPUDeviceOnTask(op_device, &remote_cpu_device));
1578           // Always copy to the remote CPU so that the actual device can be
1579           // correctly determined after the kernel is selected/instantiated,
1580           // since the op might have its inputs on host memory.
1581           TensorHandle* handle = input;
1582           Device* handle_device = handle->DeviceOrHostCPU(ctx);
1583           // If the input is already on the right device, then nothing to do.
1584           if (remote_cpu_device != handle_device) {
1585             VLOG(6) << "remote_cpu_device != handle_device";
1586             TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
1587                 &ctx, op, op_device, handle, i, handle_device,
1588                 remote_cpu_device, &handle));
1589             op->UpdateInput(i, handle);
1590             input = handle;
1591             input_device = remote_cpu_device;
1592             input_device_name = &remote_cpu_device->name();
1593             // Unref handle since it has a ref as an input now
1594             handle->Unref();
1595           }
1596         } else {
1597           serialize_resource_dtype_and_shape =
1598               (input->dtype == DT_RESOURCE) &&
1599               (!input->HasResourceShapeMirror(op_device,
1600                                               ctx.GetContextViewId()));
1601         }
1602       }
1603       auto* input_handle = remote_op->add_op_inputs()->mutable_remote_handle();
1604       // For a remote component function, a function execution request and an
1605       // input generation request may come from different workers. We need to
1606       // guarantee that the input generation request is processed before the
1607       // function execution request, so wait until the remote input is ready
1608       // before sending it to the multi-device function device.
1609       const bool wait_until_ready = op->is_function();
1610       TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
1611           input, wait_until_ready, input_handle, input_device,
1612           *input_device_name, serialize_resource_dtype_and_shape));
1613       if (!input_handle->resource_dtypes_and_shapes().empty()) {
1614         TF_RETURN_IF_ERROR(
1615             input->AddResourceShapeMirror(op_device, input_handle->op_id(),
1616                                           input_handle->output_num(), &ctx));
1617       }
1618     }
1619   }
1620 
1621   PrepareRemoteOp(remote_op, op);
1622 
1623   DataTypeVector output_dtypes;
1624   TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
1625 
1626   const size_t num_outputs = output_dtypes.size();
1627   if (num_outputs != *num_retvals) {
1628     return errors::InvalidArgument(
1629         "num_retvals does not match expected output dtypes");
1630   }
1631   *num_retvals = num_outputs;
1632 
1633   const tensorflow::uint64 id = remote_op->id();
1634   for (size_t i = 0; i < num_outputs; ++i) {
1635     // TODO(nareshmodi): Change the callback to instead add the decref to a
1636     // list of pending decrefs that we can send as a batch with the next
1637     // execute.
1638 
1639     // The device_ and resource_device_ of this TensorHandle might be
1640     // incorrect. For multi-device functions, we don't know the output device
1641     // until the function is instantiated on a remote worker. Luckily, we don't
1642     // need to know the correct remote device here. We just need to know that it
1643     // is remote. If we need copy this tensor to this process or run any ops
1644     // which take this tensor as an input, block until the correct device is
1645     // set.
1646     const bool unknown_device = op->is_function();
1647     retvals[i] = TensorHandle::CreateUnshapedRemoteHandle(
1648         id, i, remote_task, output_dtypes[i], op_device, &ctx, unknown_device);
1649   }
1650 
1651   // Store the data type and shape of a remote resource variable on the
1652   // corresponding remote TensorHandle (output of 'VarHandleOp').
1653   // If the variable is an input of a remote function, the function may need
1654   // the type and shape during function instantiation. Store the type and
1655   // shape on eager master and sent them to the default function device along
1656   // with the EnqueueRequest.
1657   TF_RETURN_IF_ERROR(
1658       StoreResourceDtypesAndShapes(*remote_op, output_dtypes, retvals));
1659 
1660   auto& executor = op->Executor();
1661   VLOG(4) << "Execute remote eager op: " << op->Name()
1662           << " (is async?: " << executor.Async() << ").";
1663 
1664   const absl::InlinedVector<TensorHandle*, 4>* inputs;
1665   TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1666 
1667   std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
1668       &op->EagerContext(), std::move(request), op_device,
1669       ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(),
1670       op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
1671       *inputs, {retvals, num_outputs}));
1672 
1673   if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
1674     string msg = strings::StrCat(
1675         "Executing op ", op->Name(), " on task ",
1676         DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
1677     if (!logging::LogToListeners(msg)) {
1678       LOG(INFO) << msg;
1679     }
1680   }
1681 
1682   Status s = executor.AddOrExecute(std::move(node));
1683   // Since the operation failed, we need to Unref any outputs that were
1684   // allocated.
1685   if (!s.ok()) {
1686     for (size_t i = 0; i < num_outputs; ++i) {
1687       retvals[i]->Unref();
1688       // Ensure that any smart pointers created to wrap results become noops
1689       // rather than operating on invalid memory.
1690       retvals[i] = nullptr;
1691     }
1692   }
1693 
1694   return s;
1695 }
1696 #endif  // IS_MOBILE_PLATFORM
1697 
GetKernelOutputs(std::vector<EagerKernelRet> * outputs,int num_outputs,TensorHandle ** retvals,EagerContext * ctx,KernelAndDevice * kernel,const absl::optional<EagerFunctionParams> & eager_func_params)1698 Status GetKernelOutputs(
1699     std::vector<EagerKernelRet>* outputs, int num_outputs,
1700     TensorHandle** retvals, EagerContext* ctx, KernelAndDevice* kernel,
1701     const absl::optional<EagerFunctionParams>& eager_func_params) {
1702   for (int i = 0, end = num_outputs; i < end; ++i) {
1703     if (retvals[i] == nullptr) {
1704       EagerKernelRet& ret = (*outputs)[i];
1705       Device* output_device = ctx->CanonicalDevice(kernel->OutputDevice(i));
1706       if (ret.index() == 0) {
1707         retvals[i] = TensorHandle::CreateLocalHandle(
1708             std::move(absl::get<Tensor>(ret)),
1709             /* d= */ output_device,
1710             /* op_device= */ kernel->device(),
1711             /* resource_device= */ kernel->OutputResourceDevice(i), ctx);
1712       } else {
1713         const DataTypeVector& output_dtypes = kernel->output_dtypes();
1714         TF_RETURN_IF_ERROR(
1715             CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
1716                                  eager_func_params, ctx, &retvals[i]));
1717 #if !defined(IS_MOBILE_PLATFORM)
1718         TF_RETURN_IF_ERROR(
1719             retvals[i]->SetRemoteShape(absl::get<TensorShape>(ret),
1720                                        output_device, ctx->GetContextViewId()));
1721 #endif  // IS_MOBILE_PLATFORM
1722       }
1723     } else {
1724       if (!kernel->IsFunction() &&
1725           TF_PREDICT_FALSE(kernel->device() != retvals[i]->op_device())) {
1726         return errors::Internal(
1727             "Kernel output tensor handle has a different op device than the "
1728             "kernel. This should never happen.");
1729       }
1730       if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) !=
1731                            retvals[i]->device())) {
1732         return errors::Internal(
1733             "Kernel output tensor handle locates on a different device than "
1734             "the specified kernel output device. This should never happen.");
1735       }
1736 
1737       EagerKernelRet& ret = (*outputs)[i];
1738       if (ret.index() == 0) {
1739         TF_RETURN_IF_ERROR(retvals[i]->SetTensor(
1740             std::move(absl::get<Tensor>(ret)),
1741             ctx->CanonicalDevice(kernel->OutputDevice(i))));
1742       } else {
1743 #if defined(IS_MOBILE_PLATFORM)
1744         return errors::Unimplemented(
1745             "Remote outputs are not available on mobile devices.");
1746 #else  // !IS_MOBILE_PLATFORM
1747         TF_RETURN_IF_ERROR(retvals[i]->SetRemoteShape(
1748             absl::get<TensorShape>(ret), retvals[i]->device(),
1749             ctx->GetContextViewId()));
1750 #endif  // !IS_MOBILE_PLATFORM
1751       }
1752     }
1753   }
1754   return OkStatus();
1755 }
1756 
CollectGraphs(EagerContext * ctx)1757 void CollectGraphs(EagerContext* ctx) {
1758   mutex_lock ml(*ctx->MetadataMu());
1759 
1760   GraphCollector* collector = ctx->GetGraphCollector();
1761   mutex_lock mll(collector->mu);
1762 
1763   // Adding to partition graphs for backward compatibility.
1764   for (const auto& graph : collector->partitioned_graphs) {
1765     *ctx->RunMetadataProto()->add_partition_graphs() = graph;
1766   }
1767 
1768   if (collector->dirty) {
1769     auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
1770     *function_graphs->mutable_post_optimization_graph() =
1771         collector->optimized_graph;
1772     *function_graphs->mutable_pre_optimization_graph() = collector->raw_graph;
1773     for (const auto& graph : collector->partitioned_graphs) {
1774       *function_graphs->add_partition_graphs() = graph;
1775     }
1776   }
1777 
1778   collector->ClearGraphs();
1779 }
1780 }  // namespace
1781 
EagerExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1782 Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
1783                     int* num_retvals) {
1784   profiler::TraceMe activity([&] {
1785     return ::tensorflow::profiler::TraceMeEncode(
1786         "EagerExecute",
1787         {{"eager_op", op->Name()}, {"is_func", op->is_function()}});
1788   });
1789 
1790   if (!op->Executor().Async()) {
1791     VLOG(6) << "op: " << op->Name() << " is not Async.";
1792     if (!op->EagerContext()
1793              .GetGlobalRendezvousForFunctionLocalRendezvousStatus()
1794              .ok()) {
1795       VLOG(6) << "global_rendezvous_for_functions_ is in bad state. Resetting.";
1796       op->EagerContext().ResetGlobalRendezvousForFunction();
1797     }
1798     // In sync mode, always clear error to maintain the same behavior as before.
1799     // TODO(b/141004939): Remove this.
1800     op->Executor().ClearError();
1801   }
1802 
1803   std::unique_ptr<tensorflow::EagerOperation> out_op;
1804   TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
1805       EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
1806 
1807   if (op->IsLocal()) {
1808     if (out_op) {
1809       op = out_op.get();
1810     }
1811     TF_RETURN_IF_ERROR(MaybePackInputTensor(op));
1812     return EagerLocalExecute(op, retvals, num_retvals);
1813   }
1814 
1815 #if defined(IS_MOBILE_PLATFORM)
1816   return errors::Unimplemented(
1817       "Eager's remote execution is not available on mobile devices.");
1818 #else   // !IS_MOBILE_PLATFORM
1819   if (out_op) {
1820     op = out_op.get();
1821   }
1822   return EagerRemoteExecute(op, retvals, num_retvals);
1823 #endif  // !IS_MOBILE_PLATFORM
1824 }
1825 
1826 // TODO(gjn): Consider moving into ExecuteNode class
EagerKernelExecute(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerFunctionParams> & eager_func_params,const core::RefCountPtr<KernelAndDevice> & kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals,const absl::optional<ManagedStackTrace> & stack_trace)1827 Status EagerKernelExecute(
1828     EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
1829     const absl::optional<EagerFunctionParams>& eager_func_params,
1830     const core::RefCountPtr<KernelAndDevice>& kernel,
1831     GraphCollector* graph_collector, CancellationManager* cancellation_manager,
1832     absl::Span<TensorHandle*> retvals,
1833     const absl::optional<ManagedStackTrace>& stack_trace) {
1834   profiler::TraceMe activity("EagerKernelExecute",
1835                              profiler::TraceMeLevel::kInfo);
1836   std::vector<EagerKernelRet> outputs(1);
1837 
1838   ExecuteNodeArgs inputs(op_inputs.size());
1839   TF_RETURN_IF_ERROR(inputs.Init(ctx, op_inputs, kernel));
1840   // TODO(apassos) figure out how to record stats for ops which are a part of
1841   // functions.
1842   // TODO(b/111859745): When we support recovering from kernel/device errors, we
1843   // would need to call XlaDevice::EnsureDeviceContextOk() before using an XLA
1844   // device. We don't call it now because it is an unneeded overhead (it
1845   // acquires a lock) and we can't recover from errors anyway.
1846   ScopedStepContainer* container = ctx->StepContainer();
1847   CoordinationServiceAgent* coord_agent = nullptr;
1848 #if !defined(IS_MOBILE_PLATFORM)
1849   if (ctx->GetDistributedManager() != nullptr)
1850     coord_agent = ctx->GetDistributedManager()->GetCoordinationServiceAgent();
1851 #endif  // !IS_MOBILE_PLATFORM
1852   TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs,
1853                                  cancellation_manager, eager_func_params,
1854                                  stack_trace, coord_agent));
1855   if (graph_collector != nullptr) {
1856     CollectGraphs(ctx);
1857   }
1858 
1859   if (TF_PREDICT_FALSE(retvals.size() != outputs.size())) {
1860     return errors::Internal(
1861         "EagerKernelExecute returns a list of ", outputs.size(),
1862         " tensors but ", retvals.size(),
1863         " is expected. This should never "
1864         "happen. Please file a bug with the TensorFlow team.");
1865   }
1866   return GetKernelOutputs(&outputs, retvals.size(), retvals.data(), ctx,
1867                           kernel.get(), eager_func_params);
1868 }
1869 
1870 namespace {
1871 
LocalEagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * dstd,bool mirror,TensorHandle ** result)1872 Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1873                               EagerExecutor* executor, Device* dstd,
1874                               bool mirror, TensorHandle** result) {
1875   TF_RETURN_IF_ERROR(executor->status());
1876   Device* d = ctx->CanonicalDevice(dstd);
1877   if (mirror && h->HasLocalMirror(d)) {
1878     h->Ref();
1879     *result = h;
1880     return OkStatus();
1881   }
1882 
1883   bool async = executor->Async();
1884   if (mirror) {
1885     h->Ref();
1886     *result = h;
1887 
1888     if (h->HasLocalMirror(d)) {
1889       return OkStatus();
1890     }
1891 
1892     // We don't bother adding an empty local mirror in sync mode since we'll be
1893     // executing the operation directly and be calling AddLocalMirror. A
1894     // reference count is still needed which will be removed if the operation
1895     // fails.
1896     if (async) {
1897       Status s = h->AddEmptyLocalMirror(d);
1898       if (!s.ok()) {
1899         // If a mirror was added since we called HasLocalMirror then just return
1900         // since another thread has already added the mirror.
1901         if (s.code() == error::Code::ALREADY_EXISTS) {
1902           return OkStatus();
1903         }
1904 
1905         // Remove the previously added reference count since adding the mirror
1906         // failed.
1907         h->Unref();
1908         *result = nullptr;
1909         return s;
1910       }
1911     }
1912   } else {
1913     *result = TensorHandle::CreateEmptyLocalHandle(
1914         d, dstd, h->resource_device(), h->dtype, ctx);
1915   }
1916 
1917   Status s;
1918   if (async) {
1919     // Note that `h` may not be currently ready. However execution order will
1920     // make sure that `h` is ready before the copy is actually done.
1921     std::unique_ptr<EagerNode> node(
1922         new CopyToDeviceNode(h, *result, d, *ctx, async, mirror));
1923     s = executor->AddOrExecute(std::move(node));
1924   } else {
1925     CopyToDeviceNode node(h, *result, d, *ctx, async, mirror);
1926     s = executor->SyncExecute(&node);
1927   }
1928 
1929   // Since the operation failed, we need to Unref any outputs that were
1930   // allocated.
1931   if (!s.ok()) {
1932     (*result)->Unref();
1933     *result = nullptr;
1934   }
1935 
1936   return s;
1937 }
1938 
1939 }  // namespace
1940 
EagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * device,bool mirror,TensorHandle ** result)1941 Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1942                          EagerExecutor* executor, Device* device, bool mirror,
1943                          TensorHandle** result) {
1944   TF_RETURN_IF_ERROR(h->WaitUnknownDevice());
1945   auto send_device = h->DeviceOrHostCPU(*ctx);
1946   bool sender_is_local = send_device->IsLocal();
1947 
1948   bool receiver_is_local = device->IsLocal();
1949 
1950   if (!executor->Async()) {
1951     // In sync mode, always clear error to maintain the same behavior as before.
1952     // TODO(b/141004939): Remove this.
1953     executor->ClearError();
1954   }
1955 
1956   if (sender_is_local && receiver_is_local) {
1957     return LocalEagerCopyToDevice(h, ctx, executor, device, mirror, result);
1958   } else {
1959 #if defined(IS_MOBILE_PLATFORM)
1960     return errors::Unimplemented(
1961         "Eager's remote execution is not available on mobile devices.");
1962 #else   // !IS_MOBILE_PLATFORM
1963     uint64 recv_op_id = 0;
1964     if (receiver_is_local) {
1965       Device* d = ctx->CanonicalDevice(device);
1966       // TODO(gjn): Need to add support for async execution. Note if receiver
1967       // is local, we need to first add support in TensorHandle to wait on local
1968       // mirrors.
1969       if (mirror) {
1970         h->Ref();
1971         *result = h;
1972 
1973         if (h->HasLocalMirror(d)) {
1974           return OkStatus();
1975         }
1976 
1977         Status s = h->AddEmptyLocalMirror(d);
1978         if (!s.ok()) {
1979           // If a mirror was added since we called HasLocalMirror then just
1980           // return since another thread has already added the mirror.
1981           if (s.code() == error::Code::ALREADY_EXISTS) {
1982             return OkStatus();
1983           }
1984 
1985           // Remove the previously added reference count since adding the mirror
1986           // failed.
1987           h->Unref();
1988           *result = nullptr;
1989           return s;
1990         }
1991       } else {
1992         *result = TensorHandle::CreateEmptyLocalHandle(
1993             /* d= */ d, /* op_device= */ device,
1994             /*resource_device=*/nullptr, h->dtype, ctx);
1995       }
1996     } else {
1997       if (mirror) {
1998         if (h->HasRemoteMirror(device, ctx->GetContextViewId())) {
1999           h->Ref();
2000           *result = h;
2001           return OkStatus();
2002         }
2003       }
2004       string remote_task;
2005       if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
2006         return errors::InvalidArgument(
2007             "Unable to find remote task corresponding to device ",
2008             device->name());
2009       }
2010       recv_op_id = ctx->RemoteMgr()->NextOpId();
2011       if (mirror) {
2012         TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0,
2013                                                       remote_task, ctx));
2014         h->Ref();
2015         *result = h;
2016       } else {
2017         *result = TensorHandle::CreateUnshapedRemoteHandle(
2018             recv_op_id, 0, remote_task, h->dtype, device, ctx);
2019       }
2020     }
2021 
2022     auto node = std::make_unique<eager::RemoteCopyNode>(
2023         ctx, executor, h, result[0], device, recv_op_id);
2024     Status s = executor->AddOrExecute(std::move(node));
2025     if (!s.ok()) {
2026       result[0]->Unref();
2027       result[0] = nullptr;
2028     }
2029     return s;
2030 #endif  // !IS_MOBILE_PLATFORM
2031   }
2032 }
2033 
2034 namespace {
2035 // Low-level utility function to execute the kernel specified by `kernel` on
2036 // `kernel->device()`, with the provided inputs as `op_inputs` in the 'ctx'.
2037 // Different from `EagerKernelExecute` that ties up the thread until the
2038 // underlying function finishes execute, this function does not block the thread
2039 // and could return before the function execution finishes. The provided
2040 // `StatusCallback` will be triggered after function execution with its status.
EagerKernelExecuteAsync(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerFunctionParams> & eager_func_params,const core::RefCountPtr<KernelAndDevice> kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,TensorHandle ** retvals,int num_outputs,StatusCallback done)2041 void EagerKernelExecuteAsync(
2042     EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
2043     const absl::optional<EagerFunctionParams>& eager_func_params,
2044     const core::RefCountPtr<KernelAndDevice> kernel,
2045     GraphCollector* graph_collector, CancellationManager* cancellation_manager,
2046     TensorHandle** retvals, int num_outputs, StatusCallback done) {
2047   auto inputs = std::make_shared<ExecuteNodeArgs>(op_inputs.size());
2048   auto outputs = std::make_shared<std::vector<EagerKernelRet>>(1);
2049 
2050   Status s = inputs->Init(ctx, op_inputs, kernel);
2051   if (!s.ok()) {
2052     done(s);
2053     return;
2054   }
2055   CoordinationServiceAgent* coord_agent = nullptr;
2056 #if !defined(IS_MOBILE_PLATFORM)
2057   if (ctx->GetDistributedManager() != nullptr)
2058     coord_agent = ctx->GetDistributedManager()->GetCoordinationServiceAgent();
2059 #endif  // !IS_MOBILE_PLATFORM
2060 
2061   kernel->Ref();  // Ownership of reference is transferred to the callback
2062   kernel->RunAsync(
2063       ctx->StepContainer(), *inputs, outputs.get(), cancellation_manager,
2064       eager_func_params, coord_agent,
2065       [retvals, inputs, outputs, num_outputs, ctx, graph_collector,
2066        eager_func_params, kernel_raw = kernel.get(),
2067        done = std::move(done)](const Status& s) {
2068         auto wrapped_done = [&](const Status& s) {
2069           kernel_raw->Unref();
2070           done(s);
2071         };
2072         if (!s.ok()) {
2073           wrapped_done(s);
2074           return;
2075         }
2076         if (graph_collector != nullptr) {
2077           CollectGraphs(ctx);
2078         }
2079         DCHECK_EQ(num_outputs, outputs->size());
2080         wrapped_done(GetKernelOutputs(outputs.get(), num_outputs, retvals, ctx,
2081                                       kernel_raw, eager_func_params));
2082       });
2083 }
2084 }  // namespace
2085 
2086 // Low-level utility to run the eager operation on local devices. Different from
2087 // `EagerLocalExecute` which blocks and waits for the finishing the op
2088 // execution, this method does not block the thread and could return before the
2089 // eager operation execution finishes. The provided `StatusCallback` will be
2090 // triggered after execution with its status.
EagerLocalExecuteAsync(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,StatusCallback done)2091 void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
2092                             int* num_retvals, StatusCallback done) {
2093   if (!op->IsLocal()) {
2094     done(errors::InvalidArgument(
2095         "Remote execution is not supported in async EagerLocalExecuteAsync"));
2096     return;
2097   }
2098 
2099   profiler::ScopedMemoryDebugAnnotation op_annotation(
2100       op->op_name(), op->eager_func_params().has_value()
2101                          ? op->eager_func_params().value().step_id.value_or(0)
2102                          : 0);
2103   profiler::TraceMe activity(
2104       [&] { return absl::StrCat("EagerLocalExecuteAsync: ", op->Name()); },
2105       profiler::TraceMeLevel::kInfo);
2106   EagerContext& ctx = op->EagerContext();
2107 
2108   core::RefCountPtr<KernelAndDevice> kernel;
2109   Status s = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
2110   if (!s.ok()) {
2111     done(s);
2112     return;
2113   }
2114 
2115   int num_outputs = kernel->num_outputs();
2116   s = ValidateInputTypeAndPlacement(&ctx, op, kernel);
2117   if (!s.ok()) {
2118     done(s);
2119     return;
2120   }
2121 
2122   if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
2123     string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
2124                                  kernel->device()->name());
2125     if (!logging::LogToListeners(msg)) {
2126       LOG(INFO) << msg;
2127     }
2128   }
2129 
2130   GraphCollector* graph_collector = nullptr;
2131   if (ctx.ShouldStoreGraphs()) {
2132     graph_collector = ctx.GetGraphCollector();
2133   }
2134 
2135   for (int i = 0, end = num_outputs; i < end; ++i) {
2136     const DataTypeVector& output_dtypes = kernel->output_dtypes();
2137     retvals[i] = TensorHandle::CreateEmptyLocalHandle(
2138         /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
2139         /* op_device= */ kernel->device(),
2140         /* resource_device= */ kernel->OutputResourceDevice(i),
2141         output_dtypes[i], &ctx);
2142   }
2143 
2144   const absl::InlinedVector<TensorHandle*, 4>* inputs;
2145   s = op->TensorHandleInputs(&inputs);
2146   if (!s.ok()) {
2147     done(s);
2148     return;
2149   }
2150   EagerKernelExecuteAsync(
2151       &ctx, *inputs, op->eager_func_params(), std::move(kernel),
2152       graph_collector, op->GetCancellationManager(), retvals, num_outputs,
2153       [op, num_outputs, retvals, done = std::move(done)](const Status& s) {
2154         op->Clear();
2155         // Since the operation failed, we need to Unref any outputs if they were
2156         // allocated.
2157         if (!s.ok()) {
2158           for (int i = 0, end = num_outputs; i < end; ++i) {
2159             if (retvals[i] != nullptr) {
2160               retvals[i]->Unref();
2161               retvals[i] = nullptr;
2162             }
2163           }
2164         }
2165         done(s);
2166       });
2167 }
2168 }  // namespace tensorflow
2169