xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <utility>
22 #include <variant>
23 #include <vector>
24 
25 #include "absl/types/optional.h"
26 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
27 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
28 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/platform/errors.h"
31 
32 namespace tensorflow {
33 namespace eager {
34 
35 namespace {
36 
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)37 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
38   remote_op->set_name(op->Name());
39 
40   op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
41   remote_op->set_device(op->DeviceName());
42 }
43 
CreateUncachedKernelAndDeviceOp(EagerOperation * op,core::RefCountPtr<KernelAndDevice> * kernel)44 Status CreateUncachedKernelAndDeviceOp(
45     EagerOperation* op, core::RefCountPtr<KernelAndDevice>* kernel) {
46   EagerContext& ctx = op->EagerContext();
47   Device* device = std::get<Device*>(op->Device());
48 
49   FunctionLibraryRuntime* flr = ctx.func_lib(device);
50   if (flr == nullptr) {
51     return errors::Unavailable(
52         "Unable to find a FunctionLibraryRuntime corresponding to device ",
53         device->name());
54   }
55 
56   auto runner = (flr->runner() != nullptr) ? flr->runner() : ctx.runner();
57   kernel->reset(new KernelAndDeviceOp(ctx.GetRendezvous(), ctx.LogMemory(), flr,
58                                       runner, ctx.GetCollectiveExecutorHandle(),
59                                       ctx.HostCPU()));
60 
61   const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
62   return kernel->get()->Init(ctx.LogDevicePlacement(), ndef,
63                              /*graph_collector=*/nullptr);
64 }
65 
66 // This gets a unique wire ID. We add a random identifier so that if the
67 // worker has other clients that it is servicing, we don't have any collision.
GetUniqueWireID()68 string GetUniqueWireID() {
69   static tensorflow::uint64 random_seed = random::New64();
70   static tensorflow::mutex wireid_mutex(tensorflow::LINKER_INITIALIZED);
71   static std::atomic<int64_t> wire_id;
72   return strings::StrCat(random_seed, "_", wire_id++);
73 }
74 
75 }  // namespace
76 
RemoteCopyNode(EagerContext * ctx,EagerExecutor * executor,TensorHandle * src,TensorHandle * dst,Device * recv_device,uint64 recv_op_id)77 RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
78                                TensorHandle* src, TensorHandle* dst,
79                                Device* recv_device, uint64 recv_op_id)
80     : AsyncEagerNode(),
81       src_(src),
82       ctx_(ctx),
83       executor_(executor),
84       send_device_(src->DeviceOrHostCPU(*ctx)),
85       recv_device_(recv_device),
86       wire_id_(GetUniqueWireID()),
87       recv_op_id_(recv_op_id),
88       captured_state_(std::make_shared<CapturedSharedState>(dst)),
89       started_(false) {
90   DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal());
91   src_->Ref();
92   ctx_->Ref();
93 }
94 
~RemoteCopyNode()95 RemoteCopyNode::~RemoteCopyNode() {
96   src_->Unref();
97   ctx_->Unref();
98 }
99 
RunLocalSend(EagerOperation * op)100 Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
101   TF_RETURN_IF_ERROR(executor_->status());
102 
103   TF_RETURN_IF_ERROR(op->AddInput(src_));
104 
105   core::RefCountPtr<KernelAndDevice> kernel;
106   TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
107 
108   EagerKernelArgs args(1);
109   Device* d = ctx_->CanonicalDevice(std::get<Device*>(op->Device()));
110   TF_RETURN_IF_ERROR(src_->TensorValue(d, args.MutableInput(0)));
111   CoordinationServiceAgent* coord_agent = nullptr;
112   if (ctx_->GetDistributedManager() != nullptr)
113     coord_agent = ctx_->GetDistributedManager()->GetCoordinationServiceAgent();
114 
115   return kernel->Run(/*step_container=*/nullptr, args, /*outputs=*/nullptr,
116                      /*cancellation_manager=*/nullptr,
117                      /*eager_func_params=*/std::nullopt,
118                      /*stack_trace=*/std::nullopt, coord_agent);
119 }
120 
StartSend()121 void RemoteCopyNode::StartSend() {
122   // TODO(gjn): We should consider just using the low-level SendOp::Compute()
123   // functionality here instead of constructing an Op.
124   EagerOperation op(ctx_);
125   Status status = op.Reset("_Send", /*device_name=*/nullptr,
126                            /*remote=*/false, /*executor=*/nullptr);
127   if (!status.ok()) {
128     captured_state_->SetSendStatus(status);
129     return;
130   }
131 
132   op.SetDevice(send_device_);
133 
134   op.MutableAttrs()->Set("tensor_name", wire_id_);
135   op.MutableAttrs()->Set("send_device", send_device_->name());
136   op.MutableAttrs()->Set(
137       "send_device_incarnation",
138       static_cast<int64_t>(send_device_->attributes().incarnation()));
139   op.MutableAttrs()->Set("recv_device", recv_device_->name());
140   op.MutableAttrs()->Set("client_terminated", false);
141 
142   op.MutableAttrs()->Set("T", src_->dtype);
143 
144   DCHECK(send_device_ != nullptr);
145 
146   if (send_device_->IsLocal()) {
147     status = RunLocalSend(&op);
148     captured_state_->SetSendStatus(status);
149     return;
150   } else {
151     // Prepare the request
152     EnqueueRequest request;
153     request.set_context_id(ctx_->GetContextId());
154     auto* remote_op = request.add_queue()->mutable_operation();
155     status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
156         src_, /*wait_until_ready=*/false,
157         remote_op->add_op_inputs()->mutable_remote_handle(), src_->device(),
158         src_->DeviceOrHostCPU(*ctx_)->name());
159     if (!status.ok()) {
160       captured_state_->SetSendStatus(status);
161       return;
162     }
163 
164     PrepareRemoteOp(remote_op, &op);
165     remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
166 
167     // Issue the RPC
168     core::RefCountPtr<eager::EagerClient> eager_client;
169     status = ctx_->GetClient(send_device_, &eager_client);
170     if (!status.ok()) {
171       captured_state_->SetSendStatus(status);
172       return;
173     }
174 
175     const std::shared_ptr<CapturedSharedState>& captured_state =
176         captured_state_;
177     EnqueueResponse* response = new EnqueueResponse;
178     // If StartRecv fails very quickly, `this` can be destroyed before the
179     // callback below is executed. So, we can't capture `this`.
180     eager_client->StreamingEnqueueAsync(
181         ctx_->Executor().StreamingEnqueue(),
182         /*call_opts=*/nullptr, &request, response,
183         [response, captured_state](const Status& s) {
184           captured_state->SetSendStatus(s);
185           if (!s.ok()) {
186             captured_state->recv_cancellation()->StartCancel();
187           }
188           delete response;
189         });
190   }
191 }
192 
RunLocalRecv(EagerOperation * op,std::vector<Tensor> * outputs)193 Status RemoteCopyNode::RunLocalRecv(EagerOperation* op,
194                                     std::vector<Tensor>* outputs) {
195   TF_RETURN_IF_ERROR(executor_->status());
196 
197   core::RefCountPtr<KernelAndDevice> kernel;
198   TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
199 
200   EagerKernelArgs args;
201   std::vector<EagerKernelRet> rets;
202   CoordinationServiceAgent* coord_agent = nullptr;
203   if (ctx_->GetDistributedManager() != nullptr)
204     coord_agent = ctx_->GetDistributedManager()->GetCoordinationServiceAgent();
205   TF_RETURN_IF_ERROR(kernel->Run(/*step_container*/ nullptr, args, &rets,
206                                  captured_state_->recv_cancellation(),
207                                  /*eager_func_params=*/std::nullopt,
208                                  /*stack_trace=*/std::nullopt, coord_agent));
209   outputs->clear();
210   for (const auto& ret : rets) {
211     if (ret.index() == 0) {
212       outputs->push_back(std::get<Tensor>(ret));
213     } else {
214       return errors::Internal(
215           "Expect to receive a Tensor but got a TensorShape.");
216     }
217   }
218   return OkStatus();
219 }
220 
RunRemoteRecv(EagerOperation * op,StatusCallback done)221 void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) {
222   EnqueueRequest request;
223   uint64 context_id = ctx_->GetContextId();
224   request.set_context_id(context_id);
225   auto* remote_op = request.add_queue()->mutable_operation();
226   PrepareRemoteOp(remote_op, op);
227   remote_op->set_id(recv_op_id_);
228   uint64 context_view_id = ctx_->GetContextViewId();
229 
230   core::RefCountPtr<eager::EagerClient> eager_client;
231   Status status = ctx_->GetClient(recv_device_, &eager_client);
232   if (!status.ok()) {
233     captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
234     done(status);
235     return;
236   }
237 
238   // Don't issue the recv until send has completed.
239   //  - local send will complete very quickly.
240   //  - remote send will take some time, but remote->remote copy is
241   //    probably rare enough that we don't care much.
242   // Blocks until send has completed.
243   Status send_status = captured_state_->GetSendStatus();
244   if (!send_status.ok()) {
245     captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
246     done(send_status);
247     return;
248   }
249 
250   EnqueueResponse* response = new EnqueueResponse;
251   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
252   Device* recv_device = recv_device_;
253   eager_client->StreamingEnqueueAsync(
254       ctx_->Executor().StreamingEnqueue(),
255       /*call_opts=*/nullptr, &request, response,
256       [captured_state, response, recv_device, context_view_id,
257        done](const Status& s) {
258         if (s.ok()) {
259           Status status = captured_state->dst()->SetRemoteShape(
260               response->queue_response(0).shape(0), recv_device,
261               context_view_id);
262           if (!status.ok()) {
263             LOG(ERROR) << "Ignoring an error encountered when setting remote "
264                           "shape of tensor received by remote Recv op: "
265                        << status.ToString()
266                        << "\nThis should never happen. "
267                           "Please file an issue with the TensorFlow Team.";
268           }
269         } else {
270           captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
271         }
272         done(s);
273         delete response;
274       });
275 }
276 
StartRecv(StatusCallback done)277 void RemoteCopyNode::StartRecv(StatusCallback done) {
278   // TODO(gjn): We should consider just using the low-level RecvOp::Compute()
279   // functionality here instead of constructing an Op.
280   EagerOperation op(ctx_);
281   Status status = op.Reset("_Recv", /*device_name=*/nullptr,
282                            /*remote=*/false, /*executor=*/nullptr);
283   Device* recv_device = ctx_->CanonicalDevice(recv_device_);
284   if (!status.ok()) {
285     captured_state_->dst()->Poison(status, recv_device);
286     done(status);
287     return;
288   }
289 
290   op.SetDevice(recv_device_);
291 
292   op.MutableAttrs()->Set("tensor_name", wire_id_);
293   op.MutableAttrs()->Set("send_device", send_device_->name());
294   op.MutableAttrs()->Set(
295       "send_device_incarnation",
296       static_cast<int64_t>(send_device_->attributes().incarnation()));
297   op.MutableAttrs()->Set("recv_device", recv_device_->name());
298   op.MutableAttrs()->Set("client_terminated", false);
299 
300   op.MutableAttrs()->Set("tensor_type", src_->dtype);
301 
302   if (recv_device_->IsLocal()) {
303     std::vector<Tensor> outputs(1);
304     status = RunLocalRecv(&op, &outputs);
305     if (!status.ok()) {
306       captured_state_->dst()->Poison(status, recv_device);
307       done(status);
308       return;
309     }
310     status =
311         captured_state_->dst()->SetTensor(std::move(outputs[0]), recv_device);
312     done(status);
313   } else {
314     // Handles captured_state_->dst_ internally.
315     RunRemoteRecv(&op, std::move(done));
316   }
317 }
318 
SerializePackedHandle(const uint64 op_id,TensorHandle * packed_handle,const Device * target_device,EagerContext * ctx,SendPackedHandleOp * op)319 Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
320                              const Device* target_device, EagerContext* ctx,
321                              SendPackedHandleOp* op) {
322   op->set_op_id(op_id);
323   op->set_device_name(packed_handle->DeviceOrHostCPU(*ctx)->name());
324   for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
325     TensorHandle* h = nullptr;
326     TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
327     if (h->Type() == TensorHandle::LOCAL) {
328       // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
329       // copy it to the CPU before copying it out.
330       Tensor tensor;
331       TF_RETURN_IF_ERROR(h->CopyToDevice(*ctx, ctx->HostCPU(), &tensor));
332       auto* local_handle = op->add_handles()->mutable_local_handle();
333       local_handle->set_device(h->op_device() ? h->op_device()->name()
334                                               : ctx->HostCPU()->name());
335       tensor.AsProtoTensorContent(local_handle->mutable_tensor());
336     } else if (h->Type() == TensorHandle::REMOTE) {
337       // Only serialize the resource dtype and shape of the first handle, since
338       // all handles are of the same resource dtype and shape.
339       // If src_device is on the same task of target_device, the handle is a
340       // local handle on the target device, which means the resource dtype and
341       // shape are known on the target device.
342       Device* src_device = h->device();
343       const bool serialize_resource_dtype_and_shape =
344           (i == 0) && (h->dtype == DT_RESOURCE) &&
345           (!ctx->OnSameTask(src_device, target_device));
346       // For a remote component function, a function execution request and an
347       // input generation request may come from different workers. We need to
348       // guarantee that the input generation request is processed before the
349       // function execution request, so wait until the underlying remote handles
350       // are ready before sending a packed handle to the function device.
351       TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
352           h, /*wait_until_ready=*/true,
353           op->add_handles()->mutable_remote_handle(), src_device,
354           h->DeviceOrHostCPU(*ctx)->name(),
355           serialize_resource_dtype_and_shape));
356     } else {
357       return errors::InvalidArgument("Nested packed handles are not supported");
358     }
359   }
360   return OkStatus();
361 }
362 
StartSendPackedHandle(StatusCallback done)363 void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) {
364   Status s;
365   const uint64 context_view_id = ctx_->GetContextViewId();
366   if (!send_device_->IsLocal()) {
367     s = errors::InvalidArgument(
368         "Copy a packed handle from a remote device is not supported");
369     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
370     done(s);
371     return;
372   }
373 
374   EnqueueRequest request;
375   uint64 context_id = ctx_->GetContextId();
376   request.set_context_id(context_id);
377   s = SerializePackedHandle(recv_op_id_, src_, recv_device_, ctx_,
378                             request.add_queue()->mutable_send_packed_handle());
379   if (!s.ok()) {
380     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
381     done(s);
382     return;
383   }
384 
385   TensorShape shape;
386   s = src_->Shape(&shape);
387   if (!s.ok()) {
388     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
389     done(s);
390     return;
391   }
392   captured_state_->SetSrcShape(shape);
393 
394   core::RefCountPtr<eager::EagerClient> eager_client;
395   s = ctx_->GetClient(recv_device_, &eager_client);
396   if (!s.ok()) {
397     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
398     done(s);
399     return;
400   }
401 
402   EnqueueResponse* response = new EnqueueResponse;
403   Device* recv_device = recv_device_;
404   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
405   eager_client->StreamingEnqueueAsync(
406       ctx_->Executor().StreamingEnqueue(),
407       /*call_opts=*/nullptr, &request, response,
408       [captured_state, response, recv_device, context_view_id,
409        done](const Status& s) {
410         if (s.ok()) {
411           Status status = captured_state->dst()->SetRemoteShape(
412               captured_state->GetSrcShape(), recv_device, context_view_id);
413           if (!status.ok()) {
414             LOG(ERROR) << "Ignoring an error encountered when setting remote "
415                           "shape of tensor received by SendPackedHadnle rpc: "
416                        << status.ToString();
417           }
418         } else {
419           captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
420         }
421         done(s);
422         delete response;
423       });
424 }
425 
StartRemoteSendTensor(StatusCallback done)426 void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
427   Status s;
428   EnqueueRequest request;
429   uint64 context_id = ctx_->GetContextId();
430   request.set_context_id(context_id);
431   auto* send_tensor = request.add_queue()->mutable_send_tensor();
432   send_tensor->set_op_id(recv_op_id_);
433   send_tensor->set_device_name(recv_device_->name());
434   uint64 context_view_id = ctx_->GetContextViewId();
435 
436   // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
437   // copy it to the CPU before copying it out.
438   // TODO(fishx): Make CopyToDevice asynchronous.
439   Tensor tensor;
440   s = src_->CopyToDevice(*ctx_, ctx_->HostCPU(), &tensor);
441   if (!s.ok()) {
442     done(s);
443     return;
444   }
445   tensor.AsProtoTensorContent(send_tensor->add_tensors());
446 
447   core::RefCountPtr<eager::EagerClient> eager_client;
448   s = ctx_->GetClient(recv_device_, &eager_client);
449   if (!s.ok()) {
450     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
451     done(s);
452     return;
453   }
454   EnqueueResponse* response = new EnqueueResponse;
455   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
456   captured_state->SetSrcShape(tensor.shape());
457   Device* recv_device = recv_device_;
458   eager_client->StreamingEnqueueAsync(
459       ctx_->Executor().StreamingEnqueue(),
460       /*call_opts=*/nullptr, &request, response,
461       [captured_state, response, recv_device, context_view_id,
462        done](const Status& s) {
463         if (s.ok()) {
464           Status status = captured_state->dst()->SetRemoteShape(
465               captured_state->GetSrcShape(), recv_device, context_view_id);
466           if (!status.ok()) {
467             LOG(ERROR) << "Ignoring an error encountered when setting remote "
468                           "shape of tensor received by SendTensor rpc: "
469                        << status.ToString();
470           }
471         } else {
472           captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
473         }
474         done(s);
475         delete response;
476       });
477 }
478 
Prepare()479 Status RemoteCopyNode::Prepare() {
480   TF_RETURN_IF_ERROR(captured_state_->dst()->CopyInferenceShape(src_));
481   return OkStatus();
482 }
483 
RunAsync(StatusCallback done)484 void RemoteCopyNode::RunAsync(StatusCallback done) {
485   started_ = true;
486   if (src_->Type() == TensorHandle::PACKED) {
487     return StartSendPackedHandle(std::move(done));
488   }
489 
490   if ((ctx_->UseSendTensorRPC()) && send_device_->IsLocal() &&
491       !recv_device_->IsLocal()) {
492     return StartRemoteSendTensor(std::move(done));
493   }
494   StartSend();
495 
496   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
497   auto done_wrapper = [captured_state,
498                        done = std::move(done)](const Status& s) {
499     if (!s.ok() && errors::IsCancelled(s)) {
500       Status send_status = captured_state->GetSendStatus();
501       if (!send_status.ok()) {
502         // In this case, Recv is cancelled because the Send op failed.
503         // Return the status of the Send op instead.
504         done(send_status);
505       }
506     } else {
507       done(s);
508     }
509   };
510 
511   // StartRecv() takes care of doing the right thing to dst handle.
512   // No need to poison it after this point.
513   StartRecv(std::move(done_wrapper));
514 }
515 
Abort(Status status)516 void RemoteCopyNode::Abort(Status status) {
517   if (!started_) {
518     uint64 context_view_id = ctx_->GetContextViewId();
519     captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
520   }
521 }
522 
523 }  // namespace eager
524 }  // namespace tensorflow
525