1 #include <torch/csrc/profiler/standalone/itt_observer.h>
2
3 #include <torch/csrc/profiler/stubs/base.h>
4 #include <torch/csrc/profiler/util.h>
5
6 namespace torch::profiler::impl {
7
8 struct ITTThreadLocalState : ProfilerStateBase {
ITTThreadLocalStatetorch::profiler::impl::ITTThreadLocalState9 explicit ITTThreadLocalState(const ProfilerConfig& config)
10 : ProfilerStateBase(config) {
11 // Only `report_input_shapes` makes sense in this context.
12 TORCH_CHECK(!config.profile_memory);
13 TORCH_CHECK(!config.with_stack);
14 TORCH_CHECK(!config.with_flops);
15 TORCH_CHECK(!config.with_modules);
16 }
17 ~ITTThreadLocalState() override = default;
18
profilerTypetorch::profiler::impl::ITTThreadLocalState19 ActiveProfilerType profilerType() override {
20 return ActiveProfilerType::ITT;
21 }
22
reportMemoryUsagetorch::profiler::impl::ITTThreadLocalState23 void reportMemoryUsage(void*, int64_t, size_t, size_t, c10::Device) override {
24 }
25
getTLStorch::profiler::impl::ITTThreadLocalState26 static ITTThreadLocalState* getTLS() {
27 auto tls = ProfilerStateBase::get(/*global=*/false);
28 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
29 tls == nullptr || tls->profilerType() == ActiveProfilerType::ITT);
30 return static_cast<ITTThreadLocalState*>(tls);
31 }
32 };
33
34 template <bool report_input_shapes>
enterITT(const at::RecordFunction & fn)35 std::unique_ptr<at::ObserverContext> enterITT(const at::RecordFunction& fn) {
36 if (ITTThreadLocalState::getTLS() != nullptr) {
37 torch::profiler::impl::ittStubs()->rangePush(fn.name());
38 }
39 return nullptr;
40 }
41
pushITTCallbacks(const ProfilerConfig & config,const std::unordered_set<at::RecordScope> & scopes)42 void pushITTCallbacks(
43 const ProfilerConfig& config,
44 const std::unordered_set<at::RecordScope>& scopes) {
45 TORCH_CHECK(
46 torch::profiler::impl::ittStubs()->enabled(),
47 "Can't use ITT profiler - PyTorch was compiled without ITT");
48
49 c10::ThreadLocalDebugInfo::_push(
50 c10::DebugInfoKind::PROFILER_STATE,
51 std::make_shared<ITTThreadLocalState>(config));
52
53 auto state_ptr = ITTThreadLocalState::getTLS();
54 TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
55
56 auto handle = at::addThreadLocalCallback(
57 at::RecordFunctionCallback(
58 state_ptr->config().report_input_shapes
59 ? &enterITT</*report_input_shapes=*/true>
60 : &enterITT</*report_input_shapes=*/false>,
61 [](const at::RecordFunction&, at::ObserverContext*) {
62 torch::profiler::impl::ittStubs()->rangePop();
63 })
64 .needsInputs(config.report_input_shapes)
65 .scopes(scopes));
66 state_ptr->setCallbackHandle(handle);
67 }
68
69 } // namespace torch::profiler::impl
70