xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/unwind/range_table.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/profiler/unwind/unwind_error.h>
3 #include <algorithm>
4 #include <memory>
5 #include <optional>
6 #include <vector>
7 
8 namespace torch::unwind {
9 template <typename T>
10 struct RangeTable {
RangeTableRangeTable11   RangeTable() {
12     // guarentee that lower_bound[-1] is always valid
13     addresses_.push_back(0);
14     payloads_.emplace_back(std::nullopt);
15   }
addRangeTable16   void add(uint64_t address, std::optional<T> payload, bool sorted) {
17     if (addresses_.back() > address) {
18       UNWIND_CHECK(!sorted, "expected addresses to be sorted");
19       sorted_ = false;
20     }
21     addresses_.push_back(address);
22     payloads_.emplace_back(std::move(payload));
23   }
findRangeTable24   std::optional<T> find(uint64_t address) {
25     maybeSort();
26     auto it = std::upper_bound(addresses_.begin(), addresses_.end(), address);
27     return payloads_.at(it - addresses_.begin() - 1);
28   }
dumpRangeTable29   void dump() {
30     for (size_t i = 0; i < addresses_.size(); i++) {
31       fmt::print("{} {:x}: {}\n", i, addresses_[i], payloads_[i] ? "" : "END");
32     }
33   }
sizeRangeTable34   size_t size() const {
35     return addresses_.size();
36   }
backRangeTable37   uint64_t back() {
38     maybeSort();
39     return addresses_.back();
40   }
41 
42  private:
maybeSortRangeTable43   void maybeSort() {
44     if (sorted_) {
45       return;
46     }
47     std::vector<uint64_t> indices;
48     indices.reserve(addresses_.size());
49     for (size_t i = 0; i < addresses_.size(); i++) {
50       indices.push_back(i);
51     }
52     std::sort(indices.begin(), indices.end(), [&](uint64_t a, uint64_t b) {
53       return addresses_[a] < addresses_[b] ||
54           (addresses_[a] == addresses_[b] &&
55            bool(payloads_[a]) < bool(payloads_[b]));
56     });
57     std::vector<uint64_t> addresses;
58     std::vector<std::optional<T>> payloads;
59     addresses.reserve(addresses_.size());
60     payloads.reserve(addresses_.size());
61     for (auto i : indices) {
62       addresses.push_back(addresses_[i]);
63       payloads.push_back(payloads_[i]);
64     }
65     addresses_ = std::move(addresses);
66     payloads_ = std::move(payloads);
67     sorted_ = true;
68   }
69   bool sorted_ = true;
70   std::vector<uint64_t> addresses_;
71   std::vector<std::optional<T>> payloads_;
72 };
73 } // namespace torch::unwind
74