1 #pragma once 2 #include <torch/csrc/profiler/unwind/unwind_error.h> 3 #include <algorithm> 4 #include <memory> 5 #include <optional> 6 #include <vector> 7 8 namespace torch::unwind { 9 template <typename T> 10 struct RangeTable { RangeTableRangeTable11 RangeTable() { 12 // guarentee that lower_bound[-1] is always valid 13 addresses_.push_back(0); 14 payloads_.emplace_back(std::nullopt); 15 } addRangeTable16 void add(uint64_t address, std::optional<T> payload, bool sorted) { 17 if (addresses_.back() > address) { 18 UNWIND_CHECK(!sorted, "expected addresses to be sorted"); 19 sorted_ = false; 20 } 21 addresses_.push_back(address); 22 payloads_.emplace_back(std::move(payload)); 23 } findRangeTable24 std::optional<T> find(uint64_t address) { 25 maybeSort(); 26 auto it = std::upper_bound(addresses_.begin(), addresses_.end(), address); 27 return payloads_.at(it - addresses_.begin() - 1); 28 } dumpRangeTable29 void dump() { 30 for (size_t i = 0; i < addresses_.size(); i++) { 31 fmt::print("{} {:x}: {}\n", i, addresses_[i], payloads_[i] ? "" : "END"); 32 } 33 } sizeRangeTable34 size_t size() const { 35 return addresses_.size(); 36 } backRangeTable37 uint64_t back() { 38 maybeSort(); 39 return addresses_.back(); 40 } 41 42 private: maybeSortRangeTable43 void maybeSort() { 44 if (sorted_) { 45 return; 46 } 47 std::vector<uint64_t> indices; 48 indices.reserve(addresses_.size()); 49 for (size_t i = 0; i < addresses_.size(); i++) { 50 indices.push_back(i); 51 } 52 std::sort(indices.begin(), indices.end(), [&](uint64_t a, uint64_t b) { 53 return addresses_[a] < addresses_[b] || 54 (addresses_[a] == addresses_[b] && 55 bool(payloads_[a]) < bool(payloads_[b])); 56 }); 57 std::vector<uint64_t> addresses; 58 std::vector<std::optional<T>> payloads; 59 addresses.reserve(addresses_.size()); 60 payloads.reserve(addresses_.size()); 61 for (auto i : indices) { 62 addresses.push_back(addresses_[i]); 63 payloads.push_back(payloads_[i]); 64 } 65 addresses_ = std::move(addresses); 66 payloads_ = std::move(payloads); 67 sorted_ = true; 68 } 69 bool sorted_ = true; 70 std::vector<uint64_t> addresses_; 71 std::vector<std::optional<T>> payloads_; 72 }; 73 } // namespace torch::unwind 74