1 #include <torch/csrc/profiler/data_flow.h>
2
3 #include <c10/util/overloaded.h>
4 #include <torch/csrc/profiler/collection.h>
5
6 namespace torch::profiler::impl {
7
8 namespace {
9 static constexpr TensorImplAddress NoTensorImpl{nullptr};
10
11 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
12 struct RawTensorInfo {
13 TensorImplAddress impl_;
14 StorageImplData storage_;
15 c10::Device device_;
16 bool is_free_;
17
18 // Used to assign back to the original structs.
19 std::reference_wrapper<std::optional<AllocationID>> allocation_id_ref_;
20 std::reference_wrapper<std::optional<TensorID>> id_ref_;
21 };
22
23 struct RawTensors {
gettorch::profiler::impl::__anon3a83313c0111::RawTensors24 std::vector<RawTensorInfo>& get() {
25 return tensors_;
26 }
27
operator ()torch::profiler::impl::__anon3a83313c0111::RawTensors28 void operator()(TensorMetadata& t) {
29 tensors_.emplace_back(RawTensorInfo{
30 t.impl(), t.data_, t.device_, false, t.allocation_id_, t.id_});
31 }
32
operator ()torch::profiler::impl::__anon3a83313c0111::RawTensors33 void operator()(std::optional<TensorMetadata>& t) {
34 if (t.has_value()) {
35 (*this)(*t);
36 }
37 }
38
operator ()torch::profiler::impl::__anon3a83313c0111::RawTensors39 void operator()(ExtraFields<EventType::Allocation>& a) {
40 const StorageImplData ptr{a.ptr_};
41 const auto is_free = a.alloc_size_ < 0;
42 tensors_.emplace_back(RawTensorInfo{
43 NoTensorImpl, ptr, a.device(), is_free, a.allocation_id_, a.id_});
44 }
45
operator ()torch::profiler::impl::__anon3a83313c0111::RawTensors46 void operator()(std::vector<TensorMetadata>& t) {
47 for (auto& ti : t) {
48 (*this)(ti);
49 }
50 }
51
52 template <typename T>
operator ()torch::profiler::impl::__anon3a83313c0111::RawTensors53 void operator()(T&) {}
54
55 std::vector<RawTensorInfo> tensors_;
56 };
57 } // namespace
58
calculateUniqueTensorIDs(std::vector<std::shared_ptr<Result>> & sorted_results)59 void calculateUniqueTensorIDs(
60 std::vector<std::shared_ptr<Result>>& sorted_results) {
61 // This task is equivilent to https://leetcode.com/problems/number-of-islands/
62 // We first cluster events with a greedy index assignment, and then merge
63 // groups that overlap.
64 std::vector<RawTensorInfo> tensors;
65
66 // Flatten results to a uniform representation.
67 // --------------------------------------------------------------------------
68 {
69 RawTensors raw_tensors;
70
71 // The python tracer caches values, so it's only safe to use the first case.
72 ska::flat_hash_set<PyModuleSelf> seen_modules;
73 ska::flat_hash_set<PyOptimizerSelf> seen_optimizers;
74 for (auto& result : sorted_results) {
75 result->visit(c10::overloaded(
76 [&](ExtraFields<EventType::TorchOp>& torch_op) {
77 for (auto& i : torch_op.inputs_) {
78 std::visit(raw_tensors, i);
79 }
80 },
81 [&](ExtraFields<EventType::PyCall>& py_call) {
82 // torch.nn.Module
83 if (py_call.module_.has_value() &&
84 seen_modules.insert(py_call.module_->self_).second) {
85 for (auto& p : py_call.module_->parameters_) {
86 raw_tensors(p.metadata_);
87 raw_tensors(p.grad_metadata_);
88 }
89 }
90
91 // torch.optim.Optimizer
92 if (py_call.optimizer_.has_value() &&
93 seen_optimizers.insert(py_call.optimizer_->self_).second) {
94 for (auto& p : py_call.optimizer_->parameters_) {
95 raw_tensors(p.metadata_);
96 raw_tensors(p.grad_metadata_);
97 for (auto& state_i : p.state_) {
98 raw_tensors(state_i.second);
99 }
100 }
101 }
102 },
103 [&](auto& i) { raw_tensors(i); }));
104 }
105 tensors = std::move(raw_tensors.tensors_);
106 }
107
108 // Assign IDs to solve ABA for Storage.
109 // --------------------------------------------------------------------------
110 {
111 size_t counter{1};
112 using key_t = std::pair<StorageImplData, c10::Device>;
113 ska::flat_hash_map<key_t, size_t, HashCombine> versions;
114 for (auto& t : tensors) {
115 auto inserted = versions.insert({{t.storage_, t.device_}, counter});
116 counter += inserted.second;
117 t.allocation_id_ref_.get().emplace(AllocationID(inserted.first->second));
118 if (t.is_free_) {
119 versions.erase(inserted.first);
120 }
121 }
122 }
123
124 // Handle any allocation events which we cannot prove are for Tensor storage.
125 // --------------------------------------------------------------------------
126 {
127 ska::flat_hash_set<AllocationID> tensor_set;
128 for (const auto& t : tensors) {
129 if (t.impl_ != NoTensorImpl) {
130 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
131 tensor_set.insert(*t.allocation_id_ref_.get());
132 }
133 }
134 tensors.erase(
135 std::remove_if(
136 tensors.begin(),
137 tensors.end(),
138 [&tensor_set](const auto& i) {
139 auto it = tensor_set.find(*i.allocation_id_ref_.get());
140 return it == tensor_set.end();
141 }),
142 tensors.end());
143 }
144
145 // Handle the case that the storage of a TensorImpl changed.
146 // --------------------------------------------------------------------------
147 using storage_id_pair_t = std::pair<AllocationID, AllocationID>;
148 ska::flat_hash_set<storage_id_pair_t, HashCombine> same_group_set;
149 {
150 ska::flat_hash_map<TensorImplAddress, AllocationID> impl_map;
151 for (const auto& t : tensors) {
152 // Storage allocations / frees don't have an associated TensorImpl, so
153 // we don't want all storages to merge through nullptr.
154 if (!t.impl_) {
155 continue;
156 }
157
158 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
159 const auto allocation_id = *t.allocation_id_ref_.get();
160 const auto it = impl_map.insert({t.impl_, allocation_id}).first;
161
162 // The pair needs to be sorted for the coalesce step to work properly.
163 it->second < allocation_id
164 ? same_group_set.insert({it->second, allocation_id})
165 : same_group_set.insert({allocation_id, it->second});
166 }
167 }
168
169 // Coalesce groups and assign final IDs.
170 // --------------------------------------------------------------------------
171 ska::flat_hash_map<AllocationID, size_t> id_map;
172 {
173 std::vector<storage_id_pair_t> unique_pairs;
174 for (const auto& i : same_group_set) {
175 unique_pairs.push_back(i);
176 }
177 std::sort(unique_pairs.begin(), unique_pairs.end());
178
179 size_t current_id{0};
180 for (const auto& i : unique_pairs) {
181 auto inserted = id_map.insert({i.first, current_id});
182 current_id += inserted.second;
183 id_map.insert({i.second, inserted.first->second});
184 }
185 }
186
187 // Write back to Tensor IDs.
188 // --------------------------------------------------------------------------
189 for (const auto& t : tensors) {
190 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
191 const auto id = id_map.at(*t.allocation_id_ref_.get());
192 t.id_ref_.get().emplace(TensorID(id));
193 }
194 }
195
196 } // namespace torch::profiler::impl
197