xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/distributed/rpc/rpc_command_base.h>
4 #include <torch/csrc/jit/serialization/pickle.h>
5 #include <torch/csrc/utils/byte_order.h>
6 
7 namespace torch::distributed::rpc {
8 const std::string REMOTE_PROFILING_KEY_PREFIX = "#remote_op: ";
9 constexpr int kAutoIncrementBits = 48;
10 /*static */ thread_local std::optional<std::string>
11     RemoteProfilerManager::currentThreadLocalKey_ = std::nullopt;
getInstance()12 /*static */ RemoteProfilerManager& RemoteProfilerManager::getInstance() {
13   static RemoteProfilerManager* handler = new RemoteProfilerManager();
14   return *handler;
15 }
16 
setCurrentKey(std::string key)17 void RemoteProfilerManager::setCurrentKey(std::string key) {
18   // We should not allow overriding the current key, it needs to be committed
19   // with writeKey() explicitly first.
20   if (RemoteProfilerManager::currentThreadLocalKey_) {
21     TORCH_CHECK(
22         false,
23         "Cannot call RemoteProfilerManager::setCurrentKey when current key is already set.");
24   }
25   currentThreadLocalKey_ = std::move(key);
26 }
27 
isCurrentKeySet() const28 bool RemoteProfilerManager::isCurrentKeySet() const {
29   return currentThreadLocalKey_.has_value();
30 }
31 
unsetCurrentKey()32 void RemoteProfilerManager::unsetCurrentKey() {
33   currentThreadLocalKey_ = std::nullopt;
34 }
35 
eraseKey(const ProfilingId & globallyUniqueId)36 void RemoteProfilerManager::eraseKey(const ProfilingId& globallyUniqueId) {
37   std::lock_guard<std::mutex> guard(mutex_);
38   auto it = profiledRpcKeys_.find(globallyUniqueId);
39   TORCH_INTERNAL_ASSERT(it != profiledRpcKeys_.end());
40   profiledRpcKeys_.erase(it);
41 }
42 
retrieveRPCProfilingKey(const ProfilingId & globallyUniqueId)43 std::string RemoteProfilerManager::retrieveRPCProfilingKey(
44     const ProfilingId& globallyUniqueId) {
45   std::lock_guard<std::mutex> guard(mutex_);
46   auto it = profiledRpcKeys_.find(globallyUniqueId);
47   TORCH_INTERNAL_ASSERT(it != profiledRpcKeys_.end());
48   return it->second;
49 }
50 
getNextProfilerId()51 ProfilingId RemoteProfilerManager::getNextProfilerId() {
52   auto localId = getNextLocalId();
53   auto localWorkerId = RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_;
54   auto globallyUniqueId =
55       torch::distributed::rpc::ProfilingId(localWorkerId, localId);
56   return globallyUniqueId;
57 }
58 
getNextLocalId()59 local_id_t RemoteProfilerManager::getNextLocalId() {
60   std::lock_guard<std::mutex> guard(mutex_);
61   return currentLocalId_++;
62 }
63 
getCurrentProfilingKey()64 std::string& RemoteProfilerManager::getCurrentProfilingKey() {
65   TORCH_CHECK(
66       RemoteProfilerManager::currentThreadLocalKey_,
67       "Must set currentThreadLocalKey_ before calling getCurrentProfilingKey");
68   return *currentThreadLocalKey_;
69 }
70 
saveRPCKey(ProfilingId globallyUniqueId,const std::string & rpcProfilingKey)71 void RemoteProfilerManager::saveRPCKey(
72     ProfilingId globallyUniqueId,
73     const std::string& rpcProfilingKey) {
74   std::lock_guard<std::mutex> guard(mutex_);
75   profiledRpcKeys_.emplace(
76       std::piecewise_construct,
77       std::forward_as_tuple(globallyUniqueId),
78       std::forward_as_tuple(rpcProfilingKey));
79 }
80 
RemoteProfilerManager()81 RemoteProfilerManager::RemoteProfilerManager() {
82   auto workerId =
83       static_cast<int64_t>(RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_);
84   currentLocalId_ = workerId << kAutoIncrementBits;
85 }
86 } // namespace torch::distributed::rpc
87