xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/Export.h>
3 #include <torch/csrc/distributed/rpc/types.h>
4 #include <mutex>
5 #include <optional>
6 #include <unordered_map>
7 
8 namespace torch::distributed::rpc {
9 extern const std::string REMOTE_PROFILING_KEY_PREFIX;
10 
11 class TORCH_API RemoteProfilerManager {
12  public:
13   // Retrieves the lazily-initialized RemoteProfilerManager singleton instance.
14   static RemoteProfilerManager& getInstance();
15   // Sets the current, thread-local profiling key.
16   void setCurrentKey(std::string key);
17   // Returns whether the current profiling key is set.
18   bool isCurrentKeySet() const;
19   // Unsets the current, thread-local profiling key to allow other RPCs to reset
20   // it.
21   void unsetCurrentKey();
22   // inserts a pair (globallyUniqueId, key) to an in-memory map. The
23   // corresponding ID is used in RPC deserialization to prefix remotely profiled
24   // events with the right key.
25   void saveRPCKey(
26       ProfilingId globallyUniqueId,
27       const std::string& rpcProfilingKey);
28   // Retrieves the profiling key corresponding to the given globallyUniqueId.
29   // Throws if it is not found.
30   std::string retrieveRPCProfilingKey(const ProfilingId& globallyUniqueId);
31   // Generates the next globally unique ID for profiling.
32   ProfilingId getNextProfilerId();
33   // Retrieves the currently set thread-local profiling key. Throws if it is not
34   // set.
35   std::string& getCurrentProfilingKey();
36   // erases the globallyUniqueId from the map. This can help save memory in the
37   // case that many RPCs are being profiled.
38   void eraseKey(const ProfilingId& globallyUniqueId);
39 
40   RemoteProfilerManager(const RemoteProfilerManager& other) = delete;
41   RemoteProfilerManager operator=(const RemoteProfilerManager& other) = delete;
42   RemoteProfilerManager(RemoteProfilerManager&&) = delete;
43   RemoteProfilerManager& operator=(RemoteProfilerManager&&) = delete;
44 
45  private:
46   RemoteProfilerManager();
47   ~RemoteProfilerManager() = default;
48   local_id_t getNextLocalId();
49   std::unordered_map<ProfilingId, std::string, ProfilingId::Hash>
50       profiledRpcKeys_;
51   static thread_local std::optional<std::string> currentThreadLocalKey_;
52   std::mutex mutex_;
53   local_id_t currentLocalId_;
54 };
55 } // namespace torch::distributed::rpc
56