xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/unwind/line_number_program.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/profiler/unwind/debug_info.h>
3 #include <torch/csrc/profiler/unwind/dwarf_enums.h>
4 #include <torch/csrc/profiler/unwind/dwarf_symbolize_enums.h>
5 #include <torch/csrc/profiler/unwind/lexer.h>
6 #include <torch/csrc/profiler/unwind/sections.h>
7 #include <torch/csrc/profiler/unwind/unwind_error.h>
8 #include <tuple>
9 
10 namespace torch::unwind {
11 
12 struct LineNumberProgram {
LineNumberProgramLineNumberProgram13   LineNumberProgram(Sections& s, uint64_t offset) : s_(s), offset_(offset) {}
14 
offsetLineNumberProgram15   uint64_t offset() {
16     return offset_;
17   }
parseLineNumberProgram18   void parse() {
19     if (parsed_) {
20       return;
21     }
22     parsed_ = true;
23     CheckedLexer L = s_.debug_line.lexer(offset_);
24     std::tie(length_, is_64bit_) = L.readSectionLength();
25     program_end_ = (char*)L.loc() + length_;
26     auto version = L.read<uint16_t>();
27     UNWIND_CHECK(
28         version == 5 || version == 4,
29         "expected version 4 or 5 but found {}",
30         version);
31     if (version == 5) {
32       auto address_size = L.read<uint8_t>();
33       UNWIND_CHECK(
34           address_size == 8,
35           "expected 64-bit dwarf but found address size {}",
36           address_size);
37       segment_selector_size_ = L.read<uint8_t>();
38     }
39     header_length_ = is_64bit_ ? L.read<uint64_t>() : L.read<uint32_t>();
40     program_ = L;
41     program_.skip(int64_t(header_length_));
42     minimum_instruction_length_ = L.read<uint8_t>();
43     maximum_operations_per_instruction_ = L.read<uint8_t>();
44     default_is_stmt_ = L.read<uint8_t>();
45     line_base_ = L.read<int8_t>();
46     line_range_ = L.read<uint8_t>();
47     opcode_base_ = L.read<uint8_t>();
48     UNWIND_CHECK(line_range_ != 0, "line_range_ must be non-zero");
49     standard_opcode_lengths_.resize(opcode_base_);
50     for (size_t i = 1; i < opcode_base_; i++) {
51       standard_opcode_lengths_[i] = L.read<uint8_t>();
52     }
53     // fmt::print("{:x} {:x} {} {} {} {} {}\n", offset_, header_length_,
54     // minimum_instruction_length_, maximum_operations_per_instruction_,
55     // line_base_, line_range_, opcode_base_);
56     uint8_t directory_entry_format_count = L.read<uint8_t>();
57 
58     if (version == 5) {
59       struct Member {
60         uint64_t content_type;
61         uint64_t form;
62       };
63       std::vector<Member> directory_members;
64       for (size_t i = 0; i < directory_entry_format_count; i++) {
65         directory_members.push_back({L.readULEB128(), L.readULEB128()});
66       }
67       uint64_t directories_count = L.readULEB128();
68       for (size_t i = 0; i < directories_count; i++) {
69         for (auto& member : directory_members) {
70           switch (member.content_type) {
71             case DW_LNCT_path: {
72               include_directories_.emplace_back(
73                   s_.readString(L, member.form, is_64bit_));
74             } break;
75             default: {
76               skipForm(L, member.form);
77             } break;
78           }
79         }
80       }
81 
82       for (auto i : c10::irange(directories_count)) {
83         (void)i;
84         LOG_INFO("{} {}\n", i, include_directories_[i]);
85       }
86       auto file_name_entry_format_count = L.read<uint8_t>();
87       std::vector<Member> file_members;
88       for (size_t i = 0; i < file_name_entry_format_count; i++) {
89         file_members.push_back({L.readULEB128(), L.readULEB128()});
90       }
91       auto files_count = L.readULEB128();
92       for (size_t i = 0; i < files_count; i++) {
93         for (auto& member : file_members) {
94           switch (member.content_type) {
95             case DW_LNCT_path: {
96               file_names_.emplace_back(
97                   s_.readString(L, member.form, is_64bit_));
98             } break;
99             case DW_LNCT_directory_index: {
100               file_directory_index_.emplace_back(readData(L, member.form));
101               UNWIND_CHECK(
102                   file_directory_index_.back() < include_directories_.size(),
103                   "directory index out of range");
104             } break;
105             default: {
106               skipForm(L, member.form);
107             } break;
108           }
109         }
110       }
111       for (auto i : c10::irange(files_count)) {
112         (void)i;
113         LOG_INFO("{} {} {}\n", i, file_names_[i], file_directory_index_[i]);
114       }
115     } else {
116       include_directories_.emplace_back(""); // implicit cwd
117       while (true) {
118         auto str = L.readCString();
119         if (*str == '\0') {
120           break;
121         }
122         include_directories_.emplace_back(str);
123       }
124       file_names_.emplace_back("");
125       file_directory_index_.emplace_back(0);
126       while (true) {
127         auto str = L.readCString();
128         if (*str == '\0') {
129           break;
130         }
131         auto directory_index = L.readULEB128();
132         L.readULEB128(); // mod_time
133         L.readULEB128(); // file_length
134         file_names_.emplace_back(str);
135         file_directory_index_.push_back(directory_index);
136       }
137     }
138     UNWIND_CHECK(
139         maximum_operations_per_instruction_ == 1,
140         "maximum_operations_per_instruction_ must be 1");
141     UNWIND_CHECK(
142         minimum_instruction_length_ == 1,
143         "minimum_instruction_length_ must be 1");
144     readProgram();
145   }
146   struct Entry {
147     uint32_t file = 1;
148     int64_t line = 1;
149   };
findLineNumberProgram150   std::optional<Entry> find(uint64_t address) {
151     auto e = program_index_.find(address);
152     if (!e) {
153       return std::nullopt;
154     }
155     return all_programs_.at(*e).find(address);
156   }
filenameLineNumberProgram157   std::string filename(uint64_t index) {
158     return fmt::format(
159         "{}/{}",
160         include_directories_.at(file_directory_index_.at(index)),
161         file_names_.at(index));
162   }
163 
164  private:
skipFormLineNumberProgram165   void skipForm(CheckedLexer& L, uint64_t form) {
166     auto sz = formSize(form, is_64bit_ ? 8 : 4);
167     UNWIND_CHECK(sz, "unsupported form {}", form);
168     L.skip(int64_t(*sz));
169   }
170 
readDataLineNumberProgram171   uint64_t readData(CheckedLexer& L, uint64_t encoding) {
172     switch (encoding) {
173       case DW_FORM_data1:
174         return L.read<uint8_t>();
175       case DW_FORM_data2:
176         return L.read<uint16_t>();
177       case DW_FORM_data4:
178         return L.read<uint32_t>();
179       case DW_FORM_data8:
180         return L.read<uint64_t>();
181       case DW_FORM_udata:
182         return L.readULEB128();
183       default:
184         UNWIND_CHECK(false, "unsupported data encoding {}", encoding);
185     }
186   }
187 
produceEntryLineNumberProgram188   void produceEntry() {
189     if (shadow_) {
190       return;
191     }
192     if (ranges_.size() == 1) {
193       start_address_ = address_;
194     }
195     PRINT_LINE_TABLE(
196         "{:x}\t{}\t{}\n", address_, filename(entry_.file), entry_.line);
197     UNWIND_CHECK(
198         entry_.file < file_names_.size(),
199         "file index {} > {} entries",
200         entry_.file,
201         file_names_.size());
202     ranges_.add(address_, entry_, true);
203   }
endSequenceLineNumberProgram204   void endSequence() {
205     if (shadow_) {
206       return;
207     }
208     PRINT_LINE_TABLE(
209         "{:x}\tEND\n", address_, filename(entry_.file), entry_.line);
210     program_index_.add(start_address_, all_programs_.size(), false);
211     program_index_.add(address_, std::nullopt, false);
212     all_programs_.emplace_back(std::move(ranges_));
213     ranges_ = RangeTable<Entry>();
214   }
readProgramLineNumberProgram215   void readProgram() {
216     while (program_.loc() < program_end_) {
217       PRINT_INST("{:x}: ", (char*)program_.loc() - (s_.debug_line.data));
218       uint8_t op = program_.read<uint8_t>();
219       if (op >= opcode_base_) {
220         auto op2 = int64_t(op - opcode_base_);
221         address_ += op2 / line_range_;
222         entry_.line += line_base_ + (op2 % line_range_);
223         PRINT_INST(
224             "address += {}, line += {}\n",
225             op2 / line_range_,
226             line_base_ + (op2 % line_range_));
227         produceEntry();
228       } else {
229         switch (op) {
230           case DW_LNS_extended_op: {
231             auto len = program_.readULEB128();
232             auto extended_op = program_.read<uint8_t>();
233             switch (extended_op) {
234               case DW_LNE_end_sequence: {
235                 PRINT_INST("end_sequence\n");
236                 endSequence();
237                 entry_ = Entry{};
238               } break;
239               case DW_LNE_set_address: {
240                 address_ = program_.read<uint64_t>();
241                 if (!shadow_) {
242                   PRINT_INST(
243                       "set address {:x} {:x} {:x}\n",
244                       address_,
245                       min_address_,
246                       max_address_);
247                 }
248                 shadow_ = address_ == 0;
249               } break;
250               default: {
251                 PRINT_INST("skip extended op {}\n", extended_op);
252                 program_.skip(int64_t(len - 1));
253               } break;
254             }
255           } break;
256           case DW_LNS_copy: {
257             PRINT_INST("copy\n");
258             produceEntry();
259           } break;
260           case DW_LNS_advance_pc: {
261             PRINT_INST("advance pc\n");
262             address_ += program_.readULEB128();
263           } break;
264           case DW_LNS_advance_line: {
265             entry_.line += program_.readSLEB128();
266             PRINT_INST("advance line {}\n", entry_.line);
267 
268           } break;
269           case DW_LNS_set_file: {
270             PRINT_INST("set file\n");
271             entry_.file = program_.readULEB128();
272           } break;
273           case DW_LNS_const_add_pc: {
274             PRINT_INST("const add pc\n");
275             address_ += (255 - opcode_base_) / line_range_;
276           } break;
277           case DW_LNS_fixed_advance_pc: {
278             PRINT_INST("fixed advance pc\n");
279             address_ += program_.read<uint16_t>();
280           } break;
281           default: {
282             PRINT_INST("other {}\n", op);
283             auto n = standard_opcode_lengths_[op];
284             for (int i = 0; i < n; ++i) {
285               program_.readULEB128();
286             }
287           } break;
288         }
289       }
290     }
291     PRINT_INST(
292         "{:x}: end {:x}\n",
293         ((char*)program_.loc() - s_.debug_line.data),
294         program_end_ - s_.debug_line.data);
295   }
296 
297   uint64_t address_ = 0;
298   bool shadow_ = false;
299   bool parsed_ = false;
300   Entry entry_ = {};
301   std::vector<std::string> include_directories_;
302   std::vector<std::string> file_names_;
303   std::vector<uint64_t> file_directory_index_;
304   uint8_t segment_selector_size_ = 0;
305   uint8_t minimum_instruction_length_ = 0;
306   uint8_t maximum_operations_per_instruction_ = 0;
307   int8_t line_base_ = 0;
308   uint8_t line_range_ = 0;
309   uint8_t opcode_base_ = 0;
310   bool default_is_stmt_ = false;
311   CheckedLexer program_ = {nullptr};
312   char* program_end_ = nullptr;
313   uint64_t header_length_ = 0;
314   uint64_t length_ = 0;
315   bool is_64bit_ = false;
316   std::vector<uint8_t> standard_opcode_lengths_;
317   Sections& s_;
318   uint64_t offset_;
319   uint64_t start_address_ = 0;
320   RangeTable<uint64_t> program_index_;
321   std::vector<RangeTable<Entry>> all_programs_;
322   RangeTable<Entry> ranges_;
323 };
324 
325 } // namespace torch::unwind
326