xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/data_flow.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/profiler/data_flow.h>
2 
3 #include <c10/util/overloaded.h>
4 #include <torch/csrc/profiler/collection.h>
5 
6 namespace torch::profiler::impl {
7 
8 namespace {
9 static constexpr TensorImplAddress NoTensorImpl{nullptr};
10 
11 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
12 struct RawTensorInfo {
13   TensorImplAddress impl_;
14   StorageImplData storage_;
15   c10::Device device_;
16   bool is_free_;
17 
18   // Used to assign back to the original structs.
19   std::reference_wrapper<std::optional<AllocationID>> allocation_id_ref_;
20   std::reference_wrapper<std::optional<TensorID>> id_ref_;
21 };
22 
23 struct RawTensors {
gettorch::profiler::impl::__anon3a83313c0111::RawTensors24   std::vector<RawTensorInfo>& get() {
25     return tensors_;
26   }
27 
operator ()torch::profiler::impl::__anon3a83313c0111::RawTensors28   void operator()(TensorMetadata& t) {
29     tensors_.emplace_back(RawTensorInfo{
30         t.impl(), t.data_, t.device_, false, t.allocation_id_, t.id_});
31   }
32 
operator ()torch::profiler::impl::__anon3a83313c0111::RawTensors33   void operator()(std::optional<TensorMetadata>& t) {
34     if (t.has_value()) {
35       (*this)(*t);
36     }
37   }
38 
operator ()torch::profiler::impl::__anon3a83313c0111::RawTensors39   void operator()(ExtraFields<EventType::Allocation>& a) {
40     const StorageImplData ptr{a.ptr_};
41     const auto is_free = a.alloc_size_ < 0;
42     tensors_.emplace_back(RawTensorInfo{
43         NoTensorImpl, ptr, a.device(), is_free, a.allocation_id_, a.id_});
44   }
45 
operator ()torch::profiler::impl::__anon3a83313c0111::RawTensors46   void operator()(std::vector<TensorMetadata>& t) {
47     for (auto& ti : t) {
48       (*this)(ti);
49     }
50   }
51 
52   template <typename T>
operator ()torch::profiler::impl::__anon3a83313c0111::RawTensors53   void operator()(T&) {}
54 
55   std::vector<RawTensorInfo> tensors_;
56 };
57 } // namespace
58 
calculateUniqueTensorIDs(std::vector<std::shared_ptr<Result>> & sorted_results)59 void calculateUniqueTensorIDs(
60     std::vector<std::shared_ptr<Result>>& sorted_results) {
61   // This task is equivilent to https://leetcode.com/problems/number-of-islands/
62   // We first cluster events with a greedy index assignment, and then merge
63   // groups that overlap.
64   std::vector<RawTensorInfo> tensors;
65 
66   // Flatten results to a uniform representation.
67   // --------------------------------------------------------------------------
68   {
69     RawTensors raw_tensors;
70 
71     // The python tracer caches values, so it's only safe to use the first case.
72     ska::flat_hash_set<PyModuleSelf> seen_modules;
73     ska::flat_hash_set<PyOptimizerSelf> seen_optimizers;
74     for (auto& result : sorted_results) {
75       result->visit(c10::overloaded(
76           [&](ExtraFields<EventType::TorchOp>& torch_op) {
77             for (auto& i : torch_op.inputs_) {
78               std::visit(raw_tensors, i);
79             }
80           },
81           [&](ExtraFields<EventType::PyCall>& py_call) {
82             // torch.nn.Module
83             if (py_call.module_.has_value() &&
84                 seen_modules.insert(py_call.module_->self_).second) {
85               for (auto& p : py_call.module_->parameters_) {
86                 raw_tensors(p.metadata_);
87                 raw_tensors(p.grad_metadata_);
88               }
89             }
90 
91             // torch.optim.Optimizer
92             if (py_call.optimizer_.has_value() &&
93                 seen_optimizers.insert(py_call.optimizer_->self_).second) {
94               for (auto& p : py_call.optimizer_->parameters_) {
95                 raw_tensors(p.metadata_);
96                 raw_tensors(p.grad_metadata_);
97                 for (auto& state_i : p.state_) {
98                   raw_tensors(state_i.second);
99                 }
100               }
101             }
102           },
103           [&](auto& i) { raw_tensors(i); }));
104     }
105     tensors = std::move(raw_tensors.tensors_);
106   }
107 
108   // Assign IDs to solve ABA for Storage.
109   // --------------------------------------------------------------------------
110   {
111     size_t counter{1};
112     using key_t = std::pair<StorageImplData, c10::Device>;
113     ska::flat_hash_map<key_t, size_t, HashCombine> versions;
114     for (auto& t : tensors) {
115       auto inserted = versions.insert({{t.storage_, t.device_}, counter});
116       counter += inserted.second;
117       t.allocation_id_ref_.get().emplace(AllocationID(inserted.first->second));
118       if (t.is_free_) {
119         versions.erase(inserted.first);
120       }
121     }
122   }
123 
124   // Handle any allocation events which we cannot prove are for Tensor storage.
125   // --------------------------------------------------------------------------
126   {
127     ska::flat_hash_set<AllocationID> tensor_set;
128     for (const auto& t : tensors) {
129       if (t.impl_ != NoTensorImpl) {
130         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
131         tensor_set.insert(*t.allocation_id_ref_.get());
132       }
133     }
134     tensors.erase(
135         std::remove_if(
136             tensors.begin(),
137             tensors.end(),
138             [&tensor_set](const auto& i) {
139               auto it = tensor_set.find(*i.allocation_id_ref_.get());
140               return it == tensor_set.end();
141             }),
142         tensors.end());
143   }
144 
145   // Handle the case that the storage of a TensorImpl changed.
146   // --------------------------------------------------------------------------
147   using storage_id_pair_t = std::pair<AllocationID, AllocationID>;
148   ska::flat_hash_set<storage_id_pair_t, HashCombine> same_group_set;
149   {
150     ska::flat_hash_map<TensorImplAddress, AllocationID> impl_map;
151     for (const auto& t : tensors) {
152       // Storage allocations / frees don't have an associated TensorImpl, so
153       // we don't want all storages to merge through nullptr.
154       if (!t.impl_) {
155         continue;
156       }
157 
158       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
159       const auto allocation_id = *t.allocation_id_ref_.get();
160       const auto it = impl_map.insert({t.impl_, allocation_id}).first;
161 
162       // The pair needs to be sorted for the coalesce step to work properly.
163       it->second < allocation_id
164           ? same_group_set.insert({it->second, allocation_id})
165           : same_group_set.insert({allocation_id, it->second});
166     }
167   }
168 
169   // Coalesce groups and assign final IDs.
170   // --------------------------------------------------------------------------
171   ska::flat_hash_map<AllocationID, size_t> id_map;
172   {
173     std::vector<storage_id_pair_t> unique_pairs;
174     for (const auto& i : same_group_set) {
175       unique_pairs.push_back(i);
176     }
177     std::sort(unique_pairs.begin(), unique_pairs.end());
178 
179     size_t current_id{0};
180     for (const auto& i : unique_pairs) {
181       auto inserted = id_map.insert({i.first, current_id});
182       current_id += inserted.second;
183       id_map.insert({i.second, inserted.first->second});
184     }
185   }
186 
187   // Write back to Tensor IDs.
188   // --------------------------------------------------------------------------
189   for (const auto& t : tensors) {
190     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
191     const auto id = id_map.at(*t.allocation_id_ref_.get());
192     t.id_ref_.get().emplace(TensorID(id));
193   }
194 }
195 
196 } // namespace torch::profiler::impl
197