xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/debug_info.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/util/flat_hash_map.h>
3 #include <caffe2/serialize/inline_container.h>
4 #include <torch/csrc/jit/api/compilation_unit.h>
5 #include <torch/csrc/jit/ir/scope.h>
6 #include <torch/csrc/jit/serialization/source_range_serialization.h>
7 
8 namespace torch::jit {
9 /*
10  * MobileDebugTable:
11  * Deserializes debug_pkl and callstack_map records from PT model's zip archive
12  * and stores them in a map of debug handles to DebugInfoPair. Debug handles are
13  * unique per model and runtime, be in lite interpreter or delegate, an
14  * exception of BackendRuntimeException should raised using debug handles.
15  * getSourceDebugString method is responsible for translating debug
16  * handles to correspond debug information.
17  * This debug informatin includes stack trace of model level source code and
18  * module hierarchy where the exception occurred.
19  */
20 class MobileDebugTable {
21  public:
22   MobileDebugTable() = default;
23   MobileDebugTable(
24       std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader,
25       const std::shared_ptr<CompilationUnit>& cu);
26 
27   template <typename It>
MobileDebugTable(It begin,It end)28   MobileDebugTable(It begin, It end) : callstack_ptr_map_(begin, end) {}
29 
30   std::string getSourceDebugString(
31       const int64_t debug_handle,
32       const std::string& top_module_type_name = "ModuleTypeUnknown") const;
33   std::string getSourceDebugString(
34       const std::vector<int64_t>& debug_handles,
35       const std::string& top_module_type_name = "ModuleTypeUnknown") const;
36   std::string getModuleHierarchyInfo(
37       const int64_t debug_handle,
38       const std::string& top_module_type_name = "ModuleTypeUnknown") const;
39   std::string getModuleHierarchyInfo(
40       const std::vector<int64_t>& debug_handles,
41       const std::string& top_module_type_name = "ModuleTypeUnknown") const;
42 
getCallStackPtrMap()43   const ska::flat_hash_map<int64_t, DebugInfoTuple>& getCallStackPtrMap()
44       const {
45     return callstack_ptr_map_;
46   }
47 
48  private:
49   std::pair<std::string, std::string> getSourceDebugModuleHierarchyInfo(
50       const std::vector<int64_t>& debug_handles,
51       const std::string& top_module_type_name = "ModuleTypeUnknown") const;
52   ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptr_map_;
53 };
54 
55 } // namespace torch::jit
56