1 #pragma once 2 #include <torch/csrc/Export.h> 3 #include <torch/csrc/distributed/rpc/types.h> 4 #include <mutex> 5 #include <optional> 6 #include <unordered_map> 7 8 namespace torch::distributed::rpc { 9 extern const std::string REMOTE_PROFILING_KEY_PREFIX; 10 11 class TORCH_API RemoteProfilerManager { 12 public: 13 // Retrieves the lazily-initialized RemoteProfilerManager singleton instance. 14 static RemoteProfilerManager& getInstance(); 15 // Sets the current, thread-local profiling key. 16 void setCurrentKey(std::string key); 17 // Returns whether the current profiling key is set. 18 bool isCurrentKeySet() const; 19 // Unsets the current, thread-local profiling key to allow other RPCs to reset 20 // it. 21 void unsetCurrentKey(); 22 // inserts a pair (globallyUniqueId, key) to an in-memory map. The 23 // corresponding ID is used in RPC deserialization to prefix remotely profiled 24 // events with the right key. 25 void saveRPCKey( 26 ProfilingId globallyUniqueId, 27 const std::string& rpcProfilingKey); 28 // Retrieves the profiling key corresponding to the given globallyUniqueId. 29 // Throws if it is not found. 30 std::string retrieveRPCProfilingKey(const ProfilingId& globallyUniqueId); 31 // Generates the next globally unique ID for profiling. 32 ProfilingId getNextProfilerId(); 33 // Retrieves the currently set thread-local profiling key. Throws if it is not 34 // set. 35 std::string& getCurrentProfilingKey(); 36 // erases the globallyUniqueId from the map. This can help save memory in the 37 // case that many RPCs are being profiled. 38 void eraseKey(const ProfilingId& globallyUniqueId); 39 40 RemoteProfilerManager(const RemoteProfilerManager& other) = delete; 41 RemoteProfilerManager operator=(const RemoteProfilerManager& other) = delete; 42 RemoteProfilerManager(RemoteProfilerManager&&) = delete; 43 RemoteProfilerManager& operator=(RemoteProfilerManager&&) = delete; 44 45 private: 46 RemoteProfilerManager(); 47 ~RemoteProfilerManager() = default; 48 local_id_t getNextLocalId(); 49 std::unordered_map<ProfilingId, std::string, ProfilingId::Hash> 50 profiledRpcKeys_; 51 static thread_local std::optional<std::string> currentThreadLocalKey_; 52 std::mutex mutex_; 53 local_id_t currentLocalId_; 54 }; 55 } // namespace torch::distributed::rpc 56