xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/rpc_agent.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/message.h>
4 #include <torch/csrc/distributed/rpc/request_callback.h>
5 #include <torch/csrc/distributed/rpc/types.h>
6 
7 #include <algorithm>
8 #include <cctype>
9 #include <chrono>
10 #include <condition_variable>
11 #include <mutex>
12 #include <thread>
13 
14 namespace torch::distributed::rpc {
15 
16 using DeviceMap = std::unordered_map<c10::Device, c10::Device>;
17 
18 // Default RPC timeout
19 constexpr float kDefaultRpcTimeoutSeconds = 60;
20 // Unset RPC timeout. This is the value agent::send() will have if user does not
21 // pass in a specific timeout, and indicates that we must use the default
22 // timeout for RPCs.
23 constexpr float kUnsetRpcTimeout = -1;
24 constexpr auto kDefaultInitMethod = "env://";
25 constexpr float kSecToMsConversion = 1000;
26 constexpr auto kRpcTimeoutErrorStr =
27     "RPC ran for more than set timeout ({} ms) and will now be marked with an error";
28 
29 using steady_clock_time_point =
30     std::chrono::time_point<std::chrono::steady_clock>;
31 // Input is qualified name string, output is JIT StrongTypePtr
32 // Same as jit::TypeResolver, did not import jit::TypeResolver to here
33 // because it could introduce cyclic dependencies.
34 using TypeResolver =
35     std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
36 
37 struct TORCH_API RpcBackendOptions {
RpcBackendOptionsRpcBackendOptions38   RpcBackendOptions()
39       : RpcBackendOptions(kDefaultRpcTimeoutSeconds, kDefaultInitMethod) {}
40 
RpcBackendOptionsRpcBackendOptions41   RpcBackendOptions(float rpcTimeoutSeconds, std::string initMethod)
42       : rpcTimeoutSeconds(rpcTimeoutSeconds),
43         initMethod(std::move(initMethod)) {
44     TORCH_CHECK(rpcTimeoutSeconds >= 0, "RPC Timeout must be non-negative");
45   }
46 
47   float rpcTimeoutSeconds;
48   std::string initMethod;
49 };
50 
51 // A globally unique ID to identify an RpcAgent
52 struct TORCH_API WorkerInfo : torch::CustomClassHolder {
53   WorkerInfo(std::string name, int64_t id);
54 
55   WorkerInfo(std::string name, worker_id_t id);
56 
57   bool operator==(const WorkerInfo& rhs) {
58     return (id_ == rhs.id_) && (name_ == rhs.name_);
59   }
60 
61   static constexpr size_t MAX_NAME_LEN = 128;
62 
63   const std::string name_;
64   const worker_id_t id_;
65 };
66 
67 struct TORCH_API RegisterWorkerInfoOnce {
68   RegisterWorkerInfoOnce();
69 };
70 
71 TORCH_API std::ostream& operator<<(
72     std::ostream& os,
73     const WorkerInfo& workerInfo);
74 
75 // Struct for options to configure the RPC Retry protocol.
76 struct TORCH_API RpcRetryOptions {
77   // Using a default constructor like all other Options structs in the RPC
78   // codebase. TORCH_CHECKs for input validation are done in the
79   // sendWithRetries function.
80   RpcRetryOptions() = default;
81   // Maximum number of times we will retry the RPC
82   int maxRetries{5};
83   // Initial duration between consecutive RPC send attempts
84   std::chrono::milliseconds rpcRetryDuration{std::chrono::milliseconds(1000)};
85   // Constant for exponential backoff used while calculating future wait
86   // durations
87   float retryBackoff{1.5};
88 };
89 
90 // Struct that stores all the metadata needed to retry a given RPC.
91 struct TORCH_API RpcRetryInfo {
RpcRetryInfoRpcRetryInfo92   RpcRetryInfo(
93       const WorkerInfo& to,
94       c10::intrusive_ptr<Message> message,
95       c10::intrusive_ptr<JitFuture> originalFuture,
96       int retryCount,
97       RpcRetryOptions options)
98       : to_(to),
99         message_(std::move(message)),
100         originalFuture_(std::move(originalFuture)),
101         retryCount_(retryCount),
102         options_(options) {}
103 
104   const WorkerInfo& to_;
105   c10::intrusive_ptr<Message> message_;
106   // Future that is returned to the caller of sendWithRetries().
107   c10::intrusive_ptr<JitFuture> originalFuture_;
108   // Number of send attempts completed so far.
109   int retryCount_;
110   RpcRetryOptions options_;
111 };
112 
113 // ``RpcAgent`` is the base class for sending and receiving RPC messages. It
114 // provides a unified ``send`` API for both request and response messages, and
115 // will invoke the given ``RequestCallback`` to process received requests. It
116 // should immediately become ready to serve request and accept response after
117 // construction.
118 class TORCH_API RpcAgent {
119  public:
120   // `WorkerInfo` is the globally unique identifier for this RpcAgent instance.
121   // It contains a ``name_`` field and an ``id_`` field. ``name_`` is the
122   // globally unique name for this ``RpcAgent``. It is up to the ``RpcAgent``
123   // implementation to determine how to resolve names. ``id_`` is the globally
124   // unique ID for this ``RpcAgent``. This should be determined by the
125   // ``RpcAgent`` implementation.
126   // The ``RequestCallback`` will be invoked to handle received requests. This
127   // ``RpcAgent`` base class makes no assumption on the thread-safeness of the
128   // ``RequestCallback``. ``RpcAgent`` implementations need to make sure that
129   // its threading model conform to ``RequestCallback``'s requirement.
130   // NB: RpcAgent implementations should not start serving requests until
131   // ``start()`` is called, as there could be other contexts that have not been
132   // initialized yet at this time.
133   RpcAgent(
134       WorkerInfo id,
135       std::unique_ptr<RequestCallback> cb,
136       std::chrono::milliseconds rpcTimeout);
137 
138   virtual ~RpcAgent();
139 
140   // Send a message to the ``RpcAgent`` of id ``to`` and returns a
141   // ``JitFuture`` ptr. The implementation must be asynchronous, i.e., it
142   // cannot block until it receives the response.
143   //
144   // If ``message.isRequest()`` is true, the ``JitFuture`` will be
145   // completed when the response arrives. For other message types, the Future
146   // should be ignored by the caller.
147   virtual c10::intrusive_ptr<JitFuture> send(
148       const WorkerInfo& to,
149       c10::intrusive_ptr<Message> message,
150       const float rpcTimeoutSeconds = kUnsetRpcTimeout,
151       const DeviceMap& deviceMap = {}) = 0;
152 
153   // Retries sending the message up to maxRetries times until an ACK is
154   // received. The duration between consecutive sends is increased over
155   // time using an exponential backoff algorithm.
156   //
157   // Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a
158   // ``JitFuture`` ptr, just like send(). Caller can specify the maximum
159   // number of retries for this RPC (default is 5), initial duration between
160   // sends (default is 1000ms), and backoff constant (default is 1.5) by
161   // passing in the RpcRetryOptions struct. This API might end up
162   // executing a method twice on the remote end (it does not guarantee
163   // exactly-once semantics). Therefore, the user must ensure their requests
164   // are idempotent.
165   c10::intrusive_ptr<JitFuture> sendWithRetries(
166       const WorkerInfo& to,
167       c10::intrusive_ptr<Message> message,
168       RpcRetryOptions retryOptions = RpcRetryOptions());
169 
170   // Return a reference to the ``WorkerInfo`` of this RpcAgent.
171   // NB: not using ``std::optional<const std::string&>`` here because we might
172   // need to create a separate RPC API lib and avoid forcing all ``RpcAgent``
173   // implementations to depend on libtorch.
174   const WorkerInfo& getWorkerInfo() const;
175 
176   // Return a reference to the ``WorkerInfo`` of the given ``workerName``.
177   virtual const WorkerInfo& getWorkerInfo(
178       const std::string& workerName) const = 0;
179 
180   virtual const WorkerInfo& getWorkerInfo(worker_id_t id) const = 0;
181 
182   virtual std::vector<WorkerInfo> getWorkerInfos() const = 0;
183 
184   // Retrieve the timeout for all RPCs.
getRpcTimeout()185   inline std::chrono::milliseconds getRpcTimeout() const {
186     return rpcTimeout_.load();
187   }
188 
189   // Set the timeout for all RPCs
setRpcTimeout(const std::chrono::milliseconds & rpcTimeout)190   inline void setRpcTimeout(const std::chrono::milliseconds& rpcTimeout) {
191     rpcTimeout_.store(rpcTimeout);
192   }
193 
194   // Call sync and join all internal threads. This method should be called
195   // before every RPC process exits.
196   virtual void join(bool shutdown = false, float timeout = 0) = 0;
197 
198   // Synchronize the this process with other ``RpcAgent`` processes. Block until
199   // all ``RpcAgent``s reach this method and send all pending messages.
200   virtual void sync() = 0;
201 
202   // Sets up backend-agnostic state for accepting requests. Currently, this
203   // entails setting rpcAgentRunning_ to true, creating the retry thread, and
204   // calling the backend's startImpl.
205   void start();
206 
207   // Derived classes must override this function to start accepting requests.
208   // This is used to initialize any backend-specific state. Users must call
209   // start, not startImpl, to initialize the RPC Agent.
210   virtual void startImpl() = 0;
211 
212   // Stop accepting requests and shutdown the RPC framework as soon as possible
213   // by terminating all RPC threads.
214   void shutdown();
215 
216   // Derived classes must override this function to start accepting requests.
217   // THis is used to clean up any backend-specific state. Users must call
218   // shutdown, not shutdownImpl, to shutdown the RPC Agent.
219   virtual void shutdownImpl() = 0;
220 
221   // Check if current RPC agent is set.
222   static bool isCurrentRpcAgentSet();
223 
224   // Retrieve the valid current RPC agent.
225   static std::shared_ptr<RpcAgent> getCurrentRpcAgent();
226 
227   // Set the current RPC agent.
228   static void setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent);
229 
230   // Retrieve metrics as KV map
231   virtual std::unordered_map<std::string, std::string> getMetrics() = 0;
232 
233   // Retrieve debug info in addition to metrics as KV map
234   virtual std::unordered_map<std::string, std::string> getDebugInfo();
235 
236   // Flag to control whether GIL wait times
237   // should be profiled or not.
238   void enableGILProfiling(bool flag);
239 
240   // Retrieve wheher we should profile GIL wait times or not.
241   bool isGILProfilingEnabled();
242 
243   // Set type resolver that will be passed to JIT pickler to resolver type Ptr
244   // based on type str.
245   void setTypeResolver(std::shared_ptr<TypeResolver> typeResolver);
246 
247   // Get the type resolver
248   std::shared_ptr<TypeResolver> getTypeResolver();
249 
250   // Retrieves the device map for the provided destination worker.
251   virtual DeviceMap getDeviceMap(const WorkerInfo& dst) const;
252 
253   // Retrieve the (non-CPU) devices that are supported by the agent.
254   virtual const std::vector<c10::Device>& getDevices() const;
255 
256  protected:
257   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
258   const WorkerInfo workerInfo_;
259   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
260   const std::unique_ptr<RequestCallback> cb_;
261   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
262   std::atomic<std::chrono::milliseconds> rpcTimeout_;
263   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
264   std::atomic<bool> profilingEnabled_;
265   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
266   std::shared_ptr<TypeResolver> typeResolver_;
267   // Atomic boolean indicating whether this agent is running. It controls
268   // whether several background threads should be running. It is set in
269   // RpcAgent::start() and unset in the derived class shutdown().
270   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
271   std::atomic<bool> rpcAgentRunning_;
272 
273  private:
274   static std::shared_ptr<RpcAgent> currentRpcAgent_;
275   // Add GIL wait time data point to metrics
276   virtual void addGilWaitTime(const std::chrono::microseconds gilWaitTime) = 0;
277   friend class PythonRpcHandler;
278 
279   // Map that stores metadata for RPC's that may need to be re-tried as well as
280   // the timepoint at which we should re-try them.
281   std::map<
282       steady_clock_time_point,
283       std::unordered_set<std::shared_ptr<RpcRetryInfo>>>
284       rpcRetryMap_;
285 
286   // Thread that checks for retryable RPC's in the rpcRetryMap_ and sleeps until
287   // the next unACKed RPC's timeout has expired.
288   std::thread rpcRetryThread_;
289 
290   // Function that rpcRetryThread_ calls in a loop as long as RpcAgent is
291   // running.
292   void retryExpiredRpcs();
293 
294   // This is the callback attached to futures corresponding to send retries.
295   // This handles 3 cases: 1). send was completed, 2). send failed with an
296   // error and we've done maxRetries failed send attempts, and 3). send
297   // failed with an error and we have more retries to go. In case 1, we mark
298   // the original future as complete. In case 2, we mark the future with an
299   // error and do not retry again. In case 3, we move the RpcRetryInfo struct
300   // to another time point in the map to schedule the RPC for a future send.
301   void rpcRetryCallback(
302       JitFuture& message,
303       steady_clock_time_point newTime,
304       std::shared_ptr<RpcRetryInfo> earliestRpc);
305 
306   // Function that uses the exponential backoff algorithm to compute the next
307   // time point to retry a given RPC.
computeNewRpcRetryTime(RpcRetryOptions & options,int retryCount)308   inline steady_clock_time_point computeNewRpcRetryTime(
309       RpcRetryOptions& options,
310       int retryCount) {
311     // The exponential backoff algorithm being used here is:
312     // newTime = timeNow + (retryDuration * (backoffConstant ^ retryCount)).
313     std::chrono::milliseconds timedelta =
314         std::chrono::duration_cast<std::chrono::milliseconds>(
315             options.rpcRetryDuration * pow(options.retryBackoff, retryCount));
316     return std::chrono::time_point_cast<std::chrono::milliseconds>(
317         std::chrono::steady_clock::now() + timedelta);
318   }
319 
320   // Condition Variable to signal when the rpcRetryMap_ has been populated.
321   std::condition_variable rpcRetryMapCV_;
322 
323   // Mutex to protect RpcRetryMap_.
324   std::mutex rpcRetryMutex_;
325 };
326 
327 } // namespace torch::distributed::rpc
328 
329 namespace std {
330 template <>
331 struct hash<torch::distributed::rpc::WorkerInfo> {
332   std::size_t operator()(
333       const torch::distributed::rpc::WorkerInfo& worker_info) const noexcept {
334     return worker_info.id_;
335   }
336 };
337 } // namespace std
338