1 #pragma once 2 3 #include <torch/csrc/distributed/rpc/message.h> 4 #include <torch/csrc/distributed/rpc/rpc_agent.h> 5 #include <torch/csrc/distributed/rpc/rref_impl.h> 6 #include <torch/csrc/distributed/rpc/types.h> 7 #include <torch/csrc/distributed/rpc/utils.h> 8 9 #include <atomic> 10 #include <optional> 11 12 namespace torch::distributed::rpc { 13 14 namespace callback { 15 // It's the callback for RemoteCall. 16 void TORCH_API 17 confirmPendingUser(const JitFuture& jitFuture, const ForkId& expectedForkId); 18 19 // It's the callback for finishing creating owner rref, it returned deletedRRef, 20 // so that the deletedRRef can be handled under GIL in python_functions.cpp if 21 // deletedRRef contains python object. 22 c10::intrusive_ptr<RRef> TORCH_API 23 finishCreatingOwnerRRef(const JitFuture& jitFuture, const RRefId& rrefId); 24 } // namespace callback 25 26 // Manages RRef lifetime and keeps track of RRef forks. 27 class TORCH_API RRefContext { 28 public: 29 static RRefContext& getInstance(); 30 // NB: This method must be called before destructing RRefContext singleton. 31 // Similar to delForkOfOwner, this method returns a vector of OwnerRRefs that 32 // hold py::object. The call-site is also responsible for resetting those 33 // shared_ptr objects with a GIL. See comments at delForkOfOwner() for more 34 // details. 35 static std::vector<c10::intrusive_ptr<RRef>> destroyInstance( 36 bool ignoreRRefLeak = true); 37 38 static void handleException(const JitFuture& jitFuture); 39 40 // handle exception without throw ::c10::Error again 41 static void handleExceptionSilent(const JitFuture& jitFuture); 42 43 RRefContext(const RRefContext&) = delete; 44 RRefContext(RRefContext&& other) = delete; 45 void operator=(const RRefContext&) = delete; 46 RRefContext& operator=(RRefContext&& other) = delete; 47 48 ~RRefContext(); 49 50 // get the worker id of the current worker getWorkerId()51 inline worker_id_t getWorkerId() const { 52 return agent_->getWorkerInfo().id_; 53 } 54 55 // get the worker name of the current worker getWorkerName()56 inline const std::string& getWorkerName() const { 57 return agent_->getWorkerInfo().name_; 58 } 59 60 // generate a globally unique ID genGloballyUniqueId()61 inline GloballyUniqueId genGloballyUniqueId() { 62 return GloballyUniqueId(getWorkerId(), nextLocalId_++); 63 } 64 agent()65 inline const std::shared_ptr<RpcAgent>& agent() const { 66 return agent_; 67 } 68 69 // create a ``UserRRef`` owned by the worker ``ownerId`` 70 c10::intrusive_ptr<UserRRef> createUserRRef( 71 worker_id_t ownerId, 72 const TypePtr& type); 73 74 // Convert an RRefForkData into an RRef. This RRef could be user or owner. 75 // This RRef could have already existed before, or could be created in this 76 // method, we pass type here to validate or help the rref creation. 77 c10::intrusive_ptr<RRef> getOrCreateRRef( 78 const RRefForkData& rfd, 79 const TypePtr& type); 80 81 // Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new 82 // one. This function is called in two places: 83 // 1. when processing ``rpc.remote()``, i.e., ``SCRIPT_REMOTE_CALL`` 84 // ``PYTHON_REMOTE_CALL``. 85 // 2. when unpickling ``OwnerRRef``. 86 // What's common in these two cases are, 1) the RRefId is already generated 87 // 2) the TypePtr is presented. So it can always create the ``OwnerRRef`` if 88 // it is not yet available. 89 c10::intrusive_ptr<OwnerRRef> getOrCreateOwnerRRef( 90 const RRefId& rrefId, 91 const TypePtr& type); 92 93 // Create an empty owner rref of type. 94 // This method is called to first time generate an ``OwnerRRef``, e.g., 95 // 1) ``rpc.RRef(obj)`` 96 // 2) create the ``OwnerRRef`` on `rpc.remote()` caller side. 97 // What's common in these two cases are, 1) the RRefId hasn't been generated 98 // 2) the TypePtr is presented. 99 c10::intrusive_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type); 100 101 // Returns a Future of the OwnerRRef, which will be marked completed when 102 // ``OwnerRRef`` is created. This method is used when the TypePtr is not 103 // available, e.g., when processing to_here(). The forceCreated flag can be 104 // used to ensure that the rref is created on the owner, otherwise throw in 105 // cases where the user of this API expects this to return a completed future. 106 // Note that the return value is a intrusive_ptr to a c10::ivalue::Future that 107 // holds the RRef. 108 c10::intrusive_ptr<JitFuture> getOwnerRRef( 109 const RRefId& rrefId, 110 bool forceCreated = false); 111 112 // Adding the RRefId of an OwnerRRef into the forks_ map. This is useful when 113 // making a remote call to self, which as for now, still goes through serde 114 // and invokes request callback. In this case, the OwnerRRef has already been 115 // created on the send side, and we need to pass it to the receive side, 116 // instead of creating a new OwnerRRef. This is done by adding the OwnerRRef 117 // into owners_. However, that alone is not enough, as it could be deleted 118 // when all UserRRef die, which would then remove the OwnerRRef from owners_ 119 // and this could happen before the self remote call finishes. To prevent 120 // that, this API adds the RRefId as a ForkId, which will then delete the 121 // ForkId when the self remote is done. 122 void addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref); 123 124 // Register a fork of the ``OwnerRRef``, and inserts a intrusive_ptr of the 125 // ``OwnerRRef`` in a map to keep it alive. 126 void addForkOfOwner(const RRefId& rrefId, const ForkId& forkId); 127 // Performs the same function as addForkOfOwner but ignores duplicate 128 // requests. This idempotent function is used with RREF_FORK_REQUEST calls, 129 // whereas all other message types use the non-idempotent variant. 130 void addForkOfOwnerIfNotPresent(const RRefId& rrefId, const ForkId& forkId); 131 // Delete a fork of the ``OwnerRRef``. NB: this could trigger deletion on the 132 // IValue or py::object. For the later, this method will acquire GIL. 133 // NB: If this fork deletion triggered deleting OwnerRRef, this method will 134 // return a shared_ptr to the OwnerRRef, which is likely to be the last 135 // shared_ptr instance for it. Therefore, deleting this shared_ptr<OwnerRRef> 136 // will also trigger deleting the object it points to. If OwnerRRef holds a 137 // py::object, deleting it require GIL. The call site should guarded it with 138 // a GIL and reset the shared_ptr. The GIL-guarded deletion is intentionally 139 // left out of this function to avoid creating dependency on pybind. 140 c10::intrusive_ptr<RRef> delForkOfOwner( 141 const RRefId& rrefId, 142 const ForkId& forkId); 143 144 // Invoked when pickling an RRef to setup child/fork properly 145 RRefForkData prepareChildFork(const c10::intrusive_ptr<RRef>& rref); 146 // Invoked when unpickling an RRef to send RREF_FORK_REQUEST to owner and 147 // send RREF_CHILD_ACCEPT to the parent. 148 // NB: forkId is necessary here as the rref could be an OwnerRRef 149 void notifyOwnerAndParentOfFork( 150 const ForkId& forkId, 151 worker_id_t parent, 152 const c10::intrusive_ptr<RRef>& rref); 153 154 // When a UserRRef is forked to another worker (user or owner), it is added 155 // into pendingChildren_ to be held alive until it receives RREF_CHILD_ACCEPT 156 // from the child. 157 // NB: This is necessary for both user and owner child. As we do not have FIFO 158 // communication between workers, we need this strategy to make sure that all 159 // previously submitted rpc/remote calls are acked before sending out the 160 // RREF_USER_DELETE message. Otherwise, the OwnerRRef could be deleted too 161 // soon. 162 void addPendingChild( 163 const ForkId& forkId, 164 const c10::intrusive_ptr<RRef>& rref); 165 void delPendingChild(const ForkId& forkId); 166 167 // When a UserRRef is created, it is added into pendingUsers_ to be held alive 168 // until it receives RREF_USER_ACCEPT from the owner. 169 void addPendingUser( 170 const ForkId& forkId, 171 const c10::intrusive_ptr<RRef>& rref); 172 void delPendingUser(const ForkId& forkId); 173 void addConfirmedUser( 174 const ForkId& forkId, 175 const c10::intrusive_ptr<RRef>& rref); 176 177 // Retrieve a pending user given the fork ID. Throws if the user has already 178 // been confirmed (i.e. is no longer in the pendingUsers_ map). 179 c10::intrusive_ptr<RRef> getPendingUser(const ForkId& forkId); 180 181 // Start recording new pending UserRRefs. All pending UserRRefs introduced 182 // after this point will be put into the thread_local userTable_, which will 183 // then be consumed and cleared in waitForThreadLocalPendingRRefs(). 184 void recordThreadLocalPendingRRefs(); 185 // End recording new pending UserRRefs, and clear the thread_local userTable_. 186 // Returns a Future which will be marked as completed when all pending 187 // UserRRefs in the current userTable_ are confirmed by their owners. The bool 188 // value in the Future is unused. 189 // This method is useful to make sure RRefs in user function arguments are 190 // confirmed before launching user code. 191 // NB: Callers of this method does not need to keep the returned Future alive, 192 // because this Future is already captured in callbacks of the 193 // PendingUserState. If there is no pending UserRRefs, this method returns a 194 // completed future. 195 c10::intrusive_ptr<JitFuture> waitForThreadLocalPendingRRefs(); 196 // Only call this function when there are errors during a recording session, 197 // and it is likely that waitForThreadLocalPendingRRefs() cannot be invoked 198 // properly. 199 // TODO: make this a context guard 200 void clearRecordedPendingRRefsOnError(); 201 202 void delUser( 203 const worker_id_t owner, 204 const RRefId& rrefId, 205 const ForkId& forkId); 206 void delAllUsersAndUnforkedOwners(std::chrono::milliseconds timeoutMillis); 207 208 std::unordered_map<std::string, std::string> getDebugInfo(); 209 210 private: 211 struct PendingUserState { PendingUserStatePendingUserState212 PendingUserState(c10::intrusive_ptr<RRef> rref) 213 : rref_(std::move(rref)), 214 confirmationFuture_(c10::make_intrusive<JitFuture>(BoolType::get())) { 215 } 216 confirmPendingUserState217 inline void confirm() { 218 c10::static_intrusive_pointer_cast<UserRRef>(rref_)->confirm(); 219 confirmationFuture_->markCompleted(); 220 } 221 222 c10::intrusive_ptr<RRef> rref_; 223 // Use Future.wait() and Future.markCompleted() to block and unblock user 224 // functions. The bool value wrapped by the future_ is not used. 225 c10::intrusive_ptr<JitFuture> confirmationFuture_; 226 }; 227 228 RRefContext(std::shared_ptr<RpcAgent>); 229 230 c10::intrusive_ptr<UserRRef> createUserRRef( 231 worker_id_t ownerId, 232 const RRefId& rrefId, 233 const ForkId& forkId, 234 const TypePtr& type); 235 236 void finishForkRequest(const ForkId& forkId, worker_id_t parent); 237 238 // If there is any leak on any RRef, this method will throw an error. 239 void checkRRefLeaks(bool ignoreRRefLeak); 240 241 static std::atomic<local_id_t> nextLocalId_; 242 243 const std::shared_ptr<RpcAgent> agent_; 244 mutable std::mutex mutex_; 245 // Keep OwnerRRefs alive until there is no living UserRRefs. 246 std::unordered_map<RRefId, c10::intrusive_ptr<RRef>, RRefId::Hash> owners_; 247 // A map to track OwnerRRefs that are requested but not yet created. This can 248 // happen if the to_here() message is processed on the owner before the 249 // corresponding creator rpc.remote() message. If this happens, instead of 250 // to_here() RPC thread to block waiting for the OwnerRRef creation, the 251 // RRefContext returns a Future, so that the RPC request processing logic can 252 // attach subsequent code as a callback to that Future. 253 // NB: the OwnerRRefs in this map must be cleared when the corresponding 254 // OwnerRRef is created. Note that the values in this map are intrusive_ptrs 255 // to c10::ivalue::Future that will be marked completed with the owner RRef. 256 std::unordered_map<RRefId, c10::intrusive_ptr<JitFuture>, RRefId::Hash> 257 pendingOwners_; 258 // Tracks known living UserRRefs of an OwnerRRef 259 std::unordered_map< 260 RRefId, 261 std::unordered_set<ForkId, ForkId::Hash>, 262 RRefId::Hash> 263 forks_; 264 265 // This cond var is used by deleteAllUsers(), a event notification is sent if 266 // number of pending UserRRef or UserRRef children is reduced, or 267 // number of owned OwnerRRef is reduced. 268 std::condition_variable deleteAllUsersCV_; 269 // The follow 3 maps keep UserRRefs alive by holding a intrusive_ptr to the 270 // RRef instances. A UserRRef must be added into this map if any of the 271 // following two conditions is true: 272 // 273 // (1) A UserRRef has not been accepted by owner yet. 274 // 275 // It can be used or shared, but cannot be deleted, and hence kept alive 276 // in this map. A message of type RREF_USER_ACCEPT will move the 277 // corresponding RRef from pendingUsers_ map to confirmedUsers_ map. 278 std::unordered_map<ForkId, std::shared_ptr<PendingUserState>, ForkId::Hash> 279 pendingUsers_; 280 // UserRRefs are added into this map when it is confirmed by the owner. 281 // When destroying RRefContext this map helps to find local UserRRefs 282 // and send delete messages if they are still not deleted by Python 283 // garbage collection. 284 std::unordered_map<ForkId, c10::weak_intrusive_ptr<RRef>, ForkId::Hash> 285 confirmedUsers_; 286 287 // (2) A UserRRef has forked a child UserRRef which has not been accepted by 288 // the owner yet. 289 // 290 // In this case, this UserRRef cannot send out RREF_USER_DELETE message, 291 // as it could potentially trigger the OwnerRRef been deleted before the 292 // owner learns about the forked child. 293 std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash> 294 pendingChildren_; 295 296 // The RRef context performs its operations through async RPC requests, in 297 // order to not block the user code. Therefore the RRef context's state may be 298 // lagging a bit behind what it is intended to be, while it waits for these 299 // requests to complete. To allow syncing when needed, we store the count of 300 // these pending requests, so that users can wait for it to reach zero. 301 std::atomic<int64_t> numPendingFutures_{0}; 302 303 std::mutex destroyedMutex_; 304 bool destroyed_{false}; 305 306 // Thread local states to keep UserRRefs deserialized from user function 307 // arguments. 308 static thread_local std::vector<std::shared_ptr<PendingUserState>> userTable_; 309 // A flag indicating whether subsequently created UserRRefs should be added to 310 // the thread_local userTable_. The flag is set to true before serializing 311 // RPC arguments and then set to false before running the corresponding 312 // user code. See addPendingUser and delPendingUser for more details. 313 // NB: The reason for having this flag is because addPendingUser are called in 314 // two cases, and we only want to track the 2nd case. 315 // (1) RRef as the return value: when calling rpc.remote, the UserRRef on the 316 // caller side is added to the context using addPendingUser. 317 // (2) RRef as an argument: When running an RPC using RRefs as arguments, the 318 // RRef is forwarded to the callee as new UserRRefs (if the callee is not 319 // the owner). In this case, we block running the user function until all 320 // UserRRefs are confirmed by the owner. 321 // This contract gurantees that no UserRRefs can be used remotely without 322 // confirmation. Note that, however, the UserRRef created by rpc.remote can 323 // still be passed to local functions as arguments and used there. This is by 324 // design, because this feature is especially useful when, say a master node 325 // creates multiple UserRRefs in a loop and then shares them with other nodes. 326 // Blocking every iteration in the loop until RRefs are confirmed will slow 327 // this down. This nuance on UserRRef can be interpreted as we only make 328 // exceptions for UserRRef creators. And using the UserRRef on its creator 329 // without confirmation is OK, because the creator would either call to_here 330 // or forward the UserRRef, and both would then require confirmations from the 331 // owner. 332 static thread_local bool recording_; 333 }; 334 335 } // namespace torch::distributed::rpc 336