1 #pragma once 2 3 #ifdef USE_TENSORPIPE 4 5 #include <atomic> 6 #include <thread> 7 8 #include <c10/core/thread_pool.h> 9 #include <torch/csrc/distributed/c10d/PrefixStore.hpp> 10 #include <torch/csrc/distributed/c10d/Store.hpp> 11 #include <torch/csrc/distributed/rpc/rpc_agent.h> 12 #include <utility> 13 14 // Forward-declare the TensorPipe classes we need, to avoid including its 15 // headers in PyTorch's ones and thus have it become a public dependency. 16 17 namespace tensorpipe { 18 19 class Context; 20 class Error; 21 class Listener; 22 class Message; 23 class Pipe; 24 25 namespace transport { 26 class Context; 27 } // namespace transport 28 29 namespace channel { 30 class Context; 31 } // namespace channel 32 33 } // namespace tensorpipe 34 35 namespace torch::distributed::rpc { 36 37 // These priorities instruct TensorPipe on which transport/channel to pick 38 // during handshake. Higher priorities will take precedence over lower ones. 39 // The transport with lowest priority will be the one used to bootstrap pipes. 40 41 constexpr int64_t kShmTransportPriority = 200; 42 constexpr int64_t kIbvTransportPriority = 100; 43 // The UV transport just uses TCP and should work everywhere, thus keep it last. 44 constexpr int64_t kUvTransportPriority = 0; 45 46 constexpr int64_t kCmaChannelPriority = 1200; 47 constexpr int64_t kMultiplexedUvChannelPriority = 1100; 48 // The basic channel reuses a transport as a channel, and is thus our fallback. 49 constexpr int64_t kBasicChannelPriority = 1000; 50 51 // CPU channel have higher priority than CUDA channels, since the latter might 52 // handle CPU-to-CPU transfers, but will always be less efficient than their 53 // CPU-only counterparts. 54 constexpr int64_t kCudaIpcChannelPriority = 300; 55 constexpr int64_t kCudaGdrChannelPriority = 200; 56 constexpr int64_t kCudaXthChannelPriority = 400; 57 constexpr int64_t kCudaBasicChannelPriority = 0; 58 59 using steady_clock_time_point = 60 std::chrono::time_point<std::chrono::steady_clock>; 61 62 struct TORCH_API TransportRegistration { 63 std::shared_ptr<tensorpipe::transport::Context> transport; 64 int64_t priority; 65 std::string address; 66 }; 67 68 C10_DECLARE_REGISTRY(TensorPipeTransportRegistry, TransportRegistration); 69 70 struct TORCH_API ChannelRegistration { 71 std::shared_ptr<tensorpipe::channel::Context> channel; 72 int64_t priority; 73 }; 74 75 C10_DECLARE_REGISTRY(TensorPipeChannelRegistry, ChannelRegistration); 76 77 constexpr auto kDefaultNumWorkerThreads = 16; 78 79 struct TORCH_API TensorPipeRpcBackendOptions : public RpcBackendOptions { 80 TensorPipeRpcBackendOptions( 81 int numWorkerThreads, 82 std::optional<std::vector<std::string>> transports, 83 std::optional<std::vector<std::string>> channels, 84 float rpc_timeout, 85 std::string init_method, 86 std::unordered_map<std::string, DeviceMap> device_maps = {}, 87 std::vector<c10::Device> devices = {}) RpcBackendOptionsTensorPipeRpcBackendOptions88 : RpcBackendOptions(rpc_timeout, std::move(init_method)), 89 numWorkerThreads(numWorkerThreads), 90 transports(std::move(transports)), 91 channels(std::move(channels)), 92 deviceMaps(std::move(device_maps)), 93 devices(std::move(devices)) { 94 TORCH_CHECK( 95 numWorkerThreads > 0, 96 "num_worker_threads must be positive, got ", 97 numWorkerThreads); 98 99 if (this->transports.has_value()) { 100 for (const std::string& transportName : this->transports.value()) { 101 TORCH_CHECK( 102 TensorPipeTransportRegistry()->Has(transportName), 103 "Unknown transport: ", 104 transportName); 105 } 106 } 107 108 if (this->channels.has_value()) { 109 for (const std::string& channelName : this->channels.value()) { 110 TORCH_CHECK( 111 TensorPipeChannelRegistry()->Has(channelName), 112 "Unknown channel: ", 113 channelName); 114 } 115 } 116 } 117 setDeviceMapTensorPipeRpcBackendOptions118 void setDeviceMap(const std::string& workerName, const DeviceMap& deviceMap) { 119 auto iter = deviceMaps.find(workerName); 120 if (iter == deviceMaps.end()) { 121 deviceMaps[workerName] = deviceMap; 122 } else { 123 for (auto& entry : deviceMap) { 124 // c10::Device has no default constructor, hence map[device] dosn't work 125 // In C++-17 we can use insert_or_assign. 126 auto entryIter = iter->second.find(entry.first); 127 if (entryIter == iter->second.end()) { 128 iter->second.emplace(entry.first, entry.second); 129 } else { 130 entryIter->second = entry.second; 131 } 132 } 133 } 134 } 135 136 int numWorkerThreads; 137 const std::optional<std::vector<std::string>> transports; 138 const std::optional<std::vector<std::string>> channels; 139 std::unordered_map<std::string, DeviceMap> deviceMaps; 140 std::vector<c10::Device> devices; 141 }; 142 143 // Struct to track the network source metrics 144 struct TORCH_API NetworkSourceInfo { 145 worker_id_t srcRank; 146 std::vector<uint8_t> srcMachineAddr; 147 }; 148 149 // Struct to track aggregated network metrics 150 struct TORCH_API AggregatedNetworkData { 151 uint64_t numCalls{0}; 152 uint64_t totalSentBytes{0}; 153 uint64_t totalRecvBytes{0}; 154 uint64_t totalErrors{0}; 155 }; 156 157 // TensorPipeAgent leverages TensorPipe (https://github.com/pytorch/tensorpipe) 158 // to transparently move tensors and payloads through the fastest available 159 // transport or channel. It acts like a hybrid RPC transport, providing shared 160 // memory (linux) and TCP (linux & mac) support. CUDA support is in progress. 161 class TORCH_API TensorPipeAgent : public RpcAgent { 162 public: 163 TensorPipeAgent( 164 const c10::intrusive_ptr<::c10d::Store>& store, 165 std::string selfName, 166 worker_id_t selfId, 167 std::optional<int> worldSize, 168 TensorPipeRpcBackendOptions opts, 169 std::unordered_map<std::string, DeviceMap> reverseDeviceMaps, 170 std::vector<c10::Device> devices, 171 std::unique_ptr<RequestCallback> cb); 172 173 TensorPipeAgent(const TensorPipeAgent&) = delete; 174 TensorPipeAgent& operator=(const TensorPipeAgent&) = delete; 175 176 c10::intrusive_ptr<JitFuture> send( 177 const WorkerInfo& to, 178 c10::intrusive_ptr<Message> message, 179 const float rpcTimeoutSeconds = kUnsetRpcTimeout, 180 const DeviceMap& deviceMap = {}) override; 181 182 // join() and sync() would be deprecated - 183 // https://github.com/pytorch/pytorch/issues/27647 184 void join(bool shutdown = false, float timeout = 0) override; sync()185 void sync() override{}; 186 void startImpl() override; 187 void shutdownImpl() override; 188 189 ~TensorPipeAgent() override; 190 191 const WorkerInfo& getWorkerInfo(const std::string& workerName) const override; 192 const WorkerInfo& getWorkerInfo(worker_id_t workerId) const override; 193 std::vector<WorkerInfo> getWorkerInfos() const override; 194 void updateGroupMembership( 195 const WorkerInfo& workerInfo, 196 const std::vector<c10::Device>& devices, 197 const std::unordered_map<std::string, DeviceMap>& reverseDeviceMaps, 198 bool isJoin); 199 200 std::unordered_map<std::string, std::string> getMetrics() override; 201 202 void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override; 203 204 TensorPipeRpcBackendOptions getBackendOptions() const; 205 206 const c10::intrusive_ptr<::c10d::Store> getStore() const; 207 208 DeviceMap getDeviceMap(const WorkerInfo& dest) const override; 209 210 const std::vector<c10::Device>& getDevices() const override; 211 212 using NetworkDataDict = 213 std::unordered_map<std::string, AggregatedNetworkData>; 214 215 // Returns metrics tracked by the NetworkDataDict 216 NetworkDataDict getNetworkData(); 217 // Returns NetworkSourceInfo struct 218 NetworkSourceInfo getNetworkSourceInfo(); 219 220 static const std::string& guessAddress(); 221 222 // For testing purposes. 223 size_t timeoutMapSize(); 224 size_t numPendingResponses(); 225 size_t messageIdToTimeoutMapSize(); 226 227 const bool isStaticGroup_; 228 229 protected: 230 // TensorPipe write function that could be used to write response 231 // messages by server, and write request messages by client. This 232 // is a protected method since it is overwritten by FaultyTensorPipeAgent 233 virtual void pipeWrite( 234 const std::shared_ptr<tensorpipe::Pipe>&, 235 c10::intrusive_ptr<Message> message, 236 std::vector<c10::Device>&& devices, 237 std::vector<c10::Stream> streams, 238 std::function<void(const tensorpipe::Error&)>) noexcept; 239 240 private: 241 // Removes the given messageId with the given expirationTime from the 242 // timeoutMap_. 243 void removeFromTimeoutMap(uint64_t messageId); 244 245 // Populates workerIdToInfo_ and workerNameToInfo_ using addressStore_ 246 void prepareNames(bool isStaticGroup); 247 248 // Check the static group attribute with the value set in store 249 void checkAndSetStaticGroup(const c10::intrusive_ptr<::c10d::Store>& store); 250 251 const std::string& findWorkerURL(const WorkerInfo& worker) const; 252 253 // Only use for Dynamic RPC groups, method to have worker leave group 254 void leaveGroup(); 255 256 // TensorPipe read function that could be used to read response messages 257 // by client, and read request messages by server. 258 void pipeRead( 259 const std::shared_ptr<tensorpipe::Pipe>&, 260 std::function<void( 261 const tensorpipe::Error&, 262 c10::intrusive_ptr<Message>, 263 std::vector<c10::Stream>)>) noexcept; 264 265 // Callback of listener accept() 266 void onListenerAccepted( 267 const tensorpipe::Error& error, 268 std::shared_ptr<tensorpipe::Pipe>& pipe); 269 270 // Respond to a call from a peer 271 void respond(std::shared_ptr<tensorpipe::Pipe>& pipe); 272 273 void sendCompletedResponseMessage( 274 std::shared_ptr<tensorpipe::Pipe>& pipe, 275 JitFuture& futureResponseMessage, 276 uint64_t messageId, 277 std::vector<c10::Stream> stream); 278 279 // Collects metrics from successful RPC calls 280 void trackNetworkData( 281 uint64_t requestSize, 282 uint64_t responseSize, 283 const std::string& destWorkerName); 284 285 // Collects metrics from failed RPC calls 286 void trackNetworkError( 287 uint64_t requestSize, 288 const std::string& destWorkerName); 289 290 inline std::vector<c10::Device> getDevicesForRemote( 291 const std::string& remoteName, 292 const Message& message) const; 293 294 // When a request+response completes, we need to mark the future message as 295 // complete. However, if its timeout has already expired, it already has an 296 // error set. There is no atomic "test-and-set" way to mark a future complete 297 // only if it isn't yet. It does exist for errors (setErrorIfNeeded) but, even 298 // then, it ends up printing a log message, which may worry the user. To solve 299 // both issues we use a separate atomic flag to know the status of the future. 300 struct AtomicJitFuture { AtomicJitFutureAtomicJitFuture301 explicit AtomicJitFuture(const std::vector<c10::Device>& devices) { 302 jitFuture = c10::make_intrusive<at::ivalue::Future>( 303 at::AnyClassType::get(), devices); 304 } 305 306 std::atomic_flag isComplete = ATOMIC_FLAG_INIT; 307 c10::intrusive_ptr<JitFuture> jitFuture; 308 }; 309 310 // Maintains state per client pipe to track pending response messages and 311 // error states. pendingResponseMessage_ should be protected by a mutex since 312 // it can be raced with user send() call. 313 // TODO: To achieve better performance we can have a pipe pool per 314 // client that can be configured using RpcBackendOptions. 315 struct ClientPipe { ClientPipeClientPipe316 explicit ClientPipe(std::shared_ptr<tensorpipe::Pipe> pipe) 317 : pipe_(std::move(pipe)) {} 318 std::shared_ptr<tensorpipe::Pipe> pipe_; 319 mutable std::mutex mutex_; 320 bool inError_{false}; 321 // Map from Message Request ID's to corresponding futures. 322 std::unordered_map<uint64_t, std::shared_ptr<AtomicJitFuture>> 323 pendingResponseMessage_; 324 }; 325 326 const c10::intrusive_ptr<::c10d::Store> store_; 327 328 const TensorPipeRpcBackendOptions opts_; 329 // For dynamic RPC, the reverse device maps are updated whenever a new rank 330 // joins or leaves the group 331 std::unordered_map<std::string, DeviceMap> reverseDeviceMaps_; 332 // Local devices used by this agent. If application didn't specify this 333 // field, it will be initialized using corresponding local devices in 334 // opts_.deviceMaps and reverseDeviceMaps_; 335 std::vector<c10::Device> devices_; 336 337 ThreadPool threadPool_; 338 std::shared_ptr<tensorpipe::Context> context_; 339 std::shared_ptr<tensorpipe::Listener> listener_; 340 341 mutable std::mutex connectedPipesMutex_; 342 std::unordered_map<worker_id_t, ClientPipe> connectedPipes_; 343 344 // Maps keyed on name and id for easy WorkerInfo lookup. 345 std::unordered_map<worker_id_t, WorkerInfo> workerIdToInfo_; 346 std::unordered_map<std::string, WorkerInfo> workerNameToInfo_; 347 std::unordered_map<std::string, std::string> workerNameToURL_; 348 349 ::c10d::PrefixStore rankToNameStore_; 350 ::c10d::PrefixStore nameToAddressStore_; 351 // Store keys that will used to count joined processes and active calls during 352 // the shutdown process 353 ::c10d::PrefixStore shutdownStore_; 354 int worldSize_ = 0; 355 std::atomic<uint64_t> nextMessageID_{0}; 356 357 // Metadata used for tracking of whether certain RPCs have timed out or not. 358 struct TimeoutMessageMetadata { TimeoutMessageMetadataTimeoutMessageMetadata359 TimeoutMessageMetadata( 360 uint64_t messageId_, 361 std::shared_ptr<AtomicJitFuture> responseFuture_, 362 std::chrono::milliseconds timeout_) 363 : messageId(messageId_), 364 responseFuture(std::move(responseFuture_)), 365 timeout(timeout_) {} 366 uint64_t messageId; 367 std::shared_ptr<AtomicJitFuture> responseFuture; 368 std::chrono::milliseconds timeout; 369 }; 370 371 // Map to store the expiration times for each message. 372 std::map<steady_clock_time_point, std::vector<TimeoutMessageMetadata>> 373 timeoutMap_; 374 375 // Map to store the messageId to expiry time. 376 std::unordered_map<uint64_t, steady_clock_time_point> messageIdToTimeout_; 377 378 // Thread that will poll the timeoutMap_ for timed out messages and mark them 379 // with an error accordingly 380 std::thread timeoutThread_; 381 382 // Function run by the timeoutThread_ to check for timed out RPCs 383 void pollTimeoutRpcs(); 384 385 // Mutex to guard the timeoutMap_ 386 std::mutex timeoutMapMutex_; 387 388 // Condition Variable to signal population of the timeoutMap_ 389 std::condition_variable timeoutThreadCV_; 390 391 // Returns the expiration time for an RPC by adding the current time to the 392 // passed in timeout. computeRpcMessageExpiryTime(std::chrono::milliseconds timeout)393 inline steady_clock_time_point computeRpcMessageExpiryTime( 394 std::chrono::milliseconds timeout) const { 395 return std::chrono::time_point_cast<std::chrono::milliseconds>( 396 std::chrono::steady_clock::now() + timeout); 397 } 398 399 // Handle error on an outgoing pipe 400 void handleClientError( 401 ClientPipe& clientPipe, 402 const tensorpipe::Error& error); 403 404 // This is a generic struct for capturing Time-Series Metrics. It keeps a 405 // running sum and count of data points (observations), and can return an 406 // average of the data points seen so far. This is currently only used for 407 // tracking the GIL Wait Time in RPC Agents, but can be used for other metrics 408 // as well. 409 struct TimeSeriesMetricsTracker { 410 // Running sum of the data points seen so far 411 uint64_t currentSum_; 412 // Running count of the data points seen so far 413 uint64_t currentCount_; 414 415 explicit TimeSeriesMetricsTracker( 416 uint64_t currentSum = 0, 417 uint64_t currentCount = 0); 418 419 // Adds a data point (which is basically one observation for the metric 420 // being tracked) to the running sum and count. 421 void addData(uint64_t dataPoint); 422 // Returns the average of all the data points seen so far. 423 float computeAverage() const; 424 }; 425 426 // Map of Time-Series metrics tracked by the RPC Agent 427 std::unordered_map<std::string, TimeSeriesMetricsTracker> timeSeriesMetrics_; 428 // Mutex to guard timeSeriesMetrics_ 429 std::mutex metricsMutex_; 430 431 // Custom lock guard used to check if the RPC group is dynamic and lock the 432 // mutex if so 433 struct GroupMembershipLockGuard { GroupMembershipLockGuardGroupMembershipLockGuard434 GroupMembershipLockGuard(std::mutex& mutex, bool isStaticGroup) 435 : ref_(mutex), isStaticGroup_(isStaticGroup) { 436 if (isStaticGroup_) { 437 ref_.lock(); 438 } 439 } 440 ~GroupMembershipLockGuardGroupMembershipLockGuard441 ~GroupMembershipLockGuard() { 442 if (isStaticGroup_) { 443 ref_.unlock(); 444 } 445 } 446 447 GroupMembershipLockGuard(const GroupMembershipLockGuard&) = delete; 448 449 private: 450 std::mutex& ref_; 451 bool isStaticGroup_; 452 }; 453 // Mutex to guard access to group membership data 454 // e.g. updates to (workerIdToInfo_, workerNameToInfo_, workerNameToURL_) 455 mutable std::mutex groupMembershipMutex_; 456 457 // Map to Track Network Data 458 NetworkDataDict networkData_; 459 // Mutex to guard networkData_ 460 std::mutex networkDataMutex_; 461 462 // A mutex and a cv to guard access to the call counts and watch for changes. 463 std::mutex callCountMutex_; 464 std::condition_variable callCountCV_; 465 // Running total of un-processed, un-errored RPC calls sent 466 int32_t clientActiveCalls_{0}; 467 // Running total of un-processed RPC requests received 468 int32_t serverActiveCalls_{0}; 469 // Running total of RPC requests that will be completed asynchronously 470 int32_t serverActiveAsyncCalls_{0}; 471 472 // Whether a global graceful shutdown has begun, in which case we'll silence 473 // error messages due to remote workers closing their pipes. 474 std::atomic<bool> shuttingDown_{false}; 475 476 // Helpers to modify the counts while correctly dealing with the mutex and cv. 477 void increaseCallCount(int32_t& count); 478 void decreaseCallCount(int32_t& count); 479 480 // Helpers to set the state of the requests. 481 void markFutureAsComplete( 482 std::shared_ptr<AtomicJitFuture> atomicFuture, 483 c10::intrusive_ptr<Message> message, 484 std::vector<c10::Stream> streams); 485 void markFutureWithError( 486 std::shared_ptr<AtomicJitFuture> atomicFuture, 487 std::string errorMsg); 488 }; 489 490 } // namespace torch::distributed::rpc 491 492 #endif // USE_TENSORPIPE 493