xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/standalone/nvtx_observer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/profiler/standalone/nvtx_observer.h>
2 
3 #include <torch/csrc/profiler/stubs/base.h>
4 #include <torch/csrc/profiler/util.h>
5 
6 namespace torch::profiler::impl {
7 
8 struct NVTXThreadLocalState : ProfilerStateBase {
NVTXThreadLocalStatetorch::profiler::impl::NVTXThreadLocalState9   explicit NVTXThreadLocalState(const ProfilerConfig& config)
10       : ProfilerStateBase(config) {
11     // Only `report_input_shapes` makes sense in this context.
12     TORCH_CHECK(!config.profile_memory);
13     TORCH_CHECK(!config.with_stack);
14     TORCH_CHECK(!config.with_flops);
15     TORCH_CHECK(!config.with_modules);
16   }
17   ~NVTXThreadLocalState() override = default;
18 
profilerTypetorch::profiler::impl::NVTXThreadLocalState19   ActiveProfilerType profilerType() override {
20     return ActiveProfilerType::NVTX;
21   }
22 
reportMemoryUsagetorch::profiler::impl::NVTXThreadLocalState23   void reportMemoryUsage(void*, int64_t, size_t, size_t, c10::Device) override {
24   }
25 
getTLStorch::profiler::impl::NVTXThreadLocalState26   static NVTXThreadLocalState* getTLS() {
27     auto tls = ProfilerStateBase::get(/*global=*/false);
28     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
29         tls == nullptr || tls->profilerType() == ActiveProfilerType::NVTX);
30     return static_cast<NVTXThreadLocalState*>(tls);
31   }
32   std::pair<at::RecordFunctionHandle, int> getOpIdFromInput(
33       const at::Tensor& tensor);
34 
setProducerTensorMaptorch::profiler::impl::NVTXThreadLocalState35   void setProducerTensorMap(
36       at::TensorImpl* tensor,
37       at::RecordFunctionHandle op_id,
38       int output_nr) {
39     producer_tensor_map_[(void*)tensor] =
40         std::pair<at::RecordFunctionHandle, int>{op_id, output_nr};
41   }
42 
43  protected:
44   // Maps the address of an output Tensor to a unique op id and output
45   // index of the tensor.
46   // at::TensorImpl* is the actual type of the key, but using void*
47   // to indicate the pointer is just being used as a key
48   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
49   std::unordered_map<void*, std::pair<at::RecordFunctionHandle, int>>
50       producer_tensor_map_;
51 };
52 
getOpIdFromInput(const at::Tensor & tensor)53 std::pair<at::RecordFunctionHandle, int> NVTXThreadLocalState::getOpIdFromInput(
54     const at::Tensor& tensor) {
55   std::pair<at::RecordFunctionHandle, int> producer_op_pair(0, -1);
56   if (tensor.defined()) {
57     at::TensorImpl* ten_addr = tensor.unsafeGetTensorImpl();
58     // See if Address is in the map already
59     if (producer_tensor_map_.count((void*)ten_addr) > 0) {
60       producer_op_pair = producer_tensor_map_[(void*)ten_addr];
61     }
62   }
63   return producer_op_pair;
64 }
65 
flattenOpIdList(const c10::List<c10::IValue> & list)66 static std::list<std::pair<at::RecordFunctionHandle, int>> flattenOpIdList(
67     const c10::List<c10::IValue>& list) {
68   std::list<std::pair<at::RecordFunctionHandle, int>> input_op_id_list;
69   auto state_ptr = NVTXThreadLocalState::getTLS();
70   TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
71   for (const c10::IValue& input : list) {
72     if (input.isTensor()) {
73       const at::Tensor& tensor = input.toTensor();
74       auto producer_op_pair = state_ptr->getOpIdFromInput(tensor);
75       input_op_id_list.push_back(producer_op_pair);
76     }
77   }
78   return input_op_id_list;
79 }
80 
getInputTensorOpIds(const at::RecordFunction & fn)81 static std::list<std::pair<at::RecordFunctionHandle, int>> getInputTensorOpIds(
82     const at::RecordFunction& fn) {
83   std::pair<at::RecordFunctionHandle, int> undefined_op_pair(0, -1);
84   std::list<std::pair<at::RecordFunctionHandle, int>> input_producer_ops_;
85   auto state_ptr = NVTXThreadLocalState::getTLS();
86   TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
87   for (const c10::IValue& input_item : fn.inputs()) {
88     if (input_item.isTensor()) {
89       const at::Tensor& tensor = input_item.toTensor();
90       auto producer_pair = state_ptr->getOpIdFromInput(tensor);
91       input_producer_ops_.push_back(producer_pair);
92     } else {
93       if (input_item.isList()) {
94         std::list<std::pair<at::RecordFunctionHandle, int>> tmp_op_ids =
95             flattenOpIdList(input_item.toList());
96         // Extend the current sizes array by the array returned from input sizes
97         if (!tmp_op_ids.empty()) {
98           input_producer_ops_.splice(input_producer_ops_.end(), tmp_op_ids);
99         } else {
100           input_producer_ops_.emplace_back(undefined_op_pair);
101         }
102       } else {
103         input_producer_ops_.emplace_back(undefined_op_pair);
104       }
105     }
106   }
107   return input_producer_ops_;
108 }
109 
updateOutputTensorTracker(const at::RecordFunction & fn)110 static void updateOutputTensorTracker(const at::RecordFunction& fn) {
111   int output_nr = 0;
112   auto state_ptr = NVTXThreadLocalState::getTLS();
113   TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
114   for (const c10::IValue& s_tensor : fn.outputs()) {
115     if (s_tensor.isTensor()) {
116       const at::Tensor& tensor = s_tensor.toTensor();
117       if (tensor.defined()) {
118         auto ten_addr = tensor.unsafeGetTensorImpl();
119         state_ptr->setProducerTensorMap(ten_addr, fn.handle(), output_nr);
120       }
121     }
122     output_nr++;
123   }
124 }
125 
126 template <bool report_input_shapes>
enterNVTX(const at::RecordFunction & fn)127 std::unique_ptr<at::ObserverContext> enterNVTX(const at::RecordFunction& fn) {
128   if (NVTXThreadLocalState::getTLS() != nullptr) {
129     auto input_op_ids = getInputTensorOpIds(fn);
130     torch::profiler::impl::cudaStubs()->rangePush(
131         torch::profiler::impl::getNvtxStr(
132             fn.name(),
133             fn.seqNr(),
134             report_input_shapes ? torch::profiler::impl::inputSizes(fn, true)
135                                 : std::vector<std::vector<int64_t>>(),
136             fn.handle(),
137             report_input_shapes
138                 ? input_op_ids
139                 : std::list<std::pair<at::RecordFunctionHandle, int>>())
140             .c_str());
141   }
142   return nullptr;
143 }
144 
pushNVTXCallbacks(const ProfilerConfig & config,const std::unordered_set<at::RecordScope> & scopes)145 void pushNVTXCallbacks(
146     const ProfilerConfig& config,
147     const std::unordered_set<at::RecordScope>& scopes) {
148   TORCH_CHECK(
149       torch::profiler::impl::cudaStubs()->enabled(),
150       "Can't use NVTX profiler - PyTorch was compiled without CUDA");
151 
152   c10::ThreadLocalDebugInfo::_push(
153       c10::DebugInfoKind::PROFILER_STATE,
154       std::make_shared<NVTXThreadLocalState>(config));
155 
156   auto state_ptr = NVTXThreadLocalState::getTLS();
157   TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
158 
159   auto handle = at::addThreadLocalCallback(
160       at::RecordFunctionCallback(
161           state_ptr->config().report_input_shapes
162               ? &enterNVTX</*report_input_shapes=*/true>
163               : &enterNVTX</*report_input_shapes=*/false>,
164           [](const at::RecordFunction& fn, at::ObserverContext* ctx) {
165             torch::profiler::impl::cudaStubs()->rangePop();
166             updateOutputTensorTracker(fn);
167           })
168           .needsInputs(config.report_input_shapes)
169           .needsOutputs(config.report_input_shapes)
170           .needsIds(true)
171           .scopes(scopes));
172   state_ptr->setCallbackHandle(handle);
173 }
174 
175 } // namespace torch::profiler::impl
176