1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
17
18 #include "tensorflow/core/common_runtime/device.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/dma_helper.h"
21 #include "tensorflow/core/common_runtime/process_util.h"
22 #include "tensorflow/core/distributed_runtime/request_id.h"
23 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
24 #include "tensorflow/core/distributed_runtime/worker_cache.h"
25 #include "tensorflow/core/distributed_runtime/worker_interface.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/numbers.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/notification.h"
33 #include "tensorflow/core/platform/types.h"
34
35 namespace tensorflow {
36
37 namespace {
38
39 class RpcRemoteRendezvous : public BaseRemoteRendezvous {
40 public:
RpcRemoteRendezvous(const WorkerEnv * env,int64_t step_id)41 RpcRemoteRendezvous(const WorkerEnv* env, int64_t step_id)
42 : BaseRemoteRendezvous(env, step_id) {}
43
44 protected:
45 void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
46 const Rendezvous::Args& args,
47 DoneCallback done) override;
48
49 private:
~RpcRemoteRendezvous()50 ~RpcRemoteRendezvous() override {}
51
52 TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
53 };
54
55 // Used only to retrieve tensors from remote processes.
56 class RpcRecvTensorCall : public BaseRecvTensorCall {
57 public:
RpcRecvTensorCall()58 RpcRecvTensorCall() : wi_(nullptr), dst_device_(nullptr) {}
59
Init(WorkerInterface * wi,int64_t step_id,StringPiece key,AllocatorAttributes alloc_attrs,Device * dst_device,const Rendezvous::Args & recv_args,Rendezvous::DoneCallback done)60 void Init(WorkerInterface* wi, int64_t step_id, StringPiece key,
61 AllocatorAttributes alloc_attrs, Device* dst_device,
62 const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
63 wi_ = wi;
64 alloc_attrs_ = alloc_attrs;
65 dst_device_ = dst_device;
66 recv_args_ = recv_args;
67 done_ = std::move(done);
68 req_.set_step_id(step_id);
69 req_.set_rendezvous_key(key.data(), key.size());
70 req_.set_request_id(GetUniqueRequestId());
71 }
72
Reset()73 void Reset() {
74 // The RpcRemoteRendezvous using this object is responsible for calling
75 // ReleaseWorker() before Reset().
76 DCHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
77 << "Leaking WorkerInterface in RpcRecvTensorCall::Reset().";
78
79 alloc_attrs_ = AllocatorAttributes();
80 dst_device_ = nullptr;
81 // We don't clear opts_ and assume that Init will set up the state for
82 // opts_ appropriately.
83 req_.Clear();
84 resp_.Clear();
85 {
86 mutex_lock l(mu_);
87 status_ = OkStatus();
88 }
89 done_ = nullptr;
90 }
91
~RpcRecvTensorCall()92 ~RpcRecvTensorCall() override {
93 // Since only the RpcRecvTensorFreeList will delete an
94 // RpcRecvTensorCall, we require that ReleaseWorker() has been called before
95 // the user releases a Call object to the free list.
96 CHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
97 << "Leaking WorkerInterface in RpcRecvTensorCall destructor.";
98 }
99
Start(std::function<void ()> recv_done)100 void Start(std::function<void()> recv_done) override {
101 StartRTCall(std::move(recv_done));
102 }
103
StartAbort(const Status & s)104 void StartAbort(const Status& s) override {
105 {
106 mutex_lock l(mu_);
107 status_.Update(s);
108 }
109 opts_.StartCancel();
110 }
111
status() const112 Status status() const override {
113 mutex_lock l(mu_);
114 return status_;
115 }
116
ReleaseWorker(WorkerCacheInterface * worker_cache)117 void ReleaseWorker(WorkerCacheInterface* worker_cache) {
118 DCHECK_NE(static_cast<WorkerInterface*>(nullptr), wi_)
119 << "RpcRecvTensorCall::ReleaseWorker() called twice.";
120 worker_cache->ReleaseWorker(src_worker_, wi_);
121 wi_ = nullptr;
122 }
123
tensor() const124 const Tensor& tensor() const { return resp_.tensor(); }
125
is_dead() const126 bool is_dead() const { return resp_.metadata().is_dead(); }
127
dst_device() const128 Device* dst_device() const { return dst_device_; }
recv_args() const129 const Rendezvous::Args& recv_args() const { return recv_args_; }
done() const130 const Rendezvous::DoneCallback& done() const { return done_; }
131
132 private:
133 friend class RpcRemoteRendezvous;
134
135 // Start the main RecvTensor call, checking for an async abort.
StartRTCall(std::function<void ()> recv_done)136 void StartRTCall(std::function<void()> recv_done) {
137 resp_.InitAlloc(dst_device_, alloc_attrs_);
138 auto abort_checked = std::make_shared<Notification>();
139 auto cb = [this, abort_checked,
140 recv_done = std::move(recv_done)](const Status& s) {
141 // Make sure the Rendezvous abort checking is finished before running the
142 // callback, which might destroy the current call object.
143 abort_checked->WaitForNotification();
144 if (!s.ok()) {
145 mutex_lock l(mu_);
146 status_.Update(s);
147 }
148 recv_done();
149 };
150 wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
151
152 // NOTE: Check if the rendezvous was aborted after sending out the RPC. The
153 // ordering is important because `StartAbort` could be called right before
154 // the `RecvTensorAsync` request registers its RPC cancellation to `opts_`.
155 // In that case, the previous `StartAbort` would not trigger the
156 // cancellation of this call.
157 Status s;
158 {
159 mutex_lock l(mu_);
160 s = status_;
161 }
162 if (!s.ok()) {
163 opts_.StartCancel();
164 }
165 // Notify that the abort check has finished.
166 abort_checked->Notify();
167 }
168
169 string src_worker_;
170 string src_rel_device_;
171 WorkerInterface* wi_; // Not owned.
172 AllocatorAttributes alloc_attrs_;
173 Device* dst_device_;
174 CallOptions opts_;
175 RecvTensorRequest req_;
176 TensorResponse resp_;
177 Rendezvous::Args recv_args_;
178 Rendezvous::DoneCallback done_;
179
180 mutable mutex mu_;
181 Status status_ TF_GUARDED_BY(mu_);
182
183 TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
184 };
185
186 class RpcRecvTensorFreeList {
187 public:
RpcRecvTensorFreeList()188 RpcRecvTensorFreeList() {}
~RpcRecvTensorFreeList()189 ~RpcRecvTensorFreeList() {
190 for (size_t i = 0; i < objects_.size(); i++) {
191 delete objects_[i];
192 }
193 }
194
New()195 RpcRecvTensorCall* New() {
196 {
197 mutex_lock l(mu_);
198 if (!objects_.empty()) {
199 RpcRecvTensorCall* result = objects_.back();
200 objects_.pop_back();
201 return result;
202 }
203 }
204 return new RpcRecvTensorCall;
205 }
206
Release(RpcRecvTensorCall * obj)207 void Release(RpcRecvTensorCall* obj) {
208 obj->Reset();
209 {
210 mutex_lock l(mu_);
211 if (objects_.size() < kMaxObjects) {
212 objects_.push_back(obj);
213 return;
214 }
215 }
216 delete obj;
217 }
218
219 private:
220 static constexpr int kMaxObjects = 1000;
221
222 mutex mu_;
223 std::vector<RpcRecvTensorCall*> objects_ TF_GUARDED_BY(mu_);
224 };
225
get_call_freelist()226 static RpcRecvTensorFreeList* get_call_freelist() {
227 static RpcRecvTensorFreeList* call_freelist = new RpcRecvTensorFreeList();
228 return call_freelist;
229 }
230
RecvFromRemoteAsync(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)231 void RpcRemoteRendezvous::RecvFromRemoteAsync(
232 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
233 DoneCallback done) {
234 CHECK(is_initialized());
235 Status s;
236
237 // Prepare a RecvTensor call that can handle being aborted.
238 RpcRecvTensorCall* call = get_call_freelist()->New();
239
240 // key.src_device identifies a remote device.
241 if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
242 &call->src_rel_device_)) {
243 s = errors::Internal(parsed.src_device,
244 " is invalid remote source device.");
245 }
246 WorkerSession* sess = session();
247 std::shared_ptr<WorkerCacheInterface> worker_cache =
248 sess->GetSharedWorkerCache();
249 // The worker will be released in a subsequent call to
250 // `sess->worker_cache()->ReleaseWorker()` (if the call has not yet been
251 // initialized) or `call->ReleaseWorker()` (if it has been initialized).
252 WorkerInterface* rwi = worker_cache->GetOrCreateWorker(call->src_worker_);
253 if (s.ok() && rwi == nullptr) {
254 s = errors::Internal("No worker known as ", call->src_worker_);
255 }
256
257 Device* dst_device;
258 if (s.ok()) {
259 s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
260 }
261 if (!s.ok()) {
262 if (rwi != nullptr) {
263 sess->worker_cache()->ReleaseWorker(call->src_worker_, rwi);
264 }
265 get_call_freelist()->Release(call);
266 done(s, Args(), recv_args, Tensor{}, false);
267 return;
268 }
269
270 call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
271 recv_args, std::move(done));
272
273 // Record "call" in calls_ so that it can be aborted cleanly.
274 RegisterCall(call, recv_args);
275
276 // RendezvousMgr already aborted, shouldn't send RPC call any more
277 if (!call->status().ok()) {
278 DeregisterCall(call, recv_args);
279 // NOTE: `*sess` can potentially be deleted before we return from
280 // `call->done()(...)`, so we must release the worker before calling the
281 // callback.
282 call->ReleaseWorker(sess->worker_cache());
283 call->done()(call->status(), Args(), Args(), Tensor(), false);
284 get_call_freelist()->Release(call);
285 return;
286 }
287
288 // Start "call".
289 Ref();
290 call->Start([this, call, recv_args, worker_cache]() {
291 // Removes "call" from calls_. Prevent StartAbort().
292 DeregisterCall(call, recv_args);
293 // If StartAbort was called prior to DeregisterCall, then the
294 // current status should be bad.
295 Status s = call->status();
296 // NOTE: `*session()` can potentially be deleted before we return from
297 // `call->done()(...)`, so we must release the worker before calling the
298 // callback.
299 call->ReleaseWorker(session()->worker_cache());
300 call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
301 get_call_freelist()->Release(call);
302 Unref();
303 });
304 }
305
306 } // namespace
307
RpcRendezvousMgr(const WorkerEnv * env)308 RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
309 : BaseRendezvousMgr(env) {}
310
Create(int64_t step_id,const WorkerEnv * worker_env)311 BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64_t step_id,
312 const WorkerEnv* worker_env) {
313 return new RpcRemoteRendezvous(worker_env, step_id);
314 }
315
316 } // end namespace tensorflow
317