xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/profiler/server_process_global_profiler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
2 
3 namespace torch::distributed::rpc::profiler::processglobal {
4 
5 using namespace torch::autograd::profiler;
6 
results()7 std::vector<thread_event_lists> State::results() {
8   std::unique_lock<std::mutex> lock(resultsMutex_);
9 
10   std::vector<thread_event_lists> results;
11   results.swap(results_);
12   return results;
13 }
14 
15 mutexType currentStateStackEntryMutex;
16 std::shared_ptr<StateStackEntry> currentStateStackEntryPtr = nullptr;
17 
pushRange(std::shared_ptr<State> profilerProcessGlobalStatePtr)18 void StateStackEntry::pushRange(
19     std::shared_ptr<State> profilerProcessGlobalStatePtr) {
20   wLockType wlock(currentStateStackEntryMutex);
21 
22   auto previousStateStackEntryPtr = currentStateStackEntryPtr;
23   currentStateStackEntryPtr = std::make_shared<StateStackEntry>(
24       previousStateStackEntryPtr, std::move(profilerProcessGlobalStatePtr));
25 }
26 
popRange()27 std::shared_ptr<State> StateStackEntry::popRange() {
28   wLockType wlock(currentStateStackEntryMutex);
29 
30   auto poppedStateStackEntryPtr = currentStateStackEntryPtr;
31   TORCH_INTERNAL_ASSERT(
32       poppedStateStackEntryPtr && poppedStateStackEntryPtr->statePtr_);
33   currentStateStackEntryPtr = poppedStateStackEntryPtr->prevPtr_;
34   return poppedStateStackEntryPtr->statePtr_;
35 }
36 
pushResultRecursive(std::shared_ptr<StateStackEntry> stateStackEntryPtr,const thread_event_lists & result)37 void pushResultRecursive(
38     std::shared_ptr<StateStackEntry> stateStackEntryPtr,
39     const thread_event_lists& result) {
40   while (stateStackEntryPtr) {
41     // Put event_lists into the process-global profiler state.
42     stateStackEntryPtr->statePtr()->pushResult(result);
43     stateStackEntryPtr = stateStackEntryPtr->prevPtr();
44   }
45 }
46 
enableServer(const ProfilerConfig & new_config)47 void enableServer(const ProfilerConfig& new_config) {
48   auto new_state = std::make_shared<State>(new_config);
49   StateStackEntry::pushRange(std::move(new_state));
50 }
51 
disableServer()52 std::vector<thread_event_lists> disableServer() {
53   auto statePtr = StateStackEntry::popRange();
54   return statePtr->results();
55 }
56 
57 } // namespace torch::distributed::rpc::profiler::processglobal
58