xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/standalone/itt_observer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/profiler/standalone/itt_observer.h>
2 
3 #include <torch/csrc/profiler/stubs/base.h>
4 #include <torch/csrc/profiler/util.h>
5 
6 namespace torch::profiler::impl {
7 
8 struct ITTThreadLocalState : ProfilerStateBase {
ITTThreadLocalStatetorch::profiler::impl::ITTThreadLocalState9   explicit ITTThreadLocalState(const ProfilerConfig& config)
10       : ProfilerStateBase(config) {
11     // Only `report_input_shapes` makes sense in this context.
12     TORCH_CHECK(!config.profile_memory);
13     TORCH_CHECK(!config.with_stack);
14     TORCH_CHECK(!config.with_flops);
15     TORCH_CHECK(!config.with_modules);
16   }
17   ~ITTThreadLocalState() override = default;
18 
profilerTypetorch::profiler::impl::ITTThreadLocalState19   ActiveProfilerType profilerType() override {
20     return ActiveProfilerType::ITT;
21   }
22 
reportMemoryUsagetorch::profiler::impl::ITTThreadLocalState23   void reportMemoryUsage(void*, int64_t, size_t, size_t, c10::Device) override {
24   }
25 
getTLStorch::profiler::impl::ITTThreadLocalState26   static ITTThreadLocalState* getTLS() {
27     auto tls = ProfilerStateBase::get(/*global=*/false);
28     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
29         tls == nullptr || tls->profilerType() == ActiveProfilerType::ITT);
30     return static_cast<ITTThreadLocalState*>(tls);
31   }
32 };
33 
34 template <bool report_input_shapes>
enterITT(const at::RecordFunction & fn)35 std::unique_ptr<at::ObserverContext> enterITT(const at::RecordFunction& fn) {
36   if (ITTThreadLocalState::getTLS() != nullptr) {
37     torch::profiler::impl::ittStubs()->rangePush(fn.name());
38   }
39   return nullptr;
40 }
41 
pushITTCallbacks(const ProfilerConfig & config,const std::unordered_set<at::RecordScope> & scopes)42 void pushITTCallbacks(
43     const ProfilerConfig& config,
44     const std::unordered_set<at::RecordScope>& scopes) {
45   TORCH_CHECK(
46       torch::profiler::impl::ittStubs()->enabled(),
47       "Can't use ITT profiler - PyTorch was compiled without ITT");
48 
49   c10::ThreadLocalDebugInfo::_push(
50       c10::DebugInfoKind::PROFILER_STATE,
51       std::make_shared<ITTThreadLocalState>(config));
52 
53   auto state_ptr = ITTThreadLocalState::getTLS();
54   TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
55 
56   auto handle = at::addThreadLocalCallback(
57       at::RecordFunctionCallback(
58           state_ptr->config().report_input_shapes
59               ? &enterITT</*report_input_shapes=*/true>
60               : &enterITT</*report_input_shapes=*/false>,
61           [](const at::RecordFunction&, at::ObserverContext*) {
62             torch::profiler::impl::ittStubs()->rangePop();
63           })
64           .needsInputs(config.report_input_shapes)
65           .scopes(scopes));
66   state_ptr->setCallbackHandle(handle);
67 }
68 
69 } // namespace torch::profiler::impl
70