xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/rref_context.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/message.h>
4 #include <torch/csrc/distributed/rpc/rpc_agent.h>
5 #include <torch/csrc/distributed/rpc/rref_impl.h>
6 #include <torch/csrc/distributed/rpc/types.h>
7 #include <torch/csrc/distributed/rpc/utils.h>
8 
9 #include <atomic>
10 #include <optional>
11 
12 namespace torch::distributed::rpc {
13 
14 namespace callback {
15 // It's the callback for RemoteCall.
16 void TORCH_API
17 confirmPendingUser(const JitFuture& jitFuture, const ForkId& expectedForkId);
18 
19 // It's the callback for finishing creating owner rref, it returned deletedRRef,
20 // so that the deletedRRef can be handled under GIL in python_functions.cpp if
21 // deletedRRef contains python object.
22 c10::intrusive_ptr<RRef> TORCH_API
23 finishCreatingOwnerRRef(const JitFuture& jitFuture, const RRefId& rrefId);
24 } // namespace callback
25 
26 // Manages RRef lifetime and keeps track of RRef forks.
27 class TORCH_API RRefContext {
28  public:
29   static RRefContext& getInstance();
30   // NB: This method must be called before destructing RRefContext singleton.
31   // Similar to delForkOfOwner, this method returns a vector of OwnerRRefs that
32   // hold py::object. The call-site is also responsible for resetting those
33   // shared_ptr objects with a GIL. See comments at delForkOfOwner() for more
34   // details.
35   static std::vector<c10::intrusive_ptr<RRef>> destroyInstance(
36       bool ignoreRRefLeak = true);
37 
38   static void handleException(const JitFuture& jitFuture);
39 
40   // handle exception without throw ::c10::Error again
41   static void handleExceptionSilent(const JitFuture& jitFuture);
42 
43   RRefContext(const RRefContext&) = delete;
44   RRefContext(RRefContext&& other) = delete;
45   void operator=(const RRefContext&) = delete;
46   RRefContext& operator=(RRefContext&& other) = delete;
47 
48   ~RRefContext();
49 
50   // get the worker id of the current worker
getWorkerId()51   inline worker_id_t getWorkerId() const {
52     return agent_->getWorkerInfo().id_;
53   }
54 
55   // get the worker name of the current worker
getWorkerName()56   inline const std::string& getWorkerName() const {
57     return agent_->getWorkerInfo().name_;
58   }
59 
60   //  generate a globally unique ID
genGloballyUniqueId()61   inline GloballyUniqueId genGloballyUniqueId() {
62     return GloballyUniqueId(getWorkerId(), nextLocalId_++);
63   }
64 
agent()65   inline const std::shared_ptr<RpcAgent>& agent() const {
66     return agent_;
67   }
68 
69   // create a ``UserRRef`` owned by the worker ``ownerId``
70   c10::intrusive_ptr<UserRRef> createUserRRef(
71       worker_id_t ownerId,
72       const TypePtr& type);
73 
74   // Convert an RRefForkData into an RRef. This RRef could be user or owner.
75   // This RRef could have already existed before, or could be created in this
76   // method, we pass type here to validate or help the rref creation.
77   c10::intrusive_ptr<RRef> getOrCreateRRef(
78       const RRefForkData& rfd,
79       const TypePtr& type);
80 
81   // Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new
82   // one. This function is called in two places:
83   // 1. when processing ``rpc.remote()``, i.e., ``SCRIPT_REMOTE_CALL``
84   //    ``PYTHON_REMOTE_CALL``.
85   // 2. when unpickling ``OwnerRRef``.
86   // What's common in these two cases are, 1) the RRefId is already generated
87   // 2) the TypePtr is presented. So it can always create the ``OwnerRRef`` if
88   // it is not yet available.
89   c10::intrusive_ptr<OwnerRRef> getOrCreateOwnerRRef(
90       const RRefId& rrefId,
91       const TypePtr& type);
92 
93   // Create an empty owner rref of type.
94   // This method is called to first time generate an ``OwnerRRef``, e.g.,
95   // 1) ``rpc.RRef(obj)``
96   // 2) create the ``OwnerRRef`` on `rpc.remote()` caller side.
97   // What's common in these two cases are, 1) the RRefId hasn't been generated
98   // 2) the TypePtr is presented.
99   c10::intrusive_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type);
100 
101   // Returns a Future of the OwnerRRef, which will be marked completed when
102   // ``OwnerRRef`` is created. This method is used when the TypePtr is not
103   // available, e.g., when processing to_here(). The forceCreated flag can be
104   // used to ensure that the rref is created on the owner, otherwise throw in
105   // cases where the user of this API expects this to return a completed future.
106   // Note that the return value is a intrusive_ptr to a c10::ivalue::Future that
107   // holds the RRef.
108   c10::intrusive_ptr<JitFuture> getOwnerRRef(
109       const RRefId& rrefId,
110       bool forceCreated = false);
111 
112   // Adding the RRefId of an OwnerRRef into the forks_ map. This is useful when
113   // making a remote call to self, which as for now, still goes through serde
114   // and invokes request callback. In this case, the OwnerRRef has already been
115   // created on the send side, and we need to pass it to the receive side,
116   // instead of creating a new OwnerRRef. This is done by adding the OwnerRRef
117   // into owners_. However, that alone is not enough, as it could be deleted
118   // when all UserRRef die, which would then remove the OwnerRRef from owners_
119   // and this could happen before the self remote call finishes. To prevent
120   // that, this API adds the RRefId as a ForkId, which will then delete the
121   // ForkId when the self remote is done.
122   void addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref);
123 
124   // Register a fork of the ``OwnerRRef``, and inserts a intrusive_ptr of the
125   // ``OwnerRRef`` in a map to keep it alive.
126   void addForkOfOwner(const RRefId& rrefId, const ForkId& forkId);
127   // Performs the same function as addForkOfOwner but ignores duplicate
128   // requests. This idempotent function is used with RREF_FORK_REQUEST calls,
129   // whereas all other message types use the non-idempotent variant.
130   void addForkOfOwnerIfNotPresent(const RRefId& rrefId, const ForkId& forkId);
131   // Delete a fork of the ``OwnerRRef``. NB: this could trigger deletion on the
132   // IValue or py::object. For the later, this method will acquire GIL.
133   // NB: If this fork deletion triggered deleting OwnerRRef, this method will
134   // return a shared_ptr to the OwnerRRef, which is likely to be the last
135   // shared_ptr instance for it. Therefore, deleting this shared_ptr<OwnerRRef>
136   // will also trigger deleting the object it points to. If OwnerRRef holds a
137   // py::object, deleting it require GIL. The call site should guarded it with
138   // a GIL and reset the shared_ptr. The GIL-guarded deletion is intentionally
139   // left out of this function to avoid creating dependency on pybind.
140   c10::intrusive_ptr<RRef> delForkOfOwner(
141       const RRefId& rrefId,
142       const ForkId& forkId);
143 
144   // Invoked when pickling an RRef to setup child/fork properly
145   RRefForkData prepareChildFork(const c10::intrusive_ptr<RRef>& rref);
146   // Invoked when unpickling an RRef to send RREF_FORK_REQUEST to owner and
147   // send RREF_CHILD_ACCEPT to the parent.
148   // NB: forkId is necessary here as the rref could be an OwnerRRef
149   void notifyOwnerAndParentOfFork(
150       const ForkId& forkId,
151       worker_id_t parent,
152       const c10::intrusive_ptr<RRef>& rref);
153 
154   // When a UserRRef is forked to another worker (user or owner), it is added
155   // into pendingChildren_ to be held alive until it receives RREF_CHILD_ACCEPT
156   // from the child.
157   // NB: This is necessary for both user and owner child. As we do not have FIFO
158   // communication between workers, we need this strategy to make sure that all
159   // previously submitted rpc/remote calls are acked before sending out the
160   // RREF_USER_DELETE message. Otherwise, the OwnerRRef could be deleted too
161   // soon.
162   void addPendingChild(
163       const ForkId& forkId,
164       const c10::intrusive_ptr<RRef>& rref);
165   void delPendingChild(const ForkId& forkId);
166 
167   // When a UserRRef is created, it is added into pendingUsers_ to be held alive
168   // until it receives RREF_USER_ACCEPT from the owner.
169   void addPendingUser(
170       const ForkId& forkId,
171       const c10::intrusive_ptr<RRef>& rref);
172   void delPendingUser(const ForkId& forkId);
173   void addConfirmedUser(
174       const ForkId& forkId,
175       const c10::intrusive_ptr<RRef>& rref);
176 
177   // Retrieve a pending user given the fork ID. Throws if the user has already
178   // been confirmed (i.e. is no longer in the pendingUsers_ map).
179   c10::intrusive_ptr<RRef> getPendingUser(const ForkId& forkId);
180 
181   // Start recording new pending UserRRefs. All pending UserRRefs introduced
182   // after this point will be put into the thread_local userTable_, which will
183   // then be consumed and cleared in waitForThreadLocalPendingRRefs().
184   void recordThreadLocalPendingRRefs();
185   // End recording new pending UserRRefs, and clear the thread_local userTable_.
186   // Returns a Future which will be marked as completed when all pending
187   // UserRRefs in the current userTable_ are confirmed by their owners. The bool
188   // value in the Future is unused.
189   // This method is useful to make sure RRefs in user function arguments are
190   // confirmed before launching user code.
191   // NB: Callers of this method does not need to keep the returned Future alive,
192   // because this Future is already captured in callbacks of the
193   // PendingUserState. If there is no pending UserRRefs, this method returns a
194   // completed future.
195   c10::intrusive_ptr<JitFuture> waitForThreadLocalPendingRRefs();
196   // Only call this function when there are errors during a recording session,
197   // and it is likely that waitForThreadLocalPendingRRefs() cannot be invoked
198   // properly.
199   // TODO: make this a context guard
200   void clearRecordedPendingRRefsOnError();
201 
202   void delUser(
203       const worker_id_t owner,
204       const RRefId& rrefId,
205       const ForkId& forkId);
206   void delAllUsersAndUnforkedOwners(std::chrono::milliseconds timeoutMillis);
207 
208   std::unordered_map<std::string, std::string> getDebugInfo();
209 
210  private:
211   struct PendingUserState {
PendingUserStatePendingUserState212     PendingUserState(c10::intrusive_ptr<RRef> rref)
213         : rref_(std::move(rref)),
214           confirmationFuture_(c10::make_intrusive<JitFuture>(BoolType::get())) {
215     }
216 
confirmPendingUserState217     inline void confirm() {
218       c10::static_intrusive_pointer_cast<UserRRef>(rref_)->confirm();
219       confirmationFuture_->markCompleted();
220     }
221 
222     c10::intrusive_ptr<RRef> rref_;
223     // Use Future.wait() and Future.markCompleted() to block and unblock user
224     // functions. The bool value wrapped by the future_ is not used.
225     c10::intrusive_ptr<JitFuture> confirmationFuture_;
226   };
227 
228   RRefContext(std::shared_ptr<RpcAgent>);
229 
230   c10::intrusive_ptr<UserRRef> createUserRRef(
231       worker_id_t ownerId,
232       const RRefId& rrefId,
233       const ForkId& forkId,
234       const TypePtr& type);
235 
236   void finishForkRequest(const ForkId& forkId, worker_id_t parent);
237 
238   // If there is any leak on any RRef, this method will throw an error.
239   void checkRRefLeaks(bool ignoreRRefLeak);
240 
241   static std::atomic<local_id_t> nextLocalId_;
242 
243   const std::shared_ptr<RpcAgent> agent_;
244   mutable std::mutex mutex_;
245   // Keep OwnerRRefs alive until there is no living UserRRefs.
246   std::unordered_map<RRefId, c10::intrusive_ptr<RRef>, RRefId::Hash> owners_;
247   // A map to track OwnerRRefs that are requested but not yet created. This can
248   // happen if the to_here() message is processed on the owner before the
249   // corresponding creator rpc.remote() message. If this happens, instead of
250   // to_here() RPC thread to block waiting for the OwnerRRef creation, the
251   // RRefContext returns a Future, so that the RPC request processing logic can
252   // attach subsequent code as a callback to that Future.
253   // NB: the OwnerRRefs in this map must be cleared when the corresponding
254   // OwnerRRef is created. Note that the values in this map are intrusive_ptrs
255   // to c10::ivalue::Future that will be marked completed with the owner RRef.
256   std::unordered_map<RRefId, c10::intrusive_ptr<JitFuture>, RRefId::Hash>
257       pendingOwners_;
258   // Tracks known living UserRRefs of an OwnerRRef
259   std::unordered_map<
260       RRefId,
261       std::unordered_set<ForkId, ForkId::Hash>,
262       RRefId::Hash>
263       forks_;
264 
265   // This cond var is used by deleteAllUsers(), a event notification is sent if
266   // number of pending UserRRef or UserRRef children is reduced, or
267   // number of owned OwnerRRef is reduced.
268   std::condition_variable deleteAllUsersCV_;
269   // The follow 3 maps keep UserRRefs alive by holding a intrusive_ptr to the
270   // RRef instances. A UserRRef must be added into this map if any of the
271   // following two conditions is true:
272   //
273   // (1) A UserRRef has not been accepted by owner yet.
274   //
275   //     It can be used or shared, but cannot be deleted, and hence kept alive
276   //     in this map. A message of type RREF_USER_ACCEPT will move the
277   //     corresponding RRef from pendingUsers_ map to confirmedUsers_ map.
278   std::unordered_map<ForkId, std::shared_ptr<PendingUserState>, ForkId::Hash>
279       pendingUsers_;
280   //     UserRRefs are added into this map when it is confirmed by the owner.
281   //     When destroying RRefContext this map helps to find local UserRRefs
282   //     and send delete messages if they are still not deleted by Python
283   //     garbage collection.
284   std::unordered_map<ForkId, c10::weak_intrusive_ptr<RRef>, ForkId::Hash>
285       confirmedUsers_;
286 
287   // (2) A UserRRef has forked a child UserRRef which has not been accepted by
288   //     the owner yet.
289   //
290   //     In this case, this UserRRef cannot send out RREF_USER_DELETE message,
291   //     as it could potentially trigger the OwnerRRef been deleted before the
292   //     owner learns about the forked child.
293   std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash>
294       pendingChildren_;
295 
296   // The RRef context performs its operations through async RPC requests, in
297   // order to not block the user code. Therefore the RRef context's state may be
298   // lagging a bit behind what it is intended to be, while it waits for these
299   // requests to complete. To allow syncing when needed, we store the count of
300   // these pending requests, so that users can wait for it to reach zero.
301   std::atomic<int64_t> numPendingFutures_{0};
302 
303   std::mutex destroyedMutex_;
304   bool destroyed_{false};
305 
306   // Thread local states to keep UserRRefs deserialized from user function
307   // arguments.
308   static thread_local std::vector<std::shared_ptr<PendingUserState>> userTable_;
309   // A flag indicating whether subsequently created UserRRefs should be added to
310   // the thread_local userTable_. The flag is set to true before serializing
311   // RPC arguments and then set to false before running the corresponding
312   // user code. See addPendingUser and delPendingUser for more details.
313   // NB: The reason for having this flag is because addPendingUser are called in
314   // two cases, and we only want to track the 2nd case.
315   // (1) RRef as the return value: when calling rpc.remote, the UserRRef on the
316   //     caller side is added to the context using addPendingUser.
317   // (2) RRef as an argument: When running an RPC using RRefs as arguments, the
318   //     RRef is forwarded to the callee as new UserRRefs (if the callee is not
319   //     the owner). In this case, we block running the user function until all
320   //     UserRRefs are confirmed by the owner.
321   // This contract gurantees that no UserRRefs can be used remotely without
322   // confirmation. Note that, however, the UserRRef created by rpc.remote can
323   // still be passed to local functions as arguments and used there. This is by
324   // design, because this feature is especially useful when, say a master node
325   // creates multiple UserRRefs in a loop and then shares them with other nodes.
326   // Blocking every iteration in the loop until RRefs are confirmed will slow
327   // this down. This nuance on UserRRef can be interpreted as we only make
328   // exceptions for UserRRef creators. And using the UserRRef on its creator
329   // without confirmation is OK, because the creator would either call to_here
330   // or forward the UserRRef, and both would then require confirmations from the
331   // owner.
332   static thread_local bool recording_;
333 };
334 
335 } // namespace torch::distributed::rpc
336