1 #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h> 2 #include <torch/csrc/distributed/rpc/rpc_agent.h> 3 #include <torch/csrc/distributed/rpc/rpc_command_base.h> 4 #include <torch/csrc/jit/serialization/pickle.h> 5 #include <torch/csrc/utils/byte_order.h> 6 7 namespace torch::distributed::rpc { 8 const std::string REMOTE_PROFILING_KEY_PREFIX = "#remote_op: "; 9 constexpr int kAutoIncrementBits = 48; 10 /*static */ thread_local std::optional<std::string> 11 RemoteProfilerManager::currentThreadLocalKey_ = std::nullopt; getInstance()12/*static */ RemoteProfilerManager& RemoteProfilerManager::getInstance() { 13 static RemoteProfilerManager* handler = new RemoteProfilerManager(); 14 return *handler; 15 } 16 setCurrentKey(std::string key)17void RemoteProfilerManager::setCurrentKey(std::string key) { 18 // We should not allow overriding the current key, it needs to be committed 19 // with writeKey() explicitly first. 20 if (RemoteProfilerManager::currentThreadLocalKey_) { 21 TORCH_CHECK( 22 false, 23 "Cannot call RemoteProfilerManager::setCurrentKey when current key is already set."); 24 } 25 currentThreadLocalKey_ = std::move(key); 26 } 27 isCurrentKeySet() const28bool RemoteProfilerManager::isCurrentKeySet() const { 29 return currentThreadLocalKey_.has_value(); 30 } 31 unsetCurrentKey()32void RemoteProfilerManager::unsetCurrentKey() { 33 currentThreadLocalKey_ = std::nullopt; 34 } 35 eraseKey(const ProfilingId & globallyUniqueId)36void RemoteProfilerManager::eraseKey(const ProfilingId& globallyUniqueId) { 37 std::lock_guard<std::mutex> guard(mutex_); 38 auto it = profiledRpcKeys_.find(globallyUniqueId); 39 TORCH_INTERNAL_ASSERT(it != profiledRpcKeys_.end()); 40 profiledRpcKeys_.erase(it); 41 } 42 retrieveRPCProfilingKey(const ProfilingId & globallyUniqueId)43std::string RemoteProfilerManager::retrieveRPCProfilingKey( 44 const ProfilingId& globallyUniqueId) { 45 std::lock_guard<std::mutex> guard(mutex_); 46 auto it = profiledRpcKeys_.find(globallyUniqueId); 47 TORCH_INTERNAL_ASSERT(it != profiledRpcKeys_.end()); 48 return it->second; 49 } 50 getNextProfilerId()51ProfilingId RemoteProfilerManager::getNextProfilerId() { 52 auto localId = getNextLocalId(); 53 auto localWorkerId = RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_; 54 auto globallyUniqueId = 55 torch::distributed::rpc::ProfilingId(localWorkerId, localId); 56 return globallyUniqueId; 57 } 58 getNextLocalId()59local_id_t RemoteProfilerManager::getNextLocalId() { 60 std::lock_guard<std::mutex> guard(mutex_); 61 return currentLocalId_++; 62 } 63 getCurrentProfilingKey()64std::string& RemoteProfilerManager::getCurrentProfilingKey() { 65 TORCH_CHECK( 66 RemoteProfilerManager::currentThreadLocalKey_, 67 "Must set currentThreadLocalKey_ before calling getCurrentProfilingKey"); 68 return *currentThreadLocalKey_; 69 } 70 saveRPCKey(ProfilingId globallyUniqueId,const std::string & rpcProfilingKey)71void RemoteProfilerManager::saveRPCKey( 72 ProfilingId globallyUniqueId, 73 const std::string& rpcProfilingKey) { 74 std::lock_guard<std::mutex> guard(mutex_); 75 profiledRpcKeys_.emplace( 76 std::piecewise_construct, 77 std::forward_as_tuple(globallyUniqueId), 78 std::forward_as_tuple(rpcProfilingKey)); 79 } 80 RemoteProfilerManager()81RemoteProfilerManager::RemoteProfilerManager() { 82 auto workerId = 83 static_cast<int64_t>(RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_); 84 currentLocalId_ = workerId << kAutoIncrementBits; 85 } 86 } // namespace torch::distributed::rpc 87