1 #include <cstring>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <torch/csrc/autograd/profiler_kineto.h>
4
5 #include <c10/macros/Export.h>
6 #include <c10/util/ApproximateClock.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/flat_hash_map.h>
9 #include <c10/util/irange.h>
10 #include <c10/util/overloaded.h>
11
12 #include <torch/csrc/profiler/api.h>
13 #include <torch/csrc/profiler/collection.h>
14 #include <torch/csrc/profiler/containers.h>
15 #include <torch/csrc/profiler/events.h>
16 #include <torch/csrc/profiler/kineto_shim.h>
17 #include <torch/csrc/profiler/orchestration/observer.h>
18 #include <torch/csrc/profiler/perf.h>
19 #include <torch/csrc/profiler/standalone/itt_observer.h>
20 #include <torch/csrc/profiler/standalone/nvtx_observer.h>
21 #include <torch/csrc/profiler/standalone/privateuse1_observer.h>
22 #include <torch/csrc/profiler/util.h>
23
24 #include <ATen/Context.h>
25
26 #include <stdexcept>
27 #include <utility>
28
29 #ifdef USE_KINETO
30 #include <ApproximateClock.h>
31 #include <libkineto.h>
32 #include <time_since_epoch.h>
33
34 #ifndef _MSC_VER
35 // TODO: TO be removed, once this properly works from libkineto
36 // Literal copy-n-paste from third_party/kineto/libkineto/src/WeakSymbols.cpp
37 extern "C" {
38 // This function is needed to avoid superfluous dependency on GNU OpenMP library
39 // when cuPTI is linked statically For more details see
40 // https://github.com/pytorch/pytorch/issues/51026
41 __attribute__((weak)) int acc_get_device_type();
acc_get_device_type()42 __attribute__((weak)) int acc_get_device_type() {
43 throw std::runtime_error(
44 "Dummy implementation of acc_get_device_type is not supposed to be called!");
45 }
46 } // extern "C"
47 #endif // _MSC_VER
48 #endif // USE_KINETO
49
50 namespace torch {
51 namespace autograd::profiler {
52
53 namespace {
getTimeNs()54 inline int64_t getTimeNs() {
55 #ifdef USE_KINETO
56 return libkineto::timeSinceEpoch(std::chrono::system_clock::now());
57 #else
58 return c10::getTime();
59 #endif // USE_KINETO
60 }
61
62 using torch::profiler::impl::ActiveProfilerType;
63 using torch::profiler::impl::EventType;
64 using torch::profiler::impl::ExtraFields;
65 using torch::profiler::impl::get_record_concrete_inputs_enabled;
66 using torch::profiler::impl::ivalueListToStr;
67 using torch::profiler::impl::ivalueToStr;
68 using torch::profiler::impl::op_input_t;
69 using torch::profiler::impl::ProfilerStateBase;
70 using torch::profiler::impl::PyExtraFieldsBase;
71 using torch::profiler::impl::Result;
72 using torch::profiler::impl::shape;
73 using torch::profiler::impl::shapesToStr;
74 using torch::profiler::impl::stacksToStr;
75 using torch::profiler::impl::strListToStr;
76 using torch::profiler::impl::TensorMetadata;
77 using torch::profiler::impl::variantShapesToStr;
78
79 struct OpArgData {
80 bool hasData;
81 std::vector<shape> shapes;
82 std::vector<std::string> dtypes;
83 std::vector<c10::IValue> concreteInputs;
84 std::vector<std::vector<int64_t>> shapesForKinetoEvent;
85 std::vector<shape> strides;
86 };
87
parseArgData(const std::vector<op_input_t> & input_shapes,const std::vector<op_input_t> & concreteInputs)88 auto parseArgData(
89 const std::vector<op_input_t>& input_shapes,
90 const std::vector<op_input_t>& concreteInputs) {
91 if (input_shapes.empty()) {
92 return OpArgData{false, {}, {}, {}, {}, {}};
93 }
94
95 std::vector<shape> shapes(input_shapes.size());
96 std::vector<shape> strides(input_shapes.size());
97 std::vector<std::vector<int64_t>> shapesForKinetoEvent(input_shapes.size());
98
99 std::vector<std::string> dtypes(input_shapes.size());
100 std::vector<c10::IValue> concrete_inputs_list;
101
102 for (const auto& i : c10::irange(input_shapes.size())) {
103 std::visit(
104 c10::overloaded(
105 [&](const TensorMetadata& t) {
106 shapes[i] = t.sizes_;
107 shapesForKinetoEvent[i] = t.sizes_;
108 dtypes[i] = std::string(scalarTypeToTypeMeta(t.dtype_).name());
109 strides[i] = t.strides_;
110 },
111 [&](const std::vector<TensorMetadata>& l) {
112 std::vector<std::vector<int64_t>> shape;
113 shape.reserve(l.size());
114 std::vector<std::vector<int64_t>> stride;
115 stride.reserve(l.size());
116 for (const auto& t : l) {
117 shape.emplace_back(t.sizes_);
118 stride.emplace_back(t.strides_);
119 }
120 shapes[i] = shape;
121 strides[i] = stride;
122 dtypes[i] = "TensorList";
123 },
124 [&](const c10::IValue&) { dtypes[i] = "Scalar"; },
125 [&](const auto&) {}),
126 input_shapes[i]);
127 }
128
129 // If we recorded concrete inputs, then parse them
130 if (input_shapes.size() == concreteInputs.size() && !concreteInputs.empty()) {
131 concrete_inputs_list.resize(input_shapes.size());
132
133 for (const auto& i : c10::irange(input_shapes.size())) {
134 std::visit(
135 c10::overloaded(
136 [&](const c10::IValue& val) { concrete_inputs_list[i] = val; },
137 [&](const auto&) {}),
138 input_shapes[i]);
139 std::visit(
140 c10::overloaded(
141 [&](const c10::IValue& val) {
142 concrete_inputs_list[i] = val;
143 dtypes[i] = "ScalarList";
144 },
145 [&](const auto&) {}),
146 concreteInputs[i]);
147 }
148 }
149
150 return OpArgData{
151 true,
152 shapes,
153 dtypes,
154 concrete_inputs_list,
155 shapesForKinetoEvent,
156 strides};
157 }
158
159 struct MetadataBase {
MetadataBasetorch::autograd::profiler::__anoneaef9dcb0111::MetadataBase160 /* implicit */ MetadataBase(const std::shared_ptr<Result>& result)
161 : kinetoActivity_{result->kineto_activity_} {
162 if (std::holds_alternative<ExtraFields<EventType::Kineto>>(
163 result->extra_fields_)) {
164 // In order to add metadata we have to downcast from
165 // `libkineto::ITraceActivity` to `libkineto::GenericTraceActivity`. We
166 // know that all activities provided by PyTorch are of the correct type,
167 // however Kineto profilers can (and do) add events that inherit directly
168 // from ITraceActivity. As a result, any Result which was constructed from
169 // an event that Kineto provided is unsafe to cast.
170 if (!(SOFT_ASSERT(!hasKinetoActivity()))) {
171 result->kineto_activity_ = nullptr;
172 }
173 kinetoActivity_ = result->kineto_activity_;
174 }
175 }
176
addMetadatatorch::autograd::profiler::__anoneaef9dcb0111::MetadataBase177 void addMetadata(const std::string& key, const std::string& value) {
178 if (kinetoActivity_ && !value.empty() && value != "\"\"") {
179 torch::profiler::impl::kineto::addMetadata(
180 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
181 const_cast<torch::profiler::impl::kineto::activity_t*>(
182 kinetoActivity_),
183 key,
184 value);
185 }
186 }
187
hasKinetoActivitytorch::autograd::profiler::__anoneaef9dcb0111::MetadataBase188 bool hasKinetoActivity() const {
189 return kinetoActivity_ != nullptr;
190 }
191
192 private:
193 const torch::profiler::impl::kineto::activity_t* kinetoActivity_{nullptr};
194 };
195
196 struct AddTensorboardFields : public MetadataBase {
AddTensorboardFieldstorch::autograd::profiler::__anoneaef9dcb0111::AddTensorboardFields197 AddTensorboardFields(
198 const std::shared_ptr<Result>& result,
199 KinetoEvent& kineto_event)
200 : MetadataBase(result) {
201 result->visit(*this);
202 const auto module_hierarchy = kineto_event.moduleHierarchy();
203 addMetadata("Module Hierarchy", stacksToStr(module_hierarchy.vec(), "."));
204 addMetadata("Call stack", stacksToStr(kineto_event.stack().vec(), ";"));
205
206 result->visit_if_base<PyExtraFieldsBase>([&, this](const auto& i) -> void {
207 this->addMetadata("Python id", std::to_string(i.id_));
208
209 std::optional<std::string> parent_id;
210 std::shared_ptr<Result> parent = result->parent_.lock();
211 while (parent && !parent_id.has_value()) {
212 parent->visit_if_base<PyExtraFieldsBase>(
213 [&](const auto& j) { parent_id = std::to_string(j.id_); });
214 parent = parent->parent_.lock();
215 }
216 this->addMetadata("Python parent id", parent_id.value_or("null"));
217 });
218 }
219
operator ()torch::autograd::profiler::__anoneaef9dcb0111::AddTensorboardFields220 void operator()(const ExtraFields<EventType::PyCall>& py_call) {
221 if (py_call.module_.has_value()) {
222 addMetadata("Python module id", std::to_string(py_call.module_->id_));
223 }
224 }
225
226 template <typename T>
operator ()torch::autograd::profiler::__anoneaef9dcb0111::AddTensorboardFields227 void operator()(const T&) {}
228 };
229
230 struct AddGenericMetadata : public MetadataBase {
AddGenericMetadatatorch::autograd::profiler::__anoneaef9dcb0111::AddGenericMetadata231 AddGenericMetadata(
232 std::shared_ptr<Result>& result,
233 const torch::profiler::impl::ProfilerConfig* config)
234 : MetadataBase(result), config_(config) {
235 result->visit(*this);
236 if (config->experimental_config.verbose) {
237 result->visit_if_base<PyExtraFieldsBase>(
238 [&, this](const auto& i) -> void {
239 this->addMetadata("Python thread", std::to_string(i.python_tid_));
240 });
241 }
242 }
243
operator ()torch::autograd::profiler::__anoneaef9dcb0111::AddGenericMetadata244 void operator()(ExtraFields<EventType::TorchOp>& op_event) {
245 const auto arg_data =
246 parseArgData(op_event.inputs_, op_event.concrete_inputs_);
247
248 if (arg_data.hasData) {
249 if (get_record_concrete_inputs_enabled()) {
250 addMetadata("Input Dims", variantShapesToStr(arg_data.shapes));
251 addMetadata("Input Strides", variantShapesToStr(arg_data.strides));
252 } else {
253 addMetadata("Input Dims", shapesToStr(arg_data.shapesForKinetoEvent));
254 }
255 addMetadata("Input type", strListToStr(arg_data.dtypes));
256 if (!arg_data.concreteInputs.empty()) {
257 addMetadata(
258 "Concrete Inputs", ivalueListToStr(arg_data.concreteInputs));
259 }
260 }
261
262 // Add metadata for kwinputs if exist
263 for (const auto& [key, val] : op_event.kwinputs_) {
264 bool isString = val.isString();
265 addMetadata(key, ivalueToStr(val, isString));
266 }
267 // Add extra metadata if any
268 for (const auto& [key, val] : op_event.extra_meta_) {
269 addMetadata(key, val);
270 }
271
272 if (config_ && !config_->experimental_config.performance_events.empty()) {
273 auto& event_names = config_->experimental_config.performance_events;
274 for (const auto i : c10::irange(op_event.perf_event_counters_->size())) {
275 addMetadata(
276 event_names[i],
277 std::to_string((*op_event.perf_event_counters_)[i]));
278 }
279 }
280
281 // add information about an associated forward op, if a sequence number
282 // is available (e.g. during training)
283 if (op_event.sequence_number_ >= 0) {
284 addMetadata("Fwd thread id", std::to_string(op_event.forward_tid_));
285 addMetadata("Sequence number", std::to_string(op_event.sequence_number_));
286 }
287 addMetadata(
288 "Record function id", std::to_string(op_event.record_function_id_));
289 }
290
operator ()torch::autograd::profiler::__anoneaef9dcb0111::AddGenericMetadata291 void operator()(ExtraFields<EventType::Backend>& backend_event) {
292 if (!backend_event.backend_.empty()) {
293 addMetadata("Backend", "\"" + backend_event.backend_ + "\"");
294 }
295 }
296
operator ()torch::autograd::profiler::__anoneaef9dcb0111::AddGenericMetadata297 void operator()(const ExtraFields<EventType::Allocation>& alloc) {
298 addMetadata("Device Type", std::to_string((int8_t)alloc.device_type_));
299 addMetadata("Device Id", std::to_string(alloc.device_index_));
300 addMetadata("Addr", std::to_string(reinterpret_cast<intptr_t>(alloc.ptr_)));
301 addMetadata("Bytes", std::to_string(alloc.alloc_size_));
302 addMetadata("Total Allocated", std::to_string(alloc.total_allocated_));
303 addMetadata("Total Reserved", std::to_string(alloc.total_reserved_));
304 }
305
operator ()torch::autograd::profiler::__anoneaef9dcb0111::AddGenericMetadata306 void operator()(const ExtraFields<EventType::OutOfMemory>& alloc) {
307 addMetadata("Device Type", std::to_string((int8_t)alloc.device_type_));
308 addMetadata("Device Id", std::to_string(alloc.device_index_));
309 addMetadata("Bytes", std::to_string(alloc.alloc_size_));
310 addMetadata("Total Allocated", std::to_string(alloc.total_allocated_));
311 addMetadata("Total Reserved", std::to_string(alloc.total_reserved_));
312 }
313
314 template <typename T>
operator ()torch::autograd::profiler::__anoneaef9dcb0111::AddGenericMetadata315 void operator()(const T&) {}
316
317 private:
318 /* To get names of the performance events */
319 const torch::profiler::impl::ProfilerConfig* config_;
320 };
321
322 struct KinetoThreadLocalState : public ProfilerStateBase {
KinetoThreadLocalStatetorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState323 explicit KinetoThreadLocalState(
324 const ProfilerConfig& config,
325 std::set<torch::profiler::impl::ActivityType> activities)
326 : ProfilerStateBase(config),
327 startTime(getTimeNs()),
328 recordQueue(config, std::move(activities)) {}
329 ~KinetoThreadLocalState() override = default;
330
gettorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState331 static KinetoThreadLocalState* get(bool global) {
332 auto* state = ProfilerStateBase::get(/*global=*/global);
333 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
334 state == nullptr ||
335 state->profilerType() == ActiveProfilerType::KINETO);
336 return static_cast<KinetoThreadLocalState*>(state);
337 }
338
profilerTypetorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState339 ActiveProfilerType profilerType() override {
340 return ActiveProfilerType::KINETO;
341 }
342
reportVulkanEventToProfilertorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState343 void reportVulkanEventToProfiler(torch::profiler::impl::vulkan_id_t id) {
344 if (!config_.disabled()) {
345 recordQueue.getSubqueue()->emplace_vulkan_event(
346 c10::getApproximateTime(), id);
347 }
348 }
349
reportMemoryUsagetorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState350 void reportMemoryUsage(
351 void* ptr,
352 int64_t alloc_size,
353 size_t total_allocated,
354 size_t total_reserved,
355 c10::Device device) override {
356 if (config_.profile_memory && !config_.disabled()) {
357 recordQueue.getSubqueue()->emplace_allocation_event(
358 c10::getApproximateTime(),
359 ptr,
360 alloc_size,
361 total_allocated,
362 total_reserved,
363 device.type(),
364 device.index());
365 }
366 }
367
reportOutOfMemorytorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState368 void reportOutOfMemory(
369 int64_t alloc_size,
370 size_t total_allocated,
371 size_t total_reserved,
372 c10::Device device) override {
373 if (config_.profile_memory && !config_.disabled()) {
374 recordQueue.getSubqueue()->emplace_ooms_event(
375 c10::getApproximateTime(),
376 alloc_size,
377 total_allocated,
378 total_reserved,
379 device.type(),
380 device.index());
381 }
382 }
383
setEventPostProcessingCallbacktorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState384 void setEventPostProcessingCallback(post_process_t&& cb) {
385 eventPostProcessCb = std::move(cb);
386 }
387
pausePythontorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState388 void pausePython() {
389 recordQueue.stop();
390 }
391
resumePythontorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState392 void resumePython() {
393 recordQueue.restart();
394 }
395
396 std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>
finalizeTracetorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState397 finalizeTrace() {
398 auto end_time = getTimeNs();
399 recordQueue.stop();
400
401 std::lock_guard<std::mutex> guard(state_mutex_);
402 auto converter = clockConverter.makeConverter();
403 #ifdef USE_KINETO
404 libkineto::get_time_converter() = converter;
405 #endif
406 auto records_and_trace =
407 recordQueue.getRecords(std::move(converter), startTime, end_time);
408
409 materializeOpEvents(records_and_trace.first);
410
411 // `kinetoEvents` does not include Python events. Instead it exposes them
412 // via the `stacks` property.
413 kinetoEvents.erase(
414 std::remove_if(
415 kinetoEvents.begin(),
416 kinetoEvents.end(),
417 [](const auto& i) { return i.isPythonFunction(); }),
418 kinetoEvents.end());
419
420 return std::move(records_and_trace.second);
421 }
422
423 template <typename T>
invokeCallbacktorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState424 void invokeCallback(T& t) {
425 if (eventPostProcessCb) {
426 eventPostProcessCb(t.debug_handle_, t.jit_stack_, t.jit_modules_);
427 }
428 }
429
materializeOpEventstorch::autograd::profiler::__anoneaef9dcb0111::KinetoThreadLocalState430 void materializeOpEvents(std::vector<std::shared_ptr<Result>>& events) {
431 for (auto& e : events) {
432 if (e->parent_.expired() && e->deviceType() == c10::DeviceType::CPU) {
433 eventTree.push_back(e);
434 }
435
436 if (e->finished_) {
437 e->visit(c10::overloaded(
438 [this](ExtraFields<EventType::TorchOp>& i) { invokeCallback(i); },
439 [this](ExtraFields<EventType::Backend>& i) { invokeCallback(i); },
440 [](auto&) {}));
441
442 kinetoEvents.emplace_back(e, config_.experimental_config.verbose);
443 AddTensorboardFields add_tb(e, kinetoEvents.back());
444 AddGenericMetadata add_generic(e, &config_);
445
446 // It is not safe to use the activity after post processing.
447 e->kineto_activity_ = nullptr;
448 }
449 }
450 }
451
452 uint64_t startTime;
453 c10::ApproximateClockToUnixTimeConverter clockConverter;
454 torch::profiler::impl::RecordQueue recordQueue;
455 std::vector<KinetoEvent> kinetoEvents;
456 std::vector<experimental_event_t> eventTree;
457 // Optional, if event post-processing is enabled.
458 post_process_t eventPostProcessCb;
459 };
460
461 template <bool use_global_state_ptr = false>
onFunctionEnter(const at::RecordFunction & fn)462 std::unique_ptr<at::ObserverContext> onFunctionEnter(
463 const at::RecordFunction& fn) {
464 auto state_ptr = KinetoThreadLocalState::get(use_global_state_ptr);
465 if (!state_ptr) {
466 return nullptr;
467 }
468 return state_ptr->recordQueue.getSubqueue()->begin_op(fn);
469 }
470
471 // @lint-ignore CLANGTIDY clang-diagnostic-unused-parameter
472 template <bool use_global_state_ptr = false>
onFunctionExit(const at::RecordFunction & fn,at::ObserverContext * ctx_ptr)473 void onFunctionExit(
474 const at::RecordFunction& fn,
475 at::ObserverContext* ctx_ptr) {
476 auto state_ptr = KinetoThreadLocalState::get(use_global_state_ptr);
477 if (!state_ptr) {
478 return;
479 }
480 const auto& config = state_ptr->config();
481 auto* kineto_ctx_ptr =
482 static_cast<torch::profiler::impl::KinetoObserverContext*>(ctx_ptr);
483 TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr);
484 kineto_ctx_ptr->event_->end_time_ = c10::getApproximateTime();
485 if (!config.experimental_config.performance_events.empty()) {
486 state_ptr->recordQueue.getSubqueue()->disable_perf_profiler(
487 *kineto_ctx_ptr->event_->counters_);
488 }
489 kineto_ctx_ptr->event_->basic_fields_.end_tid_ =
490 at::RecordFunction::currentThreadId();
491 if (config.state == ProfilerState::KINETO_GPU_FALLBACK) {
492 try {
493 auto fallback = kineto_ctx_ptr->fallback_;
494 TORCH_INTERNAL_ASSERT(fallback != nullptr);
495 torch::profiler::impl::cudaStubs()->record(
496 nullptr, &fallback->device_event_end_, nullptr);
497 } catch (const std::exception& e) {
498 LOG(WARNING) << "Failed to record CUDA event. " << e.what();
499 }
500 } else if (config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
501 auto fallback = kineto_ctx_ptr->fallback_;
502 TORCH_INTERNAL_ASSERT(fallback != nullptr);
503 torch::profiler::impl::privateuse1Stubs()->record(
504 nullptr, &fallback->device_event_end_, nullptr);
505 }
506
507 if (fn.scope() == at::RecordScope::USER_SCOPE) {
508 torch::profiler::impl::kineto::popUserCorrelationId();
509 } else {
510 torch::profiler::impl::kineto::popCorrelationId();
511 }
512 }
513
514 template <bool use_global_callback = false>
pushProfilingCallbacks(const std::unordered_set<at::RecordScope> & scopes)515 void pushProfilingCallbacks(const std::unordered_set<at::RecordScope>& scopes) {
516 auto registration_state_ptr =
517 KinetoThreadLocalState::get(use_global_callback);
518 TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set");
519 auto recordFunctionCallback =
520 at::RecordFunctionCallback(
521 onFunctionEnter<use_global_callback>,
522 onFunctionExit<use_global_callback>)
523 .needsInputs(registration_state_ptr->config().report_input_shapes)
524 .scopes(scopes);
525
526 if constexpr (use_global_callback) {
527 registration_state_ptr->setCallbackHandle(
528 at::addGlobalCallback(recordFunctionCallback));
529 } else {
530 registration_state_ptr->setCallbackHandle(
531 at::addThreadLocalCallback(recordFunctionCallback));
532 }
533 }
534
535 struct ProfilerStateInfo {
536 std::shared_ptr<KinetoThreadLocalState> state_ptr;
537 std::unordered_set<at::RecordScope> scopes;
538 };
539 std::shared_ptr<ProfilerStateInfo> profiler_state_info_ptr{nullptr};
540
541 } // namespace
542
reportBackendEventToActiveKinetoProfiler(const int64_t start_time_us,const int64_t end_time_us,const int64_t debug_handle,const at::RecordScope scope,const std::string & event_name,const std::string & backend_name)543 void reportBackendEventToActiveKinetoProfiler(
544 const int64_t start_time_us,
545 const int64_t end_time_us,
546 const int64_t debug_handle,
547 const at::RecordScope scope,
548 const std::string& event_name,
549 const std::string& backend_name) {
550 TORCH_INTERNAL_ASSERT(
551 KinetoThreadLocalState::get(/*global=*/true) == nullptr,
552 "On-demand profiling does not support post processing callback");
553
554 auto state_ptr = KinetoThreadLocalState::get(/*global=*/false);
555 if (!state_ptr) {
556 return;
557 }
558
559 state_ptr->recordQueue.getSubqueue()->emplace_backend_event(
560 start_time_us,
561 end_time_us,
562 debug_handle,
563 scope,
564 event_name,
565 backend_name);
566
567 /* no support for input shapes now?
568 if (config.report_input_shapes) {
569 ctx_ptr->shapes = inputSizes(fn);
570 ctx_ptr->dtypes = inputTypes(fn);
571 }
572 */
573 }
574
prepareProfiler(const torch::profiler::impl::ProfilerConfig & config,const std::set<torch::profiler::impl::ActivityType> & activities)575 void prepareProfiler(
576 const torch::profiler::impl::ProfilerConfig& config,
577 const std::set<torch::profiler::impl::ActivityType>& activities) {
578 if (config.state == ProfilerState::NVTX ||
579 config.state == ProfilerState::ITT) {
580 return;
581 }
582 TORCH_CHECK(
583 config.state == ProfilerState::KINETO ||
584 config.state == ProfilerState::KINETO_GPU_FALLBACK ||
585 config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK,
586 "Supported only in Kineto profiler");
587 torch::profiler::impl::kineto::prepareTrace(
588 /*cpuOnly=*/!(
589 at::hasCUDA() || at::hasXPU() || at::hasMTIA() ||
590 c10::get_privateuse1_backend() != "privateuseone"),
591 activities,
592 config.experimental_config);
593
594 if (!config.experimental_config.performance_events.empty()) {
595 /* For now only CPU activity is supported */
596 TORCH_CHECK(
597 activities.count(torch::autograd::profiler::ActivityType::CPU),
598 "Cannot run cpu hardware profiler without CPU activities, please only use CPU activity type");
599 /*
600 * Sending a warning and passing the non-standard event to the backend
601 * Backend can abort if the event is not supported.
602 * TODO Should we gracefully drop the invalid event if we have atleast one
603 * valid?
604 */
605 auto is_standard_event = [](const std::string& event) -> bool {
606 for (auto e : torch::profiler::ProfilerPerfEvents) {
607 if (!std::strcmp(event.c_str(), e)) {
608 return true;
609 }
610 }
611 return false;
612 };
613
614 for (const auto& e : config.experimental_config.performance_events) {
615 if (!is_standard_event(e)) {
616 TORCH_WARN("Forwarding a non-standard CPU performance event : ", e);
617 }
618 }
619 }
620 }
621
toggleTorchOpCollectionDynamic(bool enable)622 static void toggleTorchOpCollectionDynamic(bool enable) {
623 auto state_ptr = ProfilerStateBase::get();
624 if (state_ptr) {
625 const auto& config = state_ptr->config();
626 if (enable) {
627 auto scopes = profiler_state_info_ptr->scopes;
628 config.global() ? pushProfilingCallbacks</*global=*/true>(scopes)
629 : pushProfilingCallbacks</*global=*/false>(scopes);
630 } else {
631 state_ptr->removeCallback();
632 }
633 }
634 }
635
636 // Set this function to be unused as profiler implementation needs more
637 // refactoring to support Python ops collection dynamic toggling
638 #ifdef _MSC_VER
639 #define UNUSED
640 #else
641 #define UNUSED __attribute__((unused))
642 #endif
togglePythonCollectionDynamic(bool enable)643 static UNUSED void togglePythonCollectionDynamic(bool enable) {
644 auto state_ptr = ProfilerStateBase::get();
645 if (state_ptr) {
646 auto global = state_ptr->config().global();
647 KinetoThreadLocalState* kineto_thread_local_state_ptr =
648 KinetoThreadLocalState::get(global);
649 if (enable) {
650 kineto_thread_local_state_ptr->resumePython();
651 } else {
652 kineto_thread_local_state_ptr->pausePython();
653 }
654 }
655 }
656
toggleCPUCollectionDynamic(bool enable)657 static void toggleCPUCollectionDynamic(bool enable) {
658 toggleTorchOpCollectionDynamic(enable);
659 // For now we only support Torch Op collection dynamic toggling as
660 // implementing Python ops would require not only string parsing to get rid of
661 // the toggling events as well as other unfinished events as well as changes
662 // in stack logic
663 // togglePythonCollectionDynamic(enable);
664 }
665
toggleCollectionDynamic(const bool enable,const std::set<torch::profiler::impl::ActivityType> & activities)666 void toggleCollectionDynamic(
667 const bool enable,
668 const std::set<torch::profiler::impl::ActivityType>& activities) {
669 if (activities.count(torch::autograd::profiler::ActivityType::CPU) > 0 &&
670 activities.count(torch::autograd::profiler::ActivityType::CUDA) == 0) {
671 LOG(WARNING)
672 << "Toggling CPU activity with CUDA activity on may result in traces with CUDA events on artibrary tracks";
673 }
674 for (auto act : activities) {
675 if (act == torch::autograd::profiler::ActivityType::CUDA) {
676 torch::profiler::impl::kineto::toggleCollectionDynamic(enable);
677 } else if (act == torch::autograd::profiler::ActivityType::CPU) {
678 toggleCPUCollectionDynamic(enable);
679 } else {
680 LOG(WARNING)
681 << "Dynamic toggle is only supported for CPU/GPU activity, skipping toggling of "
682 << actToString(act);
683 continue;
684 }
685 }
686 }
687
enableProfilerWithEventPostProcess(const torch::profiler::impl::ProfilerConfig & config,const std::set<torch::profiler::impl::ActivityType> & activities,post_process_t && cb,const std::unordered_set<at::RecordScope> & scopes)688 void enableProfilerWithEventPostProcess(
689 const torch::profiler::impl::ProfilerConfig& config,
690 const std::set<torch::profiler::impl::ActivityType>& activities,
691 post_process_t&& cb,
692 const std::unordered_set<at::RecordScope>& scopes) {
693 TORCH_CHECK(
694 config.state != ProfilerState::NVTX,
695 "NVTX does not support post processing callback.");
696 TORCH_CHECK(
697 config.state != ProfilerState::ITT,
698 "ITT does not support post processing callback.");
699 TORCH_INTERNAL_ASSERT(
700 KinetoThreadLocalState::get(/*global=*/true) == nullptr,
701 "On-demand profiling does not support post processing callback");
702
703 enableProfiler(config, activities, scopes);
704 auto state_ptr = KinetoThreadLocalState::get(config.global());
705 state_ptr->setEventPostProcessingCallback(std::move(cb));
706 }
707
enableProfiler(const torch::profiler::impl::ProfilerConfig & config,const std::set<torch::profiler::impl::ActivityType> & activities,const std::unordered_set<at::RecordScope> & scopes)708 void enableProfiler(
709 const torch::profiler::impl::ProfilerConfig& config,
710 const std::set<torch::profiler::impl::ActivityType>& activities,
711 const std::unordered_set<at::RecordScope>& scopes) {
712 const auto has_cpu = activities.count(ActivityType::CPU);
713 TORCH_CHECK(
714 KinetoThreadLocalState::get(/*global=*/config.global()) == nullptr,
715 "Profiler is already enabled",
716 (config.global() ? "." : " on this thread."));
717
718 if (config.state == ProfilerState::NVTX) {
719 torch::profiler::impl::pushNVTXCallbacks(config, scopes);
720 return;
721 } else if (config.state == ProfilerState::ITT) {
722 torch::profiler::impl::pushITTCallbacks(config, scopes);
723 return;
724 } else if (config.state == ProfilerState::PRIVATEUSE1) {
725 torch::profiler::impl::pushPRIVATEUSE1CallbacksStub(config, scopes);
726 return;
727 }
728
729 TORCH_CHECK(
730 config.state == ProfilerState::KINETO ||
731 config.state == ProfilerState::KINETO_GPU_FALLBACK ||
732 config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK ||
733 config.global());
734 TORCH_CHECK(!activities.empty(), "No activities specified.");
735 TORCH_INTERNAL_ASSERT(
736 has_cpu || !config.global(),
737 "Ondemand profiling must enable CPU tracing");
738
739 auto state_ptr = std::make_shared<KinetoThreadLocalState>(config, activities);
740 KinetoThreadLocalState::push(state_ptr);
741
742 if (has_cpu) {
743 config.global() ? pushProfilingCallbacks</*global=*/true>(scopes)
744 : pushProfilingCallbacks</*global=*/false>(scopes);
745 }
746
747 if (!config.global()) {
748 torch::profiler::impl::kineto::startTrace();
749 }
750
751 if (has_cpu) {
752 auto state_info_ptr = std::make_shared<ProfilerStateInfo>();
753 state_info_ptr->state_ptr = state_ptr;
754 state_info_ptr->scopes = scopes;
755 profiler_state_info_ptr = state_info_ptr;
756 }
757 }
758
isProfilerEnabledInMainThread()759 bool isProfilerEnabledInMainThread() {
760 return profiler_state_info_ptr != nullptr;
761 }
762
enableProfilerInChildThread()763 void enableProfilerInChildThread() {
764 auto state_info_ptr = profiler_state_info_ptr;
765 TORCH_CHECK(state_info_ptr, "Profiler is not enabled in main thread.");
766 TORCH_CHECK(
767 KinetoThreadLocalState::get(/*global=*/false) == nullptr,
768 "Profiler is already enabled in this thread.");
769
770 KinetoThreadLocalState::push(state_info_ptr->state_ptr);
771 pushProfilingCallbacks</*global=*/false>(state_info_ptr->scopes);
772 }
773
disableProfilerInChildThread()774 void disableProfilerInChildThread() {
775 auto state_ptr = ProfilerStateBase::pop();
776 TORCH_CHECK(
777 state_ptr,
778 "Can't disable Kineto profiler when it's not running in this thread");
779 state_ptr->removeCallback();
780 }
781
disableProfiler()782 std::unique_ptr<ProfilerResult> disableProfiler() {
783 // releasing to inform child threads to stop profiling
784 profiler_state_info_ptr = nullptr;
785
786 auto state_ptr = ProfilerStateBase::pop();
787 const auto& config = state_ptr->config();
788 TORCH_CHECK(
789 state_ptr &&
790 (config.state == ProfilerState::KINETO ||
791 config.state == ProfilerState::KINETO_GPU_FALLBACK ||
792 config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK ||
793 config.state == ProfilerState::KINETO_ONDEMAND ||
794 config.state == ProfilerState::NVTX ||
795 config.state == ProfilerState::ITT ||
796 config.state == ProfilerState::PRIVATEUSE1),
797 "Can't disable Kineto profiler when it's not running");
798
799 state_ptr->removeCallback();
800
801 // Traces are converged via libkineto automatically for ondemand flow
802 if (state_ptr->config().global()) {
803 (void)std::static_pointer_cast<KinetoThreadLocalState>(state_ptr)
804 ->finalizeTrace();
805 return std::make_unique<ProfilerResult>();
806 }
807
808 // Shared among NVTX, PRIVATEUSE1, KINETO, KINETO_GPU_FALLBACK,
809 // KINETO_PRIVATEUSE1_FALLBACK
810 std::unique_ptr<ProfilerResult> result;
811 if (state_ptr->config().state == ProfilerState::NVTX ||
812 state_ptr->config().state == ProfilerState::PRIVATEUSE1) {
813 result = std::make_unique<ProfilerResult>();
814 }
815
816 if (config.state == ProfilerState::KINETO ||
817 config.state == ProfilerState::KINETO_GPU_FALLBACK ||
818 config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
819 auto kineto_state_ptr =
820 std::static_pointer_cast<KinetoThreadLocalState>(state_ptr);
821 auto trace = kineto_state_ptr->finalizeTrace();
822 result = std::make_unique<ProfilerResult>(
823 kineto_state_ptr->startTime,
824 std::move(kineto_state_ptr->kinetoEvents),
825 std::move(trace),
826 std::move(kineto_state_ptr->eventTree));
827 }
828
829 return result;
830 }
831
KinetoEvent(const std::shared_ptr<const torch::profiler::impl::Result> & result,const bool verbose)832 KinetoEvent::KinetoEvent(
833 const std::shared_ptr<const torch::profiler::impl::Result>& result,
834 const bool verbose)
835 : result_{result} {
836 TORCH_INTERNAL_ASSERT(result != nullptr);
837
838 if (verbose) {
839 // Populate Python stack
840 auto parent = result_->parent_.lock();
841 while (parent != nullptr) {
842 parent->visit_if_base<PyExtraFieldsBase>(
__anoneaef9dcb1202(const auto&) 843 [&](const auto&) { python_stack_.push_back(parent->name()); });
844 parent = parent->parent_.lock();
845 }
846 }
847
__anoneaef9dcb1302(const auto& op) 848 result->visit_if_base<ExtraFields<EventType::TorchOp>>([&](const auto& op) {
849 auto arg_data = parseArgData(op.inputs_, op.concrete_inputs_);
850 shapes_ = std::move(arg_data.shapesForKinetoEvent);
851 dtypes_ = std::move(arg_data.dtypes);
852 concrete_inputs_ = std::move(arg_data.concreteInputs);
853 kwinputs_ = std::move(op.kwinputs_);
854 });
855 }
856
isPythonFunction() const857 bool KinetoEvent::isPythonFunction() const {
858 bool out{false};
859 result_->visit_if_base<PyExtraFieldsBase>([&](const auto&) { out = true; });
860 return out;
861 }
862
hasShapes() const863 bool KinetoEvent::hasShapes() const {
864 return !shapes_.empty();
865 }
866
shapes() const867 const c10::ArrayRef<std::vector<int64_t>> KinetoEvent::shapes() const {
868 return shapes_;
869 }
870
hasTypes() const871 bool KinetoEvent::hasTypes() const {
872 return !dtypes_.empty();
873 }
874
dtypes() const875 const c10::ArrayRef<std::string> KinetoEvent::dtypes() const {
876 return dtypes_;
877 }
878
hasConcreteInputs() const879 bool KinetoEvent::hasConcreteInputs() const {
880 return !concrete_inputs_.empty();
881 }
882
concreteInputs() const883 const c10::ArrayRef<c10::IValue> KinetoEvent::concreteInputs() const {
884 return concrete_inputs_;
885 }
886
hasKwinputs() const887 bool KinetoEvent::hasKwinputs() const {
888 return !kwinputs_.empty();
889 }
890
kwinputs() const891 const std::unordered_map<std::string, c10::IValue> KinetoEvent::kwinputs()
892 const {
893 return kwinputs_;
894 }
895
stack() const896 const c10::ArrayRef<std::string> KinetoEvent::stack() const {
897 auto get = [&](const auto& i) -> auto& {
898 return !i.jit_stack_.empty() ? i.jit_stack_ : python_stack_;
899 };
900
901 auto const& extra_fields = result_->extra_fields_;
902 if (auto p = std::get_if<ExtraFields<EventType::TorchOp>>(&extra_fields)) {
903 return get(*p);
904 }
905 if (auto p = std::get_if<ExtraFields<EventType::Backend>>(&extra_fields)) {
906 return get(*p);
907 }
908 return python_stack_;
909 }
910
moduleHierarchy() const911 const c10::ArrayRef<std::string> KinetoEvent::moduleHierarchy() const {
912 auto const& extra_fields = result_->extra_fields_;
913 if (auto p = std::get_if<ExtraFields<EventType::TorchOp>>(&extra_fields)) {
914 return p->jit_modules_;
915 }
916 if (auto p = std::get_if<ExtraFields<EventType::Backend>>(&extra_fields)) {
917 return p->jit_modules_;
918 }
919 return {};
920 }
921
endNs() const922 uint64_t KinetoEvent::endNs() const {
923 return result_->endTimeNS();
924 }
925
durationNs() const926 uint64_t KinetoEvent::durationNs() const {
927 return (result_->endTimeNS() - result_->start_time_ns_);
928 }
929
debugHandle() const930 int64_t KinetoEvent::debugHandle() const {
931 return result_->visit(c10::overloaded(
932 [](const ExtraFields<EventType::TorchOp>& i) { return i.debug_handle_; },
933 [](const ExtraFields<EventType::Backend>& i) { return i.debug_handle_; },
934 [](const auto&) -> int64_t { return -1; }));
935 }
936
deviceIndex() const937 int KinetoEvent::deviceIndex() const {
938 return result_->visit(c10::overloaded(
939 [](const ExtraFields<EventType::Allocation>& i) {
940 return static_cast<int>(i.device_index_);
941 },
942 [](const ExtraFields<EventType::OutOfMemory>& i) {
943 return static_cast<int>(i.device_index_);
944 },
945 [&](const auto&) {
946 return static_cast<int>(result_->kineto_info_.device);
947 }));
948 }
949
hasStack() const950 bool KinetoEvent::hasStack() const {
951 return !stack().empty();
952 }
953
cudaElapsedUs() const954 int64_t KinetoEvent::cudaElapsedUs() const {
955 auto cuda_event_start = fallbackStart();
956 auto cuda_event_end = fallbackEnd();
957 if (!cuda_event_start || !cuda_event_end) {
958 return -1;
959 }
960 try {
961 return (int64_t)torch::profiler::impl::cudaStubs()->elapsed(
962 &cuda_event_start, &cuda_event_end);
963 } catch (std::exception& e) {
964 LOG(WARNING) << "Failed to measure time between two CUDA events. "
965 << e.what();
966 }
967 return -1;
968 }
969
privateuse1ElapsedUs() const970 int64_t KinetoEvent::privateuse1ElapsedUs() const {
971 auto privateuse1_event_start = fallbackStart();
972 auto privateuse1_event_end = fallbackEnd();
973 if (!privateuse1_event_start || !privateuse1_event_end) {
974 return -1;
975 }
976 return (int64_t)torch::profiler::impl::privateuse1Stubs()->elapsed(
977 &privateuse1_event_start, &privateuse1_event_end);
978 return -1;
979 }
980
getPerfEventCounters(std::vector<uint64_t> & in) const981 void KinetoEvent::getPerfEventCounters(std::vector<uint64_t>& in) const {
982 return result_->visit(c10::overloaded(
983 [&in](const ExtraFields<EventType::TorchOp>& e) -> void {
984 const size_t n = e.perf_event_counters_->size();
985 // should be rare
986 if (in.size() < n) {
987 in.resize(n, 0);
988 }
989 for (size_t i = 0; i < n; ++i) {
990 in[i] = (*e.perf_event_counters_)[i];
991 }
992 },
993 [](const auto&) -> void { return; }));
994 }
995
996 #define FORWARD_FROM_RESULT(method_name, result_expr) \
997 decltype(std::declval<KinetoEvent>().method_name()) \
998 KinetoEvent::method_name() const { \
999 return static_cast<decltype(std::declval<KinetoEvent>().method_name())>( \
1000 result_->result_expr); \
1001 }
1002
FORWARD_FROM_RESULT(startThreadId,start_tid_)1003 FORWARD_FROM_RESULT(startThreadId, start_tid_)
1004 FORWARD_FROM_RESULT(endThreadId, endTID())
1005 FORWARD_FROM_RESULT(activityType, kinetoType())
1006 FORWARD_FROM_RESULT(name, name())
1007 FORWARD_FROM_RESULT(deviceType, deviceType())
1008 FORWARD_FROM_RESULT(startNs, start_time_ns_)
1009 FORWARD_FROM_RESULT(correlationId, correlationID())
1010 FORWARD_FROM_RESULT(deviceResourceId, kineto_info_.resource)
1011 #undef FORWARD_FROM_RESULT
1012
1013 // Most of the fields in `KinetoEvent` only make sense for a single event type.
1014 // (Generally TorchOp.) For all other types they simply return the default
1015 // value. This macro provides a succinct way of expressing this behavior.
1016 #define TYPED_ATTR_WITH_DEFAULT( \
1017 event_type, method_name, expression, default_value) \
1018 decltype(std::declval<KinetoEvent>().method_name()) \
1019 KinetoEvent::method_name() const { \
1020 using out_t = decltype(std::declval<KinetoEvent>().method_name()); \
1021 return result_->visit(c10::overloaded( \
1022 [](const ExtraFields<EventType::event_type>& e) -> out_t { \
1023 return expression; \
1024 }, \
1025 [](const auto&) -> out_t { return default_value; })); \
1026 }
1027
1028 #define TYPED_ATTR(event_type, method_name, expression) \
1029 TYPED_ATTR_WITH_DEFAULT(event_type, method_name, expression, {})
1030
1031 TYPED_ATTR_WITH_DEFAULT(TorchOp, sequenceNr, e.sequence_number_, -1)
1032 TYPED_ATTR(TorchOp, fwdThreadId, e.sequence_number_ >= 0 ? e.forward_tid_ : 0)
1033 TYPED_ATTR(TorchOp, scope, static_cast<uint8_t>(e.scope_))
1034 TYPED_ATTR(TorchOp, hasModuleHierarchy, !e.jit_modules_.empty())
1035 TYPED_ATTR(TorchOp, isAsync, e.is_async_)
1036 TYPED_ATTR(TorchOp, extraMeta, e.extra_meta_)
1037 TYPED_ATTR(TorchOp, fallbackStart, e.device_fallback_.device_event_start_)
1038 TYPED_ATTR(TorchOp, fallbackEnd, e.device_fallback_.device_event_end_)
1039 TYPED_ATTR(
1040 TorchOp,
1041 flops,
1042 !e.extra_args_.empty()
1043 ? torch::profiler::impl::computeFlops(e.name_, e.extra_args_)
1044 : 0)
1045 TYPED_ATTR(Backend, backend, e.backend_)
1046 TYPED_ATTR(Allocation, nBytes, e.alloc_size_)
__anoneaef9dcb1e02() 1047 TYPED_ATTR(Kineto, linkedCorrelationId, [&]() {
1048 const auto linked = e.linked_activity_.lock();
1049 return linked ? linked->correlationID() : 0;
1050 }())
1051 #undef TYPED_ATTR
1052 #undef TYPED_ATTR_WITH_DEFAULT
1053
1054 ProfilerResult::ProfilerResult(
1055 uint64_t start_time,
1056 std::vector<KinetoEvent> events,
1057 std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>&&
1058 trace,
1059 std::vector<experimental_event_t>&& event_tree)
1060 : trace_start_ns_(start_time),
1061 events_(std::move(events)),
1062 trace_(std::move(trace)),
1063 event_tree_(std::move(event_tree)) {}
1064 ProfilerResult::ProfilerResult() = default;
1065 ProfilerResult::~ProfilerResult() = default;
1066
save(const std::string & path)1067 void ProfilerResult::save(const std::string& path) {
1068 trace_->save(path);
1069 }
1070
1071 } // namespace autograd::profiler
1072
1073 namespace profiler::impl {
_reportVulkanEventToProfiler(vulkan_id_t id)1074 void _reportVulkanEventToProfiler(vulkan_id_t id) {
1075 auto state_ptr = ::torch::autograd::profiler::KinetoThreadLocalState::get(
1076 /*global=*/false);
1077 if (state_ptr) {
1078 state_ptr->reportVulkanEventToProfiler(id);
1079 }
1080 }
1081 } // namespace profiler::impl
1082
1083 } // namespace torch
1084