1 #include <torch/csrc/distributed/rpc/rref_impl.h>
2
3 #include <ATen/record_function.h>
4 #include <c10/core/impl/DeviceGuardImplInterface.h>
5 #include <fmt/format.h>
6 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
7 #include <torch/csrc/distributed/autograd/utils.h>
8 #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
9 #include <torch/csrc/distributed/rpc/rref_context.h>
10 #include <torch/csrc/distributed/rpc/rref_proto.h>
11 #include <torch/csrc/distributed/rpc/utils.h>
12
13 #include <utility>
14
15 namespace {
16 // If the type is subtype of named type, return its qualifiedname, otherwise
17 // return its type str.
getTypeStr(const c10::TypePtr & type)18 std::string getTypeStr(const c10::TypePtr& type) {
19 switch (type->kind()) {
20 case c10::TypeKind::FunctionType:
21 return type->castRaw<c10::FunctionType>()->name()->qualifiedName();
22 case c10::TypeKind::TupleType:
23 return type->castRaw<c10::TupleType>()->name()->qualifiedName();
24 case c10::TypeKind::ClassType:
25 return type->castRaw<c10::ClassType>()->name()->qualifiedName();
26 case c10::TypeKind::InterfaceType:
27 return type->castRaw<c10::InterfaceType>()->name()->qualifiedName();
28 default:
29 return type->annotation_str();
30 }
31 }
32
33 } // namespace
34
35 namespace torch::distributed::rpc {
36
37 std::atomic<local_id_t> RRefContext::nextLocalId_{0};
38
39 ////////////////////////// RRefForkData /////////////////////////////////
40
RRefForkData(worker_id_t ownerId,const RRefId & rrefId,const ForkId & forkId,worker_id_t parent,std::string typeStr)41 RRefForkData::RRefForkData(
42 worker_id_t ownerId,
43 const RRefId& rrefId,
44 const ForkId& forkId,
45 worker_id_t parent,
46 std::string typeStr)
47 : ownerId_(ownerId),
48 rrefId_(rrefId),
49 forkId_(forkId),
50 parent_(parent),
51 typeStr_(std::move(typeStr)) {}
52
53 ////////////////////////////// RRef /////////////////////////////////////
54
RRef(worker_id_t ownerId,const RRefId & rrefId,TypePtr type)55 RRef::RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
56 : RRefInterface(),
57 ownerId_(ownerId),
58 rrefId_(rrefId),
59 type_(std::move(type)) {}
60
fork() const61 RRefForkData RRef::fork() const {
62 auto& ctx = RRefContext::getInstance();
63 return RRefForkData(
64 ownerId_,
65 rrefId_,
66 ctx.genGloballyUniqueId(),
67 ctx.getWorkerId(),
68 getTypeStr(type_));
69 }
70
handleError(RPCErrorType errorType,const JitFuture & jitFuture)71 void RRef::handleError(RPCErrorType errorType, const JitFuture& jitFuture) {
72 static std::unordered_map<
73 RPCErrorType,
74 std::function<void(const JitFuture& jitFuture)>,
75 std::hash<int>>
76 errorHandlers = {
77 {RPCErrorType::TIMEOUT,
78 [this](const JitFuture& /* unused */) { setTimedOut(); }},
79 {RPCErrorType::INTENTIONAL_FAILURE,
80 [this](const JitFuture& /* unused */) { setTimedOut(); }},
81 {RPCErrorType::UNKNOWN_ERROR, [](const JitFuture& jitFuture) {
82 // Default error handler
83 RRefContext::handleException(jitFuture);
84 }}};
85 errorHandlers.find(errorType)->second(jitFuture);
86 }
87
88 ////////////////////////// UserRRef /////////////////////////////////////
89
UserRRef(worker_id_t ownerId,const RRefId & rrefId,const ForkId & forkId,TypePtr type)90 UserRRef::UserRRef(
91 worker_id_t ownerId,
92 const RRefId& rrefId,
93 const ForkId& forkId,
94 TypePtr type)
95 : RRef(ownerId, rrefId, std::move(type)),
96 forkId_(forkId),
97 confirmedByOwner_(false) {
98 // Do nothing,
99 // (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
100 // a RREF_FORK_REQUEST message to the owner.
101 // (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will
102 // properly notify the owner.
103 }
104
tryDel()105 void UserRRef::tryDel() {
106 std::lock_guard<std::mutex> lockGuard(deletedOnOwnerMutex_);
107 if (!deletedOnOwner_) {
108 try {
109 RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_);
110 deletedOnOwner_ = true;
111 } catch (const std::exception& ex) {
112 LOG(ERROR) << "Error occurred when deleting" << *this << " : "
113 << ex.what();
114 } catch (...) {
115 LOG(ERROR) << "Error occurred when deleting" << *this << " : "
116 << "unknown error";
117 }
118 }
119 }
120
~UserRRef()121 UserRRef::~UserRRef() {
122 tryDel();
123 }
124
release_resources()125 void UserRRef::release_resources() {
126 tryDel();
127 }
128
forkId() const129 const ForkId& UserRRef::forkId() const {
130 return forkId_;
131 }
132
toHere(const float timeoutSeconds) const133 IValue UserRRef::toHere(const float timeoutSeconds) const {
134 TORCH_CHECK(
135 !getTimedOut(),
136 "RRef creation via rpc.remote() timed out, and it "
137 "is possible that the RRef on the owner node does not exist.");
138 // see Note [Best-Effort Check on Deleted UserRRefs]
139 TORCH_CHECK(
140 !deletedOnOwner_,
141 *this,
142 " has been deleted. Cannot call to_here() on it after deletion.");
143 auto toHereKey = std::string("");
144 if (torch::autograd::profiler::profilerEnabled()) {
145 toHereKey = fmt::format(
146 "to_here#({})->({})",
147 RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
148 RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_).name_);
149 }
150 RECORD_USER_SCOPE(toHereKey);
151 TORCH_CHECK(
152 !type_->is_module(),
153 *this,
154 " is an RRef to a ScriptModule. "
155 "It can't be sent through RPC "
156 "from owner, ",
157 ownerWorkerInfo(),
158 ", to user, ",
159 RpcAgent::getCurrentRpcAgent()->getWorkerInfo(),
160 ".");
161
162 auto agent = RpcAgent::getCurrentRpcAgent();
163
164 // ScriptRRefFetchCall message always carries autograd context id even if
165 // the message itself does not contain any tensor, because the response would
166 // potentially contain tensors.
167 c10::intrusive_ptr<Message> msgToSend;
168
169 if (isPyObj()) {
170 msgToSend = PythonRRefFetchCall(ownerId_, rrefId()).toMessage();
171 } else {
172 msgToSend = ScriptRRefFetchCall(ownerId_, rrefId()).toMessage();
173 }
174
175 // toHere is profiled as a blocking call, and does not execute operations on
176 // the remote node. Hence, don't wrap it with a profiling message since we
177 // don't need the profiler to be enabled remotely.
178 auto jitFuture = autograd::sendMessageWithAutograd(
179 *agent,
180 agent->getWorkerInfo(ownerId_),
181 std::move(msgToSend),
182 true /* forceGradRecording */,
183 timeoutSeconds,
184 true /* forceDisableProfiling */);
185
186 // TODO: we should ideally be able to interrupt this blocking wait if we check
187 // getTimedOut() and it is true
188 // (https://github.com/pytorch/pytorch/issues/39411).
189 jitFuture->waitAndThrow();
190 auto messagePtr = jitFuture->constValue().toCustomClass<Message>();
191 MessageType msgType = messagePtr->type();
192 auto response = deserializeResponse(*messagePtr, msgType);
193 TORCH_INTERNAL_ASSERT(
194 msgType == MessageType::SCRIPT_RREF_FETCH_RET ||
195 msgType == MessageType::PYTHON_RREF_FETCH_RET,
196 "Message type should either be SCRIPT_RREF_FETCH_RET "
197 "or PYTHON_RREF_FETCH_RET");
198 RpcCommandBase& rpc = *response;
199 auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc);
200 if (isPyObj()) {
201 // wrap python serialized vector of ivalues into tuple, this
202 // made the C++ toHere interface to return single IValue
203 return ivalue::Tuple::create(rrefFetchRet.values());
204 } else {
205 return rrefFetchRet.values().front();
206 }
207 }
208
fork() const209 RRefForkData UserRRef::fork() const {
210 // Note [Best-Effort Check on Deleted UserRRefs]
211 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
212 // This check does not guarantee correctness, as there could be another thread
213 // trying to delete this UserRRef concurrently. Passing this check does not
214 // mean this RRef will be alive throughout this function. This is just our
215 // best-effort attempt to raise proper error messages. The behavior of using
216 // deleted UserRRefs is undefined.
217 //
218 // The reason for not implementing strict checks are:
219 // 1. This would need to acquire lock on deletedOnOwnerMutex_, which would
220 // introduce unnecessary overhead for most normal use cases.
221 // 2. This would introduce a lot of complexities to get the behavior correct.
222 // Assume we acquired the lock here, and there is another thread X block
223 // waiting in tryDel() on the lock. Exiting this fork function would
224 // unblock thread X. However, while X proceeds with deleting this UserRRef,
225 // the call site of fork() might have added the UserRRef to
226 // pendingChildren_ map, but up to this point, nothing prevents X from
227 // deleting this RRef even if it shouldn't do so due to the state change
228 // in pendingChildren_. We might be able to get it right for now by locking
229 // and checking pendingChildren_ in X, but the gain does not seem to
230 // worth the complexity.
231 TORCH_CHECK(
232 !deletedOnOwner_,
233 *this,
234 " has been deleted. Cannot call fork an UserRRef after deletion.");
235 return RRef::fork();
236 }
237
238 ////////////////////////// OwnerRRef /////////////////////////////////////
239
OwnerRRef(worker_id_t ownerId,const RRefId & rrefId,TypePtr type,std::vector<c10::Device> devices)240 OwnerRRef::OwnerRRef(
241 worker_id_t ownerId,
242 const RRefId& rrefId,
243 TypePtr type,
244 std::vector<c10::Device> devices)
245 : OwnerRRef(
246 ownerId,
247 rrefId,
248 std::move(type),
249 /* value */ {},
250 std::move(devices)) {}
251
OwnerRRef(worker_id_t ownerId,const RRefId & rrefId,TypePtr type,std::optional<IValue> value,std::vector<c10::Device> devices)252 OwnerRRef::OwnerRRef(
253 worker_id_t ownerId,
254 const RRefId& rrefId,
255 TypePtr type,
256 std::optional<IValue> value,
257 std::vector<c10::Device> devices)
258 : RRef(ownerId, rrefId, std::move(type)) {
259 future_ = c10::make_intrusive<JitFuture>(type_, std::move(devices));
260
261 if (value.has_value()) {
262 future_->markCompleted(value.value());
263 }
264 }
265
getValue() const266 const IValue& OwnerRRef::getValue() const {
267 TORCH_CHECK(
268 !getTimedOut(),
269 "RRef creation via rpc.remote() timed out, and it "
270 "is possible that the RRef on the owner node does not exist.");
271 future_->waitAndThrow();
272 return future_->constValue();
273 }
274
hasValue() const275 bool OwnerRRef::hasValue() const {
276 return future_->completed();
277 }
278
getFuture()279 c10::intrusive_ptr<JitFuture> OwnerRRef::getFuture() {
280 return future_;
281 }
282
setValue(IValue && value)283 void OwnerRRef::setValue(IValue&& value) {
284 future_->markCompleted(value);
285 }
286
setError(std::exception_ptr eptr)287 void OwnerRRef::setError(std::exception_ptr eptr) {
288 future_->setErrorIfNeeded(std::move(eptr));
289 }
290
operator <<(std::ostream & os,const RRef & rref)291 std::ostream& operator<<(std::ostream& os, const RRef& rref) {
292 if (rref.isOwner()) {
293 return os << "OwnerRRef("
294 << "rref_id=" << rref.rrefId() << ")";
295 } else {
296 return os << "UserRRef("
297 << "rref_id=" << rref.rrefId()
298 << ", fork_id=" << static_cast<const UserRRef*>(&rref)->forkId()
299 << ")";
300 }
301 }
302
303 } // namespace torch::distributed::rpc
304