1 #pragma once 2 3 #include <torch/csrc/distributed/rpc/message.h> 4 #include <torch/csrc/distributed/rpc/request_callback.h> 5 #include <torch/csrc/distributed/rpc/types.h> 6 7 #include <algorithm> 8 #include <cctype> 9 #include <chrono> 10 #include <condition_variable> 11 #include <mutex> 12 #include <thread> 13 14 namespace torch::distributed::rpc { 15 16 using DeviceMap = std::unordered_map<c10::Device, c10::Device>; 17 18 // Default RPC timeout 19 constexpr float kDefaultRpcTimeoutSeconds = 60; 20 // Unset RPC timeout. This is the value agent::send() will have if user does not 21 // pass in a specific timeout, and indicates that we must use the default 22 // timeout for RPCs. 23 constexpr float kUnsetRpcTimeout = -1; 24 constexpr auto kDefaultInitMethod = "env://"; 25 constexpr float kSecToMsConversion = 1000; 26 constexpr auto kRpcTimeoutErrorStr = 27 "RPC ran for more than set timeout ({} ms) and will now be marked with an error"; 28 29 using steady_clock_time_point = 30 std::chrono::time_point<std::chrono::steady_clock>; 31 // Input is qualified name string, output is JIT StrongTypePtr 32 // Same as jit::TypeResolver, did not import jit::TypeResolver to here 33 // because it could introduce cyclic dependencies. 34 using TypeResolver = 35 std::function<c10::StrongTypePtr(const c10::QualifiedName&)>; 36 37 struct TORCH_API RpcBackendOptions { RpcBackendOptionsRpcBackendOptions38 RpcBackendOptions() 39 : RpcBackendOptions(kDefaultRpcTimeoutSeconds, kDefaultInitMethod) {} 40 RpcBackendOptionsRpcBackendOptions41 RpcBackendOptions(float rpcTimeoutSeconds, std::string initMethod) 42 : rpcTimeoutSeconds(rpcTimeoutSeconds), 43 initMethod(std::move(initMethod)) { 44 TORCH_CHECK(rpcTimeoutSeconds >= 0, "RPC Timeout must be non-negative"); 45 } 46 47 float rpcTimeoutSeconds; 48 std::string initMethod; 49 }; 50 51 // A globally unique ID to identify an RpcAgent 52 struct TORCH_API WorkerInfo : torch::CustomClassHolder { 53 WorkerInfo(std::string name, int64_t id); 54 55 WorkerInfo(std::string name, worker_id_t id); 56 57 bool operator==(const WorkerInfo& rhs) { 58 return (id_ == rhs.id_) && (name_ == rhs.name_); 59 } 60 61 static constexpr size_t MAX_NAME_LEN = 128; 62 63 const std::string name_; 64 const worker_id_t id_; 65 }; 66 67 struct TORCH_API RegisterWorkerInfoOnce { 68 RegisterWorkerInfoOnce(); 69 }; 70 71 TORCH_API std::ostream& operator<<( 72 std::ostream& os, 73 const WorkerInfo& workerInfo); 74 75 // Struct for options to configure the RPC Retry protocol. 76 struct TORCH_API RpcRetryOptions { 77 // Using a default constructor like all other Options structs in the RPC 78 // codebase. TORCH_CHECKs for input validation are done in the 79 // sendWithRetries function. 80 RpcRetryOptions() = default; 81 // Maximum number of times we will retry the RPC 82 int maxRetries{5}; 83 // Initial duration between consecutive RPC send attempts 84 std::chrono::milliseconds rpcRetryDuration{std::chrono::milliseconds(1000)}; 85 // Constant for exponential backoff used while calculating future wait 86 // durations 87 float retryBackoff{1.5}; 88 }; 89 90 // Struct that stores all the metadata needed to retry a given RPC. 91 struct TORCH_API RpcRetryInfo { RpcRetryInfoRpcRetryInfo92 RpcRetryInfo( 93 const WorkerInfo& to, 94 c10::intrusive_ptr<Message> message, 95 c10::intrusive_ptr<JitFuture> originalFuture, 96 int retryCount, 97 RpcRetryOptions options) 98 : to_(to), 99 message_(std::move(message)), 100 originalFuture_(std::move(originalFuture)), 101 retryCount_(retryCount), 102 options_(options) {} 103 104 const WorkerInfo& to_; 105 c10::intrusive_ptr<Message> message_; 106 // Future that is returned to the caller of sendWithRetries(). 107 c10::intrusive_ptr<JitFuture> originalFuture_; 108 // Number of send attempts completed so far. 109 int retryCount_; 110 RpcRetryOptions options_; 111 }; 112 113 // ``RpcAgent`` is the base class for sending and receiving RPC messages. It 114 // provides a unified ``send`` API for both request and response messages, and 115 // will invoke the given ``RequestCallback`` to process received requests. It 116 // should immediately become ready to serve request and accept response after 117 // construction. 118 class TORCH_API RpcAgent { 119 public: 120 // `WorkerInfo` is the globally unique identifier for this RpcAgent instance. 121 // It contains a ``name_`` field and an ``id_`` field. ``name_`` is the 122 // globally unique name for this ``RpcAgent``. It is up to the ``RpcAgent`` 123 // implementation to determine how to resolve names. ``id_`` is the globally 124 // unique ID for this ``RpcAgent``. This should be determined by the 125 // ``RpcAgent`` implementation. 126 // The ``RequestCallback`` will be invoked to handle received requests. This 127 // ``RpcAgent`` base class makes no assumption on the thread-safeness of the 128 // ``RequestCallback``. ``RpcAgent`` implementations need to make sure that 129 // its threading model conform to ``RequestCallback``'s requirement. 130 // NB: RpcAgent implementations should not start serving requests until 131 // ``start()`` is called, as there could be other contexts that have not been 132 // initialized yet at this time. 133 RpcAgent( 134 WorkerInfo id, 135 std::unique_ptr<RequestCallback> cb, 136 std::chrono::milliseconds rpcTimeout); 137 138 virtual ~RpcAgent(); 139 140 // Send a message to the ``RpcAgent`` of id ``to`` and returns a 141 // ``JitFuture`` ptr. The implementation must be asynchronous, i.e., it 142 // cannot block until it receives the response. 143 // 144 // If ``message.isRequest()`` is true, the ``JitFuture`` will be 145 // completed when the response arrives. For other message types, the Future 146 // should be ignored by the caller. 147 virtual c10::intrusive_ptr<JitFuture> send( 148 const WorkerInfo& to, 149 c10::intrusive_ptr<Message> message, 150 const float rpcTimeoutSeconds = kUnsetRpcTimeout, 151 const DeviceMap& deviceMap = {}) = 0; 152 153 // Retries sending the message up to maxRetries times until an ACK is 154 // received. The duration between consecutive sends is increased over 155 // time using an exponential backoff algorithm. 156 // 157 // Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a 158 // ``JitFuture`` ptr, just like send(). Caller can specify the maximum 159 // number of retries for this RPC (default is 5), initial duration between 160 // sends (default is 1000ms), and backoff constant (default is 1.5) by 161 // passing in the RpcRetryOptions struct. This API might end up 162 // executing a method twice on the remote end (it does not guarantee 163 // exactly-once semantics). Therefore, the user must ensure their requests 164 // are idempotent. 165 c10::intrusive_ptr<JitFuture> sendWithRetries( 166 const WorkerInfo& to, 167 c10::intrusive_ptr<Message> message, 168 RpcRetryOptions retryOptions = RpcRetryOptions()); 169 170 // Return a reference to the ``WorkerInfo`` of this RpcAgent. 171 // NB: not using ``std::optional<const std::string&>`` here because we might 172 // need to create a separate RPC API lib and avoid forcing all ``RpcAgent`` 173 // implementations to depend on libtorch. 174 const WorkerInfo& getWorkerInfo() const; 175 176 // Return a reference to the ``WorkerInfo`` of the given ``workerName``. 177 virtual const WorkerInfo& getWorkerInfo( 178 const std::string& workerName) const = 0; 179 180 virtual const WorkerInfo& getWorkerInfo(worker_id_t id) const = 0; 181 182 virtual std::vector<WorkerInfo> getWorkerInfos() const = 0; 183 184 // Retrieve the timeout for all RPCs. getRpcTimeout()185 inline std::chrono::milliseconds getRpcTimeout() const { 186 return rpcTimeout_.load(); 187 } 188 189 // Set the timeout for all RPCs setRpcTimeout(const std::chrono::milliseconds & rpcTimeout)190 inline void setRpcTimeout(const std::chrono::milliseconds& rpcTimeout) { 191 rpcTimeout_.store(rpcTimeout); 192 } 193 194 // Call sync and join all internal threads. This method should be called 195 // before every RPC process exits. 196 virtual void join(bool shutdown = false, float timeout = 0) = 0; 197 198 // Synchronize the this process with other ``RpcAgent`` processes. Block until 199 // all ``RpcAgent``s reach this method and send all pending messages. 200 virtual void sync() = 0; 201 202 // Sets up backend-agnostic state for accepting requests. Currently, this 203 // entails setting rpcAgentRunning_ to true, creating the retry thread, and 204 // calling the backend's startImpl. 205 void start(); 206 207 // Derived classes must override this function to start accepting requests. 208 // This is used to initialize any backend-specific state. Users must call 209 // start, not startImpl, to initialize the RPC Agent. 210 virtual void startImpl() = 0; 211 212 // Stop accepting requests and shutdown the RPC framework as soon as possible 213 // by terminating all RPC threads. 214 void shutdown(); 215 216 // Derived classes must override this function to start accepting requests. 217 // THis is used to clean up any backend-specific state. Users must call 218 // shutdown, not shutdownImpl, to shutdown the RPC Agent. 219 virtual void shutdownImpl() = 0; 220 221 // Check if current RPC agent is set. 222 static bool isCurrentRpcAgentSet(); 223 224 // Retrieve the valid current RPC agent. 225 static std::shared_ptr<RpcAgent> getCurrentRpcAgent(); 226 227 // Set the current RPC agent. 228 static void setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent); 229 230 // Retrieve metrics as KV map 231 virtual std::unordered_map<std::string, std::string> getMetrics() = 0; 232 233 // Retrieve debug info in addition to metrics as KV map 234 virtual std::unordered_map<std::string, std::string> getDebugInfo(); 235 236 // Flag to control whether GIL wait times 237 // should be profiled or not. 238 void enableGILProfiling(bool flag); 239 240 // Retrieve wheher we should profile GIL wait times or not. 241 bool isGILProfilingEnabled(); 242 243 // Set type resolver that will be passed to JIT pickler to resolver type Ptr 244 // based on type str. 245 void setTypeResolver(std::shared_ptr<TypeResolver> typeResolver); 246 247 // Get the type resolver 248 std::shared_ptr<TypeResolver> getTypeResolver(); 249 250 // Retrieves the device map for the provided destination worker. 251 virtual DeviceMap getDeviceMap(const WorkerInfo& dst) const; 252 253 // Retrieve the (non-CPU) devices that are supported by the agent. 254 virtual const std::vector<c10::Device>& getDevices() const; 255 256 protected: 257 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 258 const WorkerInfo workerInfo_; 259 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 260 const std::unique_ptr<RequestCallback> cb_; 261 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 262 std::atomic<std::chrono::milliseconds> rpcTimeout_; 263 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 264 std::atomic<bool> profilingEnabled_; 265 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 266 std::shared_ptr<TypeResolver> typeResolver_; 267 // Atomic boolean indicating whether this agent is running. It controls 268 // whether several background threads should be running. It is set in 269 // RpcAgent::start() and unset in the derived class shutdown(). 270 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 271 std::atomic<bool> rpcAgentRunning_; 272 273 private: 274 static std::shared_ptr<RpcAgent> currentRpcAgent_; 275 // Add GIL wait time data point to metrics 276 virtual void addGilWaitTime(const std::chrono::microseconds gilWaitTime) = 0; 277 friend class PythonRpcHandler; 278 279 // Map that stores metadata for RPC's that may need to be re-tried as well as 280 // the timepoint at which we should re-try them. 281 std::map< 282 steady_clock_time_point, 283 std::unordered_set<std::shared_ptr<RpcRetryInfo>>> 284 rpcRetryMap_; 285 286 // Thread that checks for retryable RPC's in the rpcRetryMap_ and sleeps until 287 // the next unACKed RPC's timeout has expired. 288 std::thread rpcRetryThread_; 289 290 // Function that rpcRetryThread_ calls in a loop as long as RpcAgent is 291 // running. 292 void retryExpiredRpcs(); 293 294 // This is the callback attached to futures corresponding to send retries. 295 // This handles 3 cases: 1). send was completed, 2). send failed with an 296 // error and we've done maxRetries failed send attempts, and 3). send 297 // failed with an error and we have more retries to go. In case 1, we mark 298 // the original future as complete. In case 2, we mark the future with an 299 // error and do not retry again. In case 3, we move the RpcRetryInfo struct 300 // to another time point in the map to schedule the RPC for a future send. 301 void rpcRetryCallback( 302 JitFuture& message, 303 steady_clock_time_point newTime, 304 std::shared_ptr<RpcRetryInfo> earliestRpc); 305 306 // Function that uses the exponential backoff algorithm to compute the next 307 // time point to retry a given RPC. computeNewRpcRetryTime(RpcRetryOptions & options,int retryCount)308 inline steady_clock_time_point computeNewRpcRetryTime( 309 RpcRetryOptions& options, 310 int retryCount) { 311 // The exponential backoff algorithm being used here is: 312 // newTime = timeNow + (retryDuration * (backoffConstant ^ retryCount)). 313 std::chrono::milliseconds timedelta = 314 std::chrono::duration_cast<std::chrono::milliseconds>( 315 options.rpcRetryDuration * pow(options.retryBackoff, retryCount)); 316 return std::chrono::time_point_cast<std::chrono::milliseconds>( 317 std::chrono::steady_clock::now() + timedelta); 318 } 319 320 // Condition Variable to signal when the rpcRetryMap_ has been populated. 321 std::condition_variable rpcRetryMapCV_; 322 323 // Mutex to protect RpcRetryMap_. 324 std::mutex rpcRetryMutex_; 325 }; 326 327 } // namespace torch::distributed::rpc 328 329 namespace std { 330 template <> 331 struct hash<torch::distributed::rpc::WorkerInfo> { 332 std::size_t operator()( 333 const torch::distributed::rpc::WorkerInfo& worker_info) const noexcept { 334 return worker_info.id_; 335 } 336 }; 337 } // namespace std 338