xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/rpc_agent.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/DeadlockDetection.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 
4 namespace torch::distributed::rpc {
5 
RegisterWorkerInfoOnce()6 RegisterWorkerInfoOnce::RegisterWorkerInfoOnce() {
7   // WorkerInfo needs to be registered exactly once. Since the op registration
8   // happens in libtorch_python we wrap the class registration in a helper to
9   // make sure that if there's multiple copies of Python such as used in
10   // torch::deploy we only ever register it once.
11   static auto workerInfo = torch::class_<WorkerInfo>("dist_rpc", "WorkerInfo")
12                                .def(torch::init<std::string, int64_t>());
13 }
14 
WorkerInfo(std::string name,int64_t id)15 WorkerInfo::WorkerInfo(std::string name, int64_t id)
16     : WorkerInfo(std::move(name), (worker_id_t)id) {
17   TORCH_CHECK(
18       id <= std::numeric_limits<worker_id_t>::max(),
19       "RPC worker id ",
20       id,
21       " out of bound of int16_t.");
22 }
23 
WorkerInfo(std::string name,worker_id_t id)24 WorkerInfo::WorkerInfo(std::string name, worker_id_t id)
25     : name_(std::move(name)), id_(id) {
26   bool validSize = name_.length() < MAX_NAME_LEN && !name_.empty();
27   bool validChar =
28       std::find_if(name_.begin(), name_.end(), [](char c) {
29         return !(std::isalnum(c) || c == '-' || c == '_' || c == ':');
30       }) == name_.end();
31   TORCH_CHECK(
32       validSize && validChar,
33       "Worker name must match ^[A-Za-z0-9-_:]*$, "
34       "and must be non-empty and shorter than ",
35       MAX_NAME_LEN,
36       " chars, "
37       "but got ",
38       name_);
39 }
40 
41 // Large Time Duration for waiting on the condition variable until the map is
42 // population. Cannot use
43 // std::chrono::time_point<std::chrono::steady_clock>::max() due to a known
44 // overflow-related bug.
45 constexpr auto kLargeTimeDuration = std::chrono::hours(10000);
46 
RpcAgent(WorkerInfo workerId,std::unique_ptr<RequestCallback> cb,std::chrono::milliseconds rpcTimeout)47 RpcAgent::RpcAgent(
48     WorkerInfo workerId,
49     std::unique_ptr<RequestCallback> cb,
50     std::chrono::milliseconds rpcTimeout)
51     : workerInfo_(std::move(workerId)),
52       cb_(std::move(cb)),
53       rpcTimeout_(rpcTimeout),
54       profilingEnabled_(false),
55       rpcAgentRunning_(false) {}
56 
~RpcAgent()57 RpcAgent::~RpcAgent() {
58   if (rpcAgentRunning_.load()) {
59     shutdown();
60   }
61 }
62 
start()63 void RpcAgent::start() {
64   rpcAgentRunning_.store(true);
65   rpcRetryThread_ = std::thread(&RpcAgent::retryExpiredRpcs, this);
66   startImpl();
67 }
68 
shutdown()69 void RpcAgent::shutdown() {
70   TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP();
71   std::unique_lock<std::mutex> lock(rpcRetryMutex_);
72   rpcAgentRunning_.store(false);
73   lock.unlock();
74   rpcRetryMapCV_.notify_one();
75   if (rpcRetryThread_.joinable()) {
76     rpcRetryThread_.join();
77   }
78   // NOLINTNEXTLINE(clang-analyzer-cplusplus.PureVirtualCall)
79   shutdownImpl();
80 }
81 
sendWithRetries(const WorkerInfo & to,c10::intrusive_ptr<Message> message,RpcRetryOptions retryOptions)82 c10::intrusive_ptr<JitFuture> RpcAgent::sendWithRetries(
83     const WorkerInfo& to,
84     c10::intrusive_ptr<Message> message,
85     RpcRetryOptions retryOptions) {
86   TORCH_CHECK(retryOptions.maxRetries >= 0, "maxRetries cannot be negative.");
87   TORCH_CHECK(
88       retryOptions.retryBackoff >= 1,
89       "maxRetries cannot be exponentially decaying.");
90   TORCH_CHECK(
91       retryOptions.rpcRetryDuration.count() >= 0,
92       "rpcRetryDuration cannot be negative.");
93 
94   auto originalFuture =
95       c10::make_intrusive<JitFuture>(at::AnyClassType::get(), getDevices());
96   steady_clock_time_point newTime =
97       computeNewRpcRetryTime(retryOptions, /* retryCount */ 0);
98   auto firstRetryRpc = std::make_shared<RpcRetryInfo>(
99       to,
100       message,
101       originalFuture,
102       /* retryCount */ 0,
103       retryOptions);
104   auto jitFuture = send(to, std::move(message));
105   jitFuture->addCallback([this, newTime, firstRetryRpc](JitFuture& future) {
106     rpcRetryCallback(future, newTime, firstRetryRpc);
107   });
108 
109   return originalFuture;
110 }
111 
retryExpiredRpcs()112 void RpcAgent::retryExpiredRpcs() {
113   // Stores the retried futures so callbacks can be added outside the lock.
114   std::vector<
115       std::pair<c10::intrusive_ptr<JitFuture>, std::shared_ptr<RpcRetryInfo>>>
116       futures;
117   // Stores futures and exception messages for non-retriable error-ed futures.
118   std::vector<std::pair<c10::intrusive_ptr<JitFuture>, std::string>>
119       errorFutures;
120 
121   while (rpcAgentRunning_.load()) {
122     std::unique_lock<std::mutex> lock(rpcRetryMutex_);
123 
124     // We must continue sleeping as long as the RPC Agent is running and when
125     // either the Retry Map is empty, or when the Retry Map's earliest expiring
126     // RPC is set to be retried in the future.
127     steady_clock_time_point earliestTimeout =
128         std::chrono::steady_clock::now() + kLargeTimeDuration;
129 
130     for (;;) {
131       if (!rpcAgentRunning_.load())
132         return;
133       if (std::chrono::steady_clock::now() >= earliestTimeout)
134         break;
135       if (!rpcRetryMap_.empty()) {
136         earliestTimeout = rpcRetryMap_.begin()->first;
137       }
138       rpcRetryMapCV_.wait_until(lock, earliestTimeout);
139     }
140 
141     // Updating these since something may have been added to the map while this
142     // thread was sleeping.
143     earliestTimeout = rpcRetryMap_.begin()->first;
144     auto& earliestRpcList = rpcRetryMap_.begin()->second;
145 
146     // We iterate through all the RPC's set to be retried at the current
147     // timepoint, resend those RPC's, and add the RPC's and their futures to
148     // a list to later attach callbacks. These callbacks either schedule
149     // the RPC for a future retry or marks it with success/error depending on
150     // the outcome of the current send. Then, we clean up the rpcRetryMap_.
151     for (auto it = earliestRpcList.begin(); it != earliestRpcList.end();
152          /* no increment */) {
153       auto& earliestRpc = *it;
154       c10::intrusive_ptr<JitFuture> jitFuture;
155 
156       // send() will throw an exception if an RPC is retried while the agent is
157       // shutdown. We must catch this exception and mark the original future
158       // with an error, since this RPC never succeeded and can no longer be
159       // retried.
160       try {
161         jitFuture = send(earliestRpc->to_, earliestRpc->message_);
162         futures.emplace_back(jitFuture, earliestRpc);
163       } catch (std::exception& e) {
164         // We must store the futures and exception messages here and only mark
165         // the futures with an error after releasing the lock.
166         errorFutures.emplace_back(earliestRpc->originalFuture_, e.what());
167       }
168 
169       // A callback will be attached to all futures for the retries in this
170       // list. Thus they will either be rescheduled for future retries or they
171       // will be marked as complete. We can safely delete them from the retry
172       // Map for the current timepoint.
173       it = earliestRpcList.erase(it);
174     }
175 
176     // If there are no more RPC's set to be retried at the current timepoint,
177     // we can remove the corresponding unordered_set from the retry map.
178     if (earliestRpcList.empty()) {
179       rpcRetryMap_.erase(earliestTimeout);
180     }
181 
182     lock.unlock();
183     // We attach callbacks to the futures outside of the lock to prevent
184     // potential deadlocks.
185     for (const auto& it : futures) {
186       auto jitFuture = it.first;
187       auto earliestRpc = it.second;
188       steady_clock_time_point newTime = computeNewRpcRetryTime(
189           earliestRpc->options_, earliestRpc->retryCount_);
190       earliestRpc->retryCount_++;
191 
192       jitFuture->addCallback([this, newTime, earliestRpc](JitFuture& future) {
193         rpcRetryCallback(future, newTime, earliestRpc);
194       });
195     }
196     futures.clear();
197 
198     // For exceptions caught while retrying RPC's above, we set those futures
199     // with errors now that we have released the lock.
200     for (const auto& it : errorFutures) {
201       auto errorFuture = it.first;
202       auto errorMsg = it.second;
203       errorFuture->setError(
204           std::make_exception_ptr(std::runtime_error(errorMsg)));
205     }
206     errorFutures.clear();
207   }
208 }
209 
rpcRetryCallback(JitFuture & jitFuture,steady_clock_time_point newTime,std::shared_ptr<RpcRetryInfo> earliestRpc)210 void RpcAgent::rpcRetryCallback(
211     JitFuture& jitFuture,
212     steady_clock_time_point newTime,
213     std::shared_ptr<RpcRetryInfo> earliestRpc) {
214   if (jitFuture.hasError()) {
215     // Adding one since we want to include the original send as well and not
216     // just the retry count.
217     LOG(INFO) << "Send try " << (earliestRpc->retryCount_ + 1) << " failed";
218     if (!rpcAgentRunning_.load()) {
219       // If the RPC Agent has shutdown, we cannot retry messages. Thus we mark
220       // the future with an error since the RPC was never completed
221       // successfully.
222       std::string errorMessage = c10::str(
223           "RPC Agent is no longer running on Node ",
224           RpcAgent::getWorkerInfo().id_,
225           ". Cannot retry message.");
226       earliestRpc->originalFuture_->setError(jitFuture.exception_ptr());
227     } else if (earliestRpc->retryCount_ < earliestRpc->options_.maxRetries) {
228       // If the previous future completed with an error and we haven't
229       // completed maxRetries send attempts, we move the earliestRpc
230       // struct to a new time point in the retry map (effectively
231       // scheduling it for a future retry.)
232       {
233         std::lock_guard<std::mutex> retryMapLock(rpcRetryMutex_);
234         rpcRetryMap_[newTime].emplace(std::move(earliestRpc));
235       }
236       // The retry thread waits for the map to be populated. Thus we notify
237       // once an item has been added.
238       rpcRetryMapCV_.notify_one();
239     } else {
240       // We have completed maxRetries send attempts. We're now marking
241       // the future with an error.
242       std::string errorMessage = c10::str(
243           "The RPC has not succeeded after the specified number of max retries (",
244           earliestRpc->options_.maxRetries,
245           ").");
246       earliestRpc->originalFuture_->setError(
247           std::make_exception_ptr(std::runtime_error(errorMessage)));
248     }
249   } else {
250     // This try succeeded, so we can make the original future as complete.
251     earliestRpc->originalFuture_->markCompleted(
252         jitFuture.value(), jitFuture.storages());
253   }
254 }
255 
getWorkerInfo() const256 const WorkerInfo& RpcAgent::getWorkerInfo() const {
257   return workerInfo_;
258 }
259 
260 std::shared_ptr<RpcAgent> RpcAgent::currentRpcAgent_ = nullptr;
261 
isCurrentRpcAgentSet()262 bool RpcAgent::isCurrentRpcAgentSet() {
263   return std::atomic_load(&currentRpcAgent_) != nullptr;
264 }
265 
getCurrentRpcAgent()266 std::shared_ptr<RpcAgent> RpcAgent::getCurrentRpcAgent() {
267   std::shared_ptr<RpcAgent> agent = std::atomic_load(&currentRpcAgent_);
268   TORCH_CHECK(
269       agent,
270       "Current RPC agent is not set! Did you initialize the RPC "
271       "framework (e.g. by calling `rpc.init_rpc`)?");
272   return agent;
273 }
274 
setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent)275 void RpcAgent::setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent) {
276   if (rpcAgent) {
277     std::shared_ptr<RpcAgent> previousAgent;
278     // Use compare_exchange so that we don't actually perform the exchange if
279     // that would trigger the assert just below. See:
280     // https://en.cppreference.com/w/cpp/atomic/atomic_compare_exchange
281     std::atomic_compare_exchange_strong(
282         &currentRpcAgent_, &previousAgent, std::move(rpcAgent));
283     TORCH_INTERNAL_ASSERT(
284         previousAgent == nullptr, "Current RPC agent is set!");
285   } else {
286     // We can't use compare_exchange (we don't know what value to expect) but we
287     // don't need to, as the only case that would trigger the assert is if we
288     // replaced nullptr with nullptr, which we can just do as it has no effect.
289     std::shared_ptr<RpcAgent> previousAgent =
290         std::atomic_exchange(&currentRpcAgent_, std::move(rpcAgent));
291     TORCH_INTERNAL_ASSERT(
292         previousAgent != nullptr, "Current RPC agent is not set!");
293   }
294 }
295 
setTypeResolver(std::shared_ptr<TypeResolver> typeResolver)296 void RpcAgent::setTypeResolver(std::shared_ptr<TypeResolver> typeResolver) {
297   typeResolver_ = std::move(typeResolver);
298 }
299 
getTypeResolver()300 std::shared_ptr<TypeResolver> RpcAgent::getTypeResolver() {
301   TORCH_INTERNAL_ASSERT(typeResolver_, "Type resolver is not set!");
302   return typeResolver_;
303 }
304 
enableGILProfiling(bool flag)305 void RpcAgent::enableGILProfiling(bool flag) {
306   profilingEnabled_ = flag;
307 }
308 
isGILProfilingEnabled()309 bool RpcAgent::isGILProfilingEnabled() {
310   return profilingEnabled_.load();
311 }
312 
getDeviceMap(const WorkerInfo &) const313 DeviceMap RpcAgent::getDeviceMap(const WorkerInfo& /* unused */) const {
314   // Default implementation has no device map.
315   return {};
316 }
317 
getDevices() const318 const std::vector<c10::Device>& RpcAgent::getDevices() const {
319   // By default the agent is CPU-only.
320   static const std::vector<c10::Device> noDevices = {};
321   return noDevices;
322 }
323 
getDebugInfo()324 std::unordered_map<std::string, std::string> RpcAgent::getDebugInfo() {
325   /* This would later include more info other than metrics for eg: may include
326      stack traces for the threads owned by the agent */
327   // Default implementation: return getMetrics().
328   return getMetrics();
329 }
330 
operator <<(std::ostream & os,const WorkerInfo & workerInfo)331 std::ostream& operator<<(std::ostream& os, const WorkerInfo& workerInfo) {
332   return os << "WorkerInfo(id=" << workerInfo.id_
333             << ", name=" << workerInfo.name_ << ")";
334 }
335 
336 } // namespace torch::distributed::rpc
337