1 #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h> 2 3 namespace torch::distributed::rpc::profiler::processglobal { 4 5 using namespace torch::autograd::profiler; 6 results()7std::vector<thread_event_lists> State::results() { 8 std::unique_lock<std::mutex> lock(resultsMutex_); 9 10 std::vector<thread_event_lists> results; 11 results.swap(results_); 12 return results; 13 } 14 15 mutexType currentStateStackEntryMutex; 16 std::shared_ptr<StateStackEntry> currentStateStackEntryPtr = nullptr; 17 pushRange(std::shared_ptr<State> profilerProcessGlobalStatePtr)18void StateStackEntry::pushRange( 19 std::shared_ptr<State> profilerProcessGlobalStatePtr) { 20 wLockType wlock(currentStateStackEntryMutex); 21 22 auto previousStateStackEntryPtr = currentStateStackEntryPtr; 23 currentStateStackEntryPtr = std::make_shared<StateStackEntry>( 24 previousStateStackEntryPtr, std::move(profilerProcessGlobalStatePtr)); 25 } 26 popRange()27std::shared_ptr<State> StateStackEntry::popRange() { 28 wLockType wlock(currentStateStackEntryMutex); 29 30 auto poppedStateStackEntryPtr = currentStateStackEntryPtr; 31 TORCH_INTERNAL_ASSERT( 32 poppedStateStackEntryPtr && poppedStateStackEntryPtr->statePtr_); 33 currentStateStackEntryPtr = poppedStateStackEntryPtr->prevPtr_; 34 return poppedStateStackEntryPtr->statePtr_; 35 } 36 pushResultRecursive(std::shared_ptr<StateStackEntry> stateStackEntryPtr,const thread_event_lists & result)37void pushResultRecursive( 38 std::shared_ptr<StateStackEntry> stateStackEntryPtr, 39 const thread_event_lists& result) { 40 while (stateStackEntryPtr) { 41 // Put event_lists into the process-global profiler state. 42 stateStackEntryPtr->statePtr()->pushResult(result); 43 stateStackEntryPtr = stateStackEntryPtr->prevPtr(); 44 } 45 } 46 enableServer(const ProfilerConfig & new_config)47void enableServer(const ProfilerConfig& new_config) { 48 auto new_state = std::make_shared<State>(new_config); 49 StateStackEntry::pushRange(std::move(new_state)); 50 } 51 disableServer()52std::vector<thread_event_lists> disableServer() { 53 auto statePtr = StateStackEntry::popRange(); 54 return statePtr->results(); 55 } 56 57 } // namespace torch::distributed::rpc::profiler::processglobal 58