1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
17
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <utility>
22 #include <variant>
23 #include <vector>
24
25 #include "absl/types/optional.h"
26 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
27 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
28 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/platform/errors.h"
31
32 namespace tensorflow {
33 namespace eager {
34
35 namespace {
36
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)37 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
38 remote_op->set_name(op->Name());
39
40 op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
41 remote_op->set_device(op->DeviceName());
42 }
43
CreateUncachedKernelAndDeviceOp(EagerOperation * op,core::RefCountPtr<KernelAndDevice> * kernel)44 Status CreateUncachedKernelAndDeviceOp(
45 EagerOperation* op, core::RefCountPtr<KernelAndDevice>* kernel) {
46 EagerContext& ctx = op->EagerContext();
47 Device* device = std::get<Device*>(op->Device());
48
49 FunctionLibraryRuntime* flr = ctx.func_lib(device);
50 if (flr == nullptr) {
51 return errors::Unavailable(
52 "Unable to find a FunctionLibraryRuntime corresponding to device ",
53 device->name());
54 }
55
56 auto runner = (flr->runner() != nullptr) ? flr->runner() : ctx.runner();
57 kernel->reset(new KernelAndDeviceOp(ctx.GetRendezvous(), ctx.LogMemory(), flr,
58 runner, ctx.GetCollectiveExecutorHandle(),
59 ctx.HostCPU()));
60
61 const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
62 return kernel->get()->Init(ctx.LogDevicePlacement(), ndef,
63 /*graph_collector=*/nullptr);
64 }
65
66 // This gets a unique wire ID. We add a random identifier so that if the
67 // worker has other clients that it is servicing, we don't have any collision.
GetUniqueWireID()68 string GetUniqueWireID() {
69 static tensorflow::uint64 random_seed = random::New64();
70 static tensorflow::mutex wireid_mutex(tensorflow::LINKER_INITIALIZED);
71 static std::atomic<int64_t> wire_id;
72 return strings::StrCat(random_seed, "_", wire_id++);
73 }
74
75 } // namespace
76
RemoteCopyNode(EagerContext * ctx,EagerExecutor * executor,TensorHandle * src,TensorHandle * dst,Device * recv_device,uint64 recv_op_id)77 RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
78 TensorHandle* src, TensorHandle* dst,
79 Device* recv_device, uint64 recv_op_id)
80 : AsyncEagerNode(),
81 src_(src),
82 ctx_(ctx),
83 executor_(executor),
84 send_device_(src->DeviceOrHostCPU(*ctx)),
85 recv_device_(recv_device),
86 wire_id_(GetUniqueWireID()),
87 recv_op_id_(recv_op_id),
88 captured_state_(std::make_shared<CapturedSharedState>(dst)),
89 started_(false) {
90 DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal());
91 src_->Ref();
92 ctx_->Ref();
93 }
94
~RemoteCopyNode()95 RemoteCopyNode::~RemoteCopyNode() {
96 src_->Unref();
97 ctx_->Unref();
98 }
99
RunLocalSend(EagerOperation * op)100 Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
101 TF_RETURN_IF_ERROR(executor_->status());
102
103 TF_RETURN_IF_ERROR(op->AddInput(src_));
104
105 core::RefCountPtr<KernelAndDevice> kernel;
106 TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
107
108 EagerKernelArgs args(1);
109 Device* d = ctx_->CanonicalDevice(std::get<Device*>(op->Device()));
110 TF_RETURN_IF_ERROR(src_->TensorValue(d, args.MutableInput(0)));
111 CoordinationServiceAgent* coord_agent = nullptr;
112 if (ctx_->GetDistributedManager() != nullptr)
113 coord_agent = ctx_->GetDistributedManager()->GetCoordinationServiceAgent();
114
115 return kernel->Run(/*step_container=*/nullptr, args, /*outputs=*/nullptr,
116 /*cancellation_manager=*/nullptr,
117 /*eager_func_params=*/std::nullopt,
118 /*stack_trace=*/std::nullopt, coord_agent);
119 }
120
StartSend()121 void RemoteCopyNode::StartSend() {
122 // TODO(gjn): We should consider just using the low-level SendOp::Compute()
123 // functionality here instead of constructing an Op.
124 EagerOperation op(ctx_);
125 Status status = op.Reset("_Send", /*device_name=*/nullptr,
126 /*remote=*/false, /*executor=*/nullptr);
127 if (!status.ok()) {
128 captured_state_->SetSendStatus(status);
129 return;
130 }
131
132 op.SetDevice(send_device_);
133
134 op.MutableAttrs()->Set("tensor_name", wire_id_);
135 op.MutableAttrs()->Set("send_device", send_device_->name());
136 op.MutableAttrs()->Set(
137 "send_device_incarnation",
138 static_cast<int64_t>(send_device_->attributes().incarnation()));
139 op.MutableAttrs()->Set("recv_device", recv_device_->name());
140 op.MutableAttrs()->Set("client_terminated", false);
141
142 op.MutableAttrs()->Set("T", src_->dtype);
143
144 DCHECK(send_device_ != nullptr);
145
146 if (send_device_->IsLocal()) {
147 status = RunLocalSend(&op);
148 captured_state_->SetSendStatus(status);
149 return;
150 } else {
151 // Prepare the request
152 EnqueueRequest request;
153 request.set_context_id(ctx_->GetContextId());
154 auto* remote_op = request.add_queue()->mutable_operation();
155 status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
156 src_, /*wait_until_ready=*/false,
157 remote_op->add_op_inputs()->mutable_remote_handle(), src_->device(),
158 src_->DeviceOrHostCPU(*ctx_)->name());
159 if (!status.ok()) {
160 captured_state_->SetSendStatus(status);
161 return;
162 }
163
164 PrepareRemoteOp(remote_op, &op);
165 remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
166
167 // Issue the RPC
168 core::RefCountPtr<eager::EagerClient> eager_client;
169 status = ctx_->GetClient(send_device_, &eager_client);
170 if (!status.ok()) {
171 captured_state_->SetSendStatus(status);
172 return;
173 }
174
175 const std::shared_ptr<CapturedSharedState>& captured_state =
176 captured_state_;
177 EnqueueResponse* response = new EnqueueResponse;
178 // If StartRecv fails very quickly, `this` can be destroyed before the
179 // callback below is executed. So, we can't capture `this`.
180 eager_client->StreamingEnqueueAsync(
181 ctx_->Executor().StreamingEnqueue(),
182 /*call_opts=*/nullptr, &request, response,
183 [response, captured_state](const Status& s) {
184 captured_state->SetSendStatus(s);
185 if (!s.ok()) {
186 captured_state->recv_cancellation()->StartCancel();
187 }
188 delete response;
189 });
190 }
191 }
192
RunLocalRecv(EagerOperation * op,std::vector<Tensor> * outputs)193 Status RemoteCopyNode::RunLocalRecv(EagerOperation* op,
194 std::vector<Tensor>* outputs) {
195 TF_RETURN_IF_ERROR(executor_->status());
196
197 core::RefCountPtr<KernelAndDevice> kernel;
198 TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
199
200 EagerKernelArgs args;
201 std::vector<EagerKernelRet> rets;
202 CoordinationServiceAgent* coord_agent = nullptr;
203 if (ctx_->GetDistributedManager() != nullptr)
204 coord_agent = ctx_->GetDistributedManager()->GetCoordinationServiceAgent();
205 TF_RETURN_IF_ERROR(kernel->Run(/*step_container*/ nullptr, args, &rets,
206 captured_state_->recv_cancellation(),
207 /*eager_func_params=*/std::nullopt,
208 /*stack_trace=*/std::nullopt, coord_agent));
209 outputs->clear();
210 for (const auto& ret : rets) {
211 if (ret.index() == 0) {
212 outputs->push_back(std::get<Tensor>(ret));
213 } else {
214 return errors::Internal(
215 "Expect to receive a Tensor but got a TensorShape.");
216 }
217 }
218 return OkStatus();
219 }
220
RunRemoteRecv(EagerOperation * op,StatusCallback done)221 void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) {
222 EnqueueRequest request;
223 uint64 context_id = ctx_->GetContextId();
224 request.set_context_id(context_id);
225 auto* remote_op = request.add_queue()->mutable_operation();
226 PrepareRemoteOp(remote_op, op);
227 remote_op->set_id(recv_op_id_);
228 uint64 context_view_id = ctx_->GetContextViewId();
229
230 core::RefCountPtr<eager::EagerClient> eager_client;
231 Status status = ctx_->GetClient(recv_device_, &eager_client);
232 if (!status.ok()) {
233 captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
234 done(status);
235 return;
236 }
237
238 // Don't issue the recv until send has completed.
239 // - local send will complete very quickly.
240 // - remote send will take some time, but remote->remote copy is
241 // probably rare enough that we don't care much.
242 // Blocks until send has completed.
243 Status send_status = captured_state_->GetSendStatus();
244 if (!send_status.ok()) {
245 captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
246 done(send_status);
247 return;
248 }
249
250 EnqueueResponse* response = new EnqueueResponse;
251 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
252 Device* recv_device = recv_device_;
253 eager_client->StreamingEnqueueAsync(
254 ctx_->Executor().StreamingEnqueue(),
255 /*call_opts=*/nullptr, &request, response,
256 [captured_state, response, recv_device, context_view_id,
257 done](const Status& s) {
258 if (s.ok()) {
259 Status status = captured_state->dst()->SetRemoteShape(
260 response->queue_response(0).shape(0), recv_device,
261 context_view_id);
262 if (!status.ok()) {
263 LOG(ERROR) << "Ignoring an error encountered when setting remote "
264 "shape of tensor received by remote Recv op: "
265 << status.ToString()
266 << "\nThis should never happen. "
267 "Please file an issue with the TensorFlow Team.";
268 }
269 } else {
270 captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
271 }
272 done(s);
273 delete response;
274 });
275 }
276
StartRecv(StatusCallback done)277 void RemoteCopyNode::StartRecv(StatusCallback done) {
278 // TODO(gjn): We should consider just using the low-level RecvOp::Compute()
279 // functionality here instead of constructing an Op.
280 EagerOperation op(ctx_);
281 Status status = op.Reset("_Recv", /*device_name=*/nullptr,
282 /*remote=*/false, /*executor=*/nullptr);
283 Device* recv_device = ctx_->CanonicalDevice(recv_device_);
284 if (!status.ok()) {
285 captured_state_->dst()->Poison(status, recv_device);
286 done(status);
287 return;
288 }
289
290 op.SetDevice(recv_device_);
291
292 op.MutableAttrs()->Set("tensor_name", wire_id_);
293 op.MutableAttrs()->Set("send_device", send_device_->name());
294 op.MutableAttrs()->Set(
295 "send_device_incarnation",
296 static_cast<int64_t>(send_device_->attributes().incarnation()));
297 op.MutableAttrs()->Set("recv_device", recv_device_->name());
298 op.MutableAttrs()->Set("client_terminated", false);
299
300 op.MutableAttrs()->Set("tensor_type", src_->dtype);
301
302 if (recv_device_->IsLocal()) {
303 std::vector<Tensor> outputs(1);
304 status = RunLocalRecv(&op, &outputs);
305 if (!status.ok()) {
306 captured_state_->dst()->Poison(status, recv_device);
307 done(status);
308 return;
309 }
310 status =
311 captured_state_->dst()->SetTensor(std::move(outputs[0]), recv_device);
312 done(status);
313 } else {
314 // Handles captured_state_->dst_ internally.
315 RunRemoteRecv(&op, std::move(done));
316 }
317 }
318
SerializePackedHandle(const uint64 op_id,TensorHandle * packed_handle,const Device * target_device,EagerContext * ctx,SendPackedHandleOp * op)319 Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
320 const Device* target_device, EagerContext* ctx,
321 SendPackedHandleOp* op) {
322 op->set_op_id(op_id);
323 op->set_device_name(packed_handle->DeviceOrHostCPU(*ctx)->name());
324 for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
325 TensorHandle* h = nullptr;
326 TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
327 if (h->Type() == TensorHandle::LOCAL) {
328 // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
329 // copy it to the CPU before copying it out.
330 Tensor tensor;
331 TF_RETURN_IF_ERROR(h->CopyToDevice(*ctx, ctx->HostCPU(), &tensor));
332 auto* local_handle = op->add_handles()->mutable_local_handle();
333 local_handle->set_device(h->op_device() ? h->op_device()->name()
334 : ctx->HostCPU()->name());
335 tensor.AsProtoTensorContent(local_handle->mutable_tensor());
336 } else if (h->Type() == TensorHandle::REMOTE) {
337 // Only serialize the resource dtype and shape of the first handle, since
338 // all handles are of the same resource dtype and shape.
339 // If src_device is on the same task of target_device, the handle is a
340 // local handle on the target device, which means the resource dtype and
341 // shape are known on the target device.
342 Device* src_device = h->device();
343 const bool serialize_resource_dtype_and_shape =
344 (i == 0) && (h->dtype == DT_RESOURCE) &&
345 (!ctx->OnSameTask(src_device, target_device));
346 // For a remote component function, a function execution request and an
347 // input generation request may come from different workers. We need to
348 // guarantee that the input generation request is processed before the
349 // function execution request, so wait until the underlying remote handles
350 // are ready before sending a packed handle to the function device.
351 TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
352 h, /*wait_until_ready=*/true,
353 op->add_handles()->mutable_remote_handle(), src_device,
354 h->DeviceOrHostCPU(*ctx)->name(),
355 serialize_resource_dtype_and_shape));
356 } else {
357 return errors::InvalidArgument("Nested packed handles are not supported");
358 }
359 }
360 return OkStatus();
361 }
362
StartSendPackedHandle(StatusCallback done)363 void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) {
364 Status s;
365 const uint64 context_view_id = ctx_->GetContextViewId();
366 if (!send_device_->IsLocal()) {
367 s = errors::InvalidArgument(
368 "Copy a packed handle from a remote device is not supported");
369 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
370 done(s);
371 return;
372 }
373
374 EnqueueRequest request;
375 uint64 context_id = ctx_->GetContextId();
376 request.set_context_id(context_id);
377 s = SerializePackedHandle(recv_op_id_, src_, recv_device_, ctx_,
378 request.add_queue()->mutable_send_packed_handle());
379 if (!s.ok()) {
380 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
381 done(s);
382 return;
383 }
384
385 TensorShape shape;
386 s = src_->Shape(&shape);
387 if (!s.ok()) {
388 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
389 done(s);
390 return;
391 }
392 captured_state_->SetSrcShape(shape);
393
394 core::RefCountPtr<eager::EagerClient> eager_client;
395 s = ctx_->GetClient(recv_device_, &eager_client);
396 if (!s.ok()) {
397 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
398 done(s);
399 return;
400 }
401
402 EnqueueResponse* response = new EnqueueResponse;
403 Device* recv_device = recv_device_;
404 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
405 eager_client->StreamingEnqueueAsync(
406 ctx_->Executor().StreamingEnqueue(),
407 /*call_opts=*/nullptr, &request, response,
408 [captured_state, response, recv_device, context_view_id,
409 done](const Status& s) {
410 if (s.ok()) {
411 Status status = captured_state->dst()->SetRemoteShape(
412 captured_state->GetSrcShape(), recv_device, context_view_id);
413 if (!status.ok()) {
414 LOG(ERROR) << "Ignoring an error encountered when setting remote "
415 "shape of tensor received by SendPackedHadnle rpc: "
416 << status.ToString();
417 }
418 } else {
419 captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
420 }
421 done(s);
422 delete response;
423 });
424 }
425
StartRemoteSendTensor(StatusCallback done)426 void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
427 Status s;
428 EnqueueRequest request;
429 uint64 context_id = ctx_->GetContextId();
430 request.set_context_id(context_id);
431 auto* send_tensor = request.add_queue()->mutable_send_tensor();
432 send_tensor->set_op_id(recv_op_id_);
433 send_tensor->set_device_name(recv_device_->name());
434 uint64 context_view_id = ctx_->GetContextViewId();
435
436 // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
437 // copy it to the CPU before copying it out.
438 // TODO(fishx): Make CopyToDevice asynchronous.
439 Tensor tensor;
440 s = src_->CopyToDevice(*ctx_, ctx_->HostCPU(), &tensor);
441 if (!s.ok()) {
442 done(s);
443 return;
444 }
445 tensor.AsProtoTensorContent(send_tensor->add_tensors());
446
447 core::RefCountPtr<eager::EagerClient> eager_client;
448 s = ctx_->GetClient(recv_device_, &eager_client);
449 if (!s.ok()) {
450 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
451 done(s);
452 return;
453 }
454 EnqueueResponse* response = new EnqueueResponse;
455 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
456 captured_state->SetSrcShape(tensor.shape());
457 Device* recv_device = recv_device_;
458 eager_client->StreamingEnqueueAsync(
459 ctx_->Executor().StreamingEnqueue(),
460 /*call_opts=*/nullptr, &request, response,
461 [captured_state, response, recv_device, context_view_id,
462 done](const Status& s) {
463 if (s.ok()) {
464 Status status = captured_state->dst()->SetRemoteShape(
465 captured_state->GetSrcShape(), recv_device, context_view_id);
466 if (!status.ok()) {
467 LOG(ERROR) << "Ignoring an error encountered when setting remote "
468 "shape of tensor received by SendTensor rpc: "
469 << status.ToString();
470 }
471 } else {
472 captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
473 }
474 done(s);
475 delete response;
476 });
477 }
478
Prepare()479 Status RemoteCopyNode::Prepare() {
480 TF_RETURN_IF_ERROR(captured_state_->dst()->CopyInferenceShape(src_));
481 return OkStatus();
482 }
483
RunAsync(StatusCallback done)484 void RemoteCopyNode::RunAsync(StatusCallback done) {
485 started_ = true;
486 if (src_->Type() == TensorHandle::PACKED) {
487 return StartSendPackedHandle(std::move(done));
488 }
489
490 if ((ctx_->UseSendTensorRPC()) && send_device_->IsLocal() &&
491 !recv_device_->IsLocal()) {
492 return StartRemoteSendTensor(std::move(done));
493 }
494 StartSend();
495
496 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
497 auto done_wrapper = [captured_state,
498 done = std::move(done)](const Status& s) {
499 if (!s.ok() && errors::IsCancelled(s)) {
500 Status send_status = captured_state->GetSendStatus();
501 if (!send_status.ok()) {
502 // In this case, Recv is cancelled because the Send op failed.
503 // Return the status of the Send op instead.
504 done(send_status);
505 }
506 } else {
507 done(s);
508 }
509 };
510
511 // StartRecv() takes care of doing the right thing to dst handle.
512 // No need to poison it after this point.
513 StartRecv(std::move(done_wrapper));
514 }
515
Abort(Status status)516 void RemoteCopyNode::Abort(Status status) {
517 if (!started_) {
518 uint64 context_view_id = ctx_->GetContextViewId();
519 captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
520 }
521 }
522
523 } // namespace eager
524 } // namespace tensorflow
525