1 #include <torch/csrc/distributed/rpc/rref_context.h>
2 #include <torch/csrc/distributed/rpc/rref_proto.h>
3 #include <torch/csrc/distributed/rpc/utils.h>
4
5 #include <sstream>
6
7 namespace torch::distributed::rpc {
8
9 thread_local std::vector<std::shared_ptr<RRefContext::PendingUserState>>
10 RRefContext::userTable_;
11 thread_local bool RRefContext::recording_ = false;
12
13 namespace callback {
confirmPendingUser(const JitFuture & jitFuture,const ForkId & expectedForkId)14 void confirmPendingUser(
15 const JitFuture& jitFuture,
16 const ForkId& expectedForkId) {
17 if (!jitFuture.hasError()) {
18 auto msgPtr = jitFuture.constValue().toCustomClass<Message>();
19 auto msgType = msgPtr->type();
20 auto rpc = deserializeResponse(*msgPtr, msgType);
21 auto& rr = dynamic_cast<RemoteRet&>(*rpc);
22 TORCH_INTERNAL_ASSERT(rr.forkId() == expectedForkId);
23 } else {
24 // Handle errors, such as timeouts, by invoking the error handler on the
25 // rref.
26 // Note [Best Effort Error handling for Remote calls]:
27 // When remote calls initiated by rpc.remote() fail, such as with a timeout
28 // error, we take a best-effort approach to error handling. We handle errors
29 // when callbacks corresponding to the remote call run, and set the error
30 // information on the RRef. If the RRef has not been used by the application
31 // before this process (such as to_here or fork call), then future uses of
32 // the RRef will appropriately raise errors. However, it is possible that
33 // the user application will use the RRef before the errors are handled. In
34 // this case, errors may not be raised as they have not yet been handled.
35 auto rref_ptr = RRefContext::getInstance().getPendingUser(expectedForkId);
36 auto errorType = getRPCErrorType(jitFuture);
37 rref_ptr->handleError(errorType, jitFuture);
38 }
39 RRefContext::getInstance().delPendingUser(expectedForkId);
40 }
41
finishCreatingOwnerRRef(const JitFuture & jitFuture,const RRefId & rrefId)42 c10::intrusive_ptr<RRef> finishCreatingOwnerRRef(
43 const JitFuture& jitFuture,
44 const RRefId& rrefId) {
45 if (jitFuture.hasError()) {
46 auto& ctx = RRefContext::getInstance();
47 // We expect to run this callback only after the OwnerRRef has been created,
48 // since this is only invoked when sending to self.
49 auto rref_ptr =
50 fromRRefInterface(ctx.getOwnerRRef(rrefId, /* foreCreated */ true)
51 ->constValue()
52 .toRRef());
53 auto errorType = getRPCErrorType(jitFuture);
54 rref_ptr->handleError(errorType, jitFuture);
55 // OwnerRRefs do not have a forkId, so don't need to assert here.
56 auto deletedRRef =
57 ctx.delForkOfOwner(rref_ptr->rrefId(), rref_ptr->rrefId());
58 return deletedRRef;
59 } else {
60 auto msgPtr = jitFuture.constValue().toCustomClass<Message>();
61 auto msgType = msgPtr->type();
62 auto rpc = deserializeResponse(*msgPtr, msgType);
63 auto& rr = dynamic_cast<RemoteRet&>(*rpc);
64 TORCH_INTERNAL_ASSERT(
65 rr.rrefId() == rr.forkId(),
66 "Expecting an OwnerRRef as RemoteRet but got a fork.");
67 auto& ctx = RRefContext::getInstance();
68 auto deletedRRef = ctx.delForkOfOwner(rr.rrefId(), rr.rrefId());
69 return deletedRRef;
70 }
71 }
72
73 } // namespace callback
74
75 // Keys for RRef-related debug information.
76 const std::string kNumOwnerRRefs = "num_owner_rrefs";
77 const std::string kNumPendingFutures = "num_pending_futures";
78 const std::string kNumPendingUsers = "num_pending_users";
79 const std::string kNumForks = "num_forks";
80
getInstance()81 RRefContext& RRefContext::getInstance() {
82 // Leaky singleton to avoid module destructor races.
83 static RRefContext* context = new RRefContext(RpcAgent::getCurrentRpcAgent());
84 return *context;
85 }
86
destroyInstance(bool ignoreRRefLeak)87 std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance(
88 bool ignoreRRefLeak) {
89 auto& ctx = RRefContext::getInstance();
90 {
91 std::lock_guard<std::mutex> lock(ctx.destroyedMutex_);
92 ctx.destroyed_ = true;
93 }
94 ctx.checkRRefLeaks(ignoreRRefLeak);
95 std::vector<c10::intrusive_ptr<RRef>> deletedRRefs;
96 for (auto& entry : ctx.owners_) {
97 auto rref = entry.second;
98 if (rref->isPyObj()) {
99 deletedRRefs.emplace_back(std::move(rref));
100 }
101 }
102 ctx.owners_.clear();
103 ctx.pendingOwners_.clear();
104 return deletedRRefs;
105 }
106
handleException(const JitFuture & jitFuture)107 void RRefContext::handleException(const JitFuture& jitFuture) {
108 if (jitFuture.hasError()) {
109 auto errMsg = jitFuture.tryRetrieveErrorMessage();
110 VLOG(1) << "Got exception: " << errMsg;
111 TORCH_CHECK(false, errMsg);
112 }
113 }
114
handleExceptionSilent(const JitFuture & jitFuture)115 void RRefContext::handleExceptionSilent(const JitFuture& jitFuture) {
116 if (jitFuture.hasError()) {
117 auto errMsg = jitFuture.tryRetrieveErrorMessage();
118 VLOG(1) << "Got exception: " << errMsg;
119 TORCH_CHECK_MSG(false, errMsg);
120 }
121 }
122
RRefContext(std::shared_ptr<RpcAgent> agent)123 RRefContext::RRefContext(std::shared_ptr<RpcAgent> agent)
124 : agent_(std::move(agent)) {}
125
~RRefContext()126 RRefContext::~RRefContext() {
127 if (!owners_.empty()) {
128 VLOG(1) << "Destructing RRefContext with non-empty OwnerRRef set. "
129 << "This would likely cause Python deref error. "
130 << "Make sure destroyInstance() is invoked before destruction.";
131 }
132 }
133
getDebugInfo()134 std::unordered_map<std::string, std::string> RRefContext::getDebugInfo() {
135 std::unordered_map<std::string, std::string> info;
136 std::unique_lock<std::mutex> lock(mutex_);
137 auto ownerSize = owners_.size();
138 auto numPendingUsers = pendingUsers_.size();
139 int numForks = 0;
140 for (const auto& owner : forks_) {
141 numForks += owner.second.size();
142 }
143 lock.unlock();
144 info[kNumOwnerRRefs] = std::to_string(ownerSize);
145 info[kNumPendingFutures] = std::to_string(numPendingFutures_.load());
146 info[kNumPendingUsers] = std::to_string(numPendingUsers);
147 info[kNumForks] = std::to_string(numForks);
148 return info;
149 }
150
checkRRefLeaks(bool ignoreRRefLeak)151 void RRefContext::checkRRefLeaks(bool ignoreRRefLeak) {
152 if (!forks_.empty()) {
153 std::stringstream ss;
154 for (auto& entry : forks_) {
155 const RRefId& rrefId = entry.first;
156 for (const auto& forkId : entry.second) {
157 ss << "Leaking RRef " << rrefId << " with fork Id " << forkId << '\n';
158 }
159 }
160
161 LOG(WARNING)
162 << "Detected RRef Leaks during shutdown. This usually "
163 << "occurs when the application code still holds references to RRef "
164 << "instances when calling shutdown(). If the program has "
165 << "completed correctly and the process is exiting, it is OK to "
166 << "ignore these leaks. However, if you program will keep running "
167 << "after this, these leaks could result in memory leaks on RRef "
168 << "owners. Please make sure all RRefs are out of scope and Python "
169 << "GC has deleted them before calling shutdown(): \n"
170 << ss.str();
171 if (!ignoreRRefLeak) {
172 TORCH_CHECK(false, ss.str());
173 }
174 }
175 }
176
createUserRRef(worker_id_t ownerId,const TypePtr & type)177 c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef(
178 worker_id_t ownerId,
179 const TypePtr& type) {
180 TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner.");
181 // Explicitly creating rrefId before forkId to make sure the order is
182 // deterministic, as the argument evaluation order is system and compiler
183 // dependent.
184 const auto rrefId = genGloballyUniqueId();
185 const auto forkId = genGloballyUniqueId();
186 return createUserRRef(ownerId, rrefId, forkId, type);
187 }
188
createUserRRef(worker_id_t ownerId,const RRefId & rrefId,const ForkId & forkId,const TypePtr & type)189 c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef(
190 worker_id_t ownerId,
191 const RRefId& rrefId,
192 const ForkId& forkId,
193 const TypePtr& type) {
194 TORCH_CHECK(ownerId != getWorkerId(), "RRef owner cannot create user RRef.");
195 // RRefContext does not track user RRefs, it will be destructed when there
196 // is no shared_ptrs pointing to it.
197 //
198 // NB: cannot use make_shared here as the constructor of UserRRef is private.
199 // NB: This UserRRef has not been confirmed by the owner yet. This function's
200 // call site is responsible for adding this UserRRef to pendingUsers_.
201 // Currently, there are two call sites.
202 // (1) The creator user in python_functions.cpp
203 // (2) The callee user in RRefContext::notifyOwnerAndParentOfFork.
204 //
205 // The reason for not adding the pending user here is to put addPendingUser()
206 // close to where the RPC occurs, and it is more clear to pair it with
207 // deletePendingUser() in the response callback at the call site.
208 return c10::make_intrusive<UserRRef>(ownerId, rrefId, forkId, type);
209 }
210
delUser(const worker_id_t owner,const RRefId & rrefId,const ForkId & forkId)211 void RRefContext::delUser(
212 const worker_id_t owner,
213 const RRefId& rrefId,
214 const ForkId& forkId) {
215 {
216 std::lock_guard<std::mutex> lock(destroyedMutex_);
217 if (!destroyed_) {
218 // Sending an RRefUserDelete causes the receiver to run delForkOfOwner,
219 // which is now idempotent. See the comment at RRefContext::delForkOfOwner
220 // for more details.
221 ++numPendingFutures_;
222 auto jitFuture = agent_->sendWithRetries(
223 agent_->getWorkerInfo(owner),
224 RRefUserDelete(rrefId, forkId).toMessage());
225
226 jitFuture->addCallback([this](JitFuture& future) {
227 handleExceptionSilent(future);
228 --numPendingFutures_;
229 });
230 }
231 }
232
233 std::lock_guard<std::mutex> lock(mutex_);
234 confirmedUsers_.erase(forkId);
235 }
236
delAllUsersAndUnforkedOwners(std::chrono::milliseconds timeoutMillis)237 void RRefContext::delAllUsersAndUnforkedOwners(
238 std::chrono::milliseconds timeoutMillis) {
239 // First, wait for all pending UserRRefs to be confirmed,
240 // one kind is pendingUsers_, which are shared from Owner,
241 // the other kind pendingChildren_, which are shared from another User.
242 std::unordered_map<ForkId, c10::weak_intrusive_ptr<RRef>, ForkId::Hash>
243 tempConfirmedUsers;
244 {
245 std::unique_lock<std::mutex> lock(mutex_);
246 bool noPending = deleteAllUsersCV_.wait_for(lock, timeoutMillis, [this]() {
247 return pendingUsers_.empty() && pendingChildren_.empty();
248 });
249 if (!noPending) {
250 LOG(ERROR)
251 << "Timed out waiting for pending UserRRefs to be confirmed by owner and parent.";
252 }
253 tempConfirmedUsers.swap(confirmedUsers_);
254 }
255
256 // Start sending UserRRef delete messages, after all pendings are confirmed.
257 // Note, there should be no new forkings in between, because it's assumed that
258 // this utility is called during graceful shutdown, where no new user RPCs can
259 // be initiaited anymore.
260 for (const auto& user : tempConfirmedUsers) {
261 c10::intrusive_ptr<RRef> rref_ptr = user.second.lock();
262 if (!rref_ptr) {
263 continue;
264 }
265 // tryDel() below will re-acquire lock, lock must be released here.
266 rref_ptr->tryDel();
267 }
268
269 // If an rref in the owners_ map has never been forked, we will never get a
270 // corresponding message from the forking node(s) telling us to delete the
271 // RRef. Hence we delete the RRef here. This can occur when a remote call is
272 // sent to self and times out.
273 {
274 std::unique_lock<std::mutex> lock(mutex_);
275 std::vector<RRefId> unforkedOwners;
276 for (const auto& it : owners_) {
277 auto rrefId = it.first;
278 if (forks_.find(rrefId) == forks_.end()) {
279 // Successful fork of owner was never processed.
280 unforkedOwners.push_back(rrefId);
281 }
282 }
283 for (auto& rrefId : unforkedOwners) {
284 LOG(INFO) << "Removing unforked OwnerRRef with RRefId: " << rrefId;
285 auto iter = owners_.find(rrefId);
286 TORCH_CHECK(
287 iter != owners_.end(),
288 c10::str("Did not find OwnerRRef with RRefId: ", rrefId));
289 owners_.erase(iter);
290 }
291 }
292 // Wait for this node to process all delete UserRRef messages it may get for
293 // the OwnerRRefs that exist on this node.
294 {
295 std::unique_lock<std::mutex> lock(mutex_);
296 bool noOwner = deleteAllUsersCV_.wait_for(
297 lock, timeoutMillis, [this]() { return owners_.empty(); });
298 if (!noOwner) {
299 LOG(ERROR) << "Timed out waiting for pending OwnerRRefs to be deleted.";
300 }
301 }
302 }
303
getOrCreateRRef(const RRefForkData & rrefForkData,const TypePtr & type)304 c10::intrusive_ptr<RRef> RRefContext::getOrCreateRRef(
305 const RRefForkData& rrefForkData,
306 const TypePtr& type) {
307 auto& ownerId = rrefForkData.ownerId_;
308 auto& rrefId = rrefForkData.rrefId_;
309 auto& forkId = rrefForkData.forkId_;
310 if (ownerId == getWorkerId()) {
311 return getOrCreateOwnerRRef(rrefId, type);
312 } else {
313 return createUserRRef(ownerId, rrefId, forkId, type);
314 }
315 }
316
getOrCreateOwnerRRef(const RRefId & rrefId,const TypePtr & type)317 c10::intrusive_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
318 const RRefId& rrefId,
319 const TypePtr& type) {
320 std::lock_guard<std::mutex> lock(mutex_);
321 const auto iter = owners_.find(rrefId);
322 if (iter == owners_.end()) {
323 // Scenario (1) the first time this owner knows about this RRef
324 //
325 // NB: cannot use make_shared here as the constructor of OwnerRRef is
326 // private.
327 auto rref = c10::make_intrusive<OwnerRRef>(
328 getWorkerId(), rrefId, type, agent_->getDevices());
329 owners_[rref->rrefId()] = rref;
330 const auto pendingOwnerIter = pendingOwners_.find(rrefId);
331 if (pendingOwnerIter != pendingOwners_.end()) {
332 // cast to RRefInterface to hold it into IValue
333 auto rrefPtr = fromOwnerRRef(rref);
334 pendingOwnerIter->second->markCompleted(IValue(rrefPtr));
335 pendingOwners_.erase(pendingOwnerIter);
336 }
337 return rref;
338 } else {
339 // Scenario (2) retrieving an existing RRef
340 auto ownerRRef = fromRRefInterface(iter->second);
341 // Now double check if the two types match
342 //
343 // Why we are special casing the check for tensor type here?
344 // this is because tensor types might get specialized on tensors when
345 // we pass inputs to the function, i.e. TensorType can filled with
346 // specific shape info, requires_grad info, etc. so the OwerRRef we
347 // found might already have those infos, but the `type` we passed in
348 // here is a plain TensorType, they are not equal relationship:
349 // specialized TensorType <: plain TensorType
350 //
351 // In RPC we don't care the difference as we ser/de with just the
352 // plain TensorType. This is not a issue for UserRRef creation either,
353 // since Tensor can only get specialized with a previous run of local
354 // JIT function, and we shouldn't preserve the specialized SubTensorType
355 // information on other workers because it's only information only.
356 if (type->isSubtypeOf(*TensorType::get())) {
357 TORCH_INTERNAL_ASSERT(
358 ownerRRef->type()->isSubtypeOf(*TensorType::get()),
359 "Expect OwnerRRef to be a sub-type of TensorType, but got ",
360 ownerRRef->type()->repr_str());
361 } else {
362 TORCH_INTERNAL_ASSERT(
363 *ownerRRef->type() == *type,
364 "OwnerRRef type is ",
365 ownerRRef->type()->repr_str(),
366 ", expected type is ",
367 type->repr_str());
368 }
369 return ownerRRef;
370 }
371 }
372
createOwnerRRef(const TypePtr & type)373 c10::intrusive_ptr<OwnerRRef> RRefContext::createOwnerRRef(
374 const TypePtr& type) {
375 // Don't add this OnwerRRef to the owners_ map yet, otherwise
376 // it will never be removed from there. Instead, only add it to the
377 // map in prepareChildFork, in case this local RRef is being passed
378 // to another worker.
379 return c10::make_intrusive<OwnerRRef>(
380 getWorkerId(), genGloballyUniqueId(), type, agent_->getDevices());
381 }
382
getOwnerRRef(const RRefId & rrefId,bool forceCreated)383 c10::intrusive_ptr<JitFuture> RRefContext::getOwnerRRef(
384 const RRefId& rrefId,
385 bool forceCreated) {
386 std::unique_lock<std::mutex> lock(mutex_);
387 const auto iter = owners_.find(rrefId);
388 if (iter == owners_.end()) {
389 if (forceCreated) {
390 TORCH_INTERNAL_ASSERT(
391 false,
392 c10::str("Expected OwnerRRef with id ", rrefId, " to be created."));
393 }
394 // Scenario (1) RRef is used before it is created
395 const auto pendingOwnerIter = pendingOwners_.find(rrefId);
396 if (pendingOwnerIter == pendingOwners_.end()) {
397 // Note: The type passed into RRefType::create() does not matter here, as
398 // the future is marked as completed with the RRef of the specific type
399 // in getOrCreateOwnerRRef().
400 // We need to set devices here, even if they won't be used by the value
401 // (an RRef object doesn't contain any tensors, it just provides means to
402 // retrieve them) because we need them to be propagated/ to child futures.
403 // This is silly and we should find a way to avoid this.
404 auto futureOwner = c10::make_intrusive<JitFuture>(
405 RRefType::create(c10::AnyType::get()), agent_->getDevices());
406 pendingOwners_[rrefId] = futureOwner;
407 return futureOwner;
408 } else {
409 return pendingOwnerIter->second;
410 }
411 } else {
412 // Scenario (2) retrieving an existing RRef
413 // Marks IValue Future as completed with the RRef IValue.
414 auto owner = iter->second;
415 auto rrefPtr = fromOwnerRRef(owner);
416
417 // We need to set devices here, even if they won't be used by the value (an
418 // RRef object doesn't contain any tensors, it just provides means to
419 // retrieve them) because we need them to be propagated/ to child futures.
420 // This is silly and we should find a way to avoid this.
421 auto futureOwner = c10::make_intrusive<JitFuture>(
422 RRefType::create(owner->type()), agent_->getDevices());
423 futureOwner->markCompleted(IValue(rrefPtr));
424 return futureOwner;
425 }
426 }
427
prepareChildFork(const c10::intrusive_ptr<RRef> & rref)428 RRefForkData RRefContext::prepareChildFork(
429 const c10::intrusive_ptr<RRef>& rref) {
430 // If we know that rref creation on the owner has timed out, raise it to the
431 // user here, otherwise continue with pickling.
432
433 TORCH_CHECK(
434 !rref->getTimedOut(),
435 "RRef creation via rpc.remote() timed out, and it "
436 "is possible that the RRef on the owner node does not exist.");
437 auto rrefForkData = rref->fork();
438 if (rref->isOwner()) {
439 // Note [Early Fork Registration]
440 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
441 // If the parent (caller) is the owner, directly register the fork, instead
442 // of waiting for another RREF_FORK_REQUEST or RREF_CHILD_ACCEPT message. An
443 // Alternative is adding the fork when the callee user ACKs. However, before
444 // that, the owner still have to adds the OwnerRRef into some map to keep it
445 // alive (e.g., in pendingChildren_). Hence, adding the fork here or in the
446 // ACK does not making any difference but only add complexity.
447 // TODO: When adding failure retries and timeout, this fork needs to be
448 // deleted if the owner does not receive the ACK within the timeout.
449 addForkOfOwner(rrefForkData.rrefId_, rrefForkData.forkId_);
450 // ensure that this RRef is in the owners_ list to keep it alive.
451 // this is needed for OwnerRRefs that were created locally.
452 {
453 std::lock_guard<std::mutex> lock(mutex_);
454 owners_[rref->rrefId()] = rref;
455 }
456 } else {
457 // Note [Useful Phantom Fork ID for User to Owner Call]
458 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
459 // If the callee of dist.remote or dist.rpc is the owner of this RRef, the
460 // callee will not create a fork using this rrefForkData.forkId_, because
461 // the owner will only keep one `OwnerRRef` instance and will not create any
462 // `UserRRef` instances. However, this rrefForkData.forkId_ is still
463 // necessary, as the caller user needs to keep this `UserRRef` alive until
464 // it gets the ACK from the callee owner. Otherwise, the delete message
465 // could arrive at the owner before this dist.rpc or dist.remote call, which
466 // could potentially trigger the `OwnerRRef` to be deleted before running
467 // the user code.
468 addPendingChild(rrefForkData.forkId_, rref);
469 }
470 return rrefForkData;
471 }
472
notifyOwnerAndParentOfFork(const ForkId & forkId,worker_id_t parent,const c10::intrusive_ptr<RRef> & rref)473 void RRefContext::notifyOwnerAndParentOfFork(
474 const ForkId& forkId,
475 worker_id_t parent,
476 const c10::intrusive_ptr<RRef>& rref) {
477 // Fork is shared from owner.
478 if (parent == rref->owner()) {
479 if (parent == agent_->getWorkerInfo().id_) {
480 // Owner sending RRef to self, remove the forkId as it was added during
481 // pickling
482 auto deletedRRef = delForkOfOwner(rref->rrefId(), forkId);
483 if (deletedRRef) {
484 TORCH_INTERNAL_ASSERT(
485 deletedRRef->rrefId() == rref->rrefId(),
486 "Deleting a fork of ",
487 rref->rrefId(),
488 " triggered deleting the OwnerRRef of ",
489 deletedRRef->rrefId());
490 // NB: not necessary to reset deletedRRef as rref is another shared_ptr
491 // instance pointing to the same OwnerRRef.
492 }
493 } else {
494 // If the parent is the owner, this fork has already been added into the
495 // forks_ map when the owner sends the message to the callee user.
496 // Hence, it is not necessary to send another RREF_CHILD_ACCEPT or
497 // RREF_FORK_REQUEST back to the owner. See Note [Early Fork
498 // Registration].
499 std::lock_guard<std::mutex> lock(mutex_);
500 addConfirmedUser(forkId, rref);
501 }
502 return;
503 }
504
505 // Fork is shared from user.
506 if (rref->isOwner()) {
507 // See Note [Useful Phantom Fork ID for User to Owner Call]
508 // In this case, the owner is the caller, and it does not add the fork id
509 // into forks_. Because, there will be no real `UserRRef` associated
510 // with this fork ID.
511 ++numPendingFutures_;
512 auto jitFuture = agent_->sendWithRetries(
513 agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
514 jitFuture->addCallback([this](JitFuture& future) {
515 handleExceptionSilent(future);
516 --numPendingFutures_;
517 });
518 } else {
519 ++numPendingFutures_;
520 auto jitFuture = agent_->sendWithRetries(
521 agent_->getWorkerInfo(rref->owner()),
522 RRefForkRequest(rref->rrefId(), forkId).toMessage());
523
524 addPendingUser(forkId, rref);
525
526 jitFuture->addCallback([this, forkId, parent](JitFuture& future) {
527 handleException(future);
528 this->finishForkRequest(forkId, parent);
529 // Decrease after calling finishForkRequest because, as that creates a new
530 // future, it might otherwise cause the count to briefly go to zero.
531 --numPendingFutures_;
532 });
533 }
534 }
535
addPendingChild(const ForkId & forkId,const c10::intrusive_ptr<RRef> & rref)536 void RRefContext::addPendingChild(
537 const ForkId& forkId,
538 const c10::intrusive_ptr<RRef>& rref) {
539 // see Note [Early Fork Registration]
540 // If the parent is the owner, it should directly add the child UserRRef as a
541 // fork.
542 TORCH_INTERNAL_ASSERT(
543 !rref->isOwner(), "OwnerRRef should not have a pending child.");
544 std::lock_guard<std::mutex> lock(mutex_);
545 TORCH_INTERNAL_ASSERT(
546 pendingChildren_.find(forkId) == pendingChildren_.end(),
547 "Inconsistent states: attempt to add the same child fork twice.");
548 pendingChildren_[forkId] = rref;
549 }
550
delPendingChild(const ForkId & forkId)551 void RRefContext::delPendingChild(const ForkId& forkId) {
552 c10::intrusive_ptr<RRef> deletedUser;
553 {
554 std::lock_guard<std::mutex> lock(mutex_);
555 auto iter = pendingChildren_.find(forkId);
556 // We first check whether the child exists in pendingChildren_. It's
557 // possible the child may have been removed by a previous send attempt, and
558 // this check (as opposed to an assertion here) ensures that messages that
559 // trigger this function are idempotent.
560 if (iter != pendingChildren_.end()) {
561 // Since this UserRRef is removed from the map,
562 // the refcount of this UserRRef could reach to 0,
563 // so the "destructor", `release_resources()`, might be called,
564 // in which the lock is acquired again.
565 // So it must be destructed with the lock released.
566 // Meet this constraint by creating a temporary pointer to increase the
567 // refcount, extending its lifetime until lock released.
568 deletedUser = iter->second; // Increase refcount.
569 pendingChildren_.erase(iter); // Decrease refcount.
570 } else {
571 LOG(INFO) << "Ignoring duplicate request to delete child UserRRef with "
572 << "ForkId = " << forkId;
573 }
574 }
575 deleteAllUsersCV_.notify_all();
576 // The refcount of this UserRRef could reach to 0,
577 // so the "destructor", release_resources(), might be called,
578 // in which the lock is acquired again,
579 // so must destruct it with the lock released.
580 deletedUser.reset(); // Decrease refcount.
581 }
582
addPendingUser(const ForkId & forkId,const c10::intrusive_ptr<RRef> & rref)583 void RRefContext::addPendingUser(
584 const ForkId& forkId,
585 const c10::intrusive_ptr<RRef>& rref) {
586 TORCH_INTERNAL_ASSERT(
587 !rref->isOwner(), "Attempt to add an OwnerRRef as a pending User.");
588
589 auto state = std::make_shared<PendingUserState>(rref);
590 if (recording_) {
591 // adding and waiting for pending users are guaranteed to be called from the
592 // same thread, but deleting pending users will be called from another
593 // thread. As the delPendingUser will not be able to access the same
594 // thread_local variable, we cannot address this problem by making
595 // pendingUsers_ thread_local. Instead, pendingUsers_ and userTable_ share
596 // the same PendingUserState shared_ptr.
597 userTable_.push_back(state);
598 }
599
600 std::lock_guard<std::mutex> lock(mutex_);
601 TORCH_INTERNAL_ASSERT(
602 pendingUsers_.find(forkId) == pendingUsers_.end(),
603 "Inconsistent states: attempt to add the same UserRRef twice.");
604
605 pendingUsers_.emplace(
606 std::piecewise_construct,
607 std::forward_as_tuple(forkId),
608 std::forward_as_tuple(state));
609 }
610
delPendingUser(const ForkId & forkId)611 void RRefContext::delPendingUser(const ForkId& forkId) {
612 std::shared_ptr<PendingUserState> deletedState = nullptr;
613 {
614 std::lock_guard<std::mutex> lock(mutex_);
615 auto iter = pendingUsers_.find(forkId);
616 TORCH_INTERNAL_ASSERT(
617 iter != pendingUsers_.end(),
618 "Inconsistent states: attempt to delete a non-exist UserRRef.");
619
620 // There are two reasons for keeping the deleted PendingUserState alive
621 // until exiting the critical section.
622 // (1) Since this UserRRef is removed from the map, the refcount of this
623 // UserRRef could reach to 0. So the resource destructor
624 // (`release_resources()`) might be called, in which the lock is
625 // acquired again. Hence, it must be destructed with the lock released.
626 // To meet this constraint, we intentionally create a temporary pointer
627 // to increase the refcount of the deleted PendingUserState, extending
628 // its lifetime until lock released.
629 // (2) Since #34497, a user function only runs after all RRefs in the
630 // arguments are confirmed by their owners, which is done by adding the
631 // RPC processing logic as a callback to the UserRRef ready future. So,
632 // calling `confirm` on the PendingUserState could trigger pending user
633 // functions, which might in turn acquire the lock in RRefContext.
634 // Hence, we must release the lock to prevent deadlock.
635 // NB: Another option is to use reentrant lock. However, it is better for
636 // the developers to fully understand the locking behavior instead of
637 // hiding the subtle logic using a reentrant lock.
638 deletedState = iter->second; // Increase refcount
639
640 addConfirmedUser(forkId, iter->second->rref_);
641 pendingUsers_.erase(iter); // Decrease refcount.
642 }
643 deletedState->confirm();
644 deleteAllUsersCV_.notify_all();
645 deletedState.reset(); // Decrease refcount.
646 }
647
addConfirmedUser(const ForkId & forkId,const c10::intrusive_ptr<RRef> & rref)648 void RRefContext::addConfirmedUser(
649 const ForkId& forkId,
650 const c10::intrusive_ptr<RRef>& rref) {
651 // Notice, caller need to hold the mutex for confirmedUsers_.
652 // std::lock_guard<std::mutex> lock(mutex_);
653 confirmedUsers_.emplace(
654 std::piecewise_construct,
655 std::forward_as_tuple(forkId),
656 std::forward_as_tuple(rref));
657 }
658
getPendingUser(const ForkId & forkId)659 c10::intrusive_ptr<RRef> RRefContext::getPendingUser(const ForkId& forkId) {
660 std::lock_guard<std::mutex> lock(mutex_);
661 auto it = pendingUsers_.find(forkId);
662 if (it == pendingUsers_.end()) {
663 TORCH_INTERNAL_ASSERT(
664 false, "Pending user with forkId ", forkId, " not found");
665 }
666 return it->second->rref_;
667 }
668
recordThreadLocalPendingRRefs()669 void RRefContext::recordThreadLocalPendingRRefs() {
670 TORCH_INTERNAL_ASSERT(
671 userTable_.empty(),
672 "User RRef Table should be empty when start recording");
673 recording_ = true;
674 }
675
waitForThreadLocalPendingRRefs()676 c10::intrusive_ptr<JitFuture> RRefContext::waitForThreadLocalPendingRRefs() {
677 // We need to set devices here, even if they won't be used by the value (it's
678 // a bool, it doesn't contain tensors!) because we need them to be propagated
679 // to child futures. This is silly and we should find a way to avoid this.
680 auto jitFuturePtr =
681 c10::make_intrusive<JitFuture>(BoolType::get(), agent_->getDevices());
682 if (userTable_.empty()) {
683 jitFuturePtr->markCompleted(true);
684 } else {
685 auto remainingRRefs =
686 std::make_shared<std::atomic<uint64_t>>(userTable_.size());
687 for (auto& state : userTable_) {
688 state->confirmationFuture_->addCallback(
689 [jitFuturePtr, remainingRRefs](JitFuture& /* unused */) {
690 auto localCount = remainingRRefs->fetch_sub(1);
691 if (localCount == 1) {
692 jitFuturePtr->markCompleted(true);
693 }
694 });
695 }
696 userTable_.clear();
697 }
698 recording_ = false;
699 return jitFuturePtr;
700 }
701
clearRecordedPendingRRefsOnError()702 void RRefContext::clearRecordedPendingRRefsOnError() {
703 userTable_.clear();
704 recording_ = false;
705 }
706
finishForkRequest(const ForkId & forkId,worker_id_t parent)707 void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
708 delPendingUser(forkId);
709 ++numPendingFutures_;
710 auto jitFuture = agent_->sendWithRetries(
711 agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
712
713 jitFuture->addCallback([this](JitFuture& future) {
714 handleExceptionSilent(future);
715 --numPendingFutures_;
716 });
717 }
718
addSelfAsFork(c10::intrusive_ptr<OwnerRRef> & rref)719 void RRefContext::addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref) {
720 std::lock_guard<std::mutex> lock(mutex_);
721 const auto& rrefId = rref->rrefId();
722 owners_[rrefId] = rref;
723 auto& rrefForks = forks_[rrefId];
724 TORCH_INTERNAL_ASSERT(
725 rrefForks.find(rrefId) == rrefForks.end(),
726 "Attempt to add self as fork twice ",
727 rrefId);
728 rrefForks.insert(rrefId);
729 }
730
addForkOfOwner(const RRefId & rrefId,const ForkId & forkId)731 void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) {
732 std::lock_guard<std::mutex> lock(mutex_);
733 auto& rrefForks = forks_[rrefId];
734 TORCH_INTERNAL_ASSERT(
735 rrefForks.find(forkId) == rrefForks.end(),
736 "Got fork notification twice on the same RRef ",
737 forkId);
738 rrefForks.insert(forkId);
739 }
740
addForkOfOwnerIfNotPresent(const RRefId & rrefId,const ForkId & forkId)741 void RRefContext::addForkOfOwnerIfNotPresent(
742 const RRefId& rrefId,
743 const ForkId& forkId) {
744 std::lock_guard<std::mutex> lock(mutex_);
745 auto& rrefForks = forks_[rrefId];
746 // We first check whether the child exists in rrefForks. It's possible
747 // the child may have been added by a previous send attempt, and this check
748 // (as opposed to an assertion here) ensures that messages that trigger this
749 // function are idempotent.
750 if (rrefForks.find(forkId) == rrefForks.end()) {
751 rrefForks.insert(forkId);
752 } else {
753 LOG(INFO) << "Ignoring duplicate request to add Fork of OwnerRRef with "
754 << "RRefId = " << rrefId << ", ForkId = " << forkId;
755 }
756 }
757
delForkOfOwner(const RRefId & rrefId,const ForkId & forkId)758 c10::intrusive_ptr<RRef> RRefContext::delForkOfOwner(
759 const RRefId& rrefId,
760 const ForkId& forkId) {
761 c10::intrusive_ptr<RRef> deletedRRef;
762 bool ownerReduced = false;
763 // There were previously multiple TORCH_CHECKs in this function that checked
764 // whether the passed in fork was known by the user and whether the fork had
765 // already been deleted. These assertions are now replaced with nested if
766 // statements to ensure this function is idempotent. This makes it safe to
767 // retry RRefUserDelete messages.
768 {
769 std::lock_guard<std::mutex> lock(mutex_);
770 auto rrefIter = forks_.find(rrefId);
771 if (rrefIter != forks_.end()) {
772 auto& rrefForks = rrefIter->second;
773 auto forkIter = rrefForks.find(forkId);
774 if (forkIter != rrefForks.end()) {
775 rrefForks.erase(forkId);
776 } else {
777 LOG(INFO)
778 << "Could not find UserRRef instance, "
779 << "RRefId = " << rrefId << ", ForkId = " << forkId
780 << ", likely because it was deleted by a previously retried message";
781 }
782 if (rrefForks.empty()) {
783 auto ownerIter = owners_.find(rrefId);
784 if (ownerIter != owners_.end()) {
785 deletedRRef = ownerIter->second;
786 owners_.erase(ownerIter);
787 ownerReduced = true;
788 }
789 forks_.erase(rrefIter);
790 }
791 } else {
792 LOG(INFO)
793 << "Could not find OwnerRRef with RRefId = " << rrefId
794 << ", likely because it was deleted by a previously retried message";
795 }
796 }
797 if (ownerReduced) {
798 deleteAllUsersCV_.notify_all();
799 }
800 return deletedRRef;
801 }
802
803 } // namespace torch::distributed::rpc
804