1 #include <c10/util/DeadlockDetection.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3
4 namespace torch::distributed::rpc {
5
RegisterWorkerInfoOnce()6 RegisterWorkerInfoOnce::RegisterWorkerInfoOnce() {
7 // WorkerInfo needs to be registered exactly once. Since the op registration
8 // happens in libtorch_python we wrap the class registration in a helper to
9 // make sure that if there's multiple copies of Python such as used in
10 // torch::deploy we only ever register it once.
11 static auto workerInfo = torch::class_<WorkerInfo>("dist_rpc", "WorkerInfo")
12 .def(torch::init<std::string, int64_t>());
13 }
14
WorkerInfo(std::string name,int64_t id)15 WorkerInfo::WorkerInfo(std::string name, int64_t id)
16 : WorkerInfo(std::move(name), (worker_id_t)id) {
17 TORCH_CHECK(
18 id <= std::numeric_limits<worker_id_t>::max(),
19 "RPC worker id ",
20 id,
21 " out of bound of int16_t.");
22 }
23
WorkerInfo(std::string name,worker_id_t id)24 WorkerInfo::WorkerInfo(std::string name, worker_id_t id)
25 : name_(std::move(name)), id_(id) {
26 bool validSize = name_.length() < MAX_NAME_LEN && !name_.empty();
27 bool validChar =
28 std::find_if(name_.begin(), name_.end(), [](char c) {
29 return !(std::isalnum(c) || c == '-' || c == '_' || c == ':');
30 }) == name_.end();
31 TORCH_CHECK(
32 validSize && validChar,
33 "Worker name must match ^[A-Za-z0-9-_:]*$, "
34 "and must be non-empty and shorter than ",
35 MAX_NAME_LEN,
36 " chars, "
37 "but got ",
38 name_);
39 }
40
41 // Large Time Duration for waiting on the condition variable until the map is
42 // population. Cannot use
43 // std::chrono::time_point<std::chrono::steady_clock>::max() due to a known
44 // overflow-related bug.
45 constexpr auto kLargeTimeDuration = std::chrono::hours(10000);
46
RpcAgent(WorkerInfo workerId,std::unique_ptr<RequestCallback> cb,std::chrono::milliseconds rpcTimeout)47 RpcAgent::RpcAgent(
48 WorkerInfo workerId,
49 std::unique_ptr<RequestCallback> cb,
50 std::chrono::milliseconds rpcTimeout)
51 : workerInfo_(std::move(workerId)),
52 cb_(std::move(cb)),
53 rpcTimeout_(rpcTimeout),
54 profilingEnabled_(false),
55 rpcAgentRunning_(false) {}
56
~RpcAgent()57 RpcAgent::~RpcAgent() {
58 if (rpcAgentRunning_.load()) {
59 shutdown();
60 }
61 }
62
start()63 void RpcAgent::start() {
64 rpcAgentRunning_.store(true);
65 rpcRetryThread_ = std::thread(&RpcAgent::retryExpiredRpcs, this);
66 startImpl();
67 }
68
shutdown()69 void RpcAgent::shutdown() {
70 TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP();
71 std::unique_lock<std::mutex> lock(rpcRetryMutex_);
72 rpcAgentRunning_.store(false);
73 lock.unlock();
74 rpcRetryMapCV_.notify_one();
75 if (rpcRetryThread_.joinable()) {
76 rpcRetryThread_.join();
77 }
78 // NOLINTNEXTLINE(clang-analyzer-cplusplus.PureVirtualCall)
79 shutdownImpl();
80 }
81
sendWithRetries(const WorkerInfo & to,c10::intrusive_ptr<Message> message,RpcRetryOptions retryOptions)82 c10::intrusive_ptr<JitFuture> RpcAgent::sendWithRetries(
83 const WorkerInfo& to,
84 c10::intrusive_ptr<Message> message,
85 RpcRetryOptions retryOptions) {
86 TORCH_CHECK(retryOptions.maxRetries >= 0, "maxRetries cannot be negative.");
87 TORCH_CHECK(
88 retryOptions.retryBackoff >= 1,
89 "maxRetries cannot be exponentially decaying.");
90 TORCH_CHECK(
91 retryOptions.rpcRetryDuration.count() >= 0,
92 "rpcRetryDuration cannot be negative.");
93
94 auto originalFuture =
95 c10::make_intrusive<JitFuture>(at::AnyClassType::get(), getDevices());
96 steady_clock_time_point newTime =
97 computeNewRpcRetryTime(retryOptions, /* retryCount */ 0);
98 auto firstRetryRpc = std::make_shared<RpcRetryInfo>(
99 to,
100 message,
101 originalFuture,
102 /* retryCount */ 0,
103 retryOptions);
104 auto jitFuture = send(to, std::move(message));
105 jitFuture->addCallback([this, newTime, firstRetryRpc](JitFuture& future) {
106 rpcRetryCallback(future, newTime, firstRetryRpc);
107 });
108
109 return originalFuture;
110 }
111
retryExpiredRpcs()112 void RpcAgent::retryExpiredRpcs() {
113 // Stores the retried futures so callbacks can be added outside the lock.
114 std::vector<
115 std::pair<c10::intrusive_ptr<JitFuture>, std::shared_ptr<RpcRetryInfo>>>
116 futures;
117 // Stores futures and exception messages for non-retriable error-ed futures.
118 std::vector<std::pair<c10::intrusive_ptr<JitFuture>, std::string>>
119 errorFutures;
120
121 while (rpcAgentRunning_.load()) {
122 std::unique_lock<std::mutex> lock(rpcRetryMutex_);
123
124 // We must continue sleeping as long as the RPC Agent is running and when
125 // either the Retry Map is empty, or when the Retry Map's earliest expiring
126 // RPC is set to be retried in the future.
127 steady_clock_time_point earliestTimeout =
128 std::chrono::steady_clock::now() + kLargeTimeDuration;
129
130 for (;;) {
131 if (!rpcAgentRunning_.load())
132 return;
133 if (std::chrono::steady_clock::now() >= earliestTimeout)
134 break;
135 if (!rpcRetryMap_.empty()) {
136 earliestTimeout = rpcRetryMap_.begin()->first;
137 }
138 rpcRetryMapCV_.wait_until(lock, earliestTimeout);
139 }
140
141 // Updating these since something may have been added to the map while this
142 // thread was sleeping.
143 earliestTimeout = rpcRetryMap_.begin()->first;
144 auto& earliestRpcList = rpcRetryMap_.begin()->second;
145
146 // We iterate through all the RPC's set to be retried at the current
147 // timepoint, resend those RPC's, and add the RPC's and their futures to
148 // a list to later attach callbacks. These callbacks either schedule
149 // the RPC for a future retry or marks it with success/error depending on
150 // the outcome of the current send. Then, we clean up the rpcRetryMap_.
151 for (auto it = earliestRpcList.begin(); it != earliestRpcList.end();
152 /* no increment */) {
153 auto& earliestRpc = *it;
154 c10::intrusive_ptr<JitFuture> jitFuture;
155
156 // send() will throw an exception if an RPC is retried while the agent is
157 // shutdown. We must catch this exception and mark the original future
158 // with an error, since this RPC never succeeded and can no longer be
159 // retried.
160 try {
161 jitFuture = send(earliestRpc->to_, earliestRpc->message_);
162 futures.emplace_back(jitFuture, earliestRpc);
163 } catch (std::exception& e) {
164 // We must store the futures and exception messages here and only mark
165 // the futures with an error after releasing the lock.
166 errorFutures.emplace_back(earliestRpc->originalFuture_, e.what());
167 }
168
169 // A callback will be attached to all futures for the retries in this
170 // list. Thus they will either be rescheduled for future retries or they
171 // will be marked as complete. We can safely delete them from the retry
172 // Map for the current timepoint.
173 it = earliestRpcList.erase(it);
174 }
175
176 // If there are no more RPC's set to be retried at the current timepoint,
177 // we can remove the corresponding unordered_set from the retry map.
178 if (earliestRpcList.empty()) {
179 rpcRetryMap_.erase(earliestTimeout);
180 }
181
182 lock.unlock();
183 // We attach callbacks to the futures outside of the lock to prevent
184 // potential deadlocks.
185 for (const auto& it : futures) {
186 auto jitFuture = it.first;
187 auto earliestRpc = it.second;
188 steady_clock_time_point newTime = computeNewRpcRetryTime(
189 earliestRpc->options_, earliestRpc->retryCount_);
190 earliestRpc->retryCount_++;
191
192 jitFuture->addCallback([this, newTime, earliestRpc](JitFuture& future) {
193 rpcRetryCallback(future, newTime, earliestRpc);
194 });
195 }
196 futures.clear();
197
198 // For exceptions caught while retrying RPC's above, we set those futures
199 // with errors now that we have released the lock.
200 for (const auto& it : errorFutures) {
201 auto errorFuture = it.first;
202 auto errorMsg = it.second;
203 errorFuture->setError(
204 std::make_exception_ptr(std::runtime_error(errorMsg)));
205 }
206 errorFutures.clear();
207 }
208 }
209
rpcRetryCallback(JitFuture & jitFuture,steady_clock_time_point newTime,std::shared_ptr<RpcRetryInfo> earliestRpc)210 void RpcAgent::rpcRetryCallback(
211 JitFuture& jitFuture,
212 steady_clock_time_point newTime,
213 std::shared_ptr<RpcRetryInfo> earliestRpc) {
214 if (jitFuture.hasError()) {
215 // Adding one since we want to include the original send as well and not
216 // just the retry count.
217 LOG(INFO) << "Send try " << (earliestRpc->retryCount_ + 1) << " failed";
218 if (!rpcAgentRunning_.load()) {
219 // If the RPC Agent has shutdown, we cannot retry messages. Thus we mark
220 // the future with an error since the RPC was never completed
221 // successfully.
222 std::string errorMessage = c10::str(
223 "RPC Agent is no longer running on Node ",
224 RpcAgent::getWorkerInfo().id_,
225 ". Cannot retry message.");
226 earliestRpc->originalFuture_->setError(jitFuture.exception_ptr());
227 } else if (earliestRpc->retryCount_ < earliestRpc->options_.maxRetries) {
228 // If the previous future completed with an error and we haven't
229 // completed maxRetries send attempts, we move the earliestRpc
230 // struct to a new time point in the retry map (effectively
231 // scheduling it for a future retry.)
232 {
233 std::lock_guard<std::mutex> retryMapLock(rpcRetryMutex_);
234 rpcRetryMap_[newTime].emplace(std::move(earliestRpc));
235 }
236 // The retry thread waits for the map to be populated. Thus we notify
237 // once an item has been added.
238 rpcRetryMapCV_.notify_one();
239 } else {
240 // We have completed maxRetries send attempts. We're now marking
241 // the future with an error.
242 std::string errorMessage = c10::str(
243 "The RPC has not succeeded after the specified number of max retries (",
244 earliestRpc->options_.maxRetries,
245 ").");
246 earliestRpc->originalFuture_->setError(
247 std::make_exception_ptr(std::runtime_error(errorMessage)));
248 }
249 } else {
250 // This try succeeded, so we can make the original future as complete.
251 earliestRpc->originalFuture_->markCompleted(
252 jitFuture.value(), jitFuture.storages());
253 }
254 }
255
getWorkerInfo() const256 const WorkerInfo& RpcAgent::getWorkerInfo() const {
257 return workerInfo_;
258 }
259
260 std::shared_ptr<RpcAgent> RpcAgent::currentRpcAgent_ = nullptr;
261
isCurrentRpcAgentSet()262 bool RpcAgent::isCurrentRpcAgentSet() {
263 return std::atomic_load(¤tRpcAgent_) != nullptr;
264 }
265
getCurrentRpcAgent()266 std::shared_ptr<RpcAgent> RpcAgent::getCurrentRpcAgent() {
267 std::shared_ptr<RpcAgent> agent = std::atomic_load(¤tRpcAgent_);
268 TORCH_CHECK(
269 agent,
270 "Current RPC agent is not set! Did you initialize the RPC "
271 "framework (e.g. by calling `rpc.init_rpc`)?");
272 return agent;
273 }
274
setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent)275 void RpcAgent::setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent) {
276 if (rpcAgent) {
277 std::shared_ptr<RpcAgent> previousAgent;
278 // Use compare_exchange so that we don't actually perform the exchange if
279 // that would trigger the assert just below. See:
280 // https://en.cppreference.com/w/cpp/atomic/atomic_compare_exchange
281 std::atomic_compare_exchange_strong(
282 ¤tRpcAgent_, &previousAgent, std::move(rpcAgent));
283 TORCH_INTERNAL_ASSERT(
284 previousAgent == nullptr, "Current RPC agent is set!");
285 } else {
286 // We can't use compare_exchange (we don't know what value to expect) but we
287 // don't need to, as the only case that would trigger the assert is if we
288 // replaced nullptr with nullptr, which we can just do as it has no effect.
289 std::shared_ptr<RpcAgent> previousAgent =
290 std::atomic_exchange(¤tRpcAgent_, std::move(rpcAgent));
291 TORCH_INTERNAL_ASSERT(
292 previousAgent != nullptr, "Current RPC agent is not set!");
293 }
294 }
295
setTypeResolver(std::shared_ptr<TypeResolver> typeResolver)296 void RpcAgent::setTypeResolver(std::shared_ptr<TypeResolver> typeResolver) {
297 typeResolver_ = std::move(typeResolver);
298 }
299
getTypeResolver()300 std::shared_ptr<TypeResolver> RpcAgent::getTypeResolver() {
301 TORCH_INTERNAL_ASSERT(typeResolver_, "Type resolver is not set!");
302 return typeResolver_;
303 }
304
enableGILProfiling(bool flag)305 void RpcAgent::enableGILProfiling(bool flag) {
306 profilingEnabled_ = flag;
307 }
308
isGILProfilingEnabled()309 bool RpcAgent::isGILProfilingEnabled() {
310 return profilingEnabled_.load();
311 }
312
getDeviceMap(const WorkerInfo &) const313 DeviceMap RpcAgent::getDeviceMap(const WorkerInfo& /* unused */) const {
314 // Default implementation has no device map.
315 return {};
316 }
317
getDevices() const318 const std::vector<c10::Device>& RpcAgent::getDevices() const {
319 // By default the agent is CPU-only.
320 static const std::vector<c10::Device> noDevices = {};
321 return noDevices;
322 }
323
getDebugInfo()324 std::unordered_map<std::string, std::string> RpcAgent::getDebugInfo() {
325 /* This would later include more info other than metrics for eg: may include
326 stack traces for the threads owned by the agent */
327 // Default implementation: return getMetrics().
328 return getMetrics();
329 }
330
operator <<(std::ostream & os,const WorkerInfo & workerInfo)331 std::ostream& operator<<(std::ostream& os, const WorkerInfo& workerInfo) {
332 return os << "WorkerInfo(id=" << workerInfo.id_
333 << ", name=" << workerInfo.name_ << ")";
334 }
335
336 } // namespace torch::distributed::rpc
337