1 #include <torch/csrc/profiler/standalone/nvtx_observer.h>
2
3 #include <torch/csrc/profiler/stubs/base.h>
4 #include <torch/csrc/profiler/util.h>
5
6 namespace torch::profiler::impl {
7
8 struct NVTXThreadLocalState : ProfilerStateBase {
NVTXThreadLocalStatetorch::profiler::impl::NVTXThreadLocalState9 explicit NVTXThreadLocalState(const ProfilerConfig& config)
10 : ProfilerStateBase(config) {
11 // Only `report_input_shapes` makes sense in this context.
12 TORCH_CHECK(!config.profile_memory);
13 TORCH_CHECK(!config.with_stack);
14 TORCH_CHECK(!config.with_flops);
15 TORCH_CHECK(!config.with_modules);
16 }
17 ~NVTXThreadLocalState() override = default;
18
profilerTypetorch::profiler::impl::NVTXThreadLocalState19 ActiveProfilerType profilerType() override {
20 return ActiveProfilerType::NVTX;
21 }
22
reportMemoryUsagetorch::profiler::impl::NVTXThreadLocalState23 void reportMemoryUsage(void*, int64_t, size_t, size_t, c10::Device) override {
24 }
25
getTLStorch::profiler::impl::NVTXThreadLocalState26 static NVTXThreadLocalState* getTLS() {
27 auto tls = ProfilerStateBase::get(/*global=*/false);
28 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
29 tls == nullptr || tls->profilerType() == ActiveProfilerType::NVTX);
30 return static_cast<NVTXThreadLocalState*>(tls);
31 }
32 std::pair<at::RecordFunctionHandle, int> getOpIdFromInput(
33 const at::Tensor& tensor);
34
setProducerTensorMaptorch::profiler::impl::NVTXThreadLocalState35 void setProducerTensorMap(
36 at::TensorImpl* tensor,
37 at::RecordFunctionHandle op_id,
38 int output_nr) {
39 producer_tensor_map_[(void*)tensor] =
40 std::pair<at::RecordFunctionHandle, int>{op_id, output_nr};
41 }
42
43 protected:
44 // Maps the address of an output Tensor to a unique op id and output
45 // index of the tensor.
46 // at::TensorImpl* is the actual type of the key, but using void*
47 // to indicate the pointer is just being used as a key
48 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
49 std::unordered_map<void*, std::pair<at::RecordFunctionHandle, int>>
50 producer_tensor_map_;
51 };
52
getOpIdFromInput(const at::Tensor & tensor)53 std::pair<at::RecordFunctionHandle, int> NVTXThreadLocalState::getOpIdFromInput(
54 const at::Tensor& tensor) {
55 std::pair<at::RecordFunctionHandle, int> producer_op_pair(0, -1);
56 if (tensor.defined()) {
57 at::TensorImpl* ten_addr = tensor.unsafeGetTensorImpl();
58 // See if Address is in the map already
59 if (producer_tensor_map_.count((void*)ten_addr) > 0) {
60 producer_op_pair = producer_tensor_map_[(void*)ten_addr];
61 }
62 }
63 return producer_op_pair;
64 }
65
flattenOpIdList(const c10::List<c10::IValue> & list)66 static std::list<std::pair<at::RecordFunctionHandle, int>> flattenOpIdList(
67 const c10::List<c10::IValue>& list) {
68 std::list<std::pair<at::RecordFunctionHandle, int>> input_op_id_list;
69 auto state_ptr = NVTXThreadLocalState::getTLS();
70 TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
71 for (const c10::IValue& input : list) {
72 if (input.isTensor()) {
73 const at::Tensor& tensor = input.toTensor();
74 auto producer_op_pair = state_ptr->getOpIdFromInput(tensor);
75 input_op_id_list.push_back(producer_op_pair);
76 }
77 }
78 return input_op_id_list;
79 }
80
getInputTensorOpIds(const at::RecordFunction & fn)81 static std::list<std::pair<at::RecordFunctionHandle, int>> getInputTensorOpIds(
82 const at::RecordFunction& fn) {
83 std::pair<at::RecordFunctionHandle, int> undefined_op_pair(0, -1);
84 std::list<std::pair<at::RecordFunctionHandle, int>> input_producer_ops_;
85 auto state_ptr = NVTXThreadLocalState::getTLS();
86 TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
87 for (const c10::IValue& input_item : fn.inputs()) {
88 if (input_item.isTensor()) {
89 const at::Tensor& tensor = input_item.toTensor();
90 auto producer_pair = state_ptr->getOpIdFromInput(tensor);
91 input_producer_ops_.push_back(producer_pair);
92 } else {
93 if (input_item.isList()) {
94 std::list<std::pair<at::RecordFunctionHandle, int>> tmp_op_ids =
95 flattenOpIdList(input_item.toList());
96 // Extend the current sizes array by the array returned from input sizes
97 if (!tmp_op_ids.empty()) {
98 input_producer_ops_.splice(input_producer_ops_.end(), tmp_op_ids);
99 } else {
100 input_producer_ops_.emplace_back(undefined_op_pair);
101 }
102 } else {
103 input_producer_ops_.emplace_back(undefined_op_pair);
104 }
105 }
106 }
107 return input_producer_ops_;
108 }
109
updateOutputTensorTracker(const at::RecordFunction & fn)110 static void updateOutputTensorTracker(const at::RecordFunction& fn) {
111 int output_nr = 0;
112 auto state_ptr = NVTXThreadLocalState::getTLS();
113 TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
114 for (const c10::IValue& s_tensor : fn.outputs()) {
115 if (s_tensor.isTensor()) {
116 const at::Tensor& tensor = s_tensor.toTensor();
117 if (tensor.defined()) {
118 auto ten_addr = tensor.unsafeGetTensorImpl();
119 state_ptr->setProducerTensorMap(ten_addr, fn.handle(), output_nr);
120 }
121 }
122 output_nr++;
123 }
124 }
125
126 template <bool report_input_shapes>
enterNVTX(const at::RecordFunction & fn)127 std::unique_ptr<at::ObserverContext> enterNVTX(const at::RecordFunction& fn) {
128 if (NVTXThreadLocalState::getTLS() != nullptr) {
129 auto input_op_ids = getInputTensorOpIds(fn);
130 torch::profiler::impl::cudaStubs()->rangePush(
131 torch::profiler::impl::getNvtxStr(
132 fn.name(),
133 fn.seqNr(),
134 report_input_shapes ? torch::profiler::impl::inputSizes(fn, true)
135 : std::vector<std::vector<int64_t>>(),
136 fn.handle(),
137 report_input_shapes
138 ? input_op_ids
139 : std::list<std::pair<at::RecordFunctionHandle, int>>())
140 .c_str());
141 }
142 return nullptr;
143 }
144
pushNVTXCallbacks(const ProfilerConfig & config,const std::unordered_set<at::RecordScope> & scopes)145 void pushNVTXCallbacks(
146 const ProfilerConfig& config,
147 const std::unordered_set<at::RecordScope>& scopes) {
148 TORCH_CHECK(
149 torch::profiler::impl::cudaStubs()->enabled(),
150 "Can't use NVTX profiler - PyTorch was compiled without CUDA");
151
152 c10::ThreadLocalDebugInfo::_push(
153 c10::DebugInfoKind::PROFILER_STATE,
154 std::make_shared<NVTXThreadLocalState>(config));
155
156 auto state_ptr = NVTXThreadLocalState::getTLS();
157 TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
158
159 auto handle = at::addThreadLocalCallback(
160 at::RecordFunctionCallback(
161 state_ptr->config().report_input_shapes
162 ? &enterNVTX</*report_input_shapes=*/true>
163 : &enterNVTX</*report_input_shapes=*/false>,
164 [](const at::RecordFunction& fn, at::ObserverContext* ctx) {
165 torch::profiler::impl::cudaStubs()->rangePop();
166 updateOutputTensorTracker(fn);
167 })
168 .needsInputs(config.report_input_shapes)
169 .needsOutputs(config.report_input_shapes)
170 .needsIds(true)
171 .scopes(scopes));
172 state_ptr->setCallbackHandle(handle);
173 }
174
175 } // namespace torch::profiler::impl
176