xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/orchestration/observer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/profiler/orchestration/observer.h>
2 
3 #include <torch/csrc/profiler/util.h>
4 
5 #include <utility>
6 
7 namespace torch {
8 namespace profiler {
9 namespace impl {
10 
11 using GlobalManager = GlobalStateManager<ProfilerStateBase>;
12 
13 // ----------------------------------------------------------------------------
14 // -- Profiler Config ---------------------------------------------------------
15 // ----------------------------------------------------------------------------
ExperimentalConfig(std::vector<std::string> profiler_metrics,bool profiler_measure_per_kernel,bool verbose,std::vector<std::string> performance_events,bool enable_cuda_sync_events,bool adjust_timestamps)16 ExperimentalConfig::ExperimentalConfig(
17     std::vector<std::string> profiler_metrics,
18     bool profiler_measure_per_kernel,
19     bool verbose,
20     std::vector<std::string> performance_events,
21     bool enable_cuda_sync_events,
22     bool adjust_timestamps)
23     : profiler_metrics{std::move(profiler_metrics)},
24       profiler_measure_per_kernel{profiler_measure_per_kernel},
25       verbose{verbose},
26       performance_events(std::move(performance_events)),
27       enable_cuda_sync_events{enable_cuda_sync_events},
28       adjust_timestamps{adjust_timestamps} {}
29 
operator bool() const30 /*explicit*/ ExperimentalConfig::operator bool() const {
31   return !profiler_metrics.empty();
32 }
33 
ProfilerConfig(ProfilerState state,bool report_input_shapes,bool profile_memory,bool with_stack,bool with_flops,bool with_modules,ExperimentalConfig experimental_config)34 ProfilerConfig::ProfilerConfig(
35     ProfilerState state,
36     bool report_input_shapes,
37     bool profile_memory,
38     bool with_stack,
39     bool with_flops,
40     bool with_modules,
41     ExperimentalConfig experimental_config)
42     : state{state},
43       experimental_config{std::move(experimental_config)},
44       report_input_shapes{report_input_shapes},
45       profile_memory{profile_memory},
46       with_stack{with_stack},
47       with_flops{with_flops},
48       with_modules{with_modules} {}
49 
disabled() const50 bool ProfilerConfig::disabled() const {
51   return state == torch::profiler::impl::ProfilerState::Disabled;
52 }
53 
global() const54 bool ProfilerConfig::global() const {
55   return state == torch::profiler::impl::ProfilerState::KINETO_ONDEMAND;
56 }
57 
58 namespace {
59 enum ProfilerIValueIdx {
60   STATE = 0,
61   REPORT_INPUT_SHAPES,
62   PROFILE_MEMORY,
63   NUM_PROFILER_CFG_IVALUE_IDX // must be last in list
64 };
65 } // namespace
66 
toIValue() const67 at::IValue ProfilerConfig::toIValue() const {
68   c10::impl::GenericList eventIValueList(at::AnyType::get());
69   eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX);
70   eventIValueList.emplace_back(static_cast<int64_t>(state));
71   eventIValueList.emplace_back(report_input_shapes);
72   eventIValueList.emplace_back(profile_memory);
73   return eventIValueList;
74 }
75 
fromIValue(const at::IValue & profilerConfigIValue)76 ProfilerConfig ProfilerConfig::fromIValue(
77     const at::IValue& profilerConfigIValue) {
78   TORCH_INTERNAL_ASSERT(
79       profilerConfigIValue.isList(),
80       "Expected IValue to contain type c10::impl::GenericList");
81   auto ivalues = profilerConfigIValue.toList();
82   TORCH_INTERNAL_ASSERT(
83       ivalues.size() == NUM_PROFILER_CFG_IVALUE_IDX,
84       c10::str(
85           "Expected exactly ",
86           NUM_PROFILER_CFG_IVALUE_IDX,
87           " ivalues to resconstruct ProfilerConfig."));
88   return ProfilerConfig(
89       static_cast<ProfilerState>(ivalues.get(ProfilerIValueIdx::STATE).toInt()),
90       ivalues.get(ProfilerIValueIdx::REPORT_INPUT_SHAPES).toBool(),
91       ivalues.get(ProfilerIValueIdx::PROFILE_MEMORY).toBool());
92 }
93 
94 // ----------------------------------------------------------------------------
95 // -- Profiler base class -----------------------------------------------------
96 // ----------------------------------------------------------------------------
ProfilerStateBase(ProfilerConfig config)97 /*explicit*/ ProfilerStateBase::ProfilerStateBase(ProfilerConfig config)
98     : c10::MemoryReportingInfoBase(), config_(std::move(config)) {}
99 
~ProfilerStateBase()100 ProfilerStateBase::~ProfilerStateBase() {
101   if (handle_) {
102     auto handle = handle_;
103     removeCallback();
104     SOFT_ASSERT(false, "Leaked callback handle: ", handle);
105   }
106 }
107 
get(bool global)108 /*static*/ ProfilerStateBase* ProfilerStateBase::get(bool global) {
109   auto* out = global
110       ? GlobalManager::get()
111       : static_cast<ProfilerStateBase*>(
112             c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE));
113   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global);
114   return out;
115 }
116 
push(std::shared_ptr<ProfilerStateBase> && state)117 /*static*/ void ProfilerStateBase::push(
118     std::shared_ptr<ProfilerStateBase>&& state) {
119   TORCH_INTERNAL_ASSERT(state != nullptr);
120   if (state->config().global()) {
121     GlobalManager::push(std::move(state));
122   } else {
123     c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
124   }
125 }
126 
127 namespace {
popTLS()128 std::shared_ptr<ProfilerStateBase> popTLS() {
129   // If there is no active thread local profiler then we simply return null.
130   // However if there is an active profiler but it is not the top
131   // `DebugInfoBase`then `c10::ThreadLocalDebugInfo::_pop` will throw.
132   // TODO(robieta): make `noexcept` version.
133   return c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE)
134       ? std::static_pointer_cast<ProfilerStateBase>(
135             c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE))
136       : nullptr;
137 }
138 } // namespace
139 
pop(bool global)140 /*static*/ std::shared_ptr<ProfilerStateBase> ProfilerStateBase::pop(
141     bool global) {
142   auto out = global ? GlobalManager::pop() : popTLS();
143   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global);
144   return out;
145 }
146 
setCallbackHandle(at::CallbackHandle handle)147 void ProfilerStateBase::setCallbackHandle(at::CallbackHandle handle) {
148   if (handle_) {
149     at::removeCallback(handle_);
150     SOFT_ASSERT(
151         false,
152         "ProfilerStateBase already has a registered callback. "
153         "Removing to avoid leaked callback.");
154   }
155 
156   handle_ = handle;
157 }
158 
removeCallback()159 void ProfilerStateBase::removeCallback() {
160   if (handle_) {
161     at::removeCallback(handle_);
162     handle_ = 0;
163   }
164 }
165 
profilerEnabled()166 bool profilerEnabled() {
167   auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
168   return state_ptr && !state_ptr->config().disabled();
169 }
170 
profilerType()171 TORCH_API ActiveProfilerType profilerType() {
172   auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
173   return state_ptr == nullptr ? ActiveProfilerType::NONE
174                               : state_ptr->profilerType();
175 }
176 
getProfilerConfig()177 torch::profiler::impl::ProfilerConfig getProfilerConfig() {
178   auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
179   TORCH_CHECK(
180       state_ptr,
181       "Tried to access profiler config, but profiler is not enabled!");
182   return state_ptr->config();
183 }
184 
185 } // namespace impl
186 } // namespace profiler
187 } // namespace torch
188