xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/tensorpipe_agent.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_TENSORPIPE
4 
5 #include <atomic>
6 #include <thread>
7 
8 #include <c10/core/thread_pool.h>
9 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
10 #include <torch/csrc/distributed/c10d/Store.hpp>
11 #include <torch/csrc/distributed/rpc/rpc_agent.h>
12 #include <utility>
13 
14 // Forward-declare the TensorPipe classes we need, to avoid including its
15 // headers in PyTorch's ones and thus have it become a public dependency.
16 
17 namespace tensorpipe {
18 
19 class Context;
20 class Error;
21 class Listener;
22 class Message;
23 class Pipe;
24 
25 namespace transport {
26 class Context;
27 } // namespace transport
28 
29 namespace channel {
30 class Context;
31 } // namespace channel
32 
33 } // namespace tensorpipe
34 
35 namespace torch::distributed::rpc {
36 
37 // These priorities instruct TensorPipe on which transport/channel to pick
38 // during handshake. Higher priorities will take precedence over lower ones.
39 // The transport with lowest priority will be the one used to bootstrap pipes.
40 
41 constexpr int64_t kShmTransportPriority = 200;
42 constexpr int64_t kIbvTransportPriority = 100;
43 // The UV transport just uses TCP and should work everywhere, thus keep it last.
44 constexpr int64_t kUvTransportPriority = 0;
45 
46 constexpr int64_t kCmaChannelPriority = 1200;
47 constexpr int64_t kMultiplexedUvChannelPriority = 1100;
48 // The basic channel reuses a transport as a channel, and is thus our fallback.
49 constexpr int64_t kBasicChannelPriority = 1000;
50 
51 // CPU channel have higher priority than CUDA channels, since the latter might
52 // handle CPU-to-CPU transfers, but will always be less efficient than their
53 // CPU-only counterparts.
54 constexpr int64_t kCudaIpcChannelPriority = 300;
55 constexpr int64_t kCudaGdrChannelPriority = 200;
56 constexpr int64_t kCudaXthChannelPriority = 400;
57 constexpr int64_t kCudaBasicChannelPriority = 0;
58 
59 using steady_clock_time_point =
60     std::chrono::time_point<std::chrono::steady_clock>;
61 
62 struct TORCH_API TransportRegistration {
63   std::shared_ptr<tensorpipe::transport::Context> transport;
64   int64_t priority;
65   std::string address;
66 };
67 
68 C10_DECLARE_REGISTRY(TensorPipeTransportRegistry, TransportRegistration);
69 
70 struct TORCH_API ChannelRegistration {
71   std::shared_ptr<tensorpipe::channel::Context> channel;
72   int64_t priority;
73 };
74 
75 C10_DECLARE_REGISTRY(TensorPipeChannelRegistry, ChannelRegistration);
76 
77 constexpr auto kDefaultNumWorkerThreads = 16;
78 
79 struct TORCH_API TensorPipeRpcBackendOptions : public RpcBackendOptions {
80   TensorPipeRpcBackendOptions(
81       int numWorkerThreads,
82       std::optional<std::vector<std::string>> transports,
83       std::optional<std::vector<std::string>> channels,
84       float rpc_timeout,
85       std::string init_method,
86       std::unordered_map<std::string, DeviceMap> device_maps = {},
87       std::vector<c10::Device> devices = {})
RpcBackendOptionsTensorPipeRpcBackendOptions88       : RpcBackendOptions(rpc_timeout, std::move(init_method)),
89         numWorkerThreads(numWorkerThreads),
90         transports(std::move(transports)),
91         channels(std::move(channels)),
92         deviceMaps(std::move(device_maps)),
93         devices(std::move(devices)) {
94     TORCH_CHECK(
95         numWorkerThreads > 0,
96         "num_worker_threads must be positive, got ",
97         numWorkerThreads);
98 
99     if (this->transports.has_value()) {
100       for (const std::string& transportName : this->transports.value()) {
101         TORCH_CHECK(
102             TensorPipeTransportRegistry()->Has(transportName),
103             "Unknown transport: ",
104             transportName);
105       }
106     }
107 
108     if (this->channels.has_value()) {
109       for (const std::string& channelName : this->channels.value()) {
110         TORCH_CHECK(
111             TensorPipeChannelRegistry()->Has(channelName),
112             "Unknown channel: ",
113             channelName);
114       }
115     }
116   }
117 
setDeviceMapTensorPipeRpcBackendOptions118   void setDeviceMap(const std::string& workerName, const DeviceMap& deviceMap) {
119     auto iter = deviceMaps.find(workerName);
120     if (iter == deviceMaps.end()) {
121       deviceMaps[workerName] = deviceMap;
122     } else {
123       for (auto& entry : deviceMap) {
124         // c10::Device has no default constructor, hence map[device] dosn't work
125         // In C++-17 we can use insert_or_assign.
126         auto entryIter = iter->second.find(entry.first);
127         if (entryIter == iter->second.end()) {
128           iter->second.emplace(entry.first, entry.second);
129         } else {
130           entryIter->second = entry.second;
131         }
132       }
133     }
134   }
135 
136   int numWorkerThreads;
137   const std::optional<std::vector<std::string>> transports;
138   const std::optional<std::vector<std::string>> channels;
139   std::unordered_map<std::string, DeviceMap> deviceMaps;
140   std::vector<c10::Device> devices;
141 };
142 
143 // Struct to track the network source metrics
144 struct TORCH_API NetworkSourceInfo {
145   worker_id_t srcRank;
146   std::vector<uint8_t> srcMachineAddr;
147 };
148 
149 // Struct to track aggregated network metrics
150 struct TORCH_API AggregatedNetworkData {
151   uint64_t numCalls{0};
152   uint64_t totalSentBytes{0};
153   uint64_t totalRecvBytes{0};
154   uint64_t totalErrors{0};
155 };
156 
157 // TensorPipeAgent leverages TensorPipe (https://github.com/pytorch/tensorpipe)
158 // to transparently move tensors and payloads through the fastest available
159 // transport or channel. It acts like a hybrid RPC transport, providing shared
160 // memory (linux) and TCP (linux & mac) support. CUDA support is in progress.
161 class TORCH_API TensorPipeAgent : public RpcAgent {
162  public:
163   TensorPipeAgent(
164       const c10::intrusive_ptr<::c10d::Store>& store,
165       std::string selfName,
166       worker_id_t selfId,
167       std::optional<int> worldSize,
168       TensorPipeRpcBackendOptions opts,
169       std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
170       std::vector<c10::Device> devices,
171       std::unique_ptr<RequestCallback> cb);
172 
173   TensorPipeAgent(const TensorPipeAgent&) = delete;
174   TensorPipeAgent& operator=(const TensorPipeAgent&) = delete;
175 
176   c10::intrusive_ptr<JitFuture> send(
177       const WorkerInfo& to,
178       c10::intrusive_ptr<Message> message,
179       const float rpcTimeoutSeconds = kUnsetRpcTimeout,
180       const DeviceMap& deviceMap = {}) override;
181 
182   // join() and sync() would be deprecated -
183   // https://github.com/pytorch/pytorch/issues/27647
184   void join(bool shutdown = false, float timeout = 0) override;
sync()185   void sync() override{};
186   void startImpl() override;
187   void shutdownImpl() override;
188 
189   ~TensorPipeAgent() override;
190 
191   const WorkerInfo& getWorkerInfo(const std::string& workerName) const override;
192   const WorkerInfo& getWorkerInfo(worker_id_t workerId) const override;
193   std::vector<WorkerInfo> getWorkerInfos() const override;
194   void updateGroupMembership(
195       const WorkerInfo& workerInfo,
196       const std::vector<c10::Device>& devices,
197       const std::unordered_map<std::string, DeviceMap>& reverseDeviceMaps,
198       bool isJoin);
199 
200   std::unordered_map<std::string, std::string> getMetrics() override;
201 
202   void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override;
203 
204   TensorPipeRpcBackendOptions getBackendOptions() const;
205 
206   const c10::intrusive_ptr<::c10d::Store> getStore() const;
207 
208   DeviceMap getDeviceMap(const WorkerInfo& dest) const override;
209 
210   const std::vector<c10::Device>& getDevices() const override;
211 
212   using NetworkDataDict =
213       std::unordered_map<std::string, AggregatedNetworkData>;
214 
215   // Returns metrics tracked by the NetworkDataDict
216   NetworkDataDict getNetworkData();
217   // Returns NetworkSourceInfo struct
218   NetworkSourceInfo getNetworkSourceInfo();
219 
220   static const std::string& guessAddress();
221 
222   // For testing purposes.
223   size_t timeoutMapSize();
224   size_t numPendingResponses();
225   size_t messageIdToTimeoutMapSize();
226 
227   const bool isStaticGroup_;
228 
229  protected:
230   // TensorPipe write function that could be used to write response
231   // messages by server, and write request messages by client. This
232   // is a protected method since it is overwritten by FaultyTensorPipeAgent
233   virtual void pipeWrite(
234       const std::shared_ptr<tensorpipe::Pipe>&,
235       c10::intrusive_ptr<Message> message,
236       std::vector<c10::Device>&& devices,
237       std::vector<c10::Stream> streams,
238       std::function<void(const tensorpipe::Error&)>) noexcept;
239 
240  private:
241   // Removes the given messageId with the given expirationTime from the
242   // timeoutMap_.
243   void removeFromTimeoutMap(uint64_t messageId);
244 
245   // Populates workerIdToInfo_ and workerNameToInfo_ using addressStore_
246   void prepareNames(bool isStaticGroup);
247 
248   // Check the static group attribute with the value set in store
249   void checkAndSetStaticGroup(const c10::intrusive_ptr<::c10d::Store>& store);
250 
251   const std::string& findWorkerURL(const WorkerInfo& worker) const;
252 
253   // Only use for Dynamic RPC groups, method to have worker leave group
254   void leaveGroup();
255 
256   // TensorPipe read function that could be used to read response messages
257   // by client, and read request messages by server.
258   void pipeRead(
259       const std::shared_ptr<tensorpipe::Pipe>&,
260       std::function<void(
261           const tensorpipe::Error&,
262           c10::intrusive_ptr<Message>,
263           std::vector<c10::Stream>)>) noexcept;
264 
265   // Callback of listener accept()
266   void onListenerAccepted(
267       const tensorpipe::Error& error,
268       std::shared_ptr<tensorpipe::Pipe>& pipe);
269 
270   // Respond to a call from a peer
271   void respond(std::shared_ptr<tensorpipe::Pipe>& pipe);
272 
273   void sendCompletedResponseMessage(
274       std::shared_ptr<tensorpipe::Pipe>& pipe,
275       JitFuture& futureResponseMessage,
276       uint64_t messageId,
277       std::vector<c10::Stream> stream);
278 
279   // Collects metrics from successful RPC calls
280   void trackNetworkData(
281       uint64_t requestSize,
282       uint64_t responseSize,
283       const std::string& destWorkerName);
284 
285   // Collects metrics from failed RPC calls
286   void trackNetworkError(
287       uint64_t requestSize,
288       const std::string& destWorkerName);
289 
290   inline std::vector<c10::Device> getDevicesForRemote(
291       const std::string& remoteName,
292       const Message& message) const;
293 
294   // When a request+response completes, we need to mark the future message as
295   // complete. However, if its timeout has already expired, it already has an
296   // error set. There is no atomic "test-and-set" way to mark a future complete
297   // only if it isn't yet. It does exist for errors (setErrorIfNeeded) but, even
298   // then, it ends up printing a log message, which may worry the user. To solve
299   // both issues we use a separate atomic flag to know the status of the future.
300   struct AtomicJitFuture {
AtomicJitFutureAtomicJitFuture301     explicit AtomicJitFuture(const std::vector<c10::Device>& devices) {
302       jitFuture = c10::make_intrusive<at::ivalue::Future>(
303           at::AnyClassType::get(), devices);
304     }
305 
306     std::atomic_flag isComplete = ATOMIC_FLAG_INIT;
307     c10::intrusive_ptr<JitFuture> jitFuture;
308   };
309 
310   // Maintains state per client pipe to track pending response messages and
311   // error states. pendingResponseMessage_ should be protected by a mutex since
312   // it can be raced with user send() call.
313   // TODO: To achieve better performance we can have a pipe pool per
314   // client that can be configured using RpcBackendOptions.
315   struct ClientPipe {
ClientPipeClientPipe316     explicit ClientPipe(std::shared_ptr<tensorpipe::Pipe> pipe)
317         : pipe_(std::move(pipe)) {}
318     std::shared_ptr<tensorpipe::Pipe> pipe_;
319     mutable std::mutex mutex_;
320     bool inError_{false};
321     // Map from Message Request ID's to corresponding futures.
322     std::unordered_map<uint64_t, std::shared_ptr<AtomicJitFuture>>
323         pendingResponseMessage_;
324   };
325 
326   const c10::intrusive_ptr<::c10d::Store> store_;
327 
328   const TensorPipeRpcBackendOptions opts_;
329   // For dynamic RPC, the reverse device maps are updated whenever a new rank
330   // joins or leaves the group
331   std::unordered_map<std::string, DeviceMap> reverseDeviceMaps_;
332   // Local devices used by this agent. If application didn't specify this
333   // field, it will be initialized using corresponding local devices in
334   // opts_.deviceMaps and reverseDeviceMaps_;
335   std::vector<c10::Device> devices_;
336 
337   ThreadPool threadPool_;
338   std::shared_ptr<tensorpipe::Context> context_;
339   std::shared_ptr<tensorpipe::Listener> listener_;
340 
341   mutable std::mutex connectedPipesMutex_;
342   std::unordered_map<worker_id_t, ClientPipe> connectedPipes_;
343 
344   // Maps keyed on name and id for easy WorkerInfo lookup.
345   std::unordered_map<worker_id_t, WorkerInfo> workerIdToInfo_;
346   std::unordered_map<std::string, WorkerInfo> workerNameToInfo_;
347   std::unordered_map<std::string, std::string> workerNameToURL_;
348 
349   ::c10d::PrefixStore rankToNameStore_;
350   ::c10d::PrefixStore nameToAddressStore_;
351   // Store keys that will used to count joined processes and active calls during
352   // the shutdown process
353   ::c10d::PrefixStore shutdownStore_;
354   int worldSize_ = 0;
355   std::atomic<uint64_t> nextMessageID_{0};
356 
357   // Metadata used for tracking of whether certain RPCs have timed out or not.
358   struct TimeoutMessageMetadata {
TimeoutMessageMetadataTimeoutMessageMetadata359     TimeoutMessageMetadata(
360         uint64_t messageId_,
361         std::shared_ptr<AtomicJitFuture> responseFuture_,
362         std::chrono::milliseconds timeout_)
363         : messageId(messageId_),
364           responseFuture(std::move(responseFuture_)),
365           timeout(timeout_) {}
366     uint64_t messageId;
367     std::shared_ptr<AtomicJitFuture> responseFuture;
368     std::chrono::milliseconds timeout;
369   };
370 
371   // Map to store the expiration times for each message.
372   std::map<steady_clock_time_point, std::vector<TimeoutMessageMetadata>>
373       timeoutMap_;
374 
375   // Map to store the messageId to expiry time.
376   std::unordered_map<uint64_t, steady_clock_time_point> messageIdToTimeout_;
377 
378   // Thread that will poll the timeoutMap_ for timed out messages and mark them
379   // with an error accordingly
380   std::thread timeoutThread_;
381 
382   // Function run by the timeoutThread_ to check for timed out RPCs
383   void pollTimeoutRpcs();
384 
385   // Mutex to guard the timeoutMap_
386   std::mutex timeoutMapMutex_;
387 
388   // Condition Variable to signal population of the timeoutMap_
389   std::condition_variable timeoutThreadCV_;
390 
391   // Returns the expiration time for an RPC by adding the current time to the
392   // passed in timeout.
computeRpcMessageExpiryTime(std::chrono::milliseconds timeout)393   inline steady_clock_time_point computeRpcMessageExpiryTime(
394       std::chrono::milliseconds timeout) const {
395     return std::chrono::time_point_cast<std::chrono::milliseconds>(
396         std::chrono::steady_clock::now() + timeout);
397   }
398 
399   // Handle error on an outgoing pipe
400   void handleClientError(
401       ClientPipe& clientPipe,
402       const tensorpipe::Error& error);
403 
404   // This is a generic struct for capturing Time-Series Metrics. It keeps a
405   // running sum and count of data points (observations), and can return an
406   // average of the data points seen so far. This is currently only used for
407   // tracking the GIL Wait Time in RPC Agents, but can be used for other metrics
408   // as well.
409   struct TimeSeriesMetricsTracker {
410     // Running sum of the data points seen so far
411     uint64_t currentSum_;
412     // Running count of the data points seen so far
413     uint64_t currentCount_;
414 
415     explicit TimeSeriesMetricsTracker(
416         uint64_t currentSum = 0,
417         uint64_t currentCount = 0);
418 
419     // Adds a data point (which is basically one observation for the metric
420     // being tracked) to the running sum and count.
421     void addData(uint64_t dataPoint);
422     // Returns the average of all the data points seen so far.
423     float computeAverage() const;
424   };
425 
426   // Map of Time-Series metrics tracked by the RPC Agent
427   std::unordered_map<std::string, TimeSeriesMetricsTracker> timeSeriesMetrics_;
428   // Mutex to guard timeSeriesMetrics_
429   std::mutex metricsMutex_;
430 
431   // Custom lock guard used to check if the RPC group is dynamic and lock the
432   // mutex if so
433   struct GroupMembershipLockGuard {
GroupMembershipLockGuardGroupMembershipLockGuard434     GroupMembershipLockGuard(std::mutex& mutex, bool isStaticGroup)
435         : ref_(mutex), isStaticGroup_(isStaticGroup) {
436       if (isStaticGroup_) {
437         ref_.lock();
438       }
439     }
440 
~GroupMembershipLockGuardGroupMembershipLockGuard441     ~GroupMembershipLockGuard() {
442       if (isStaticGroup_) {
443         ref_.unlock();
444       }
445     }
446 
447     GroupMembershipLockGuard(const GroupMembershipLockGuard&) = delete;
448 
449    private:
450     std::mutex& ref_;
451     bool isStaticGroup_;
452   };
453   // Mutex to guard access to group membership data
454   // e.g. updates to (workerIdToInfo_, workerNameToInfo_, workerNameToURL_)
455   mutable std::mutex groupMembershipMutex_;
456 
457   // Map to Track Network Data
458   NetworkDataDict networkData_;
459   // Mutex to guard networkData_
460   std::mutex networkDataMutex_;
461 
462   // A mutex and a cv to guard access to the call counts and watch for changes.
463   std::mutex callCountMutex_;
464   std::condition_variable callCountCV_;
465   // Running total of un-processed, un-errored RPC calls sent
466   int32_t clientActiveCalls_{0};
467   // Running total of un-processed RPC requests received
468   int32_t serverActiveCalls_{0};
469   // Running total of RPC requests that will be completed asynchronously
470   int32_t serverActiveAsyncCalls_{0};
471 
472   // Whether a global graceful shutdown has begun, in which case we'll silence
473   // error messages due to remote workers closing their pipes.
474   std::atomic<bool> shuttingDown_{false};
475 
476   // Helpers to modify the counts while correctly dealing with the mutex and cv.
477   void increaseCallCount(int32_t& count);
478   void decreaseCallCount(int32_t& count);
479 
480   // Helpers to set the state of the requests.
481   void markFutureAsComplete(
482       std::shared_ptr<AtomicJitFuture> atomicFuture,
483       c10::intrusive_ptr<Message> message,
484       std::vector<c10::Stream> streams);
485   void markFutureWithError(
486       std::shared_ptr<AtomicJitFuture> atomicFuture,
487       std::string errorMsg);
488 };
489 
490 } // namespace torch::distributed::rpc
491 
492 #endif // USE_TENSORPIPE
493