xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/rref_context.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/rref_context.h>
2 #include <torch/csrc/distributed/rpc/rref_proto.h>
3 #include <torch/csrc/distributed/rpc/utils.h>
4 
5 #include <sstream>
6 
7 namespace torch::distributed::rpc {
8 
9 thread_local std::vector<std::shared_ptr<RRefContext::PendingUserState>>
10     RRefContext::userTable_;
11 thread_local bool RRefContext::recording_ = false;
12 
13 namespace callback {
confirmPendingUser(const JitFuture & jitFuture,const ForkId & expectedForkId)14 void confirmPendingUser(
15     const JitFuture& jitFuture,
16     const ForkId& expectedForkId) {
17   if (!jitFuture.hasError()) {
18     auto msgPtr = jitFuture.constValue().toCustomClass<Message>();
19     auto msgType = msgPtr->type();
20     auto rpc = deserializeResponse(*msgPtr, msgType);
21     auto& rr = dynamic_cast<RemoteRet&>(*rpc);
22     TORCH_INTERNAL_ASSERT(rr.forkId() == expectedForkId);
23   } else {
24     // Handle errors, such as timeouts, by invoking the error handler on the
25     // rref.
26     // Note [Best Effort Error handling for Remote calls]:
27     // When remote calls initiated by rpc.remote() fail, such as with a timeout
28     // error, we take a best-effort approach to error handling. We handle errors
29     // when callbacks corresponding to the remote call run, and set the error
30     // information on the RRef. If the RRef has not been used by the application
31     // before this process (such as to_here or fork call), then future uses of
32     // the RRef will appropriately raise errors. However, it is possible that
33     // the user application will use the RRef before the errors are handled. In
34     // this case, errors may not be raised as they have not yet been handled.
35     auto rref_ptr = RRefContext::getInstance().getPendingUser(expectedForkId);
36     auto errorType = getRPCErrorType(jitFuture);
37     rref_ptr->handleError(errorType, jitFuture);
38   }
39   RRefContext::getInstance().delPendingUser(expectedForkId);
40 }
41 
finishCreatingOwnerRRef(const JitFuture & jitFuture,const RRefId & rrefId)42 c10::intrusive_ptr<RRef> finishCreatingOwnerRRef(
43     const JitFuture& jitFuture,
44     const RRefId& rrefId) {
45   if (jitFuture.hasError()) {
46     auto& ctx = RRefContext::getInstance();
47     // We expect to run this callback only after the OwnerRRef has been created,
48     // since this is only invoked when sending to self.
49     auto rref_ptr =
50         fromRRefInterface(ctx.getOwnerRRef(rrefId, /* foreCreated */ true)
51                               ->constValue()
52                               .toRRef());
53     auto errorType = getRPCErrorType(jitFuture);
54     rref_ptr->handleError(errorType, jitFuture);
55     // OwnerRRefs do not have a forkId, so don't need to assert here.
56     auto deletedRRef =
57         ctx.delForkOfOwner(rref_ptr->rrefId(), rref_ptr->rrefId());
58     return deletedRRef;
59   } else {
60     auto msgPtr = jitFuture.constValue().toCustomClass<Message>();
61     auto msgType = msgPtr->type();
62     auto rpc = deserializeResponse(*msgPtr, msgType);
63     auto& rr = dynamic_cast<RemoteRet&>(*rpc);
64     TORCH_INTERNAL_ASSERT(
65         rr.rrefId() == rr.forkId(),
66         "Expecting an OwnerRRef as RemoteRet but got a fork.");
67     auto& ctx = RRefContext::getInstance();
68     auto deletedRRef = ctx.delForkOfOwner(rr.rrefId(), rr.rrefId());
69     return deletedRRef;
70   }
71 }
72 
73 } // namespace callback
74 
75 // Keys for RRef-related debug information.
76 const std::string kNumOwnerRRefs = "num_owner_rrefs";
77 const std::string kNumPendingFutures = "num_pending_futures";
78 const std::string kNumPendingUsers = "num_pending_users";
79 const std::string kNumForks = "num_forks";
80 
getInstance()81 RRefContext& RRefContext::getInstance() {
82   // Leaky singleton to avoid module destructor races.
83   static RRefContext* context = new RRefContext(RpcAgent::getCurrentRpcAgent());
84   return *context;
85 }
86 
destroyInstance(bool ignoreRRefLeak)87 std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance(
88     bool ignoreRRefLeak) {
89   auto& ctx = RRefContext::getInstance();
90   {
91     std::lock_guard<std::mutex> lock(ctx.destroyedMutex_);
92     ctx.destroyed_ = true;
93   }
94   ctx.checkRRefLeaks(ignoreRRefLeak);
95   std::vector<c10::intrusive_ptr<RRef>> deletedRRefs;
96   for (auto& entry : ctx.owners_) {
97     auto rref = entry.second;
98     if (rref->isPyObj()) {
99       deletedRRefs.emplace_back(std::move(rref));
100     }
101   }
102   ctx.owners_.clear();
103   ctx.pendingOwners_.clear();
104   return deletedRRefs;
105 }
106 
handleException(const JitFuture & jitFuture)107 void RRefContext::handleException(const JitFuture& jitFuture) {
108   if (jitFuture.hasError()) {
109     auto errMsg = jitFuture.tryRetrieveErrorMessage();
110     VLOG(1) << "Got exception: " << errMsg;
111     TORCH_CHECK(false, errMsg);
112   }
113 }
114 
handleExceptionSilent(const JitFuture & jitFuture)115 void RRefContext::handleExceptionSilent(const JitFuture& jitFuture) {
116   if (jitFuture.hasError()) {
117     auto errMsg = jitFuture.tryRetrieveErrorMessage();
118     VLOG(1) << "Got exception: " << errMsg;
119     TORCH_CHECK_MSG(false, errMsg);
120   }
121 }
122 
RRefContext(std::shared_ptr<RpcAgent> agent)123 RRefContext::RRefContext(std::shared_ptr<RpcAgent> agent)
124     : agent_(std::move(agent)) {}
125 
~RRefContext()126 RRefContext::~RRefContext() {
127   if (!owners_.empty()) {
128     VLOG(1) << "Destructing RRefContext with non-empty OwnerRRef set. "
129             << "This would likely cause Python deref error. "
130             << "Make sure destroyInstance() is invoked before destruction.";
131   }
132 }
133 
getDebugInfo()134 std::unordered_map<std::string, std::string> RRefContext::getDebugInfo() {
135   std::unordered_map<std::string, std::string> info;
136   std::unique_lock<std::mutex> lock(mutex_);
137   auto ownerSize = owners_.size();
138   auto numPendingUsers = pendingUsers_.size();
139   int numForks = 0;
140   for (const auto& owner : forks_) {
141     numForks += owner.second.size();
142   }
143   lock.unlock();
144   info[kNumOwnerRRefs] = std::to_string(ownerSize);
145   info[kNumPendingFutures] = std::to_string(numPendingFutures_.load());
146   info[kNumPendingUsers] = std::to_string(numPendingUsers);
147   info[kNumForks] = std::to_string(numForks);
148   return info;
149 }
150 
checkRRefLeaks(bool ignoreRRefLeak)151 void RRefContext::checkRRefLeaks(bool ignoreRRefLeak) {
152   if (!forks_.empty()) {
153     std::stringstream ss;
154     for (auto& entry : forks_) {
155       const RRefId& rrefId = entry.first;
156       for (const auto& forkId : entry.second) {
157         ss << "Leaking RRef " << rrefId << " with fork Id " << forkId << '\n';
158       }
159     }
160 
161     LOG(WARNING)
162         << "Detected RRef Leaks during shutdown. This usually "
163         << "occurs when the application code still holds references to RRef "
164         << "instances when calling shutdown(). If the program has "
165         << "completed correctly and the process is exiting, it is OK to "
166         << "ignore these leaks. However, if you program will keep running "
167         << "after this, these leaks could result in memory leaks on RRef "
168         << "owners. Please make sure all RRefs are out of scope and Python "
169         << "GC has deleted them before calling shutdown(): \n"
170         << ss.str();
171     if (!ignoreRRefLeak) {
172       TORCH_CHECK(false, ss.str());
173     }
174   }
175 }
176 
createUserRRef(worker_id_t ownerId,const TypePtr & type)177 c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef(
178     worker_id_t ownerId,
179     const TypePtr& type) {
180   TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner.");
181   // Explicitly creating rrefId before forkId to make sure the order is
182   // deterministic, as the argument evaluation order is system and compiler
183   // dependent.
184   const auto rrefId = genGloballyUniqueId();
185   const auto forkId = genGloballyUniqueId();
186   return createUserRRef(ownerId, rrefId, forkId, type);
187 }
188 
createUserRRef(worker_id_t ownerId,const RRefId & rrefId,const ForkId & forkId,const TypePtr & type)189 c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef(
190     worker_id_t ownerId,
191     const RRefId& rrefId,
192     const ForkId& forkId,
193     const TypePtr& type) {
194   TORCH_CHECK(ownerId != getWorkerId(), "RRef owner cannot create user RRef.");
195   // RRefContext does not track user RRefs, it will be destructed when there
196   // is no shared_ptrs pointing to it.
197   //
198   // NB: cannot use make_shared here as the constructor of UserRRef is private.
199   // NB: This UserRRef has not been confirmed by the owner yet. This function's
200   // call site is responsible for adding this UserRRef to pendingUsers_.
201   // Currently, there are two call sites.
202   // (1) The creator user in python_functions.cpp
203   // (2) The callee user in RRefContext::notifyOwnerAndParentOfFork.
204   //
205   // The reason for not adding the pending user here is to put addPendingUser()
206   // close to where the RPC occurs, and it is more clear to pair it with
207   // deletePendingUser() in the response callback at the call site.
208   return c10::make_intrusive<UserRRef>(ownerId, rrefId, forkId, type);
209 }
210 
delUser(const worker_id_t owner,const RRefId & rrefId,const ForkId & forkId)211 void RRefContext::delUser(
212     const worker_id_t owner,
213     const RRefId& rrefId,
214     const ForkId& forkId) {
215   {
216     std::lock_guard<std::mutex> lock(destroyedMutex_);
217     if (!destroyed_) {
218       // Sending an RRefUserDelete causes the receiver to run delForkOfOwner,
219       // which is now idempotent. See the comment at RRefContext::delForkOfOwner
220       // for more details.
221       ++numPendingFutures_;
222       auto jitFuture = agent_->sendWithRetries(
223           agent_->getWorkerInfo(owner),
224           RRefUserDelete(rrefId, forkId).toMessage());
225 
226       jitFuture->addCallback([this](JitFuture& future) {
227         handleExceptionSilent(future);
228         --numPendingFutures_;
229       });
230     }
231   }
232 
233   std::lock_guard<std::mutex> lock(mutex_);
234   confirmedUsers_.erase(forkId);
235 }
236 
delAllUsersAndUnforkedOwners(std::chrono::milliseconds timeoutMillis)237 void RRefContext::delAllUsersAndUnforkedOwners(
238     std::chrono::milliseconds timeoutMillis) {
239   // First, wait for all pending UserRRefs to be confirmed,
240   // one kind is pendingUsers_, which are shared from Owner,
241   // the other kind pendingChildren_, which are shared from another User.
242   std::unordered_map<ForkId, c10::weak_intrusive_ptr<RRef>, ForkId::Hash>
243       tempConfirmedUsers;
244   {
245     std::unique_lock<std::mutex> lock(mutex_);
246     bool noPending = deleteAllUsersCV_.wait_for(lock, timeoutMillis, [this]() {
247       return pendingUsers_.empty() && pendingChildren_.empty();
248     });
249     if (!noPending) {
250       LOG(ERROR)
251           << "Timed out waiting for pending UserRRefs to be confirmed by owner and parent.";
252     }
253     tempConfirmedUsers.swap(confirmedUsers_);
254   }
255 
256   // Start sending UserRRef delete messages, after all pendings are confirmed.
257   // Note, there should be no new forkings in between, because it's assumed that
258   // this utility is called during graceful shutdown, where no new user RPCs can
259   // be initiaited anymore.
260   for (const auto& user : tempConfirmedUsers) {
261     c10::intrusive_ptr<RRef> rref_ptr = user.second.lock();
262     if (!rref_ptr) {
263       continue;
264     }
265     // tryDel() below will re-acquire lock, lock must be released here.
266     rref_ptr->tryDel();
267   }
268 
269   // If an rref in the owners_ map has never been forked, we will never get a
270   // corresponding message from the forking node(s) telling us to delete the
271   // RRef. Hence we delete the RRef here. This can occur when a remote call is
272   // sent to self and times out.
273   {
274     std::unique_lock<std::mutex> lock(mutex_);
275     std::vector<RRefId> unforkedOwners;
276     for (const auto& it : owners_) {
277       auto rrefId = it.first;
278       if (forks_.find(rrefId) == forks_.end()) {
279         // Successful fork of owner was never processed.
280         unforkedOwners.push_back(rrefId);
281       }
282     }
283     for (auto& rrefId : unforkedOwners) {
284       LOG(INFO) << "Removing unforked OwnerRRef with RRefId: " << rrefId;
285       auto iter = owners_.find(rrefId);
286       TORCH_CHECK(
287           iter != owners_.end(),
288           c10::str("Did not find OwnerRRef with RRefId: ", rrefId));
289       owners_.erase(iter);
290     }
291   }
292   // Wait for this node to process all delete UserRRef messages it may get for
293   // the OwnerRRefs that exist on this node.
294   {
295     std::unique_lock<std::mutex> lock(mutex_);
296     bool noOwner = deleteAllUsersCV_.wait_for(
297         lock, timeoutMillis, [this]() { return owners_.empty(); });
298     if (!noOwner) {
299       LOG(ERROR) << "Timed out waiting for pending OwnerRRefs to be deleted.";
300     }
301   }
302 }
303 
getOrCreateRRef(const RRefForkData & rrefForkData,const TypePtr & type)304 c10::intrusive_ptr<RRef> RRefContext::getOrCreateRRef(
305     const RRefForkData& rrefForkData,
306     const TypePtr& type) {
307   auto& ownerId = rrefForkData.ownerId_;
308   auto& rrefId = rrefForkData.rrefId_;
309   auto& forkId = rrefForkData.forkId_;
310   if (ownerId == getWorkerId()) {
311     return getOrCreateOwnerRRef(rrefId, type);
312   } else {
313     return createUserRRef(ownerId, rrefId, forkId, type);
314   }
315 }
316 
getOrCreateOwnerRRef(const RRefId & rrefId,const TypePtr & type)317 c10::intrusive_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
318     const RRefId& rrefId,
319     const TypePtr& type) {
320   std::lock_guard<std::mutex> lock(mutex_);
321   const auto iter = owners_.find(rrefId);
322   if (iter == owners_.end()) {
323     // Scenario (1) the first time this owner knows about this RRef
324     //
325     // NB: cannot use make_shared here as the constructor of OwnerRRef is
326     // private.
327     auto rref = c10::make_intrusive<OwnerRRef>(
328         getWorkerId(), rrefId, type, agent_->getDevices());
329     owners_[rref->rrefId()] = rref;
330     const auto pendingOwnerIter = pendingOwners_.find(rrefId);
331     if (pendingOwnerIter != pendingOwners_.end()) {
332       // cast to RRefInterface to hold it into IValue
333       auto rrefPtr = fromOwnerRRef(rref);
334       pendingOwnerIter->second->markCompleted(IValue(rrefPtr));
335       pendingOwners_.erase(pendingOwnerIter);
336     }
337     return rref;
338   } else {
339     // Scenario (2) retrieving an existing RRef
340     auto ownerRRef = fromRRefInterface(iter->second);
341     // Now double check if the two types match
342     //
343     // Why we are special casing the check for tensor type here?
344     // this is because tensor types might get specialized on tensors when
345     // we pass inputs to the function, i.e. TensorType can filled with
346     // specific shape info, requires_grad info, etc. so the OwerRRef we
347     // found might already have those infos, but the `type` we passed in
348     // here is a plain TensorType, they are not equal relationship:
349     // specialized TensorType <: plain TensorType
350     //
351     // In RPC we don't care the difference as we ser/de with just the
352     // plain TensorType. This is not a issue for UserRRef creation either,
353     // since Tensor can only get specialized with a previous run of local
354     // JIT function, and we shouldn't preserve the specialized SubTensorType
355     // information on other workers because it's only information only.
356     if (type->isSubtypeOf(*TensorType::get())) {
357       TORCH_INTERNAL_ASSERT(
358           ownerRRef->type()->isSubtypeOf(*TensorType::get()),
359           "Expect OwnerRRef to be a sub-type of TensorType, but got ",
360           ownerRRef->type()->repr_str());
361     } else {
362       TORCH_INTERNAL_ASSERT(
363           *ownerRRef->type() == *type,
364           "OwnerRRef type is ",
365           ownerRRef->type()->repr_str(),
366           ", expected type is ",
367           type->repr_str());
368     }
369     return ownerRRef;
370   }
371 }
372 
createOwnerRRef(const TypePtr & type)373 c10::intrusive_ptr<OwnerRRef> RRefContext::createOwnerRRef(
374     const TypePtr& type) {
375   // Don't add this OnwerRRef to the owners_ map yet, otherwise
376   // it will never be removed from there. Instead, only add it to the
377   // map in prepareChildFork, in case this local RRef is being passed
378   // to another worker.
379   return c10::make_intrusive<OwnerRRef>(
380       getWorkerId(), genGloballyUniqueId(), type, agent_->getDevices());
381 }
382 
getOwnerRRef(const RRefId & rrefId,bool forceCreated)383 c10::intrusive_ptr<JitFuture> RRefContext::getOwnerRRef(
384     const RRefId& rrefId,
385     bool forceCreated) {
386   std::unique_lock<std::mutex> lock(mutex_);
387   const auto iter = owners_.find(rrefId);
388   if (iter == owners_.end()) {
389     if (forceCreated) {
390       TORCH_INTERNAL_ASSERT(
391           false,
392           c10::str("Expected OwnerRRef with id ", rrefId, " to be created."));
393     }
394     // Scenario (1) RRef is used before it is created
395     const auto pendingOwnerIter = pendingOwners_.find(rrefId);
396     if (pendingOwnerIter == pendingOwners_.end()) {
397       // Note: The type passed into RRefType::create() does not matter here, as
398       // the future is marked as completed with the RRef of the specific type
399       // in getOrCreateOwnerRRef().
400       // We need to set devices here, even if they won't be used by the value
401       // (an RRef object doesn't contain any tensors, it just provides means to
402       // retrieve them) because we need them to be propagated/ to child futures.
403       // This is silly and we should find a way to avoid this.
404       auto futureOwner = c10::make_intrusive<JitFuture>(
405           RRefType::create(c10::AnyType::get()), agent_->getDevices());
406       pendingOwners_[rrefId] = futureOwner;
407       return futureOwner;
408     } else {
409       return pendingOwnerIter->second;
410     }
411   } else {
412     // Scenario (2) retrieving an existing RRef
413     // Marks IValue Future as completed with the RRef IValue.
414     auto owner = iter->second;
415     auto rrefPtr = fromOwnerRRef(owner);
416 
417     // We need to set devices here, even if they won't be used by the value (an
418     // RRef object doesn't contain any tensors, it just provides means to
419     // retrieve them) because we need them to be propagated/ to child futures.
420     // This is silly and we should find a way to avoid this.
421     auto futureOwner = c10::make_intrusive<JitFuture>(
422         RRefType::create(owner->type()), agent_->getDevices());
423     futureOwner->markCompleted(IValue(rrefPtr));
424     return futureOwner;
425   }
426 }
427 
prepareChildFork(const c10::intrusive_ptr<RRef> & rref)428 RRefForkData RRefContext::prepareChildFork(
429     const c10::intrusive_ptr<RRef>& rref) {
430   // If we know that rref creation on the owner has timed out, raise it to the
431   // user here, otherwise continue with pickling.
432 
433   TORCH_CHECK(
434       !rref->getTimedOut(),
435       "RRef creation via rpc.remote() timed out, and it "
436       "is possible that the RRef on the owner node does not exist.");
437   auto rrefForkData = rref->fork();
438   if (rref->isOwner()) {
439     // Note [Early Fork Registration]
440     // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
441     // If the parent (caller) is the owner, directly register the fork, instead
442     // of waiting for another RREF_FORK_REQUEST or RREF_CHILD_ACCEPT message. An
443     // Alternative is adding the fork when the callee user ACKs. However, before
444     // that, the owner still have to adds the OwnerRRef into some map to keep it
445     // alive (e.g., in pendingChildren_). Hence, adding the fork here or in the
446     // ACK does not making any difference but only add complexity.
447     // TODO: When adding failure retries and timeout, this fork needs to be
448     // deleted if the owner does not receive the ACK within the timeout.
449     addForkOfOwner(rrefForkData.rrefId_, rrefForkData.forkId_);
450     // ensure that this RRef is in the owners_ list to keep it alive.
451     // this is needed for OwnerRRefs that were created locally.
452     {
453       std::lock_guard<std::mutex> lock(mutex_);
454       owners_[rref->rrefId()] = rref;
455     }
456   } else {
457     // Note [Useful Phantom Fork ID for User to Owner Call]
458     // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
459     // If the callee of dist.remote or dist.rpc is the owner of this RRef, the
460     // callee will not create a fork using this rrefForkData.forkId_, because
461     // the owner will only keep one `OwnerRRef` instance and will not create any
462     // `UserRRef` instances. However, this rrefForkData.forkId_ is still
463     // necessary, as the caller user needs to keep this `UserRRef` alive until
464     // it gets the ACK from the callee owner. Otherwise, the delete message
465     // could arrive at the owner before this dist.rpc or dist.remote call, which
466     // could potentially trigger the `OwnerRRef` to be deleted before running
467     // the user code.
468     addPendingChild(rrefForkData.forkId_, rref);
469   }
470   return rrefForkData;
471 }
472 
notifyOwnerAndParentOfFork(const ForkId & forkId,worker_id_t parent,const c10::intrusive_ptr<RRef> & rref)473 void RRefContext::notifyOwnerAndParentOfFork(
474     const ForkId& forkId,
475     worker_id_t parent,
476     const c10::intrusive_ptr<RRef>& rref) {
477   // Fork is shared from owner.
478   if (parent == rref->owner()) {
479     if (parent == agent_->getWorkerInfo().id_) {
480       // Owner sending RRef to self, remove the forkId as it was added during
481       // pickling
482       auto deletedRRef = delForkOfOwner(rref->rrefId(), forkId);
483       if (deletedRRef) {
484         TORCH_INTERNAL_ASSERT(
485             deletedRRef->rrefId() == rref->rrefId(),
486             "Deleting a fork of ",
487             rref->rrefId(),
488             " triggered deleting the OwnerRRef of ",
489             deletedRRef->rrefId());
490         // NB: not necessary to reset deletedRRef as rref is another shared_ptr
491         // instance pointing to the same OwnerRRef.
492       }
493     } else {
494       // If the parent is the owner, this fork has already been added into the
495       // forks_ map when the owner sends the message to the callee user.
496       // Hence, it is not necessary to send another RREF_CHILD_ACCEPT or
497       // RREF_FORK_REQUEST back to the owner. See Note [Early Fork
498       // Registration].
499       std::lock_guard<std::mutex> lock(mutex_);
500       addConfirmedUser(forkId, rref);
501     }
502     return;
503   }
504 
505   // Fork is shared from user.
506   if (rref->isOwner()) {
507     // See Note [Useful Phantom Fork ID for User to Owner Call]
508     // In this case, the owner is the caller, and it does not add the fork id
509     // into forks_. Because, there will be no real `UserRRef` associated
510     // with this fork ID.
511     ++numPendingFutures_;
512     auto jitFuture = agent_->sendWithRetries(
513         agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
514     jitFuture->addCallback([this](JitFuture& future) {
515       handleExceptionSilent(future);
516       --numPendingFutures_;
517     });
518   } else {
519     ++numPendingFutures_;
520     auto jitFuture = agent_->sendWithRetries(
521         agent_->getWorkerInfo(rref->owner()),
522         RRefForkRequest(rref->rrefId(), forkId).toMessage());
523 
524     addPendingUser(forkId, rref);
525 
526     jitFuture->addCallback([this, forkId, parent](JitFuture& future) {
527       handleException(future);
528       this->finishForkRequest(forkId, parent);
529       // Decrease after calling finishForkRequest because, as that creates a new
530       // future, it might otherwise cause the count to briefly go to zero.
531       --numPendingFutures_;
532     });
533   }
534 }
535 
addPendingChild(const ForkId & forkId,const c10::intrusive_ptr<RRef> & rref)536 void RRefContext::addPendingChild(
537     const ForkId& forkId,
538     const c10::intrusive_ptr<RRef>& rref) {
539   // see Note [Early Fork Registration]
540   // If the parent is the owner, it should directly add the child UserRRef as a
541   // fork.
542   TORCH_INTERNAL_ASSERT(
543       !rref->isOwner(), "OwnerRRef should not have a pending child.");
544   std::lock_guard<std::mutex> lock(mutex_);
545   TORCH_INTERNAL_ASSERT(
546       pendingChildren_.find(forkId) == pendingChildren_.end(),
547       "Inconsistent states: attempt to add the same child fork twice.");
548   pendingChildren_[forkId] = rref;
549 }
550 
delPendingChild(const ForkId & forkId)551 void RRefContext::delPendingChild(const ForkId& forkId) {
552   c10::intrusive_ptr<RRef> deletedUser;
553   {
554     std::lock_guard<std::mutex> lock(mutex_);
555     auto iter = pendingChildren_.find(forkId);
556     // We first check whether the child exists in pendingChildren_. It's
557     // possible the child may have been removed by a previous send attempt, and
558     // this check (as opposed to an assertion here) ensures that messages that
559     // trigger this function are idempotent.
560     if (iter != pendingChildren_.end()) {
561       // Since this UserRRef is removed from the map,
562       // the refcount of this UserRRef could reach to 0,
563       // so the "destructor", `release_resources()`, might be called,
564       // in which the lock is acquired again.
565       // So it must be destructed with the lock released.
566       // Meet this constraint by creating a temporary pointer to increase the
567       // refcount, extending its lifetime until lock released.
568       deletedUser = iter->second; // Increase refcount.
569       pendingChildren_.erase(iter); // Decrease refcount.
570     } else {
571       LOG(INFO) << "Ignoring duplicate request to delete child UserRRef with "
572                 << "ForkId = " << forkId;
573     }
574   }
575   deleteAllUsersCV_.notify_all();
576   // The refcount of this UserRRef could reach to 0,
577   // so the "destructor", release_resources(), might be called,
578   // in which the lock is acquired again,
579   // so must destruct it with the lock released.
580   deletedUser.reset(); // Decrease refcount.
581 }
582 
addPendingUser(const ForkId & forkId,const c10::intrusive_ptr<RRef> & rref)583 void RRefContext::addPendingUser(
584     const ForkId& forkId,
585     const c10::intrusive_ptr<RRef>& rref) {
586   TORCH_INTERNAL_ASSERT(
587       !rref->isOwner(), "Attempt to add an OwnerRRef as a pending User.");
588 
589   auto state = std::make_shared<PendingUserState>(rref);
590   if (recording_) {
591     // adding and waiting for pending users are guaranteed to be called from the
592     // same thread, but deleting pending users will be called from another
593     // thread. As the delPendingUser will not be able to access the same
594     // thread_local variable, we cannot address this problem by making
595     // pendingUsers_ thread_local. Instead, pendingUsers_ and userTable_ share
596     // the same PendingUserState shared_ptr.
597     userTable_.push_back(state);
598   }
599 
600   std::lock_guard<std::mutex> lock(mutex_);
601   TORCH_INTERNAL_ASSERT(
602       pendingUsers_.find(forkId) == pendingUsers_.end(),
603       "Inconsistent states: attempt to add the same UserRRef twice.");
604 
605   pendingUsers_.emplace(
606       std::piecewise_construct,
607       std::forward_as_tuple(forkId),
608       std::forward_as_tuple(state));
609 }
610 
delPendingUser(const ForkId & forkId)611 void RRefContext::delPendingUser(const ForkId& forkId) {
612   std::shared_ptr<PendingUserState> deletedState = nullptr;
613   {
614     std::lock_guard<std::mutex> lock(mutex_);
615     auto iter = pendingUsers_.find(forkId);
616     TORCH_INTERNAL_ASSERT(
617         iter != pendingUsers_.end(),
618         "Inconsistent states: attempt to delete a non-exist UserRRef.");
619 
620     // There are two reasons for keeping the deleted PendingUserState alive
621     // until exiting the critical section.
622     // (1) Since this UserRRef is removed from the map, the refcount of this
623     //     UserRRef could reach to 0. So the resource destructor
624     //     (`release_resources()`) might be called, in which the lock is
625     //     acquired again. Hence, it must be destructed with the lock released.
626     //     To meet this constraint, we intentionally create a temporary pointer
627     //     to increase the refcount of the deleted PendingUserState, extending
628     //     its lifetime until lock released.
629     // (2) Since #34497, a user function only runs after all RRefs in the
630     //     arguments are confirmed by their owners, which is done by adding the
631     //     RPC processing logic as a callback to the UserRRef ready future. So,
632     //     calling `confirm` on the PendingUserState could trigger pending user
633     //     functions, which might in turn acquire the lock in RRefContext.
634     //     Hence, we must release the lock to prevent deadlock.
635     // NB: Another option is to use reentrant lock. However, it is better for
636     // the developers to fully understand the locking behavior instead of
637     // hiding the subtle logic using a reentrant lock.
638     deletedState = iter->second; // Increase refcount
639 
640     addConfirmedUser(forkId, iter->second->rref_);
641     pendingUsers_.erase(iter); // Decrease refcount.
642   }
643   deletedState->confirm();
644   deleteAllUsersCV_.notify_all();
645   deletedState.reset(); // Decrease refcount.
646 }
647 
addConfirmedUser(const ForkId & forkId,const c10::intrusive_ptr<RRef> & rref)648 void RRefContext::addConfirmedUser(
649     const ForkId& forkId,
650     const c10::intrusive_ptr<RRef>& rref) {
651   // Notice, caller need to hold the mutex for confirmedUsers_.
652   // std::lock_guard<std::mutex> lock(mutex_);
653   confirmedUsers_.emplace(
654       std::piecewise_construct,
655       std::forward_as_tuple(forkId),
656       std::forward_as_tuple(rref));
657 }
658 
getPendingUser(const ForkId & forkId)659 c10::intrusive_ptr<RRef> RRefContext::getPendingUser(const ForkId& forkId) {
660   std::lock_guard<std::mutex> lock(mutex_);
661   auto it = pendingUsers_.find(forkId);
662   if (it == pendingUsers_.end()) {
663     TORCH_INTERNAL_ASSERT(
664         false, "Pending user with forkId ", forkId, " not found");
665   }
666   return it->second->rref_;
667 }
668 
recordThreadLocalPendingRRefs()669 void RRefContext::recordThreadLocalPendingRRefs() {
670   TORCH_INTERNAL_ASSERT(
671       userTable_.empty(),
672       "User RRef Table should be empty when start recording");
673   recording_ = true;
674 }
675 
waitForThreadLocalPendingRRefs()676 c10::intrusive_ptr<JitFuture> RRefContext::waitForThreadLocalPendingRRefs() {
677   // We need to set devices here, even if they won't be used by the value (it's
678   // a bool, it doesn't contain tensors!) because we need them to be propagated
679   // to child futures. This is silly and we should find a way to avoid this.
680   auto jitFuturePtr =
681       c10::make_intrusive<JitFuture>(BoolType::get(), agent_->getDevices());
682   if (userTable_.empty()) {
683     jitFuturePtr->markCompleted(true);
684   } else {
685     auto remainingRRefs =
686         std::make_shared<std::atomic<uint64_t>>(userTable_.size());
687     for (auto& state : userTable_) {
688       state->confirmationFuture_->addCallback(
689           [jitFuturePtr, remainingRRefs](JitFuture& /* unused */) {
690             auto localCount = remainingRRefs->fetch_sub(1);
691             if (localCount == 1) {
692               jitFuturePtr->markCompleted(true);
693             }
694           });
695     }
696     userTable_.clear();
697   }
698   recording_ = false;
699   return jitFuturePtr;
700 }
701 
clearRecordedPendingRRefsOnError()702 void RRefContext::clearRecordedPendingRRefsOnError() {
703   userTable_.clear();
704   recording_ = false;
705 }
706 
finishForkRequest(const ForkId & forkId,worker_id_t parent)707 void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
708   delPendingUser(forkId);
709   ++numPendingFutures_;
710   auto jitFuture = agent_->sendWithRetries(
711       agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
712 
713   jitFuture->addCallback([this](JitFuture& future) {
714     handleExceptionSilent(future);
715     --numPendingFutures_;
716   });
717 }
718 
addSelfAsFork(c10::intrusive_ptr<OwnerRRef> & rref)719 void RRefContext::addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref) {
720   std::lock_guard<std::mutex> lock(mutex_);
721   const auto& rrefId = rref->rrefId();
722   owners_[rrefId] = rref;
723   auto& rrefForks = forks_[rrefId];
724   TORCH_INTERNAL_ASSERT(
725       rrefForks.find(rrefId) == rrefForks.end(),
726       "Attempt to add self as fork twice ",
727       rrefId);
728   rrefForks.insert(rrefId);
729 }
730 
addForkOfOwner(const RRefId & rrefId,const ForkId & forkId)731 void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) {
732   std::lock_guard<std::mutex> lock(mutex_);
733   auto& rrefForks = forks_[rrefId];
734   TORCH_INTERNAL_ASSERT(
735       rrefForks.find(forkId) == rrefForks.end(),
736       "Got fork notification twice on the same RRef ",
737       forkId);
738   rrefForks.insert(forkId);
739 }
740 
addForkOfOwnerIfNotPresent(const RRefId & rrefId,const ForkId & forkId)741 void RRefContext::addForkOfOwnerIfNotPresent(
742     const RRefId& rrefId,
743     const ForkId& forkId) {
744   std::lock_guard<std::mutex> lock(mutex_);
745   auto& rrefForks = forks_[rrefId];
746   // We first check whether the child exists in rrefForks. It's possible
747   // the child may have been added by a previous send attempt, and this check
748   // (as opposed to an assertion here) ensures that messages that trigger this
749   // function are idempotent.
750   if (rrefForks.find(forkId) == rrefForks.end()) {
751     rrefForks.insert(forkId);
752   } else {
753     LOG(INFO) << "Ignoring duplicate request to add Fork of OwnerRRef with "
754               << "RRefId = " << rrefId << ", ForkId = " << forkId;
755   }
756 }
757 
delForkOfOwner(const RRefId & rrefId,const ForkId & forkId)758 c10::intrusive_ptr<RRef> RRefContext::delForkOfOwner(
759     const RRefId& rrefId,
760     const ForkId& forkId) {
761   c10::intrusive_ptr<RRef> deletedRRef;
762   bool ownerReduced = false;
763   // There were previously multiple TORCH_CHECKs in this function that checked
764   // whether the passed in fork was known by the user and whether the fork had
765   // already been deleted. These assertions are now replaced with nested if
766   // statements to ensure this function is idempotent. This makes it safe to
767   // retry RRefUserDelete messages.
768   {
769     std::lock_guard<std::mutex> lock(mutex_);
770     auto rrefIter = forks_.find(rrefId);
771     if (rrefIter != forks_.end()) {
772       auto& rrefForks = rrefIter->second;
773       auto forkIter = rrefForks.find(forkId);
774       if (forkIter != rrefForks.end()) {
775         rrefForks.erase(forkId);
776       } else {
777         LOG(INFO)
778             << "Could not find UserRRef instance, "
779             << "RRefId = " << rrefId << ", ForkId = " << forkId
780             << ", likely because it was deleted by a previously retried message";
781       }
782       if (rrefForks.empty()) {
783         auto ownerIter = owners_.find(rrefId);
784         if (ownerIter != owners_.end()) {
785           deletedRRef = ownerIter->second;
786           owners_.erase(ownerIter);
787           ownerReduced = true;
788         }
789         forks_.erase(rrefIter);
790       }
791     } else {
792       LOG(INFO)
793           << "Could not find OwnerRRef with RRefId = " << rrefId
794           << ", likely because it was deleted by a previously retried message";
795     }
796   }
797   if (ownerReduced) {
798     deleteAllUsersCV_.notify_all();
799   }
800   return deletedRRef;
801 }
802 
803 } // namespace torch::distributed::rpc
804