xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
17 
18 #include "tensorflow/core/common_runtime/device.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/dma_helper.h"
21 #include "tensorflow/core/common_runtime/process_util.h"
22 #include "tensorflow/core/distributed_runtime/request_id.h"
23 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
24 #include "tensorflow/core/distributed_runtime/worker_cache.h"
25 #include "tensorflow/core/distributed_runtime/worker_interface.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/numbers.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/notification.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 
37 namespace {
38 
39 class RpcRemoteRendezvous : public BaseRemoteRendezvous {
40  public:
RpcRemoteRendezvous(const WorkerEnv * env,int64_t step_id)41   RpcRemoteRendezvous(const WorkerEnv* env, int64_t step_id)
42       : BaseRemoteRendezvous(env, step_id) {}
43 
44  protected:
45   void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
46                            const Rendezvous::Args& args,
47                            DoneCallback done) override;
48 
49  private:
~RpcRemoteRendezvous()50   ~RpcRemoteRendezvous() override {}
51 
52   TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
53 };
54 
55 // Used only to retrieve tensors from remote processes.
56 class RpcRecvTensorCall : public BaseRecvTensorCall {
57  public:
RpcRecvTensorCall()58   RpcRecvTensorCall() : wi_(nullptr), dst_device_(nullptr) {}
59 
Init(WorkerInterface * wi,int64_t step_id,StringPiece key,AllocatorAttributes alloc_attrs,Device * dst_device,const Rendezvous::Args & recv_args,Rendezvous::DoneCallback done)60   void Init(WorkerInterface* wi, int64_t step_id, StringPiece key,
61             AllocatorAttributes alloc_attrs, Device* dst_device,
62             const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
63     wi_ = wi;
64     alloc_attrs_ = alloc_attrs;
65     dst_device_ = dst_device;
66     recv_args_ = recv_args;
67     done_ = std::move(done);
68     req_.set_step_id(step_id);
69     req_.set_rendezvous_key(key.data(), key.size());
70     req_.set_request_id(GetUniqueRequestId());
71   }
72 
Reset()73   void Reset() {
74     // The RpcRemoteRendezvous using this object is responsible for calling
75     // ReleaseWorker() before Reset().
76     DCHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
77         << "Leaking WorkerInterface in RpcRecvTensorCall::Reset().";
78 
79     alloc_attrs_ = AllocatorAttributes();
80     dst_device_ = nullptr;
81     // We don't clear opts_ and assume that Init will set up the state for
82     // opts_ appropriately.
83     req_.Clear();
84     resp_.Clear();
85     {
86       mutex_lock l(mu_);
87       status_ = OkStatus();
88     }
89     done_ = nullptr;
90   }
91 
~RpcRecvTensorCall()92   ~RpcRecvTensorCall() override {
93     // Since only the RpcRecvTensorFreeList will delete an
94     // RpcRecvTensorCall, we require that ReleaseWorker() has been called before
95     // the user releases a Call object to the free list.
96     CHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
97         << "Leaking WorkerInterface in RpcRecvTensorCall destructor.";
98   }
99 
Start(std::function<void ()> recv_done)100   void Start(std::function<void()> recv_done) override {
101     StartRTCall(std::move(recv_done));
102   }
103 
StartAbort(const Status & s)104   void StartAbort(const Status& s) override {
105     {
106       mutex_lock l(mu_);
107       status_.Update(s);
108     }
109     opts_.StartCancel();
110   }
111 
status() const112   Status status() const override {
113     mutex_lock l(mu_);
114     return status_;
115   }
116 
ReleaseWorker(WorkerCacheInterface * worker_cache)117   void ReleaseWorker(WorkerCacheInterface* worker_cache) {
118     DCHECK_NE(static_cast<WorkerInterface*>(nullptr), wi_)
119         << "RpcRecvTensorCall::ReleaseWorker() called twice.";
120     worker_cache->ReleaseWorker(src_worker_, wi_);
121     wi_ = nullptr;
122   }
123 
tensor() const124   const Tensor& tensor() const { return resp_.tensor(); }
125 
is_dead() const126   bool is_dead() const { return resp_.metadata().is_dead(); }
127 
dst_device() const128   Device* dst_device() const { return dst_device_; }
recv_args() const129   const Rendezvous::Args& recv_args() const { return recv_args_; }
done() const130   const Rendezvous::DoneCallback& done() const { return done_; }
131 
132  private:
133   friend class RpcRemoteRendezvous;
134 
135   // Start the main RecvTensor call, checking for an async abort.
StartRTCall(std::function<void ()> recv_done)136   void StartRTCall(std::function<void()> recv_done) {
137     resp_.InitAlloc(dst_device_, alloc_attrs_);
138     auto abort_checked = std::make_shared<Notification>();
139     auto cb = [this, abort_checked,
140                recv_done = std::move(recv_done)](const Status& s) {
141       // Make sure the Rendezvous abort checking is finished before running the
142       // callback, which might destroy the current call object.
143       abort_checked->WaitForNotification();
144       if (!s.ok()) {
145         mutex_lock l(mu_);
146         status_.Update(s);
147       }
148       recv_done();
149     };
150     wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
151 
152     // NOTE: Check if the rendezvous was aborted after sending out the RPC. The
153     // ordering is important because `StartAbort` could be called right before
154     // the `RecvTensorAsync` request registers its RPC cancellation to `opts_`.
155     // In that case, the previous `StartAbort` would not trigger the
156     // cancellation of this call.
157     Status s;
158     {
159       mutex_lock l(mu_);
160       s = status_;
161     }
162     if (!s.ok()) {
163       opts_.StartCancel();
164     }
165     // Notify that the abort check has finished.
166     abort_checked->Notify();
167   }
168 
169   string src_worker_;
170   string src_rel_device_;
171   WorkerInterface* wi_;  // Not owned.
172   AllocatorAttributes alloc_attrs_;
173   Device* dst_device_;
174   CallOptions opts_;
175   RecvTensorRequest req_;
176   TensorResponse resp_;
177   Rendezvous::Args recv_args_;
178   Rendezvous::DoneCallback done_;
179 
180   mutable mutex mu_;
181   Status status_ TF_GUARDED_BY(mu_);
182 
183   TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
184 };
185 
186 class RpcRecvTensorFreeList {
187  public:
RpcRecvTensorFreeList()188   RpcRecvTensorFreeList() {}
~RpcRecvTensorFreeList()189   ~RpcRecvTensorFreeList() {
190     for (size_t i = 0; i < objects_.size(); i++) {
191       delete objects_[i];
192     }
193   }
194 
New()195   RpcRecvTensorCall* New() {
196     {
197       mutex_lock l(mu_);
198       if (!objects_.empty()) {
199         RpcRecvTensorCall* result = objects_.back();
200         objects_.pop_back();
201         return result;
202       }
203     }
204     return new RpcRecvTensorCall;
205   }
206 
Release(RpcRecvTensorCall * obj)207   void Release(RpcRecvTensorCall* obj) {
208     obj->Reset();
209     {
210       mutex_lock l(mu_);
211       if (objects_.size() < kMaxObjects) {
212         objects_.push_back(obj);
213         return;
214       }
215     }
216     delete obj;
217   }
218 
219  private:
220   static constexpr int kMaxObjects = 1000;
221 
222   mutex mu_;
223   std::vector<RpcRecvTensorCall*> objects_ TF_GUARDED_BY(mu_);
224 };
225 
get_call_freelist()226 static RpcRecvTensorFreeList* get_call_freelist() {
227   static RpcRecvTensorFreeList* call_freelist = new RpcRecvTensorFreeList();
228   return call_freelist;
229 }
230 
RecvFromRemoteAsync(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)231 void RpcRemoteRendezvous::RecvFromRemoteAsync(
232     const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
233     DoneCallback done) {
234   CHECK(is_initialized());
235   Status s;
236 
237   // Prepare a RecvTensor call that can handle being aborted.
238   RpcRecvTensorCall* call = get_call_freelist()->New();
239 
240   // key.src_device identifies a remote device.
241   if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
242                                         &call->src_rel_device_)) {
243     s = errors::Internal(parsed.src_device,
244                          " is invalid remote source device.");
245   }
246   WorkerSession* sess = session();
247   std::shared_ptr<WorkerCacheInterface> worker_cache =
248       sess->GetSharedWorkerCache();
249   // The worker will be released in a subsequent call to
250   // `sess->worker_cache()->ReleaseWorker()` (if the call has not yet been
251   // initialized) or `call->ReleaseWorker()` (if it has been initialized).
252   WorkerInterface* rwi = worker_cache->GetOrCreateWorker(call->src_worker_);
253   if (s.ok() && rwi == nullptr) {
254     s = errors::Internal("No worker known as ", call->src_worker_);
255   }
256 
257   Device* dst_device;
258   if (s.ok()) {
259     s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
260   }
261   if (!s.ok()) {
262     if (rwi != nullptr) {
263       sess->worker_cache()->ReleaseWorker(call->src_worker_, rwi);
264     }
265     get_call_freelist()->Release(call);
266     done(s, Args(), recv_args, Tensor{}, false);
267     return;
268   }
269 
270   call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
271              recv_args, std::move(done));
272 
273   // Record "call" in calls_ so that it can be aborted cleanly.
274   RegisterCall(call, recv_args);
275 
276   // RendezvousMgr already aborted, shouldn't send RPC call any more
277   if (!call->status().ok()) {
278     DeregisterCall(call, recv_args);
279     // NOTE: `*sess` can potentially be deleted before we return from
280     // `call->done()(...)`, so we must release the worker before calling the
281     // callback.
282     call->ReleaseWorker(sess->worker_cache());
283     call->done()(call->status(), Args(), Args(), Tensor(), false);
284     get_call_freelist()->Release(call);
285     return;
286   }
287 
288   // Start "call".
289   Ref();
290   call->Start([this, call, recv_args, worker_cache]() {
291     // Removes "call" from calls_. Prevent StartAbort().
292     DeregisterCall(call, recv_args);
293     // If StartAbort was called prior to DeregisterCall, then the
294     // current status should be bad.
295     Status s = call->status();
296     // NOTE: `*session()` can potentially be deleted before we return from
297     // `call->done()(...)`, so we must release the worker before calling the
298     // callback.
299     call->ReleaseWorker(session()->worker_cache());
300     call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
301     get_call_freelist()->Release(call);
302     Unref();
303   });
304 }
305 
306 }  // namespace
307 
RpcRendezvousMgr(const WorkerEnv * env)308 RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
309     : BaseRendezvousMgr(env) {}
310 
Create(int64_t step_id,const WorkerEnv * worker_env)311 BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64_t step_id,
312                                                const WorkerEnv* worker_env) {
313   return new RpcRemoteRendezvous(worker_env, step_id);
314 }
315 
316 }  // end namespace tensorflow
317