xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/RankLocal.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #pragma once
3 
4 #include <shared_mutex>
5 
6 #include <torch/csrc/autograd/function.h>
7 
8 namespace c10d {
9 
10 // `RankLocal` maintains a unique instance of T for each non-autograd thread.
11 // For non-autograd threads, `RankLocal<T>::get()` functions similar to
12 // thread_local. For autograd threads, `RankLocal<T>::get()` returns the
13 // instance of T corresponding to the enqueuing non-autograd thread. The
14 // mechanism allows for rank-specific context shared between forward and
15 // backward. It works for both the one-rank-per-process and one-rank-per-thread
16 // scenarios.
17 //
18 // NOTE: RankLocal doesn't make the underlying objects thread-safe.
19 template <typename T>
20 class RankLocal {
21  public:
22   RankLocal(const RankLocal&) = delete;
23   RankLocal& operator=(const RankLocal&) = delete;
24 
get()25   static T& get() {
26     // Fast path: non-autograd threads can simply return
27     // the object reference cached in TLS.
28     if (cached_ != nullptr) {
29       return *cached_;
30     }
31     const auto node = torch::autograd::get_current_node();
32     auto fwd_thread_id = node == nullptr ? at::RecordFunction::currentThreadId()
33                                          : node->thread_id();
34     // Optimistically acquire the read lock first, since most likely we are in
35     // an autograd thread and the object has already been constructed.
36     {
37       std::shared_lock read_lock(lock_);
38       auto it = thread_id_to_rank_local_.find(fwd_thread_id);
39       if (it != thread_id_to_rank_local_.end()) {
40         // Cache for non-autograd threads
41         if (node == nullptr) {
42           cached_ = &it->second;
43         }
44         return it->second;
45       }
46     }
47 
48     std::unique_lock write_lock(lock_);
49     auto [it, _] = thread_id_to_rank_local_.try_emplace(fwd_thread_id);
50     // Cache for non-autograd threads
51     if (node == nullptr) {
52       cached_ = &it->second;
53     }
54     return it->second;
55   }
56 
57  private:
RankLocal()58   RankLocal(){};
59   thread_local static T* cached_;
60   static std::unordered_map<uint64_t, T> thread_id_to_rank_local_;
61   static std::shared_mutex lock_;
62 };
63 
64 template <typename T>
65 thread_local T* RankLocal<T>::cached_ = nullptr;
66 
67 template <typename T>
68 std::unordered_map<uint64_t, T> RankLocal<T>::thread_id_to_rank_local_;
69 
70 template <typename T>
71 std::shared_mutex RankLocal<T>::lock_;
72 
73 } // namespace c10d
74