1 #include <torch/csrc/profiler/orchestration/observer.h>
2
3 #include <torch/csrc/profiler/util.h>
4
5 #include <utility>
6
7 namespace torch {
8 namespace profiler {
9 namespace impl {
10
11 using GlobalManager = GlobalStateManager<ProfilerStateBase>;
12
13 // ----------------------------------------------------------------------------
14 // -- Profiler Config ---------------------------------------------------------
15 // ----------------------------------------------------------------------------
ExperimentalConfig(std::vector<std::string> profiler_metrics,bool profiler_measure_per_kernel,bool verbose,std::vector<std::string> performance_events,bool enable_cuda_sync_events,bool adjust_timestamps)16 ExperimentalConfig::ExperimentalConfig(
17 std::vector<std::string> profiler_metrics,
18 bool profiler_measure_per_kernel,
19 bool verbose,
20 std::vector<std::string> performance_events,
21 bool enable_cuda_sync_events,
22 bool adjust_timestamps)
23 : profiler_metrics{std::move(profiler_metrics)},
24 profiler_measure_per_kernel{profiler_measure_per_kernel},
25 verbose{verbose},
26 performance_events(std::move(performance_events)),
27 enable_cuda_sync_events{enable_cuda_sync_events},
28 adjust_timestamps{adjust_timestamps} {}
29
operator bool() const30 /*explicit*/ ExperimentalConfig::operator bool() const {
31 return !profiler_metrics.empty();
32 }
33
ProfilerConfig(ProfilerState state,bool report_input_shapes,bool profile_memory,bool with_stack,bool with_flops,bool with_modules,ExperimentalConfig experimental_config)34 ProfilerConfig::ProfilerConfig(
35 ProfilerState state,
36 bool report_input_shapes,
37 bool profile_memory,
38 bool with_stack,
39 bool with_flops,
40 bool with_modules,
41 ExperimentalConfig experimental_config)
42 : state{state},
43 experimental_config{std::move(experimental_config)},
44 report_input_shapes{report_input_shapes},
45 profile_memory{profile_memory},
46 with_stack{with_stack},
47 with_flops{with_flops},
48 with_modules{with_modules} {}
49
disabled() const50 bool ProfilerConfig::disabled() const {
51 return state == torch::profiler::impl::ProfilerState::Disabled;
52 }
53
global() const54 bool ProfilerConfig::global() const {
55 return state == torch::profiler::impl::ProfilerState::KINETO_ONDEMAND;
56 }
57
58 namespace {
59 enum ProfilerIValueIdx {
60 STATE = 0,
61 REPORT_INPUT_SHAPES,
62 PROFILE_MEMORY,
63 NUM_PROFILER_CFG_IVALUE_IDX // must be last in list
64 };
65 } // namespace
66
toIValue() const67 at::IValue ProfilerConfig::toIValue() const {
68 c10::impl::GenericList eventIValueList(at::AnyType::get());
69 eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX);
70 eventIValueList.emplace_back(static_cast<int64_t>(state));
71 eventIValueList.emplace_back(report_input_shapes);
72 eventIValueList.emplace_back(profile_memory);
73 return eventIValueList;
74 }
75
fromIValue(const at::IValue & profilerConfigIValue)76 ProfilerConfig ProfilerConfig::fromIValue(
77 const at::IValue& profilerConfigIValue) {
78 TORCH_INTERNAL_ASSERT(
79 profilerConfigIValue.isList(),
80 "Expected IValue to contain type c10::impl::GenericList");
81 auto ivalues = profilerConfigIValue.toList();
82 TORCH_INTERNAL_ASSERT(
83 ivalues.size() == NUM_PROFILER_CFG_IVALUE_IDX,
84 c10::str(
85 "Expected exactly ",
86 NUM_PROFILER_CFG_IVALUE_IDX,
87 " ivalues to resconstruct ProfilerConfig."));
88 return ProfilerConfig(
89 static_cast<ProfilerState>(ivalues.get(ProfilerIValueIdx::STATE).toInt()),
90 ivalues.get(ProfilerIValueIdx::REPORT_INPUT_SHAPES).toBool(),
91 ivalues.get(ProfilerIValueIdx::PROFILE_MEMORY).toBool());
92 }
93
94 // ----------------------------------------------------------------------------
95 // -- Profiler base class -----------------------------------------------------
96 // ----------------------------------------------------------------------------
ProfilerStateBase(ProfilerConfig config)97 /*explicit*/ ProfilerStateBase::ProfilerStateBase(ProfilerConfig config)
98 : c10::MemoryReportingInfoBase(), config_(std::move(config)) {}
99
~ProfilerStateBase()100 ProfilerStateBase::~ProfilerStateBase() {
101 if (handle_) {
102 auto handle = handle_;
103 removeCallback();
104 SOFT_ASSERT(false, "Leaked callback handle: ", handle);
105 }
106 }
107
get(bool global)108 /*static*/ ProfilerStateBase* ProfilerStateBase::get(bool global) {
109 auto* out = global
110 ? GlobalManager::get()
111 : static_cast<ProfilerStateBase*>(
112 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE));
113 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global);
114 return out;
115 }
116
push(std::shared_ptr<ProfilerStateBase> && state)117 /*static*/ void ProfilerStateBase::push(
118 std::shared_ptr<ProfilerStateBase>&& state) {
119 TORCH_INTERNAL_ASSERT(state != nullptr);
120 if (state->config().global()) {
121 GlobalManager::push(std::move(state));
122 } else {
123 c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
124 }
125 }
126
127 namespace {
popTLS()128 std::shared_ptr<ProfilerStateBase> popTLS() {
129 // If there is no active thread local profiler then we simply return null.
130 // However if there is an active profiler but it is not the top
131 // `DebugInfoBase`then `c10::ThreadLocalDebugInfo::_pop` will throw.
132 // TODO(robieta): make `noexcept` version.
133 return c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE)
134 ? std::static_pointer_cast<ProfilerStateBase>(
135 c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE))
136 : nullptr;
137 }
138 } // namespace
139
pop(bool global)140 /*static*/ std::shared_ptr<ProfilerStateBase> ProfilerStateBase::pop(
141 bool global) {
142 auto out = global ? GlobalManager::pop() : popTLS();
143 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global);
144 return out;
145 }
146
setCallbackHandle(at::CallbackHandle handle)147 void ProfilerStateBase::setCallbackHandle(at::CallbackHandle handle) {
148 if (handle_) {
149 at::removeCallback(handle_);
150 SOFT_ASSERT(
151 false,
152 "ProfilerStateBase already has a registered callback. "
153 "Removing to avoid leaked callback.");
154 }
155
156 handle_ = handle;
157 }
158
removeCallback()159 void ProfilerStateBase::removeCallback() {
160 if (handle_) {
161 at::removeCallback(handle_);
162 handle_ = 0;
163 }
164 }
165
profilerEnabled()166 bool profilerEnabled() {
167 auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
168 return state_ptr && !state_ptr->config().disabled();
169 }
170
profilerType()171 TORCH_API ActiveProfilerType profilerType() {
172 auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
173 return state_ptr == nullptr ? ActiveProfilerType::NONE
174 : state_ptr->profilerType();
175 }
176
getProfilerConfig()177 torch::profiler::impl::ProfilerConfig getProfilerConfig() {
178 auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
179 TORCH_CHECK(
180 state_ptr,
181 "Tried to access profiler config, but profiler is not enabled!");
182 return state_ptr->config();
183 }
184
185 } // namespace impl
186 } // namespace profiler
187 } // namespace torch
188