xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/agent_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
4 #include <torch/csrc/distributed/rpc/utils.h>
5 
6 namespace torch {
7 namespace distributed {
8 namespace rpc {
9 
10 // All RPC peers should call into this function at the same time. Each peer
11 // provides its own id and name, and this function uses the given Store to
12 // gather global name-to-id mapping on all peers.
13 TORCH_API std::unordered_map<std::string, worker_id_t> collectNames(
14     ::c10d::PrefixStore store,
15     const worker_id_t selfId,
16     const std::string& selfName,
17     const int worldSize);
18 
19 // Ranks in dynamic RPC groups will initially call into this to establish the
20 // name-to-id mapping for the current peers in the group. The current rank will
21 // put its own worker info in the store and discover all the ranks that came
22 // before it. NOTE: This needs to be called with the Dynamic RPC group
23 // membership management token held.
24 TORCH_API std::unordered_map<std::string, worker_id_t> collectCurrentNames(
25     ::c10d::PrefixStore store,
26     const worker_id_t selfId,
27     const std::string& selfName);
28 
29 // Remove name frmo Store, used in dynamic RPC groups.
30 // NOTE: This needs to be called with the Dynamic RPC group
31 // membership management token held.
32 TORCH_API void removeCurrentName(
33     ::c10d::PrefixStore store,
34     const worker_id_t selfId,
35     const std::string& selfName);
36 
37 // This performs a synchronization of all call counts by using store.
38 // All RPC peers wait for others to join to exit at the same time.
39 TORCH_API int syncCallCount(
40     ::c10d::PrefixStore store,
41     const int worldSize,
42     int activeCalls = 0);
43 
44 } // namespace rpc
45 } // namespace distributed
46 } // namespace torch
47