1 2 #pragma once 3 4 #include <shared_mutex> 5 6 #include <torch/csrc/autograd/function.h> 7 8 namespace c10d { 9 10 // `RankLocal` maintains a unique instance of T for each non-autograd thread. 11 // For non-autograd threads, `RankLocal<T>::get()` functions similar to 12 // thread_local. For autograd threads, `RankLocal<T>::get()` returns the 13 // instance of T corresponding to the enqueuing non-autograd thread. The 14 // mechanism allows for rank-specific context shared between forward and 15 // backward. It works for both the one-rank-per-process and one-rank-per-thread 16 // scenarios. 17 // 18 // NOTE: RankLocal doesn't make the underlying objects thread-safe. 19 template <typename T> 20 class RankLocal { 21 public: 22 RankLocal(const RankLocal&) = delete; 23 RankLocal& operator=(const RankLocal&) = delete; 24 get()25 static T& get() { 26 // Fast path: non-autograd threads can simply return 27 // the object reference cached in TLS. 28 if (cached_ != nullptr) { 29 return *cached_; 30 } 31 const auto node = torch::autograd::get_current_node(); 32 auto fwd_thread_id = node == nullptr ? at::RecordFunction::currentThreadId() 33 : node->thread_id(); 34 // Optimistically acquire the read lock first, since most likely we are in 35 // an autograd thread and the object has already been constructed. 36 { 37 std::shared_lock read_lock(lock_); 38 auto it = thread_id_to_rank_local_.find(fwd_thread_id); 39 if (it != thread_id_to_rank_local_.end()) { 40 // Cache for non-autograd threads 41 if (node == nullptr) { 42 cached_ = &it->second; 43 } 44 return it->second; 45 } 46 } 47 48 std::unique_lock write_lock(lock_); 49 auto [it, _] = thread_id_to_rank_local_.try_emplace(fwd_thread_id); 50 // Cache for non-autograd threads 51 if (node == nullptr) { 52 cached_ = &it->second; 53 } 54 return it->second; 55 } 56 57 private: RankLocal()58 RankLocal(){}; 59 thread_local static T* cached_; 60 static std::unordered_map<uint64_t, T> thread_id_to_rank_local_; 61 static std::shared_mutex lock_; 62 }; 63 64 template <typename T> 65 thread_local T* RankLocal<T>::cached_ = nullptr; 66 67 template <typename T> 68 std::unordered_map<uint64_t, T> RankLocal<T>::thread_id_to_rank_local_; 69 70 template <typename T> 71 std::shared_mutex RankLocal<T>::lock_; 72 73 } // namespace c10d 74