1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/execute.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <optional>
21 #include <vector>
22
23 // clang-format off
24 // Required for IS_MOBILE_PLATFORM
25 #include "absl/container/btree_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_replace.h"
28 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/framework/function.pb.h"
31 #include "tensorflow/core/framework/kernel_def.pb.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/lib/core/refcount.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/platform.h"
39 #include "tensorflow/core/platform/protobuf.h"
40
41 // clang-format on
42
43 #include "absl/container/inlined_vector.h"
44 #include "absl/strings/match.h"
45 #include "absl/strings/str_cat.h"
46 #include "absl/types/optional.h"
47 #include "tensorflow/c/tf_tensor_internal.h"
48 #include "tensorflow/compiler/jit/defs.h"
49 #include "tensorflow/core/common_runtime/colocation_graph.h"
50 #include "tensorflow/core/common_runtime/device.h"
51 #include "tensorflow/core/common_runtime/device_set.h"
52 #include "tensorflow/core/common_runtime/eager/context.h"
53 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
54 #include "tensorflow/core/common_runtime/eager/execute_node.h"
55 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
56 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
57 #include "tensorflow/core/framework/dataset.h"
58 #include "tensorflow/core/framework/function.h"
59 #include "tensorflow/core/framework/logging.h"
60 #include "tensorflow/core/framework/node_def_util.h"
61 #include "tensorflow/core/framework/tensor_reference.h"
62 #include "tensorflow/core/framework/types.pb.h"
63 #include "tensorflow/core/lib/core/errors.h"
64 #include "tensorflow/core/platform/statusor.h"
65 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
66 #include "tensorflow/core/profiler/lib/traceme.h"
67 #include "tensorflow/core/protobuf/error_codes.pb.h"
68 #include "tensorflow/core/util/device_name_utils.h"
69 #if !defined(IS_MOBILE_PLATFORM)
70 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
71 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
72 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
73 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
74 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
75 #endif // IS_MOBILE_PLATFORM
76 #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
77 #include "tensorflow/core/framework/step_stats.pb.h"
78 #include "tensorflow/core/framework/tensor.h"
79 #include "tensorflow/core/framework/types.h"
80 #include "tensorflow/core/lib/core/status.h"
81 #include "tensorflow/core/lib/gtl/cleanup.h"
82 #include "tensorflow/core/lib/gtl/flatset.h"
83 #include "tensorflow/core/lib/random/random.h"
84 #include "tensorflow/core/platform/env.h"
85 #include "tensorflow/core/platform/mutex.h"
86 #include "tensorflow/core/util/ptr_util.h"
87 #include "tensorflow/core/util/util.h"
88
89 #ifdef INTEL_MKL
90 #include "tensorflow/core/graph/mkl_graph_util.h"
91 #endif
92
93 namespace tensorflow {
94
95 namespace {
96
DeviceNameOrUnspecified(Device * device)97 const string& DeviceNameOrUnspecified(Device* device) {
98 static string* unspecified_string = new string("<unspecified>");
99 return (device == nullptr) ? *unspecified_string : device->name();
100 }
101
102 // Returns whether a kernel should be cached.
KernelCacheEnabled(const OpDef & op_def)103 bool KernelCacheEnabled(const OpDef& op_def) {
104 if (data::DatasetOpKernel::IsDatasetOp(op_def)) {
105 return false;
106 }
107 // TODO(b/162540360): Revisit a way to mark kernels as uncachable once we have
108 // 5+ kernels to exclude.
109 return true;
110 }
111
112 // This function expects *handle to point to an existing tensor handle that is
113 // currently on "handle_device", but where the operation expects that input to
114 // reside on "expected_input_device". The function will arrange for this
115 // transfer to happen and will return OK on success and will storage a new
116 // handle to the equivalent tensor on the correct device in "*result". Or if an
117 // error is encountered, it will return a non-OK status and set "*result" to
118 // nullptr.
119 //
120 // `op_device` is passed in explicitly because `op->device()` might be
121 // unset and we might have selected some specific device to run this op on.
CopyInputToExpectedDevice(EagerContext * ctx,EagerOperation * op,Device * op_device,TensorHandle * handle,int i,Device * handle_device,Device * expected_input_device,TensorHandle ** result)122 Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op,
123 Device* op_device,
124 TensorHandle* handle, // op->Inputs()[i]
125 int i, Device* handle_device,
126 Device* expected_input_device,
127 TensorHandle** result) {
128 VLOG(6) << "Expected input device: " << expected_input_device->name()
129 << "; handle_device: " << handle_device->name();
130 // Should only be called when these don't match
131 DCHECK(expected_input_device != handle_device);
132 *result = nullptr;
133 const string& op_device_name = DeviceNameOrUnspecified(op_device);
134
135 switch (ctx->GetDevicePlacementPolicy()) {
136 case DEVICE_PLACEMENT_SILENT_FOR_INT32:
137 // TODO(xpan): See if we could bubble python related error up
138 // to python level.
139 if (handle->dtype == DT_INT32) {
140 // Note: enabling silent copies of int32 tensors to match behavior
141 // of graph mode.
142 break;
143 }
144 VLOG(6) << "DevicePlacementPolicy: DEVICE_PLACEMENT_SILENT_FOR_INT32 but "
145 "input type is not INT32.";
146 TF_FALLTHROUGH_INTENDED;
147 case DEVICE_PLACEMENT_EXPLICIT:
148 // tf.identity is allowed to copy, as indicated in the error message
149 // below.
150 if (op->Name() == "Identity" ||
151 op->Name() == "IdentityN"
152 // Constants start on CPU:0 and are copied via EagerConst to the
153 // current device.
154 || op->Name() == "_EagerConst") {
155 break;
156 }
157 return errors::InvalidArgument(
158 "Tensors on conflicting devices:"
159 " cannot compute ",
160 op->Name(), " as input #", i, " was expected to be on ",
161 expected_input_device->name(), " but is actually on ",
162 handle_device->name(), " (operation running on ", op_device_name, ")",
163 " Tensors can be copied explicitly using:"
164 " `with tf.device(device_name): x = tf.identity(x)`"
165 " or transparently copied by using"
166 " tf.config.experimental.set_device_policy('silent')."
167 " Copying tensors between devices may slow down your model");
168 case DEVICE_PLACEMENT_WARN:
169 LOG(WARNING) << "before computing " << op->Name() << " input #" << i
170 << " was expected to be on " << expected_input_device->name()
171 << " but is actually on " << handle_device->name()
172 << " (operation running on " << op_device_name
173 << "). This triggers a copy which can be a performance "
174 "bottleneck.";
175 break;
176 case DEVICE_PLACEMENT_SILENT: // Do nothing.
177 break;
178 }
179 // We are only here if the policy is warn or silent copies, so we should
180 // trigger a copy.
181 TensorHandle* result_handle = nullptr;
182 profiler::TraceMe activity(
183 [&] {
184 return absl::StrCat("_Send input ", i, " from ", handle_device->name(),
185 " to ", expected_input_device->name());
186 },
187 profiler::TraceMeLevel::kInfo);
188 Status status =
189 EagerCopyToDevice(handle, ctx, &op->Executor(), expected_input_device,
190 /* mirror= */ true, &result_handle);
191 activity.Stop();
192 if (!status.ok()) {
193 return Status(
194 status.code(),
195 absl::StrCat("Failed copying input tensor from ", handle_device->name(),
196 " to ", expected_input_device->name(), " in order to run ",
197 op->Name(), ": ", status.error_message()));
198 }
199
200 *result = result_handle;
201
202 return OkStatus();
203 }
204
205 // `op_device_name` the name of the device on which the op will run, if any.
206 // For functions running using function library runtime, the device can be
207 // unspecified.
ValidateInputTypeAndPlacement(EagerContext * ctx,EagerOperation * op,const core::RefCountPtr<KernelAndDevice> & kernel)208 Status ValidateInputTypeAndPlacement(
209 EagerContext* ctx, EagerOperation* op,
210 const core::RefCountPtr<KernelAndDevice>& kernel) {
211 profiler::TraceMe activity("ValidateInputTypeAndPlacement",
212 profiler::TraceMeLevel::kInfo);
213 const int n_inputs = op->Inputs().size();
214 if (kernel->num_inputs() != n_inputs) {
215 return errors::InvalidArgument("expected ", kernel->num_inputs(),
216 " inputs, got ", n_inputs);
217 }
218 const bool is_function = kernel->IsFunction();
219 if (n_inputs > 0) {
220 const DataType* input_types = &kernel->input_dtypes()[0];
221 const absl::InlinedVector<TensorHandle*, 4>* handles;
222 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&handles));
223 for (int i = 0; i < n_inputs; ++i) {
224 TensorHandle* handle = (*handles)[i];
225 Device* expected_device = kernel->InputDevice(i);
226 if (!kernel->IsFunction() && handle->Type() == TensorHandle::PACKED) {
227 // Extract a handle on the op device from a packed input.
228 // This happens when a function is marked for XLA compilation.
229 // MaybePackInputTensor guarantees that a primitive op has no packed
230 // input at this point.
231 for (int j = 0; j < handle->NumPackedHandles(); ++j) {
232 TensorHandle* h = nullptr;
233 TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(j, &h));
234 if ((h->op_device() != nullptr) &&
235 (h->op_device()->name() == op->DeviceName())) {
236 op->UpdateInput(i, h);
237 handle = h;
238 break;
239 }
240 }
241 }
242 Device* handle_device = handle->DeviceOrHostCPU(*ctx);
243 const bool maybe_copy =
244 !is_function || handle->Type() != TensorHandle::REMOTE;
245 VLOG(6) << "!is_function: " << !is_function;
246 VLOG(6) << "handle->Type(): " << handle->Type();
247 // If the input is already on the right device, then nothing to do.
248 if (expected_device != handle_device && maybe_copy) {
249 TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
250 handle, i, handle_device,
251 expected_device, &handle));
252 op->UpdateInput(i, handle);
253 // Unref handle since it has a ref as an input now
254 handle->Unref();
255 }
256 if (handle->dtype != input_types[i]) {
257 return errors::InvalidArgument(
258 "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
259 " was expected to be a ", DataTypeString(input_types[i]),
260 " tensor but is a ", DataTypeString(handle->dtype), " tensor");
261 }
262 }
263 }
264 return OkStatus();
265 }
266
GetOutputDTypes(EagerOperation * op,DataTypeVector * output_dtypes)267 Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
268 const auto& node_def = op->MutableAttrs()->BuildNodeDef();
269 const OpDef* op_def = nullptr;
270
271 const FunctionDef* function_def =
272 op->EagerContext().FuncLibDef()->Find(op->Name());
273 if (function_def != nullptr) {
274 op_def = &(function_def->signature());
275 } else {
276 TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
277 }
278
279 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));
280
281 return OkStatus();
282 }
283
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)284 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
285 const tensorflow::Fprint128& b) {
286 return {tensorflow::FingerprintCat64(a.low64, b.low64),
287 tensorflow::FingerprintCat64(a.high64, b.high64)};
288 }
289
FingerprintCat128(const tensorflow::Fprint128 & a,const int64_t b)290 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
291 const int64_t b) {
292 auto x = tensorflow::FingerprintCat64(a.low64, b);
293 return {x, tensorflow::FingerprintCat64(a.high64, x)};
294 }
295
GetKernelDef(const EagerOperation & op,const NodeDef * node_def,const Device * op_device)296 const KernelDef* GetKernelDef(const EagerOperation& op, const NodeDef* node_def,
297 const Device* op_device) {
298 if (node_def == nullptr || op_device == nullptr) return nullptr;
299 const KernelDef* kernel_def = nullptr;
300 Status s = FindKernelDef(DeviceType(op_device->device_type()), *node_def,
301 &kernel_def,
302 /*kernel_class_name=*/nullptr);
303 if (!s.ok()) return nullptr;
304 return kernel_def;
305 }
306
IsHostMemoryArg(const EagerOperation & op,const NodeDef * node_def,const Device * op_device,const KernelDef * kernel_def,const int port_id)307 bool IsHostMemoryArg(const EagerOperation& op, const NodeDef* node_def,
308 const Device* op_device, const KernelDef* kernel_def,
309 const int port_id) {
310 if (op.is_function()) return false;
311 if (node_def == nullptr) return false;
312 if (kernel_def == nullptr || op_device == nullptr) return false;
313 const auto& host_memory_args = kernel_def->host_memory_arg();
314 const OpDef& op_def = OpRegistry::Global()->LookUp(op.Name())->op_def;
315 const int arg_id = OpPortIdToArgId(*node_def, op_def.input_arg(), port_id);
316 return std::find(host_memory_args.begin(), host_memory_args.end(),
317 op_def.input_arg(arg_id).name()) != host_memory_args.end();
318 }
319
GetDeviceForInput(const EagerOperation & op,const EagerContext & ctx,const bool is_host_memory_arg,TensorHandle * tensor_handle,Device ** result)320 Status GetDeviceForInput(const EagerOperation& op, const EagerContext& ctx,
321 const bool is_host_memory_arg,
322 TensorHandle* tensor_handle, Device** result) {
323 Device* cpu_device = ctx.HostCPU();
324 string device_name;
325 if (tensor_handle->Type() != TensorHandle::LOCAL) {
326 Device* device = tensor_handle->device();
327 device_name = device != nullptr ? device->name() : cpu_device->name();
328 *result = (device == nullptr ? cpu_device : device);
329 } else if (tensor_handle->dtype == DT_RESOURCE) {
330 // Use the resource's actual device because it is the device that will
331 // influence partitioning the multi-device function.
332 const Tensor* tensor;
333 // TODO(fishx): Avoid blocking here.
334 TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
335 if (tensor->NumElements() == 0) {
336 return errors::InvalidArgument("Empty resource handle");
337 }
338 const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
339 device_name = handle.device();
340
341 Device* input_device;
342 TF_RETURN_IF_ERROR(
343 ctx.FindDeviceFromName(device_name.c_str(), &input_device));
344 *result = input_device;
345 } else {
346 Device* device = tensor_handle->device();
347 const bool is_tpu = device != nullptr && device->device_type() == "TPU";
348 // int32 return values can be placed on TPUs.
349 // int32 retrun values can be placed on device for eager operations.
350 const bool use_host_memory =
351 is_tpu || (!op.is_function() && device != cpu_device &&
352 !is_host_memory_arg)
353 ? MTypeFromDTypeIntsOnDevice(tensor_handle->dtype)
354 : MTypeFromDType(tensor_handle->dtype);
355 if (use_host_memory) {
356 *result = cpu_device;
357 } else {
358 // Eager ops executing as functions should have their preferred inputs set
359 // to the op's device. This allows us to avoid expensive D2H copies if a
360 // mirror of the tensor already exists on the op's device.
361 if (!op.is_function() && device != cpu_device && !is_host_memory_arg) {
362 device = absl::get<Device*>(op.Device());
363 }
364 *result = (device == nullptr ? cpu_device : device);
365 }
366 }
367 return OkStatus();
368 }
369
370 // Appends a TensorShape object to Fprint128 hash.
371 // For best performance, we would like to avoid dynamic memory allocation in
372 // this function.
373 // If "shape" has unknown rank, we attach "?" to hashed content; otherwise we
374 // attach every dim size to hashed content.
AppendTensorShapeToFingerprint(const PartialTensorShape & shape,Fprint128 * fingerprint)375 void AppendTensorShapeToFingerprint(const PartialTensorShape& shape,
376 Fprint128* fingerprint) {
377 if (shape.unknown_rank()) {
378 char c = '?';
379 *fingerprint = FingerprintCat128(*fingerprint, c);
380 } else {
381 for (int i = 0; i < shape.dims(); i++) {
382 int64_t dim = shape.dim_size(i);
383 *fingerprint = FingerprintCat128(*fingerprint, dim);
384 }
385 }
386 }
387
GetFuncAttr(const EagerOperation * op,const EagerContext & ctx,const char * attr_name,bool * value)388 Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx,
389 const char* attr_name, bool* value) {
390 Status status = op->Attrs().Get(attr_name, value);
391 if (status.ok()) {
392 VLOG(2) << "Caller explicitly specifies "
393 << (attr_name ? "=true " : "=false, ") << op->DebugString();
394 return OkStatus();
395 }
396
397 const FunctionDef* function_def =
398 ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name());
399 if (function_def == nullptr) {
400 return errors::NotFound("Failed to find function '", op->Name(), "'");
401 }
402
403 status = GetNodeAttr(AttrSlice(&function_def->attr()), attr_name, value);
404 if (status.ok()) {
405 VLOG(2) << "Function definition explicitly specifies "
406 << (attr_name ? "=true" : "=false");
407 return OkStatus();
408 }
409 return status;
410 }
411
MustCompileWithXLA(const EagerOperation * op,const EagerContext & ctx,bool * compile_with_xla)412 Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx,
413 bool* compile_with_xla) {
414 if (!op->is_function()) {
415 *compile_with_xla = false;
416 return OkStatus();
417 }
418
419 if (op->eager_func_params().has_value() &&
420 op->eager_func_params().value().is_component_function) {
421 // If the op is a component of a multi-device function, don't compile it
422 // with XLA.
423 *compile_with_xla = false;
424 return OkStatus();
425 }
426
427 Status status = GetFuncAttr(op, ctx, kXlaMustCompileAttr, compile_with_xla);
428 if (status.ok()) {
429 return OkStatus();
430 }
431
432 // No explicit requests. Compile for XLA devices by default.
433 if (op->GetDeviceParsedName().type == "TPU" ||
434 op->GetDeviceParsedName().type == "XLA_GPU" ||
435 op->GetDeviceParsedName().type == "XLA_CPU") {
436 VLOG(2) << "Compiling " << op->Name()
437 << " with XLA because it is running on an XLA device "
438 << op->GetDeviceParsedName().type;
439 *compile_with_xla = true;
440 } else {
441 *compile_with_xla = false;
442 }
443
444 return OkStatus();
445 }
446
VerifyWrappableInCallOp(const OpDef & opdef,EagerOperation * op)447 Status VerifyWrappableInCallOp(const OpDef& opdef, EagerOperation* op) {
448 absl::flat_hash_set<string> opdef_attrs;
449 for (const auto& attr : opdef.attr()) {
450 opdef_attrs.insert(attr.name());
451 }
452 const auto& node_def = op->MutableAttrs()->BuildNodeDef();
453 for (const auto& attr : node_def.attr()) {
454 if (opdef_attrs.find(attr.first) == opdef_attrs.end()) {
455 return errors::Unimplemented("EagerOperation: ", op->Name(),
456 " has a private attr '", attr.first, "'.");
457 }
458 }
459 return OkStatus();
460 }
461
462 using ProtoArgListType = protobuf::RepeatedPtrField<OpDef_ArgDef>;
463
EscapeOrigName(const string & orig_name)464 string EscapeOrigName(const string& orig_name) {
465 // Replace _ with __ in the original name to avoid name conflicts.
466 return absl::StrReplaceAll(orig_name, {{"_", "__"}});
467 }
468
469 // Variadic args are flattened during wrapping. This utility returns the name
470 // of a flattened arg/attr.
GetFlatName(const string orig_name,int index)471 string GetFlatName(const string orig_name, int index) {
472 return absl::StrCat(EscapeOrigName(orig_name), "_", index);
473 }
474
475 // Builds the name of the wrapping FunctionDef for an eager op.
476 //
477 // For ops without variadic inputs/outputs, the name is simply __wrapped_OpType.
478 //
479 // For ops with variadic inputs/outputs, the arity of each variadic attr is
480 // encoded in the name. For example:
481 //
482 // IdentityN[T:[DT_FLOAT, DT_INT64]] -> __wrapped__IdentityN_T_2
483 // Concat[N:2, T:DT_FLOAT] -> __wrapped__Concat_N_2
BuildWrappedOpName(EagerOperation * op,const OpDef & opdef,const AbstractOpAttrs * op_attrs,string * name)484 Status BuildWrappedOpName(EagerOperation* op, const OpDef& opdef,
485 const AbstractOpAttrs* op_attrs, string* name) {
486 string fname = absl::StrCat("__wrapped__", EscapeOrigName(op->Name()));
487 // For every variadic arg in `args`, populates `attr_to_len` with
488 // (attr_name, len(arg)).
489 auto FillAttrToLen = [op_attrs, op](
490 const ProtoArgListType& args,
491 absl::btree_map<string, int>* attr_to_len) {
492 for (const auto& arg : args) {
493 if (!arg.type_list_attr().empty()) {
494 gtl::InlinedVector<DataType, 4> type_list;
495 TF_RETURN_IF_ERROR(
496 op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
497 (*attr_to_len)[arg.type_list_attr()] = type_list.size();
498 } else if (!arg.number_attr().empty()) {
499 int64_t number_attr;
500 if (!op_attrs->GetInt(arg.number_attr(), &number_attr)) {
501 return errors::Internal("Unable to read attr ", arg.number_attr(),
502 " for op ", op->Name());
503 }
504 (*attr_to_len)[arg.number_attr()] = number_attr;
505 }
506 }
507 return OkStatus();
508 };
509 absl::btree_map<string, int> attr_to_len;
510 TF_RETURN_IF_ERROR(FillAttrToLen(opdef.input_arg(), &attr_to_len));
511 TF_RETURN_IF_ERROR(FillAttrToLen(opdef.output_arg(), &attr_to_len));
512 for (auto& name_len : attr_to_len) {
513 absl::StrAppend(&fname, "_", name_len.first, "_", name_len.second);
514 }
515 // The NodeDef in the FunctionDef gets placed on `op-DeviceName()` to ensure
516 // placement consistency with eager mode.
517 // TODO(b/200153278): Ideally we would just forward the call op's device at
518 // runtime but currently there is no way to do it so we incur the cost of
519 // creating extra FunctionDefs.
520 absl::StrAppend(&fname, "_device_", op->DeviceName());
521 *name = fname;
522 return OkStatus();
523 }
524
525 // Validates the node def. This is required when running in eager op as function
526 // mode because this code path does not go through the _apply_op_helper's
527 // validation (which is reached when executing in graph mode)
528 // or the eager execution's validation (which is reached via the CreateOpKernel
529 // call).
ValidateOp(EagerOperation * op)530 Status ValidateOp(EagerOperation* op) {
531 const NodeDef& node_def = op->MutableAttrs()->BuildNodeDef();
532 const OpDef* op_def;
533 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
534 return ValidateNodeDef(node_def, *op_def);
535 }
536
537 // Builds the signature of the wrapping FunctionDef for an eager op.
538 //
539 // For ops without variadic inputs/outputs, the signature is the same as the
540 // OpDef of the original op.
541 //
542 // Variadic inputs/outputs get flattened since we do not support executing
543 // functions with variadic signatures.
544 //
545 // TODO(srbs): These examples should be tests.
546 //
547 // Examples:
548 //
549 // Mixed type list:
550 //
551 // op {
552 // name: "IdentityN"
553 // input_arg {
554 // name: "input"
555 // type_list_attr: "T"
556 // }
557 // output_arg {
558 // name: "output"
559 // type_list_attr: "T"
560 // }
561 // attr {
562 // name: "T"
563 // type: "list(type)"
564 // has_minimum: true
565 // minimum: 1
566 // }
567 // }
568 //
569 // With two inputs T=[DT_FLOAT, DT_INT64] would convert to
570 //
571 // op {
572 // name: "__wrapped__IdentityN_T_2"
573 // input_arg {
574 // name: "input_0"
575 // type_attr: "T_0"
576 // }
577 // input_arg {
578 // name: "input_1"
579 // type_attr: "T_1"
580 // }
581 // output_arg {
582 // name: "output_0"
583 // type_attr: "T_0"
584 // }
585 // output_arg {
586 // name: "output_1"
587 // type_attr: "T_1"
588 // }
589 // attr {
590 // name: "T_0"
591 // type: "type"
592 // }
593 // attr {
594 // name: "T_1"
595 // type: "type"
596 // }
597 // attr {
598 // name: "T"
599 // type: "list(type)"
600 // has_minimum: true
601 // minimum: 1
602 // }
603 // }
604 //
605 // Note that the list(type) attr is preserved so that it can get copied to the
606 // inner op via a placeholder. This allows additional verification.
607 //
608 // Single type list:
609 //
610 // op {
611 // name: "ConcatV2"
612 // input_arg {
613 // name: "values"
614 // type_attr: "T"
615 // number_attr: "N"
616 // }
617 // attr {
618 // name: "N"
619 // type: "int"
620 // has_minimum: true
621 // minimum: 2
622 // }
623 // attr {
624 // name: "T"
625 // type: "type"
626 // }
627 // [axis, output, Tidx are simply copied]
628 // }
629 //
630 // With two inputs N=2 would convert to:
631 //
632 // op {
633 // name: "__wrapped__ConcatV2_N_2"
634 // input_arg {
635 // name: "values_0"
636 // type_attr: "T"
637 // }
638 // input_arg {
639 // name: "values_1"
640 // type_attr: "T"
641 // }
642 // attr {
643 // name: "N"
644 // type: "int"
645 // has_minimum: true
646 // minimum: 2
647 // }
648 // attr {
649 // name: "T"
650 // type: "type"
651 // }
652 // [axis, output, Tidx are simply copied]
653 // }
654 //
655 // Note that the N attr is preserved so that it can get copied to the
656 // inner op via a placeholder. This allows additional verification.
BuildWrappedOpSignature(EagerOperation * op,const OpDef & opdef,const string & fname,OpDef & signature)657 Status BuildWrappedOpSignature(EagerOperation* op, const OpDef& opdef,
658 const string& fname, OpDef& signature) {
659 signature = opdef;
660 signature.clear_input_arg();
661 signature.clear_output_arg();
662 signature.set_name(fname);
663 auto op_attrs = op->GetOpAttrs();
664 auto FillSignatureArgs = [op_attrs, op](
665 const ProtoArgListType& opdef_args,
666 ProtoArgListType* sig_args,
667 absl::flat_hash_set<string>& new_attrs) {
668 for (const auto& arg : opdef_args) {
669 if (!arg.type_list_attr().empty()) {
670 gtl::InlinedVector<DataType, 4> type_list;
671 TF_RETURN_IF_ERROR(
672 op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
673 for (size_t i = 0; i < type_list.size(); i++) {
674 auto arg_def = sig_args->Add();
675 arg_def->set_name(GetFlatName(arg.name(), i));
676 auto attr_name = GetFlatName(arg.type_list_attr(), i);
677 new_attrs.insert(attr_name);
678 arg_def->set_type_attr(std::move(attr_name));
679 }
680 } else if (!arg.number_attr().empty()) {
681 int64_t number_attr;
682 if (!op_attrs->GetInt(arg.number_attr(), &number_attr)) {
683 return errors::Internal("Unable to read attr ", arg.number_attr(),
684 " for op ", op->Name());
685 }
686 for (int64_t i = 0; i < number_attr; i++) {
687 auto arg_def = sig_args->Add();
688 arg_def->set_name(GetFlatName(arg.name(), i));
689 if (!arg.type_attr().empty()) {
690 arg_def->set_type_attr(arg.type_attr());
691 } else {
692 arg_def->set_type(arg.type());
693 }
694 }
695 } else {
696 auto arg_def = sig_args->Add();
697 *arg_def = arg;
698 arg_def->set_name(EscapeOrigName(arg.name()));
699 if (!arg.type_attr().empty()) {
700 // Don't escape: type attrs are still referenced by the original name.
701 arg_def->set_type_attr(arg.type_attr());
702 }
703 }
704 }
705 return OkStatus();
706 };
707 absl::flat_hash_set<string> new_attrs;
708 TF_RETURN_IF_ERROR(FillSignatureArgs(
709 opdef.input_arg(), signature.mutable_input_arg(), new_attrs));
710 TF_RETURN_IF_ERROR(FillSignatureArgs(
711 opdef.output_arg(), signature.mutable_output_arg(), new_attrs));
712 for (auto& attr_name : new_attrs) {
713 auto attr_def = signature.mutable_attr()->Add();
714 attr_def->set_name(attr_name);
715 attr_def->set_type("type");
716 }
717 return OkStatus();
718 }
719
720 // For mixed type inputs "list(type)" we create new attributes in the signature
721 // for each element tensor (See examples in BuildWrappedOpSignature). Here
722 // we construct the values for those attributes and set them on the wrapped op.
AddMixedTypeListAttrs(EagerOperation * wrapped_op,const AbstractOpAttrs * op_attrs,const OpDef & opdef)723 Status AddMixedTypeListAttrs(EagerOperation* wrapped_op,
724 const AbstractOpAttrs* op_attrs,
725 const OpDef& opdef) {
726 auto FillAttrsToAdd =
727 [op_attrs](const ProtoArgListType& opdef_args,
728 absl::flat_hash_map<string, DataType>* attrs_to_add) {
729 for (const auto& arg : opdef_args) {
730 if (!arg.type_list_attr().empty()) {
731 gtl::InlinedVector<DataType, 4> type_list;
732 TF_RETURN_IF_ERROR(
733 op_attrs->GetTypeList(arg.type_list_attr(), &type_list));
734 for (size_t i = 0; i < type_list.size(); i++) {
735 auto attr_name = GetFlatName(arg.type_list_attr(), i);
736 (*attrs_to_add)[attr_name] = type_list[i];
737 }
738 }
739 }
740 return OkStatus();
741 };
742 absl::flat_hash_map<string, DataType> attrs_to_add;
743 TF_RETURN_IF_ERROR(FillAttrsToAdd(opdef.input_arg(), &attrs_to_add));
744 TF_RETURN_IF_ERROR(FillAttrsToAdd(opdef.output_arg(), &attrs_to_add));
745 for (auto& name_type : attrs_to_add) {
746 TF_RETURN_IF_ERROR(
747 wrapped_op->SetAttrType(name_type.first.data(), name_type.second));
748 }
749 // TODO(srbs): Rename all original attributes using EscapeOrigName.
750 return OkStatus();
751 }
752
753 // Maps the op's outputs to the function outputs. Mainly useful for variadic
754 // outputs which need to be flattened.
PopulateRetMap(FunctionDef * fdef,const AbstractOpAttrs * op_attrs,const EagerOperation * op,const OpDef & opdef,const OpDef & signature,const string & node_name)755 Status PopulateRetMap(FunctionDef* fdef, const AbstractOpAttrs* op_attrs,
756 const EagerOperation* op, const OpDef& opdef,
757 const OpDef& signature, const string& node_name) {
758 int next_sig_output = 0;
759 for (size_t i = 0; i < opdef.output_arg_size(); i++) {
760 const auto& output_arg = opdef.output_arg(i);
761 if (!output_arg.type_list_attr().empty()) {
762 gtl::InlinedVector<DataType, 4> type_list;
763 TF_RETURN_IF_ERROR(
764 op_attrs->GetTypeList(output_arg.type_list_attr(), &type_list));
765 for (int j = 0; j < type_list.size(); j++) {
766 (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
767 absl::StrCat(node_name, ":", output_arg.name(), ":", j);
768 }
769 } else if (!output_arg.number_attr().empty()) {
770 int64_t number_attr;
771 if (!op_attrs->GetInt(output_arg.number_attr(), &number_attr)) {
772 return errors::Internal("Unable to read attr ",
773 output_arg.number_attr(), " for op ",
774 op->Name());
775 }
776 for (int j = 0; j < number_attr; j++) {
777 (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
778 absl::StrCat(node_name, ":", output_arg.name(), ":", j);
779 }
780 } else {
781 (*fdef->mutable_ret())[signature.output_arg(next_sig_output++).name()] =
782 absl::StrCat(node_name, ":", output_arg.name(), ":0");
783 }
784 }
785 return OkStatus();
786 }
787
788 #ifdef INTEL_MKL
GetMKLNodeDef(NodeDef * ndef)789 inline void GetMKLNodeDef(NodeDef* ndef) {
790 // All MKL eager ops have `_kernel` private attribute that needs to be set
791 // to a fixed label.
792 AttrValue attr_kernel;
793 attr_kernel.set_s(mkl_op_registry::kMklNameChangeOpLabel);
794 (*ndef->mutable_attr()).insert({"_kernel", attr_kernel});
795 }
796 #endif // INTEL_MKL
797
WrapInCallOp(EagerOperation * op,EagerOperation ** wrapped_op)798 Status WrapInCallOp(EagerOperation* op, EagerOperation** wrapped_op) {
799 DCHECK(!op->is_function());
800 const OpDef& opdef = OpRegistry::Global()->LookUp(op->Name())->op_def;
801 // Raise an error for ops which don't support wrapping yet. This includes
802 // ops with list inputs/outputs and ops with private attrs.
803 // TODO(srbs): Support list inputs/outputs.
804 TF_RETURN_IF_ERROR(VerifyWrappableInCallOp(opdef, op));
805
806 // Build a FunctionDef containing op as a node and register with context.
807 // TODO(srbs): Here we are unable to distinguish between a FunctionDef for
808 // a wrapped eager op and an existing user defined function registered with
809 // the context e.g. with something like
810 // @tf.function
811 // def __wrapped__Add(x, y):
812 // ...
813 // This can be avoided by introducing a dict in EagerContext that stores a
814 // mapping from the eager op's name to its unique FunctionDef name.
815 auto op_attrs = op->GetOpAttrs();
816 string fname;
817 TF_RETURN_IF_ERROR(BuildWrappedOpName(op, opdef, op_attrs, &fname));
818 if (!op->EagerContext().GetFunctionDef(fname)) {
819 FunctionDef fdef;
820 // Set signature.
821 TF_RETURN_IF_ERROR(
822 BuildWrappedOpSignature(op, opdef, fname, *fdef.mutable_signature()));
823 // Add node.
824 NodeDef* ndef = fdef.add_node_def();
825 ndef->set_op(op->Name());
826 ndef->set_name(op->Name()); // This could be anything.
827 const auto& signature = fdef.signature();
828 for (size_t i = 0; i < signature.input_arg_size(); i++) {
829 ndef->add_input(absl::StrCat(fdef.signature().input_arg(i).name(), ":0"));
830 }
831 // TODO(srbs): Private attrs on the op are dropped here and applied to
832 // the call op instead. If this causes problems we might have to copy those
833 // attrs to this ndef. That would require updating fname to contain a hash
834 // of such attributes.
835 for (const auto& attr : opdef.attr()) {
836 (*ndef->mutable_attr())[attr.name()].set_placeholder(attr.name());
837 }
838 // Set the device of this node to be the exact same one that eager mode
839 // would have used.
840 // TODO(b/200153278): Ideally we would just forward the call op's device at
841 // runtime but currently there is no way to do it.
842 ndef->set_device(op->DeviceName());
843
844 #ifdef INTEL_MKL
845 if (IsMKLEnabled() &&
846 absl::StartsWith(op->Name(), mkl_op_registry::kMklOpPrefix)) {
847 GetMKLNodeDef(ndef);
848 }
849 #endif // INTEL_MKL
850
851 // Set `ret` map.
852 TF_RETURN_IF_ERROR(
853 PopulateRetMap(&fdef, op_attrs, op, opdef, signature, ndef->name()));
854 VLOG(1) << fdef.DebugString();
855 TF_RETURN_IF_ERROR(op->EagerContext().AddFunctionDef(std::move(fdef)));
856 }
857 // Build the call op.
858 auto& ctx = op->EagerContext();
859 AbstractOperationPtr call_op(ctx.CreateOperation());
860 TF_RETURN_IF_ERROR(call_op->Reset(fname.c_str(), op->DeviceName().c_str()));
861 for (auto t : op->Inputs()) {
862 TF_RETURN_IF_ERROR(call_op->AddInput(t));
863 }
864 *wrapped_op = down_cast<EagerOperation*>(call_op.release());
865 // Attributes on the elementary eager operation are applied to the call op and
866 // to the NodeDef inside the FunctionDef. This allows us to have a single
867 // FunctionDef for different attribute values. When the function is
868 // instantiated, these attributes get forwarded to the NodeDef. This is done
869 // by setting the AttrValue.placeholder field for the NodeDef attrs.
870 (*wrapped_op)->AddAttrs(op_attrs);
871 return AddMixedTypeListAttrs(*wrapped_op, op_attrs, opdef);
872 }
873
874 // Necessary condition to place int args/retvals on device but not sufficient.
875 // For eager operations return values can be placed on the device for use
876 // by subsequent eager ops. E.g.
877 // with tf.device("/GPU:0"):
878 // x = tf.random_uniform(shape=(2, 2), maxval=5, dtype=tf.int32)
879 // y = tf.random_uniform(shape=(2, 2), maxval=5, dtype=tf.int32)
880 // z = tf.bitwise.bitwise_and(x, y)
881 // In the above example `z` can use the outputs of `x` and `y` without needing
882 // an H2D copy if x and y are left on-device.
IntArgsAndRetvalsOnDevice(EagerOperation * op,const KernelDef * kernel_def)883 bool IntArgsAndRetvalsOnDevice(EagerOperation* op,
884 const KernelDef* kernel_def) {
885 // We choose to leave `EagerConsts`
886 // on HOST to avoid `shape` and other arguments that are traditionally pinned
887 // to HostMemory from being placed on-device and then being copied to host via
888 // an expensive D2H transfer.
889 if (op->Name() == "_EagerConst") return false;
890
891 // Check if any of the Op's output_arg(s) are pinned to Host.
892 if (kernel_def == nullptr) return false;
893 const OpDef& op_def = OpRegistry::Global()->LookUp(op->Name())->op_def;
894 for (const string& host_memory_arg : kernel_def->host_memory_arg()) {
895 for (const auto& output_arg : op_def.output_arg()) {
896 if (output_arg.name() == host_memory_arg) {
897 return false;
898 }
899 }
900 }
901
902 return true;
903 }
904
GetKernelCacheKey(const EagerOperation & op,const Fprint128 & op_cache_key,const std::vector<Device * > & input_device_ptrs,const std::unordered_map<int,DtypeAndPartialTensorShape> & input_resource_variable_dtypes_and_shapes)905 StatusOr<Fprint128> GetKernelCacheKey(
906 const EagerOperation& op, const Fprint128& op_cache_key,
907 const std::vector<Device*>& input_device_ptrs,
908 const std::unordered_map<int, DtypeAndPartialTensorShape>&
909 input_resource_variable_dtypes_and_shapes) {
910 EagerContext& ctx = op.EagerContext();
911
912 Fprint128 cache_key = op_cache_key;
913 /// Include soft placement policy in cache key since the placement strategy
914 // can change and thus affect which kernel is picked.
915 cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement());
916
917 // Include run_eager_op_as_function policy in cache key since the execution
918 // strategy can change and affect which kernel is picked.
919 VLOG(3) << "ctx.RunEagerOpAsFunction(): " << ctx.RunEagerOpAsFunction();
920 cache_key = FingerprintCat128(cache_key, ctx.RunEagerOpAsFunction());
921
922 // When running in eager_op_as_function mode Send/Recv ops need to be
923 // placed on the same rendezvous to match the behaviour of eager mode.
924 bool reuse_rendezvous_for_functions =
925 (ctx.RunEagerOpAsFunction() && !op.is_function()) ||
926 ctx.GetReuseRendezvousForFunctions();
927 // The launch-time rendezvous reuse setting is bundled with the kernel, so we
928 // need to include it in the cache key.
929 cache_key = FingerprintCat128(cache_key, reuse_rendezvous_for_functions);
930
931 for (int i = 0, end = input_device_ptrs.size(); i < end; ++i) {
932 cache_key = FingerprintCat128(cache_key,
933 Fingerprint128(input_device_ptrs[i]->name()));
934
935 auto input_resource = input_resource_variable_dtypes_and_shapes.find(i);
936 if (input_resource != input_resource_variable_dtypes_and_shapes.end()) {
937 // const DtypeAndPartialTensorShape& dtype_and_shape
938 const DtypeAndPartialTensorShape& dtype_and_shape =
939 input_resource->second;
940 // Add _Arg index, dtype and shape to "cache_key".
941 cache_key = FingerprintCat128(cache_key, i);
942 cache_key = FingerprintCat128(cache_key, dtype_and_shape.dtype);
943 AppendTensorShapeToFingerprint(dtype_and_shape.shape, &cache_key);
944 }
945 }
946
947 return cache_key;
948 }
949
950 // Extracts function input info for `op` with `kernel_def`.
951 // The following are extracted:
952 // `input_device_ptrs` - The input devices of `op`.
953 // `composite_devices` - Maps from a CompositeDevice name to a list of
954 // physical device names.
955 // `input_resource_variable_dtypes_shape` - A map from input index
956 // to dtype and shapes for resource inputs.
ExtractFunctionInputInfo(EagerOperation * op,const KernelDef * kernel_def,std::vector<Device * > & input_device_ptrs,absl::flat_hash_map<string,const std::vector<string> * > & composite_devices,std::unordered_map<int,DtypeAndPartialTensorShape> & input_resource_variable_dtypes_and_shapes)957 Status ExtractFunctionInputInfo(
958 EagerOperation* op, const KernelDef* kernel_def,
959 std::vector<Device*>& input_device_ptrs,
960 absl::flat_hash_map<string, const std::vector<string>*>& composite_devices,
961 std::unordered_map<int, DtypeAndPartialTensorShape>&
962 input_resource_variable_dtypes_and_shapes) {
963 profiler::TraceMe activity("EagerCopyToDevice",
964 profiler::TraceMeLevel::kInfo);
965 EagerContext& ctx = op->EagerContext();
966 input_device_ptrs.reserve(op->Inputs().size());
967 const absl::InlinedVector<TensorHandle*, 4>* inputs;
968 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
969 Device* op_device = nullptr;
970 const NodeDef* node_def = nullptr;
971 if (!op->is_function()) {
972 op_device = absl::get<Device*>(op->Device());
973 node_def = &op->MutableAttrs()->BuildNodeDef();
974 }
975 for (int i = 0, end = inputs->size(); i < end; ++i) {
976 TensorHandle* input = (*inputs)[i];
977
978 Device* input_device;
979 bool is_host_memory_arg =
980 IsHostMemoryArg(*op, node_def, op_device, kernel_def, i);
981 TF_RETURN_IF_ERROR(
982 GetDeviceForInput(*op, ctx, is_host_memory_arg, input, &input_device));
983 VLOG(1) << op->Name() << ":input:" << i << " " << input_device->name();
984 input_device_ptrs.push_back(input_device);
985 CompositeDevice* composite_device = nullptr;
986 if (ctx.FindCompositeDeviceFromName(input_device->name(), &composite_device)
987 .ok()) {
988 composite_devices[input_device->name()] =
989 composite_device->underlying_devices();
990 }
991 if (input->dtype == DT_RESOURCE) {
992 // We only care about data type and shape for resource variable inputs.
993 // But we have no way to tell if input is resource variable (other than
994 // looking it up in ResourceMgr, which is slow). So we just get
995 // resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
996 // resource_dtypes_and_shapes is not empty, take the first element.
997 std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
998 TF_RETURN_IF_ERROR(
999 input->GetResourceHandleDtypesAndShapes(&resource_dtypes_and_shapes));
1000 if (!resource_dtypes_and_shapes.empty()) {
1001 const DtypeAndPartialTensorShape& dtype_and_shape =
1002 resource_dtypes_and_shapes.at(0);
1003 input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
1004 }
1005 }
1006 }
1007 return OkStatus();
1008 }
1009
SetOpDevice(EagerContext & ctx,EagerOperation * op,Device ** device)1010 Status SetOpDevice(EagerContext& ctx, EagerOperation* op, Device** device) {
1011 // Here in local execute, set preferred device to be on the local task to
1012 // avoid placing op on a remote device with higher priority.
1013 const DeviceNameUtils::ParsedName& preferred_device =
1014 DeviceNameUtils::HasSomeDetails(op->GetDeviceParsedName())
1015 ? op->GetDeviceParsedName()
1016 : DeviceNameUtils::AddressSpace(ctx.HostCPUParsedName());
1017 // Note: We use the unwrapped op for inferring the device.
1018 // Without this, when wrapping CPU-only ops like RangeDataset we would
1019 // place the wrapped op on a GPU (if one is available) which leads to
1020 // errors because placer pins the function output nodes to GPU thereby
1021 // forcing a H2D copy of the dataset variant which is not supported.
1022 auto ndef = op->MutableAttrs()->BuildNodeDef();
1023 #ifdef INTEL_MKL
1024 if (IsMKLEnabled() &&
1025 absl::StartsWith(op->Name(), mkl_op_registry::kMklOpPrefix)) {
1026 GetMKLNodeDef(&ndef);
1027 }
1028 #endif // INTEL_MKL
1029
1030 TF_RETURN_IF_ERROR(ctx.SelectDevice(preferred_device, ndef, device));
1031
1032 VLOG(1) << "PreferredDevice " << op->Name() << ": " << preferred_device;
1033 VLOG(1) << "Placer place op [" << op->Name()
1034 << "] on device: " << (*device)->name();
1035 VLOG(4) << "Available kernels for " << op->Name() << " are"
1036 << KernelsRegisteredForOp(op->Name());
1037 op->SetDevice(*device);
1038 return OkStatus();
1039 }
1040
GetDeviceCacheKey(EagerOperation * op,const EagerContext & ctx)1041 Fprint128 GetDeviceCacheKey(EagerOperation* op, const EagerContext& ctx) {
1042 Fprint128 device_cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
1043 device_cache_key =
1044 FingerprintCat128(device_cache_key, ctx.AllowSoftPlacement());
1045 return device_cache_key;
1046 }
1047
GetOrCreateKernelAndDevice(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,core::RefCountPtr<KernelAndDevice> * out_kernel)1048 Status GetOrCreateKernelAndDevice(
1049 EagerOperation* op, TensorHandle** retvals, int* num_retvals,
1050 core::RefCountPtr<KernelAndDevice>* out_kernel) {
1051 EagerContext& ctx = op->EagerContext();
1052 Device* device = absl::get<Device*>(op->Device());
1053
1054 // Set the EagerOperation's device prior to extracting the input_device_ptrs
1055 // to avoid any redundant H2D/D2H copies.
1056 if (device == nullptr && !op->is_function()) {
1057 Fprint128 device_cache_key = GetDeviceCacheKey(op, ctx);
1058 device = ctx.GetCachedDevice(device_cache_key);
1059 if (device == nullptr) {
1060 TF_RETURN_IF_ERROR(SetOpDevice(ctx, op, &device));
1061 ctx.AddDeviceToCache(device_cache_key, device);
1062 } else {
1063 op->SetDevice(device);
1064 }
1065 }
1066
1067 // Save the original value of reuse_rendezvous_for_functions from the context.
1068 bool reuse_rendezvous_for_functions_original_value =
1069 ctx.GetReuseRendezvousForFunctions();
1070 // When running in eager_op_as_function mode Send/Recv ops need to be
1071 // placed on the same rendezvous to match the behaviour of eager mode.
1072 bool reuse_rendezvous_for_functions =
1073 (ctx.RunEagerOpAsFunction() && !op->is_function()) ||
1074 reuse_rendezvous_for_functions_original_value;
1075
1076 std::vector<Device*> input_device_ptrs;
1077 absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
1078 std::unordered_map<int, DtypeAndPartialTensorShape>
1079 input_resource_variable_dtypes_and_shapes;
1080 const KernelDef* kernel_def = nullptr;
1081 if (!op->is_function()) {
1082 const NodeDef* node_def = &op->MutableAttrs()->BuildNodeDef();
1083 kernel_def = GetKernelDef(*op, node_def, device);
1084 }
1085 if (op->is_function() || ctx.RunEagerOpAsFunction()) {
1086 TF_RETURN_IF_ERROR(ExtractFunctionInputInfo(
1087 op, kernel_def, input_device_ptrs, composite_devices,
1088 input_resource_variable_dtypes_and_shapes));
1089 }
1090
1091 TF_ASSIGN_OR_RETURN(
1092 Fprint128 cache_key,
1093 GetKernelCacheKey(*op, op->MutableAttrs()->CacheKey(op->DeviceName()),
1094 input_device_ptrs,
1095 input_resource_variable_dtypes_and_shapes));
1096 core::RefCountPtr<KernelAndDevice> kernel = ctx.GetCachedKernel(cache_key);
1097 AbstractOperationPtr wrapped_op_releaser;
1098 // We can eliminate some overhead by running simple functions using regular
1099 // CallOp kernel. However, it is tricky to figure out which functions should
1100 // be run using CallOp. Also, currently CallOp runs neither optimization
1101 // passes (needed for TPU/XLA) nor grappler.
1102 // Here are some cases where a function should be run in multi-device mode:
1103 // - Function takes at least two resources on different devices.
1104 // - Function takes a resource on deviceA and a body op explicitly placed
1105 // on deviceB.
1106 // - Function has a colocation constraint.
1107 // - Function has an explicit device annotation (which might not be using
1108 // full canonical device name) different from op_device. Note that false
1109 // positives are ok.
1110 // - Function has a node or a (node) attribute that can potentially make
1111 // the function multi-device after a rewrite pass (e.g. various XLA/TPU
1112 // special nodes and attributes)
1113 if (kernel == nullptr) {
1114 VLOG(2) << "Creating new kernel for " << op->Name() << " on device "
1115 << DeviceNameOrUnspecified(absl::get<Device*>(op->Device()));
1116 bool run_function_with_flr = false;
1117 bool function_outputs_on_op_device = false;
1118 absl::optional<string> xla_compile_device_type;
1119 if (op->is_function()) {
1120 bool compile_with_xla;
1121 TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla));
1122 if (compile_with_xla) {
1123 if (ctx.JitCompileRewrite()) {
1124 xla_compile_device_type = op->GetDeviceParsedName().type;
1125 run_function_with_flr = true;
1126 } else {
1127 // Note that it is not ideal, but currently correct, to set this
1128 // attribute after computing the kernel cache key above.
1129 // Note: If the attribute is already set to true, this is a noop.
1130 op->MutableAttrs()->Set(kXlaMustCompileAttr, true);
1131 }
1132 } else {
1133 run_function_with_flr = true;
1134 }
1135 GetFuncAttr(op, ctx, kOutputsOnOpDevice, &function_outputs_on_op_device)
1136 .IgnoreError();
1137 }
1138
1139 VLOG(2) << op->Name() << " function_outputs_on_op_device: "
1140 << function_outputs_on_op_device;
1141 if (device == nullptr) {
1142 TF_RETURN_IF_ERROR(SetOpDevice(ctx, op, &device));
1143 } else {
1144 VLOG(1) << "Device for [" << op->Name()
1145 << "] already set to: " << device->name();
1146 }
1147
1148 // Note: We wrap the eager op AFTER the device has been inferred to ensure
1149 // that placement of the NodeDef in the function is exactly the same as in
1150 // eager mode. This is specially important for cases where the
1151 // preferred device is not the actual device on which the op is run.
1152 // E.g. the preferred device for a `RangeDataset` op could be set to `GPU`
1153 // but `ctx->SelectDevice` would still place it on CPU. Placer on the other
1154 // hand would throw an error.
1155 //
1156 // Note: The wrapped function is never jit compiled but rather run via the
1157 // FLR. This is needed because certain ops e.g. `VarHandleOp` can not be
1158 // jit compiled. Ideally we would run this via the jit compiled path and
1159 // expect unsupported ops to be outside compiled but that is not supported
1160 // on GPUs right now.
1161 bool allow_small_function_optimizations = false;
1162 bool int_args_and_retvals_on_device = false;
1163 bool allow_control_flow_sync_execution = false;
1164 // TODO(b/176491312): Remove this if shape inference on import flag is
1165 // removed.
1166 bool shape_inference_on_tfe_dialect_import = true;
1167 if (ctx.RunEagerOpAsFunction() && !op->is_function()) {
1168 EagerOperation* wrapped_op = nullptr;
1169 TF_RETURN_IF_ERROR(ValidateOp(op));
1170 TF_RETURN_IF_ERROR(WrapInCallOp(op, &wrapped_op));
1171 DCHECK(wrapped_op);
1172 DCHECK(wrapped_op->is_function());
1173 wrapped_op_releaser.reset(wrapped_op);
1174 run_function_with_flr = true;
1175 allow_small_function_optimizations = true;
1176 allow_control_flow_sync_execution = true;
1177 shape_inference_on_tfe_dialect_import = false;
1178 int_args_and_retvals_on_device =
1179 IntArgsAndRetvalsOnDevice(op, kernel_def);
1180 op = wrapped_op;
1181 if (int_args_and_retvals_on_device) {
1182 op->MutableAttrs()->Set(FunctionLibraryDefinition::kIntsOnDeviceAttr,
1183 true);
1184 }
1185 }
1186 const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
1187
1188 FunctionLibraryRuntime* flr =
1189 device == nullptr ? nullptr : ctx.func_lib(device);
1190 if (device != nullptr && flr == nullptr) {
1191 return errors::NotFound(
1192 "Unable to find a FunctionLibraryRuntime corresponding to device ",
1193 device->name());
1194 }
1195 auto runner = (flr != nullptr && flr->runner() != nullptr) ? flr->runner()
1196 : ctx.runner();
1197 GraphCollector* graph_collector = nullptr;
1198 if (ctx.ShouldStoreGraphs()) {
1199 graph_collector = ctx.GetGraphCollector();
1200 }
1201 // Treat the function as multi_device only when we are not compiling
1202 // it wholly with XLA. When compiling wholly with XLA, flr->CreateKernel
1203 // will create an XlaLaunchOp kernel to compile and run the function.
1204 if (run_function_with_flr) {
1205 // Multi-device functions don't use the rendezvous from eager context.
1206 // If we use that rendezvous, multiple concurrent calls to the same
1207 // function will likely result in collisions. However, this also means
1208 // that we don't support legitimate sending/receiving across function
1209 // boundary.
1210 VLOG(2) << "Running " << ndef.op() << " using multi-device function. "
1211 << "Full node_def=" << ndef.DebugString();
1212 std::function<int64_t()> get_op_id = nullptr;
1213 #if !defined(IS_MOBILE_PLATFORM)
1214 get_op_id = [&ctx]() { return ctx.RemoteMgr()->NextOpId(); };
1215 #endif // IS_MOBILE_PLATFORM
1216
1217 ctx.reuse_rendezvous_for_functions_mu()->lock();
1218 ctx.SetReuseRendezvousForFunctions(reuse_rendezvous_for_functions);
1219 auto rendezvous_creator = ctx.RendezvousCreator();
1220 ctx.SetReuseRendezvousForFunctions(
1221 reuse_rendezvous_for_functions_original_value);
1222 ctx.reuse_rendezvous_for_functions_mu()->unlock();
1223 kernel.reset(new KernelAndDeviceFunc(
1224 flr, ctx.pflr(), std::move(input_device_ptrs),
1225 std::move(composite_devices),
1226 std::move(input_resource_variable_dtypes_and_shapes), runner,
1227 ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
1228 function_outputs_on_op_device, allow_small_function_optimizations,
1229 allow_control_flow_sync_execution,
1230 shape_inference_on_tfe_dialect_import, int_args_and_retvals_on_device,
1231 xla_compile_device_type, std::move(rendezvous_creator), get_op_id));
1232 } else {
1233 VLOG(2) << "Running " << ndef.op() << " using op kernel. "
1234 << ". Full node_def=" << ndef.DebugString();
1235 kernel.reset(new KernelAndDeviceOp(
1236 ctx.GetRendezvous(), ctx.LogMemory(), flr, runner,
1237 ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
1238 }
1239
1240 TF_RETURN_IF_ERROR(
1241 kernel->Init(ctx.LogDevicePlacement(), ndef, graph_collector));
1242
1243 if (op->is_function()) {
1244 ctx.AddKernelToCache(cache_key, kernel.get());
1245 } else {
1246 // Exclude tf.data op kernels from being cached. The reason for this is
1247 // that tf.data op kernels that accept a user-defined function will have a
1248 // unique cache key every time they are executed (because the user-defined
1249 // function is traced every time). Caching such kernels provides no
1250 // benefit and in some cases results in linear memory growth of use
1251 // programs that build input pipeline graphs in a loop.
1252 const OpDef* op_def;
1253 TF_RETURN_IF_ERROR(OpDefForOp(op->Name().data(), &op_def));
1254 if (KernelCacheEnabled(*op_def)) {
1255 ctx.AddKernelToCache(cache_key, kernel.get());
1256 }
1257 }
1258 }
1259
1260 int num_outputs = kernel->num_outputs();
1261 if (num_outputs > *num_retvals) {
1262 return errors::InvalidArgument("Expecting ", num_outputs,
1263 " outputs, but *num_retvals is ",
1264 *num_retvals);
1265 }
1266 *num_retvals = num_outputs;
1267
1268 kernel->Ref(); // Ownership of reference is passed to out_kernel.
1269 out_kernel->reset(kernel.get());
1270 return OkStatus();
1271 }
1272
CreateUnshapedOutput(const KernelAndDevice & kernel,const int output_num,Device * output_device,const DataType & output_dtype,const absl::optional<EagerFunctionParams> & eager_func_params,EagerContext * ctx,TensorHandle ** output)1273 Status CreateUnshapedOutput(
1274 const KernelAndDevice& kernel, const int output_num, Device* output_device,
1275 const DataType& output_dtype,
1276 const absl::optional<EagerFunctionParams>& eager_func_params,
1277 EagerContext* ctx, TensorHandle** output) {
1278 #if defined(IS_MOBILE_PLATFORM)
1279 return errors::Unimplemented(
1280 "Remote outputs are not available on mobile devices.");
1281 #else // !IS_MOBILE_PLATFORM
1282 int64_t op_id;
1283 if (eager_func_params.has_value()) {
1284 op_id = eager_func_params.value().op_id;
1285 } else {
1286 return errors::InvalidArgument(
1287 "Unable to find a remote op id for a remote output of ", kernel.name());
1288 }
1289 string remote_task;
1290 if (!DeviceNameUtils::GetTaskName(output_device->parsed_name(),
1291 &remote_task)) {
1292 return errors::InvalidArgument(
1293 "Unable to find remote task corresponding to device ",
1294 output_device->name());
1295 }
1296 if (ctx->RemoteMgr()->IsMaster()) {
1297 *output = TensorHandle::CreateUnshapedRemoteHandle(
1298 op_id, output_num, remote_task, output_dtype, output_device, ctx);
1299 } else {
1300 *output = TensorHandle::CreateLazyRemoteHandle(op_id, output_num,
1301 output_dtype, output_device,
1302 /*is_ready=*/false, ctx);
1303 }
1304 return OkStatus();
1305 #endif // !IS_MOBILE_PLATFORM
1306 }
1307
AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,EagerOperation * op,TensorHandle ** retvals)1308 Status AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,
1309 EagerOperation* op, TensorHandle** retvals) {
1310 EagerExecutor& executor = op->Executor();
1311 EagerContext& ctx = op->EagerContext();
1312 GraphCollector* graph_collector = nullptr;
1313 if (ctx.ShouldStoreGraphs()) {
1314 graph_collector = ctx.GetGraphCollector();
1315 }
1316 const int num_outputs = kernel->num_outputs();
1317 absl::optional<EagerFunctionParams> eager_func_params =
1318 op->eager_func_params();
1319 if (kernel->IsCrossProcess() && !eager_func_params.has_value()) {
1320 // Create an eager op id for a cross-process function if not exist.
1321 #if defined(IS_MOBILE_PLATFORM)
1322 return errors::Unimplemented(
1323 "Cross-process functions are not supported on mobile devices.");
1324 #else // !IS_MOBILE_PLATFORM
1325 const int64_t op_id = ctx.RemoteMgr()->NextOpId();
1326 eager_func_params = EagerFunctionParams{
1327 op_id, /* is_component_function= */ false, /* step_id= */ std::nullopt};
1328 #endif // !IS_MOBILE_PLATFORM
1329 }
1330 if (executor.Async()) {
1331 const DataTypeVector& output_dtypes = kernel->output_dtypes();
1332 for (int i = 0, end = num_outputs; i < end; ++i) {
1333 Device* output_device = ctx.CanonicalDevice(kernel->OutputDevice(i));
1334 if (output_device == nullptr || output_device->IsLocal()) {
1335 retvals[i] = TensorHandle::CreateEmptyLocalHandle(
1336 /* d= */ output_device, /* op_device= */ kernel->device(),
1337 /* resource_device= */ kernel->OutputResourceDevice(i),
1338 output_dtypes[i], &ctx);
1339 } else {
1340 TF_RETURN_IF_ERROR(
1341 CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
1342 eager_func_params, &ctx, &retvals[i]));
1343 }
1344 }
1345 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1346 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1347 auto node = std::make_unique<AsyncExecuteNode>(
1348 &ctx, *inputs, eager_func_params, std::move(kernel), graph_collector,
1349 op->GetCancellationManager(),
1350 absl::Span<TensorHandle*>(retvals, num_outputs), op->GetStackTrace());
1351 // Release the inputs from the eager operation since the AsyncExecuteNode
1352 // would have taken ownership. This allows the inputs to be forwarded if
1353 // possible.
1354 op->Clear();
1355 // For async mode, execution order will make sure that all
1356 // input handles are ready before executing them.
1357 // TODO(b/137118203): Consider executing "cheap" kernels inline for
1358 // performance.
1359 return executor.AddOrExecute(std::move(node));
1360 } else {
1361 for (int i = 0, end = num_outputs; i < end; ++i) {
1362 retvals[i] = nullptr;
1363 }
1364 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1365 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1366 ExecuteNode node(&ctx, *inputs, eager_func_params, kernel, graph_collector,
1367 op->GetCancellationManager(),
1368 {retvals, static_cast<size_t>(num_outputs)},
1369 op->GetStackTrace());
1370 Status s = executor.SyncExecute(&node);
1371 // We release the inputs AFTER executing the operation in sync mode since
1372 // ExecuteNode does not increment the reference count and thus does not have
1373 // ownership of the inputs while executing.
1374 op->Clear();
1375 return s;
1376 }
1377 }
1378
1379 // There are a lot of references to devices in this function and around.
1380 // Here is what they mean:
1381 // EagerOperation::Device(): The device on which the user requested the op
1382 // be executed, except if we had to change the device due to resource inputs
1383 // or CPU pinning. If the user did not request a device, the op does not
1384 // take resources, and we did not pin it to CPU, the device can be nullptr.
1385 // KernelAndDevice::Device(): The first time we see an op (combined with
1386 // its attributes), we need to create a KernelAndDevice object for it.
1387 // If op->Device() is a nullptr, we select a device for the op when
1388 // creating the KernelAndDevice. A concrete device will always be selected
1389 // here except when `op` is a function to be executed using function library
1390 // runtime. In this case, we don't select a device because running
1391 // a function with explicitly requested device has different behavior than
1392 // running without an explicitly requested device.
EagerLocalExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1393 Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
1394 int* num_retvals) {
1395 profiler::ScopedMemoryDebugAnnotation op_annotation(
1396 op->op_name(), op->eager_func_params().has_value()
1397 ? op->eager_func_params().value().step_id.value_or(0)
1398 : 0);
1399 profiler::TraceMe activity(
1400 [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
1401 profiler::TraceMeLevel::kInfo);
1402 EagerContext& ctx = op->EagerContext();
1403 auto& executor = op->Executor();
1404 TF_RETURN_IF_ERROR(executor.status());
1405
1406 core::RefCountPtr<KernelAndDevice> kernel;
1407 auto status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1408
1409 #ifdef INTEL_MKL
1410 if (IsMKLEnabled() && kernel != nullptr &&
1411 op->Device() == kVariantDeviceNull) {
1412 // oneDNN optimization pass relies on the op's assigned device to determine
1413 // whether it can be rewritten.
1414 op->SetDevice(kernel->device());
1415 }
1416 #endif // INTEL_MKL
1417
1418 // Run all the registered rewrite pass after the placement, regardless whether
1419 // the placement is successful or not. The passes can either create new ops
1420 // (without placement) or update some fields of the input op.
1421 std::unique_ptr<tensorflow::EagerOperation> out_op;
1422 TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
1423 EagerOpRewriteRegistry::POST_PLACEMENT, op, &out_op));
1424 if (out_op) {
1425 op = out_op.get();
1426 // If the out op doesn't have device, either because it is a new op or
1427 // the op wasn't placed successfully, then we do the placement again.
1428 if (op->Device() == kVariantDeviceNull) {
1429 status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
1430 }
1431 }
1432 if (!status.ok()) return status;
1433
1434 int num_outputs = kernel->num_outputs();
1435 TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel));
1436
1437 if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
1438 string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
1439 kernel->device()->name());
1440 if (!logging::LogToListeners(msg)) {
1441 LOG(INFO) << msg;
1442 }
1443 }
1444
1445 Status s = AddOrExecuteNode(std::move(kernel), op, retvals);
1446 // Since the operation failed, we need to Unref any outputs if they were
1447 // allocated.
1448 if (!s.ok()) {
1449 for (int i = 0, end = num_outputs; i < end; ++i) {
1450 if (retvals[i] != nullptr) {
1451 retvals[i]->Unref();
1452 retvals[i] = nullptr;
1453 }
1454 }
1455 }
1456
1457 return s;
1458 }
1459
1460 // Run a Pack op to pack the tensors pointed by a packed input TensorHandle if
1461 // the op is a primitive op.
MaybePackInputTensor(EagerOperation * op)1462 Status MaybePackInputTensor(EagerOperation* op) {
1463 if (op->is_function() || op->EagerContext().RunEagerOpAsFunction()) {
1464 // Functions could take packed TensorHandles as inputs.
1465 return OkStatus();
1466 }
1467 EagerContext& ctx = op->EagerContext();
1468 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1469 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1470 for (int i = 0; i < inputs->size(); ++i) {
1471 TensorHandle* handle = (*inputs)[i];
1472 if (handle->Type() == TensorHandle::PACKED) {
1473 EagerOperation pack_op(&ctx);
1474 TF_RETURN_IF_ERROR(pack_op.Reset("Pack", /*device_name=*/nullptr,
1475 /*remote=*/false, /*executor=*/nullptr));
1476 pack_op.MutableAttrs()->Set("N", handle->NumPackedHandles());
1477 pack_op.MutableAttrs()->Set("T", handle->dtype);
1478 for (int i = 0; i < handle->NumPackedHandles(); ++i) {
1479 tensorflow::TensorHandle* h = nullptr;
1480 TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(i, &h));
1481 TF_RETURN_IF_ERROR(pack_op.AddInput(h));
1482 }
1483 int num_retvals = 1;
1484 absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
1485 TF_RETURN_IF_ERROR(
1486 EagerLocalExecute(&pack_op, retvals.data(), &num_retvals));
1487 tensorflow::TensorHandle* ret = retvals.at(0);
1488 op->UpdateInput(i, ret);
1489 ret->Unref();
1490 }
1491 }
1492 return OkStatus();
1493 }
1494
1495 #if !defined(IS_MOBILE_PLATFORM)
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)1496 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
1497 EagerContext& ctx = op->EagerContext();
1498
1499 remote_op->set_id(ctx.RemoteMgr()->NextOpId());
1500 remote_op->set_name(op->Name());
1501
1502 op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs());
1503 remote_op->set_device(absl::get<Device*>(op->Device())->name());
1504 remote_op->set_is_function(op->is_function());
1505 }
1506
StoreResourceDtypesAndShapes(const eager::Operation & remote_op,const DataTypeVector & output_dtypes,TensorHandle ** retvals)1507 Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
1508 const DataTypeVector& output_dtypes,
1509 TensorHandle** retvals) {
1510 if (remote_op.name() == "VarHandleOp") {
1511 if (output_dtypes.size() != 1) {
1512 return errors::Internal("VarHandleOp should only have one output.");
1513 }
1514 if (output_dtypes[0] != DT_RESOURCE) {
1515 return errors::Internal(
1516 "The output of VarHandleOp should be a DT_RESOURCE.");
1517 }
1518 AttrSlice attr_slice = AttrSlice(&remote_op.attrs());
1519 const AttrValue* dtype;
1520 TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
1521 const AttrValue* shape;
1522 TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
1523 retvals[0]->SetResourceHandleDtypeAndShape(
1524 {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
1525 }
1526 return OkStatus();
1527 }
1528
EagerRemoteExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1529 Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
1530 int* num_retvals) {
1531 EagerContext& ctx = op->EagerContext();
1532
1533 // TODO(fishx): Remove following code when lazy tensor copy is ready.
1534 if (op->Device() == kVariantDeviceNull) {
1535 tensorflow::Device* device = nullptr;
1536 string device_name = op->DeviceName();
1537 TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device));
1538 op->SetDevice(device);
1539 }
1540
1541 core::RefCountPtr<eager::EagerClient> eager_client;
1542 uint64 context_id = ctx.GetContextId();
1543 TF_RETURN_IF_ERROR(ctx.GetClient(op->GetDeviceParsedName(), &eager_client));
1544 string remote_task;
1545 if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) {
1546 return errors::InvalidArgument(
1547 "Unable to find remote task corresponding to device ",
1548 op->DeviceName());
1549 }
1550
1551 std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
1552 request->set_context_id(context_id);
1553
1554 eager::Operation* remote_op = request->add_queue()->mutable_operation();
1555
1556 tensorflow::Device* op_device = absl::get<Device*>(op->Device());
1557 {
1558 profiler::TraceMe activity("CopyInputToExpectedDevice",
1559 profiler::TraceMeLevel::kInfo);
1560 const bool is_function = op->is_function();
1561 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1562 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1563 for (int i = 0, end = inputs->size(); i < end; i++) {
1564 tensorflow::TensorHandle* input = (*inputs)[i];
1565 tensorflow::Device* input_device = input->device();
1566 tensorflow::Device* input_device_or_cpu = input->DeviceOrHostCPU(ctx);
1567 const string* input_device_name = &input_device_or_cpu->name();
1568 bool serialize_resource_dtype_and_shape = false;
1569 if (op_device != input_device &&
1570 // If the expected and actual devices are on the same task, don't
1571 // explicitly copy, and instead depend on the copy to happen locally
1572 // when the op is executed on the device.
1573 !ctx.OnSameTask(op_device, input_device)) {
1574 if (!is_function || input_device_or_cpu->IsLocal()) {
1575 tensorflow::Device* remote_cpu_device;
1576 TF_RETURN_IF_ERROR(
1577 ctx.CPUDeviceOnTask(op_device, &remote_cpu_device));
1578 // Always copy to the remote CPU so that the actual device can be
1579 // correctly determined after the kernel is selected/instantiated,
1580 // since the op might have its inputs on host memory.
1581 TensorHandle* handle = input;
1582 Device* handle_device = handle->DeviceOrHostCPU(ctx);
1583 // If the input is already on the right device, then nothing to do.
1584 if (remote_cpu_device != handle_device) {
1585 VLOG(6) << "remote_cpu_device != handle_device";
1586 TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
1587 &ctx, op, op_device, handle, i, handle_device,
1588 remote_cpu_device, &handle));
1589 op->UpdateInput(i, handle);
1590 input = handle;
1591 input_device = remote_cpu_device;
1592 input_device_name = &remote_cpu_device->name();
1593 // Unref handle since it has a ref as an input now
1594 handle->Unref();
1595 }
1596 } else {
1597 serialize_resource_dtype_and_shape =
1598 (input->dtype == DT_RESOURCE) &&
1599 (!input->HasResourceShapeMirror(op_device,
1600 ctx.GetContextViewId()));
1601 }
1602 }
1603 auto* input_handle = remote_op->add_op_inputs()->mutable_remote_handle();
1604 // For a remote component function, a function execution request and an
1605 // input generation request may come from different workers. We need to
1606 // guarantee that the input generation request is processed before the
1607 // function execution request, so wait until the remote input is ready
1608 // before sending it to the multi-device function device.
1609 const bool wait_until_ready = op->is_function();
1610 TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
1611 input, wait_until_ready, input_handle, input_device,
1612 *input_device_name, serialize_resource_dtype_and_shape));
1613 if (!input_handle->resource_dtypes_and_shapes().empty()) {
1614 TF_RETURN_IF_ERROR(
1615 input->AddResourceShapeMirror(op_device, input_handle->op_id(),
1616 input_handle->output_num(), &ctx));
1617 }
1618 }
1619 }
1620
1621 PrepareRemoteOp(remote_op, op);
1622
1623 DataTypeVector output_dtypes;
1624 TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
1625
1626 const size_t num_outputs = output_dtypes.size();
1627 if (num_outputs != *num_retvals) {
1628 return errors::InvalidArgument(
1629 "num_retvals does not match expected output dtypes");
1630 }
1631 *num_retvals = num_outputs;
1632
1633 const tensorflow::uint64 id = remote_op->id();
1634 for (size_t i = 0; i < num_outputs; ++i) {
1635 // TODO(nareshmodi): Change the callback to instead add the decref to a
1636 // list of pending decrefs that we can send as a batch with the next
1637 // execute.
1638
1639 // The device_ and resource_device_ of this TensorHandle might be
1640 // incorrect. For multi-device functions, we don't know the output device
1641 // until the function is instantiated on a remote worker. Luckily, we don't
1642 // need to know the correct remote device here. We just need to know that it
1643 // is remote. If we need copy this tensor to this process or run any ops
1644 // which take this tensor as an input, block until the correct device is
1645 // set.
1646 const bool unknown_device = op->is_function();
1647 retvals[i] = TensorHandle::CreateUnshapedRemoteHandle(
1648 id, i, remote_task, output_dtypes[i], op_device, &ctx, unknown_device);
1649 }
1650
1651 // Store the data type and shape of a remote resource variable on the
1652 // corresponding remote TensorHandle (output of 'VarHandleOp').
1653 // If the variable is an input of a remote function, the function may need
1654 // the type and shape during function instantiation. Store the type and
1655 // shape on eager master and sent them to the default function device along
1656 // with the EnqueueRequest.
1657 TF_RETURN_IF_ERROR(
1658 StoreResourceDtypesAndShapes(*remote_op, output_dtypes, retvals));
1659
1660 auto& executor = op->Executor();
1661 VLOG(4) << "Execute remote eager op: " << op->Name()
1662 << " (is async?: " << executor.Async() << ").";
1663
1664 const absl::InlinedVector<TensorHandle*, 4>* inputs;
1665 TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
1666
1667 std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
1668 &op->EagerContext(), std::move(request), op_device,
1669 ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(),
1670 op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
1671 *inputs, {retvals, num_outputs}));
1672
1673 if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
1674 string msg = strings::StrCat(
1675 "Executing op ", op->Name(), " on task ",
1676 DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName()));
1677 if (!logging::LogToListeners(msg)) {
1678 LOG(INFO) << msg;
1679 }
1680 }
1681
1682 Status s = executor.AddOrExecute(std::move(node));
1683 // Since the operation failed, we need to Unref any outputs that were
1684 // allocated.
1685 if (!s.ok()) {
1686 for (size_t i = 0; i < num_outputs; ++i) {
1687 retvals[i]->Unref();
1688 // Ensure that any smart pointers created to wrap results become noops
1689 // rather than operating on invalid memory.
1690 retvals[i] = nullptr;
1691 }
1692 }
1693
1694 return s;
1695 }
1696 #endif // IS_MOBILE_PLATFORM
1697
GetKernelOutputs(std::vector<EagerKernelRet> * outputs,int num_outputs,TensorHandle ** retvals,EagerContext * ctx,KernelAndDevice * kernel,const absl::optional<EagerFunctionParams> & eager_func_params)1698 Status GetKernelOutputs(
1699 std::vector<EagerKernelRet>* outputs, int num_outputs,
1700 TensorHandle** retvals, EagerContext* ctx, KernelAndDevice* kernel,
1701 const absl::optional<EagerFunctionParams>& eager_func_params) {
1702 for (int i = 0, end = num_outputs; i < end; ++i) {
1703 if (retvals[i] == nullptr) {
1704 EagerKernelRet& ret = (*outputs)[i];
1705 Device* output_device = ctx->CanonicalDevice(kernel->OutputDevice(i));
1706 if (ret.index() == 0) {
1707 retvals[i] = TensorHandle::CreateLocalHandle(
1708 std::move(absl::get<Tensor>(ret)),
1709 /* d= */ output_device,
1710 /* op_device= */ kernel->device(),
1711 /* resource_device= */ kernel->OutputResourceDevice(i), ctx);
1712 } else {
1713 const DataTypeVector& output_dtypes = kernel->output_dtypes();
1714 TF_RETURN_IF_ERROR(
1715 CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i],
1716 eager_func_params, ctx, &retvals[i]));
1717 #if !defined(IS_MOBILE_PLATFORM)
1718 TF_RETURN_IF_ERROR(
1719 retvals[i]->SetRemoteShape(absl::get<TensorShape>(ret),
1720 output_device, ctx->GetContextViewId()));
1721 #endif // IS_MOBILE_PLATFORM
1722 }
1723 } else {
1724 if (!kernel->IsFunction() &&
1725 TF_PREDICT_FALSE(kernel->device() != retvals[i]->op_device())) {
1726 return errors::Internal(
1727 "Kernel output tensor handle has a different op device than the "
1728 "kernel. This should never happen.");
1729 }
1730 if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) !=
1731 retvals[i]->device())) {
1732 return errors::Internal(
1733 "Kernel output tensor handle locates on a different device than "
1734 "the specified kernel output device. This should never happen.");
1735 }
1736
1737 EagerKernelRet& ret = (*outputs)[i];
1738 if (ret.index() == 0) {
1739 TF_RETURN_IF_ERROR(retvals[i]->SetTensor(
1740 std::move(absl::get<Tensor>(ret)),
1741 ctx->CanonicalDevice(kernel->OutputDevice(i))));
1742 } else {
1743 #if defined(IS_MOBILE_PLATFORM)
1744 return errors::Unimplemented(
1745 "Remote outputs are not available on mobile devices.");
1746 #else // !IS_MOBILE_PLATFORM
1747 TF_RETURN_IF_ERROR(retvals[i]->SetRemoteShape(
1748 absl::get<TensorShape>(ret), retvals[i]->device(),
1749 ctx->GetContextViewId()));
1750 #endif // !IS_MOBILE_PLATFORM
1751 }
1752 }
1753 }
1754 return OkStatus();
1755 }
1756
CollectGraphs(EagerContext * ctx)1757 void CollectGraphs(EagerContext* ctx) {
1758 mutex_lock ml(*ctx->MetadataMu());
1759
1760 GraphCollector* collector = ctx->GetGraphCollector();
1761 mutex_lock mll(collector->mu);
1762
1763 // Adding to partition graphs for backward compatibility.
1764 for (const auto& graph : collector->partitioned_graphs) {
1765 *ctx->RunMetadataProto()->add_partition_graphs() = graph;
1766 }
1767
1768 if (collector->dirty) {
1769 auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
1770 *function_graphs->mutable_post_optimization_graph() =
1771 collector->optimized_graph;
1772 *function_graphs->mutable_pre_optimization_graph() = collector->raw_graph;
1773 for (const auto& graph : collector->partitioned_graphs) {
1774 *function_graphs->add_partition_graphs() = graph;
1775 }
1776 }
1777
1778 collector->ClearGraphs();
1779 }
1780 } // namespace
1781
EagerExecute(EagerOperation * op,TensorHandle ** retvals,int * num_retvals)1782 Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
1783 int* num_retvals) {
1784 profiler::TraceMe activity([&] {
1785 return ::tensorflow::profiler::TraceMeEncode(
1786 "EagerExecute",
1787 {{"eager_op", op->Name()}, {"is_func", op->is_function()}});
1788 });
1789
1790 if (!op->Executor().Async()) {
1791 VLOG(6) << "op: " << op->Name() << " is not Async.";
1792 if (!op->EagerContext()
1793 .GetGlobalRendezvousForFunctionLocalRendezvousStatus()
1794 .ok()) {
1795 VLOG(6) << "global_rendezvous_for_functions_ is in bad state. Resetting.";
1796 op->EagerContext().ResetGlobalRendezvousForFunction();
1797 }
1798 // In sync mode, always clear error to maintain the same behavior as before.
1799 // TODO(b/141004939): Remove this.
1800 op->Executor().ClearError();
1801 }
1802
1803 std::unique_ptr<tensorflow::EagerOperation> out_op;
1804 TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
1805 EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
1806
1807 if (op->IsLocal()) {
1808 if (out_op) {
1809 op = out_op.get();
1810 }
1811 TF_RETURN_IF_ERROR(MaybePackInputTensor(op));
1812 return EagerLocalExecute(op, retvals, num_retvals);
1813 }
1814
1815 #if defined(IS_MOBILE_PLATFORM)
1816 return errors::Unimplemented(
1817 "Eager's remote execution is not available on mobile devices.");
1818 #else // !IS_MOBILE_PLATFORM
1819 if (out_op) {
1820 op = out_op.get();
1821 }
1822 return EagerRemoteExecute(op, retvals, num_retvals);
1823 #endif // !IS_MOBILE_PLATFORM
1824 }
1825
1826 // TODO(gjn): Consider moving into ExecuteNode class
EagerKernelExecute(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerFunctionParams> & eager_func_params,const core::RefCountPtr<KernelAndDevice> & kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals,const absl::optional<ManagedStackTrace> & stack_trace)1827 Status EagerKernelExecute(
1828 EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
1829 const absl::optional<EagerFunctionParams>& eager_func_params,
1830 const core::RefCountPtr<KernelAndDevice>& kernel,
1831 GraphCollector* graph_collector, CancellationManager* cancellation_manager,
1832 absl::Span<TensorHandle*> retvals,
1833 const absl::optional<ManagedStackTrace>& stack_trace) {
1834 profiler::TraceMe activity("EagerKernelExecute",
1835 profiler::TraceMeLevel::kInfo);
1836 std::vector<EagerKernelRet> outputs(1);
1837
1838 ExecuteNodeArgs inputs(op_inputs.size());
1839 TF_RETURN_IF_ERROR(inputs.Init(ctx, op_inputs, kernel));
1840 // TODO(apassos) figure out how to record stats for ops which are a part of
1841 // functions.
1842 // TODO(b/111859745): When we support recovering from kernel/device errors, we
1843 // would need to call XlaDevice::EnsureDeviceContextOk() before using an XLA
1844 // device. We don't call it now because it is an unneeded overhead (it
1845 // acquires a lock) and we can't recover from errors anyway.
1846 ScopedStepContainer* container = ctx->StepContainer();
1847 CoordinationServiceAgent* coord_agent = nullptr;
1848 #if !defined(IS_MOBILE_PLATFORM)
1849 if (ctx->GetDistributedManager() != nullptr)
1850 coord_agent = ctx->GetDistributedManager()->GetCoordinationServiceAgent();
1851 #endif // !IS_MOBILE_PLATFORM
1852 TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs,
1853 cancellation_manager, eager_func_params,
1854 stack_trace, coord_agent));
1855 if (graph_collector != nullptr) {
1856 CollectGraphs(ctx);
1857 }
1858
1859 if (TF_PREDICT_FALSE(retvals.size() != outputs.size())) {
1860 return errors::Internal(
1861 "EagerKernelExecute returns a list of ", outputs.size(),
1862 " tensors but ", retvals.size(),
1863 " is expected. This should never "
1864 "happen. Please file a bug with the TensorFlow team.");
1865 }
1866 return GetKernelOutputs(&outputs, retvals.size(), retvals.data(), ctx,
1867 kernel.get(), eager_func_params);
1868 }
1869
1870 namespace {
1871
LocalEagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * dstd,bool mirror,TensorHandle ** result)1872 Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1873 EagerExecutor* executor, Device* dstd,
1874 bool mirror, TensorHandle** result) {
1875 TF_RETURN_IF_ERROR(executor->status());
1876 Device* d = ctx->CanonicalDevice(dstd);
1877 if (mirror && h->HasLocalMirror(d)) {
1878 h->Ref();
1879 *result = h;
1880 return OkStatus();
1881 }
1882
1883 bool async = executor->Async();
1884 if (mirror) {
1885 h->Ref();
1886 *result = h;
1887
1888 if (h->HasLocalMirror(d)) {
1889 return OkStatus();
1890 }
1891
1892 // We don't bother adding an empty local mirror in sync mode since we'll be
1893 // executing the operation directly and be calling AddLocalMirror. A
1894 // reference count is still needed which will be removed if the operation
1895 // fails.
1896 if (async) {
1897 Status s = h->AddEmptyLocalMirror(d);
1898 if (!s.ok()) {
1899 // If a mirror was added since we called HasLocalMirror then just return
1900 // since another thread has already added the mirror.
1901 if (s.code() == error::Code::ALREADY_EXISTS) {
1902 return OkStatus();
1903 }
1904
1905 // Remove the previously added reference count since adding the mirror
1906 // failed.
1907 h->Unref();
1908 *result = nullptr;
1909 return s;
1910 }
1911 }
1912 } else {
1913 *result = TensorHandle::CreateEmptyLocalHandle(
1914 d, dstd, h->resource_device(), h->dtype, ctx);
1915 }
1916
1917 Status s;
1918 if (async) {
1919 // Note that `h` may not be currently ready. However execution order will
1920 // make sure that `h` is ready before the copy is actually done.
1921 std::unique_ptr<EagerNode> node(
1922 new CopyToDeviceNode(h, *result, d, *ctx, async, mirror));
1923 s = executor->AddOrExecute(std::move(node));
1924 } else {
1925 CopyToDeviceNode node(h, *result, d, *ctx, async, mirror);
1926 s = executor->SyncExecute(&node);
1927 }
1928
1929 // Since the operation failed, we need to Unref any outputs that were
1930 // allocated.
1931 if (!s.ok()) {
1932 (*result)->Unref();
1933 *result = nullptr;
1934 }
1935
1936 return s;
1937 }
1938
1939 } // namespace
1940
EagerCopyToDevice(TensorHandle * h,EagerContext * ctx,EagerExecutor * executor,Device * device,bool mirror,TensorHandle ** result)1941 Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
1942 EagerExecutor* executor, Device* device, bool mirror,
1943 TensorHandle** result) {
1944 TF_RETURN_IF_ERROR(h->WaitUnknownDevice());
1945 auto send_device = h->DeviceOrHostCPU(*ctx);
1946 bool sender_is_local = send_device->IsLocal();
1947
1948 bool receiver_is_local = device->IsLocal();
1949
1950 if (!executor->Async()) {
1951 // In sync mode, always clear error to maintain the same behavior as before.
1952 // TODO(b/141004939): Remove this.
1953 executor->ClearError();
1954 }
1955
1956 if (sender_is_local && receiver_is_local) {
1957 return LocalEagerCopyToDevice(h, ctx, executor, device, mirror, result);
1958 } else {
1959 #if defined(IS_MOBILE_PLATFORM)
1960 return errors::Unimplemented(
1961 "Eager's remote execution is not available on mobile devices.");
1962 #else // !IS_MOBILE_PLATFORM
1963 uint64 recv_op_id = 0;
1964 if (receiver_is_local) {
1965 Device* d = ctx->CanonicalDevice(device);
1966 // TODO(gjn): Need to add support for async execution. Note if receiver
1967 // is local, we need to first add support in TensorHandle to wait on local
1968 // mirrors.
1969 if (mirror) {
1970 h->Ref();
1971 *result = h;
1972
1973 if (h->HasLocalMirror(d)) {
1974 return OkStatus();
1975 }
1976
1977 Status s = h->AddEmptyLocalMirror(d);
1978 if (!s.ok()) {
1979 // If a mirror was added since we called HasLocalMirror then just
1980 // return since another thread has already added the mirror.
1981 if (s.code() == error::Code::ALREADY_EXISTS) {
1982 return OkStatus();
1983 }
1984
1985 // Remove the previously added reference count since adding the mirror
1986 // failed.
1987 h->Unref();
1988 *result = nullptr;
1989 return s;
1990 }
1991 } else {
1992 *result = TensorHandle::CreateEmptyLocalHandle(
1993 /* d= */ d, /* op_device= */ device,
1994 /*resource_device=*/nullptr, h->dtype, ctx);
1995 }
1996 } else {
1997 if (mirror) {
1998 if (h->HasRemoteMirror(device, ctx->GetContextViewId())) {
1999 h->Ref();
2000 *result = h;
2001 return OkStatus();
2002 }
2003 }
2004 string remote_task;
2005 if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
2006 return errors::InvalidArgument(
2007 "Unable to find remote task corresponding to device ",
2008 device->name());
2009 }
2010 recv_op_id = ctx->RemoteMgr()->NextOpId();
2011 if (mirror) {
2012 TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0,
2013 remote_task, ctx));
2014 h->Ref();
2015 *result = h;
2016 } else {
2017 *result = TensorHandle::CreateUnshapedRemoteHandle(
2018 recv_op_id, 0, remote_task, h->dtype, device, ctx);
2019 }
2020 }
2021
2022 auto node = std::make_unique<eager::RemoteCopyNode>(
2023 ctx, executor, h, result[0], device, recv_op_id);
2024 Status s = executor->AddOrExecute(std::move(node));
2025 if (!s.ok()) {
2026 result[0]->Unref();
2027 result[0] = nullptr;
2028 }
2029 return s;
2030 #endif // !IS_MOBILE_PLATFORM
2031 }
2032 }
2033
2034 namespace {
2035 // Low-level utility function to execute the kernel specified by `kernel` on
2036 // `kernel->device()`, with the provided inputs as `op_inputs` in the 'ctx'.
2037 // Different from `EagerKernelExecute` that ties up the thread until the
2038 // underlying function finishes execute, this function does not block the thread
2039 // and could return before the function execution finishes. The provided
2040 // `StatusCallback` will be triggered after function execution with its status.
EagerKernelExecuteAsync(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & op_inputs,const absl::optional<EagerFunctionParams> & eager_func_params,const core::RefCountPtr<KernelAndDevice> kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,TensorHandle ** retvals,int num_outputs,StatusCallback done)2041 void EagerKernelExecuteAsync(
2042 EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
2043 const absl::optional<EagerFunctionParams>& eager_func_params,
2044 const core::RefCountPtr<KernelAndDevice> kernel,
2045 GraphCollector* graph_collector, CancellationManager* cancellation_manager,
2046 TensorHandle** retvals, int num_outputs, StatusCallback done) {
2047 auto inputs = std::make_shared<ExecuteNodeArgs>(op_inputs.size());
2048 auto outputs = std::make_shared<std::vector<EagerKernelRet>>(1);
2049
2050 Status s = inputs->Init(ctx, op_inputs, kernel);
2051 if (!s.ok()) {
2052 done(s);
2053 return;
2054 }
2055 CoordinationServiceAgent* coord_agent = nullptr;
2056 #if !defined(IS_MOBILE_PLATFORM)
2057 if (ctx->GetDistributedManager() != nullptr)
2058 coord_agent = ctx->GetDistributedManager()->GetCoordinationServiceAgent();
2059 #endif // !IS_MOBILE_PLATFORM
2060
2061 kernel->Ref(); // Ownership of reference is transferred to the callback
2062 kernel->RunAsync(
2063 ctx->StepContainer(), *inputs, outputs.get(), cancellation_manager,
2064 eager_func_params, coord_agent,
2065 [retvals, inputs, outputs, num_outputs, ctx, graph_collector,
2066 eager_func_params, kernel_raw = kernel.get(),
2067 done = std::move(done)](const Status& s) {
2068 auto wrapped_done = [&](const Status& s) {
2069 kernel_raw->Unref();
2070 done(s);
2071 };
2072 if (!s.ok()) {
2073 wrapped_done(s);
2074 return;
2075 }
2076 if (graph_collector != nullptr) {
2077 CollectGraphs(ctx);
2078 }
2079 DCHECK_EQ(num_outputs, outputs->size());
2080 wrapped_done(GetKernelOutputs(outputs.get(), num_outputs, retvals, ctx,
2081 kernel_raw, eager_func_params));
2082 });
2083 }
2084 } // namespace
2085
2086 // Low-level utility to run the eager operation on local devices. Different from
2087 // `EagerLocalExecute` which blocks and waits for the finishing the op
2088 // execution, this method does not block the thread and could return before the
2089 // eager operation execution finishes. The provided `StatusCallback` will be
2090 // triggered after execution with its status.
EagerLocalExecuteAsync(EagerOperation * op,TensorHandle ** retvals,int * num_retvals,StatusCallback done)2091 void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
2092 int* num_retvals, StatusCallback done) {
2093 if (!op->IsLocal()) {
2094 done(errors::InvalidArgument(
2095 "Remote execution is not supported in async EagerLocalExecuteAsync"));
2096 return;
2097 }
2098
2099 profiler::ScopedMemoryDebugAnnotation op_annotation(
2100 op->op_name(), op->eager_func_params().has_value()
2101 ? op->eager_func_params().value().step_id.value_or(0)
2102 : 0);
2103 profiler::TraceMe activity(
2104 [&] { return absl::StrCat("EagerLocalExecuteAsync: ", op->Name()); },
2105 profiler::TraceMeLevel::kInfo);
2106 EagerContext& ctx = op->EagerContext();
2107
2108 core::RefCountPtr<KernelAndDevice> kernel;
2109 Status s = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel);
2110 if (!s.ok()) {
2111 done(s);
2112 return;
2113 }
2114
2115 int num_outputs = kernel->num_outputs();
2116 s = ValidateInputTypeAndPlacement(&ctx, op, kernel);
2117 if (!s.ok()) {
2118 done(s);
2119 return;
2120 }
2121
2122 if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
2123 string msg = strings::StrCat("Executing op ", op->Name(), " in device ",
2124 kernel->device()->name());
2125 if (!logging::LogToListeners(msg)) {
2126 LOG(INFO) << msg;
2127 }
2128 }
2129
2130 GraphCollector* graph_collector = nullptr;
2131 if (ctx.ShouldStoreGraphs()) {
2132 graph_collector = ctx.GetGraphCollector();
2133 }
2134
2135 for (int i = 0, end = num_outputs; i < end; ++i) {
2136 const DataTypeVector& output_dtypes = kernel->output_dtypes();
2137 retvals[i] = TensorHandle::CreateEmptyLocalHandle(
2138 /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
2139 /* op_device= */ kernel->device(),
2140 /* resource_device= */ kernel->OutputResourceDevice(i),
2141 output_dtypes[i], &ctx);
2142 }
2143
2144 const absl::InlinedVector<TensorHandle*, 4>* inputs;
2145 s = op->TensorHandleInputs(&inputs);
2146 if (!s.ok()) {
2147 done(s);
2148 return;
2149 }
2150 EagerKernelExecuteAsync(
2151 &ctx, *inputs, op->eager_func_params(), std::move(kernel),
2152 graph_collector, op->GetCancellationManager(), retvals, num_outputs,
2153 [op, num_outputs, retvals, done = std::move(done)](const Status& s) {
2154 op->Clear();
2155 // Since the operation failed, we need to Unref any outputs if they were
2156 // allocated.
2157 if (!s.ok()) {
2158 for (int i = 0, end = num_outputs; i < end; ++i) {
2159 if (retvals[i] != nullptr) {
2160 retvals[i]->Unref();
2161 retvals[i] = nullptr;
2162 }
2163 }
2164 }
2165 done(s);
2166 });
2167 }
2168 } // namespace tensorflow
2169