xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <shared_mutex>
4 #include <utility>
5 
6 #include <torch/csrc/autograd/profiler.h>
7 
8 namespace torch::distributed::rpc::profiler::processglobal {
9 
10 using namespace torch::autograd::profiler;
11 
12 // Process global profiler state.
13 //
14 // This class holds information about a profiling range, from "enable" to
15 // "disable".
16 // An instance of this ``State`` will be
17 // pushed into a global stack, so nested profiling range is supported.
18 //
19 // It has 2 members.
20 // One is ``autograd::profiler::ProfilerConfig``. It's set by user and
21 // will be copied to thread-local profiler state of RPC threads.
22 // The other is a container that aggregates recorded
23 // ``autograd::profiler::Event``s from all thread-local profilers on RPC
24 // threads.
25 class State {
26  public:
State(ProfilerConfig config)27   explicit State(ProfilerConfig config) : config_(std::move(config)) {}
28   ~State() = default;
29 
config()30   const ProfilerConfig& config() const {
31     return config_;
32   }
33 
pushResult(thread_event_lists result)34   void pushResult(thread_event_lists result) {
35     std::unique_lock<std::mutex> lock(resultsMutex_);
36 
37     // NB: When a thread wants to push an entry into the this container,
38     // main control logic might have exited the process-global profile range.
39     results_.emplace_back(std::move(result));
40   }
41 
42   std::vector<thread_event_lists> results();
43 
44  private:
45   // Each result comes from a profile range. In each profile range, there is a
46   // "__profiler_start" marker event that all following events calculate time
47   // relative to it, so it's required to call
48   // parse_cpu_trace(result) for results of all profile range.
49   std::mutex resultsMutex_;
50   std::vector<thread_event_lists> results_;
51   const ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled);
52 };
53 
54 class StateStackEntry;
55 
56 #if defined(__MACH__)
57 // Compiler error: 'shared_timed_mutex' is unavailable: introduced in
58 // macOS 10.12
59 using mutexType = std::mutex;
60 // Compiler error: 'shared_lock' is unavailable: introduced in
61 // macOS 10.12
62 using rLockType = std::unique_lock<std::mutex>;
63 using wLockType = std::unique_lock<std::mutex>;
64 #else
65 using mutexType = std::shared_timed_mutex;
66 using rLockType = std::shared_lock<std::shared_timed_mutex>;
67 using wLockType = std::unique_lock<std::shared_timed_mutex>;
68 #endif
69 
70 // This is the global stack of ``State``s.
71 TORCH_API extern std::shared_ptr<StateStackEntry> currentStateStackEntryPtr;
72 TORCH_API extern mutexType currentStateStackEntryMutex;
73 
74 // This class is used to implement a stack of ``State``s.
75 // It has 2 members.
76 // One is `prevPtr`, a shared_ptr pointing to previous element in the
77 // stack.
78 // The other is ``statePtr``, a shared_ptr pointing to ``State``.
79 class StateStackEntry {
80  public:
StateStackEntry(std::shared_ptr<StateStackEntry> prevPtr,std::shared_ptr<State> statePtr)81   StateStackEntry(
82       std::shared_ptr<StateStackEntry> prevPtr,
83       std::shared_ptr<State> statePtr)
84       : prevPtr_(std::move(prevPtr)), statePtr_(std::move(statePtr)) {}
85 
86   static void pushRange(std::shared_ptr<State> profilerProcessGlobalStatePtr);
87   static std::shared_ptr<State> popRange();
88 
current()89   static std::shared_ptr<StateStackEntry> current() {
90     rLockType rlock(currentStateStackEntryMutex);
91 
92     return currentStateStackEntryPtr;
93   }
94 
prevPtr()95   std::shared_ptr<StateStackEntry> prevPtr() const {
96     return prevPtr_;
97   }
98 
statePtr()99   std::shared_ptr<State> statePtr() const {
100     return statePtr_;
101   }
102 
103  private:
104   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
105   const std::shared_ptr<StateStackEntry> prevPtr_{nullptr};
106   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
107   const std::shared_ptr<State> statePtr_{nullptr};
108 };
109 
110 // Push the result to ``State``s of current profile range and recursively outer
111 // profile ranges.
112 TORCH_API void pushResultRecursive(
113     std::shared_ptr<StateStackEntry> stateStackEntryPtr,
114     const thread_event_lists& result);
115 
116 // User-facing API.
117 //
118 // Enter a server-side process-global profiling range.
119 // Profiling range can be neste, so it's ok to call this API for multiple
120 // times. This enables all RPC threads running server-side request callbacks.
121 TORCH_API void enableServer(const ProfilerConfig& new_config);
122 //
123 // Exit a server-side process-global profiling range.
124 // Profiling range can be neste, so it's possible that profiler is still on
125 // after calling this API.
126 // This enables all RPC threads running server-side request callbacks.
127 TORCH_API std::vector<thread_event_lists> disableServer();
128 
129 } // namespace torch::distributed::rpc::profiler::processglobal
130