1 #pragma once 2 3 #include <shared_mutex> 4 #include <utility> 5 6 #include <torch/csrc/autograd/profiler.h> 7 8 namespace torch::distributed::rpc::profiler::processglobal { 9 10 using namespace torch::autograd::profiler; 11 12 // Process global profiler state. 13 // 14 // This class holds information about a profiling range, from "enable" to 15 // "disable". 16 // An instance of this ``State`` will be 17 // pushed into a global stack, so nested profiling range is supported. 18 // 19 // It has 2 members. 20 // One is ``autograd::profiler::ProfilerConfig``. It's set by user and 21 // will be copied to thread-local profiler state of RPC threads. 22 // The other is a container that aggregates recorded 23 // ``autograd::profiler::Event``s from all thread-local profilers on RPC 24 // threads. 25 class State { 26 public: State(ProfilerConfig config)27 explicit State(ProfilerConfig config) : config_(std::move(config)) {} 28 ~State() = default; 29 config()30 const ProfilerConfig& config() const { 31 return config_; 32 } 33 pushResult(thread_event_lists result)34 void pushResult(thread_event_lists result) { 35 std::unique_lock<std::mutex> lock(resultsMutex_); 36 37 // NB: When a thread wants to push an entry into the this container, 38 // main control logic might have exited the process-global profile range. 39 results_.emplace_back(std::move(result)); 40 } 41 42 std::vector<thread_event_lists> results(); 43 44 private: 45 // Each result comes from a profile range. In each profile range, there is a 46 // "__profiler_start" marker event that all following events calculate time 47 // relative to it, so it's required to call 48 // parse_cpu_trace(result) for results of all profile range. 49 std::mutex resultsMutex_; 50 std::vector<thread_event_lists> results_; 51 const ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled); 52 }; 53 54 class StateStackEntry; 55 56 #if defined(__MACH__) 57 // Compiler error: 'shared_timed_mutex' is unavailable: introduced in 58 // macOS 10.12 59 using mutexType = std::mutex; 60 // Compiler error: 'shared_lock' is unavailable: introduced in 61 // macOS 10.12 62 using rLockType = std::unique_lock<std::mutex>; 63 using wLockType = std::unique_lock<std::mutex>; 64 #else 65 using mutexType = std::shared_timed_mutex; 66 using rLockType = std::shared_lock<std::shared_timed_mutex>; 67 using wLockType = std::unique_lock<std::shared_timed_mutex>; 68 #endif 69 70 // This is the global stack of ``State``s. 71 TORCH_API extern std::shared_ptr<StateStackEntry> currentStateStackEntryPtr; 72 TORCH_API extern mutexType currentStateStackEntryMutex; 73 74 // This class is used to implement a stack of ``State``s. 75 // It has 2 members. 76 // One is `prevPtr`, a shared_ptr pointing to previous element in the 77 // stack. 78 // The other is ``statePtr``, a shared_ptr pointing to ``State``. 79 class StateStackEntry { 80 public: StateStackEntry(std::shared_ptr<StateStackEntry> prevPtr,std::shared_ptr<State> statePtr)81 StateStackEntry( 82 std::shared_ptr<StateStackEntry> prevPtr, 83 std::shared_ptr<State> statePtr) 84 : prevPtr_(std::move(prevPtr)), statePtr_(std::move(statePtr)) {} 85 86 static void pushRange(std::shared_ptr<State> profilerProcessGlobalStatePtr); 87 static std::shared_ptr<State> popRange(); 88 current()89 static std::shared_ptr<StateStackEntry> current() { 90 rLockType rlock(currentStateStackEntryMutex); 91 92 return currentStateStackEntryPtr; 93 } 94 prevPtr()95 std::shared_ptr<StateStackEntry> prevPtr() const { 96 return prevPtr_; 97 } 98 statePtr()99 std::shared_ptr<State> statePtr() const { 100 return statePtr_; 101 } 102 103 private: 104 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 105 const std::shared_ptr<StateStackEntry> prevPtr_{nullptr}; 106 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 107 const std::shared_ptr<State> statePtr_{nullptr}; 108 }; 109 110 // Push the result to ``State``s of current profile range and recursively outer 111 // profile ranges. 112 TORCH_API void pushResultRecursive( 113 std::shared_ptr<StateStackEntry> stateStackEntryPtr, 114 const thread_event_lists& result); 115 116 // User-facing API. 117 // 118 // Enter a server-side process-global profiling range. 119 // Profiling range can be neste, so it's ok to call this API for multiple 120 // times. This enables all RPC threads running server-side request callbacks. 121 TORCH_API void enableServer(const ProfilerConfig& new_config); 122 // 123 // Exit a server-side process-global profiling range. 124 // Profiling range can be neste, so it's possible that profiler is still on 125 // after calling this API. 126 // This enables all RPC threads running server-side request callbacks. 127 TORCH_API std::vector<thread_event_lists> disableServer(); 128 129 } // namespace torch::distributed::rpc::profiler::processglobal 130