xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/rref_impl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/rref_impl.h>
2 
3 #include <ATen/record_function.h>
4 #include <c10/core/impl/DeviceGuardImplInterface.h>
5 #include <fmt/format.h>
6 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
7 #include <torch/csrc/distributed/autograd/utils.h>
8 #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
9 #include <torch/csrc/distributed/rpc/rref_context.h>
10 #include <torch/csrc/distributed/rpc/rref_proto.h>
11 #include <torch/csrc/distributed/rpc/utils.h>
12 
13 #include <utility>
14 
15 namespace {
16 // If the type is subtype of named type, return its qualifiedname, otherwise
17 // return its type str.
getTypeStr(const c10::TypePtr & type)18 std::string getTypeStr(const c10::TypePtr& type) {
19   switch (type->kind()) {
20     case c10::TypeKind::FunctionType:
21       return type->castRaw<c10::FunctionType>()->name()->qualifiedName();
22     case c10::TypeKind::TupleType:
23       return type->castRaw<c10::TupleType>()->name()->qualifiedName();
24     case c10::TypeKind::ClassType:
25       return type->castRaw<c10::ClassType>()->name()->qualifiedName();
26     case c10::TypeKind::InterfaceType:
27       return type->castRaw<c10::InterfaceType>()->name()->qualifiedName();
28     default:
29       return type->annotation_str();
30   }
31 }
32 
33 } // namespace
34 
35 namespace torch::distributed::rpc {
36 
37 std::atomic<local_id_t> RRefContext::nextLocalId_{0};
38 
39 //////////////////////////  RRefForkData  /////////////////////////////////
40 
RRefForkData(worker_id_t ownerId,const RRefId & rrefId,const ForkId & forkId,worker_id_t parent,std::string typeStr)41 RRefForkData::RRefForkData(
42     worker_id_t ownerId,
43     const RRefId& rrefId,
44     const ForkId& forkId,
45     worker_id_t parent,
46     std::string typeStr)
47     : ownerId_(ownerId),
48       rrefId_(rrefId),
49       forkId_(forkId),
50       parent_(parent),
51       typeStr_(std::move(typeStr)) {}
52 
53 //////////////////////////////  RRef  /////////////////////////////////////
54 
RRef(worker_id_t ownerId,const RRefId & rrefId,TypePtr type)55 RRef::RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
56     : RRefInterface(),
57       ownerId_(ownerId),
58       rrefId_(rrefId),
59       type_(std::move(type)) {}
60 
fork() const61 RRefForkData RRef::fork() const {
62   auto& ctx = RRefContext::getInstance();
63   return RRefForkData(
64       ownerId_,
65       rrefId_,
66       ctx.genGloballyUniqueId(),
67       ctx.getWorkerId(),
68       getTypeStr(type_));
69 }
70 
handleError(RPCErrorType errorType,const JitFuture & jitFuture)71 void RRef::handleError(RPCErrorType errorType, const JitFuture& jitFuture) {
72   static std::unordered_map<
73       RPCErrorType,
74       std::function<void(const JitFuture& jitFuture)>,
75       std::hash<int>>
76       errorHandlers = {
77           {RPCErrorType::TIMEOUT,
78            [this](const JitFuture& /* unused */) { setTimedOut(); }},
79           {RPCErrorType::INTENTIONAL_FAILURE,
80            [this](const JitFuture& /* unused */) { setTimedOut(); }},
81           {RPCErrorType::UNKNOWN_ERROR, [](const JitFuture& jitFuture) {
82              // Default error handler
83              RRefContext::handleException(jitFuture);
84            }}};
85   errorHandlers.find(errorType)->second(jitFuture);
86 }
87 
88 //////////////////////////  UserRRef  /////////////////////////////////////
89 
UserRRef(worker_id_t ownerId,const RRefId & rrefId,const ForkId & forkId,TypePtr type)90 UserRRef::UserRRef(
91     worker_id_t ownerId,
92     const RRefId& rrefId,
93     const ForkId& forkId,
94     TypePtr type)
95     : RRef(ownerId, rrefId, std::move(type)),
96       forkId_(forkId),
97       confirmedByOwner_(false) {
98   // Do nothing,
99   // (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
100   //     a RREF_FORK_REQUEST message to the owner.
101   // (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will
102   //     properly notify the owner.
103 }
104 
tryDel()105 void UserRRef::tryDel() {
106   std::lock_guard<std::mutex> lockGuard(deletedOnOwnerMutex_);
107   if (!deletedOnOwner_) {
108     try {
109       RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_);
110       deletedOnOwner_ = true;
111     } catch (const std::exception& ex) {
112       LOG(ERROR) << "Error occurred when deleting" << *this << " : "
113                  << ex.what();
114     } catch (...) {
115       LOG(ERROR) << "Error occurred when deleting" << *this << " : "
116                  << "unknown error";
117     }
118   }
119 }
120 
~UserRRef()121 UserRRef::~UserRRef() {
122   tryDel();
123 }
124 
release_resources()125 void UserRRef::release_resources() {
126   tryDel();
127 }
128 
forkId() const129 const ForkId& UserRRef::forkId() const {
130   return forkId_;
131 }
132 
toHere(const float timeoutSeconds) const133 IValue UserRRef::toHere(const float timeoutSeconds) const {
134   TORCH_CHECK(
135       !getTimedOut(),
136       "RRef creation via rpc.remote() timed out, and it "
137       "is possible that the RRef on the owner node does not exist.");
138   // see Note [Best-Effort Check on Deleted UserRRefs]
139   TORCH_CHECK(
140       !deletedOnOwner_,
141       *this,
142       " has been deleted. Cannot call to_here() on it after deletion.");
143   auto toHereKey = std::string("");
144   if (torch::autograd::profiler::profilerEnabled()) {
145     toHereKey = fmt::format(
146         "to_here#({})->({})",
147         RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
148         RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_).name_);
149   }
150   RECORD_USER_SCOPE(toHereKey);
151   TORCH_CHECK(
152       !type_->is_module(),
153       *this,
154       " is an RRef to a ScriptModule. "
155       "It can't be sent through RPC "
156       "from owner, ",
157       ownerWorkerInfo(),
158       ", to user, ",
159       RpcAgent::getCurrentRpcAgent()->getWorkerInfo(),
160       ".");
161 
162   auto agent = RpcAgent::getCurrentRpcAgent();
163 
164   // ScriptRRefFetchCall message always carries autograd context id even if
165   // the message itself does not contain any tensor, because the response would
166   // potentially contain tensors.
167   c10::intrusive_ptr<Message> msgToSend;
168 
169   if (isPyObj()) {
170     msgToSend = PythonRRefFetchCall(ownerId_, rrefId()).toMessage();
171   } else {
172     msgToSend = ScriptRRefFetchCall(ownerId_, rrefId()).toMessage();
173   }
174 
175   // toHere is profiled as a blocking call, and does not execute operations on
176   // the remote node. Hence, don't wrap it with a profiling message since we
177   // don't need the profiler to be enabled remotely.
178   auto jitFuture = autograd::sendMessageWithAutograd(
179       *agent,
180       agent->getWorkerInfo(ownerId_),
181       std::move(msgToSend),
182       true /* forceGradRecording */,
183       timeoutSeconds,
184       true /* forceDisableProfiling */);
185 
186   // TODO: we should ideally be able to interrupt this blocking wait if we check
187   // getTimedOut() and it is true
188   // (https://github.com/pytorch/pytorch/issues/39411).
189   jitFuture->waitAndThrow();
190   auto messagePtr = jitFuture->constValue().toCustomClass<Message>();
191   MessageType msgType = messagePtr->type();
192   auto response = deserializeResponse(*messagePtr, msgType);
193   TORCH_INTERNAL_ASSERT(
194       msgType == MessageType::SCRIPT_RREF_FETCH_RET ||
195           msgType == MessageType::PYTHON_RREF_FETCH_RET,
196       "Message type should either be SCRIPT_RREF_FETCH_RET "
197       "or PYTHON_RREF_FETCH_RET");
198   RpcCommandBase& rpc = *response;
199   auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc);
200   if (isPyObj()) {
201     // wrap python serialized vector of ivalues into tuple, this
202     // made the C++ toHere interface to return single IValue
203     return ivalue::Tuple::create(rrefFetchRet.values());
204   } else {
205     return rrefFetchRet.values().front();
206   }
207 }
208 
fork() const209 RRefForkData UserRRef::fork() const {
210   // Note [Best-Effort Check on Deleted UserRRefs]
211   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
212   // This check does not guarantee correctness, as there could be another thread
213   // trying to delete this UserRRef concurrently. Passing this check does not
214   // mean this RRef will be alive throughout this function. This is just our
215   // best-effort attempt to raise proper error messages. The behavior of using
216   // deleted UserRRefs is undefined.
217   //
218   // The reason for not implementing strict checks are:
219   // 1. This would need to acquire lock on deletedOnOwnerMutex_, which would
220   //    introduce unnecessary overhead for most normal use cases.
221   // 2. This would introduce a lot of complexities to get the behavior correct.
222   //    Assume we acquired the lock here, and there is another thread X block
223   //    waiting in tryDel() on the lock. Exiting this fork function would
224   //    unblock thread X. However, while X proceeds with deleting this UserRRef,
225   //    the call site of fork() might have added the UserRRef to
226   //    pendingChildren_ map, but up to this point, nothing prevents X from
227   //    deleting this RRef even if it shouldn't do so due to the state change
228   //    in pendingChildren_. We might be able to get it right for now by locking
229   //    and checking pendingChildren_ in X, but the gain does not seem to
230   //    worth the complexity.
231   TORCH_CHECK(
232       !deletedOnOwner_,
233       *this,
234       " has been deleted. Cannot call fork an UserRRef after deletion.");
235   return RRef::fork();
236 }
237 
238 //////////////////////////  OwnerRRef  /////////////////////////////////////
239 
OwnerRRef(worker_id_t ownerId,const RRefId & rrefId,TypePtr type,std::vector<c10::Device> devices)240 OwnerRRef::OwnerRRef(
241     worker_id_t ownerId,
242     const RRefId& rrefId,
243     TypePtr type,
244     std::vector<c10::Device> devices)
245     : OwnerRRef(
246           ownerId,
247           rrefId,
248           std::move(type),
249           /* value */ {},
250           std::move(devices)) {}
251 
OwnerRRef(worker_id_t ownerId,const RRefId & rrefId,TypePtr type,std::optional<IValue> value,std::vector<c10::Device> devices)252 OwnerRRef::OwnerRRef(
253     worker_id_t ownerId,
254     const RRefId& rrefId,
255     TypePtr type,
256     std::optional<IValue> value,
257     std::vector<c10::Device> devices)
258     : RRef(ownerId, rrefId, std::move(type)) {
259   future_ = c10::make_intrusive<JitFuture>(type_, std::move(devices));
260 
261   if (value.has_value()) {
262     future_->markCompleted(value.value());
263   }
264 }
265 
getValue() const266 const IValue& OwnerRRef::getValue() const {
267   TORCH_CHECK(
268       !getTimedOut(),
269       "RRef creation via rpc.remote() timed out, and it "
270       "is possible that the RRef on the owner node does not exist.");
271   future_->waitAndThrow();
272   return future_->constValue();
273 }
274 
hasValue() const275 bool OwnerRRef::hasValue() const {
276   return future_->completed();
277 }
278 
getFuture()279 c10::intrusive_ptr<JitFuture> OwnerRRef::getFuture() {
280   return future_;
281 }
282 
setValue(IValue && value)283 void OwnerRRef::setValue(IValue&& value) {
284   future_->markCompleted(value);
285 }
286 
setError(std::exception_ptr eptr)287 void OwnerRRef::setError(std::exception_ptr eptr) {
288   future_->setErrorIfNeeded(std::move(eptr));
289 }
290 
operator <<(std::ostream & os,const RRef & rref)291 std::ostream& operator<<(std::ostream& os, const RRef& rref) {
292   if (rref.isOwner()) {
293     return os << "OwnerRRef("
294               << "rref_id=" << rref.rrefId() << ")";
295   } else {
296     return os << "UserRRef("
297               << "rref_id=" << rref.rrefId()
298               << ", fork_id=" << static_cast<const UserRRef*>(&rref)->forkId()
299               << ")";
300   }
301 }
302 
303 } // namespace torch::distributed::rpc
304