xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/api/compilation_unit.h>
2 #include <torch/csrc/jit/mobile/type_parser.h>
3 #include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
4 #include <torch/csrc/jit/serialization/pickle.h>
5 
6 namespace torch::jit {
7 
8 namespace {
9 constexpr int64_t kInvalidSourceRangeTag = -1;
10 } // namespace
11 
serialize(const InlinedCallStackPtr & cs_ptr,const SourceRangeTagMap & source_range_tags)12 c10::IValue InlinedCallStackSerializer::serialize(
13     const InlinedCallStackPtr& cs_ptr,
14     const SourceRangeTagMap& source_range_tags) {
15   if (!cs_ptr) {
16     return c10::IValue();
17   }
18   auto cs_it = serialized_inlined_callstack_.find(cs_ptr);
19   if (cs_it != serialized_inlined_callstack_.end()) {
20     return cs_it->second;
21   }
22   // Inlined callstack pointer is serialized as tuple of 4 elements
23   // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack),
24   // function name} Note function name is serialized separately because Function
25   // is only in memory structure. It gets constructed by JIT from serialized
26   // Code at runtime. As such even InlinedCallStack get constructed by JIT at
27   // runtime during graph inlining. However, we introduce
28   // serialization/deserialization of it in order to generate callstack debug
29   // information, _when_ equivalent InlinedCallStack cannot be constructed at
30   // runtime. For example, in lite interpreter or delegated backend.
31   std::vector<c10::IValue> elements;
32   elements.reserve(4);
33   elements.emplace_back(
34       serialize_module_instance_info(cs_ptr->module_instance()));
35   int64_t source_range_tag{kInvalidSourceRangeTag};
36   const SourceRange& sr = cs_ptr->source_range().findSourceRangeThatGenerated()
37       ? cs_ptr->source_range().findSourceRangeThatGenerated().value()
38       : cs_ptr->source_range();
39   auto sr_it = source_range_tags.find(sr);
40   if (sr_it != source_range_tags.end()) {
41     source_range_tag = sr_it->second;
42   }
43   elements.emplace_back(source_range_tag);
44   if (cs_ptr->callee()) {
45     elements.emplace_back(
46         serialize(cs_ptr->callee().value(), source_range_tags));
47   } else {
48     elements.emplace_back();
49   }
50   auto fn_name = cs_ptr->function_name();
51   if (!fn_name.empty()) {
52     elements.emplace_back(fn_name);
53   } else {
54     elements.emplace_back("FunctionName_UNKNOWN");
55   }
56   c10::IValue serialized_cs = c10::ivalue::Tuple::create(elements);
57   serialized_inlined_callstack_[cs_ptr] = serialized_cs;
58   return serialized_cs;
59 }
60 
serialize_module_instance_info(const std::optional<ModuleInstanceInfo> & m)61 c10::IValue InlinedCallStackSerializer::serialize_module_instance_info(
62     const std::optional<ModuleInstanceInfo>& m) {
63   if (!m) {
64     return c10::IValue();
65   }
66   const auto& m_val = m.value();
67   std::string module_type_name = m_val.class_type()->name()->qualifiedName();
68   auto module_instance_name = m_val.instance_name();
69   if (m_val.class_type()) {
70     module_type_name = m_val.class_type()->name()->qualifiedName();
71   }
72   auto key_val = module_type_name + module_instance_name;
73   auto m_inst_it = serialized_module_instance_info_.find(key_val);
74   if (m_inst_it != serialized_module_instance_info_.end()) {
75     return m_inst_it->second;
76   }
77   // Module instance info is serialized as
78   // {type name, instance name}
79   serialized_module_instance_info_[key_val] =
80       c10::ivalue::Tuple::create({module_type_name, module_instance_name});
81   return serialized_module_instance_info_[key_val];
82 }
83 
pickle(const std::unordered_map<int64_t,DebugInfoTuple> & callstack_ptrs,const SourceRangeTagMap & source_range_tags)84 std::vector<char> CallStackDebugInfoPickler::pickle(
85     const std::unordered_map<int64_t, DebugInfoTuple>& callstack_ptrs,
86     const SourceRangeTagMap& source_range_tags) {
87   std::vector<c10::IValue> ivalues;
88   for (const auto& it : callstack_ptrs) {
89     int64_t debug_handle = it.first;
90     std::vector<c10::IValue> elements;
91     /*
92      * Debug handles and debug info (source range + inlinded callstack)
93      * are serialized as a tuple of 3 elements
94      * {debug_handle, source_range_tag, serialized_callstack}
95      */
96     elements.reserve(4);
97     elements.emplace_back(debug_handle);
98     int64_t source_range_tag{kInvalidSourceRangeTag};
99     const auto& source_range =
100         std::get<kDebugInfoTupleSourceRangeIndex>(it.second);
101     const SourceRange& sr = source_range.findSourceRangeThatGenerated()
102         ? source_range.findSourceRangeThatGenerated().value()
103         : source_range;
104     auto sr_it = source_range_tags.find(sr);
105     if (sr_it != source_range_tags.end()) {
106       source_range_tag = sr_it->second;
107     }
108     elements.emplace_back(source_range_tag);
109     elements.emplace_back(std::get<kDebugInfoTupleNodeNameIndex>(it.second));
110     const auto& inlined_cs_ptr =
111         std::get<kDebugInfoTupleInlinedCSIndex>(it.second);
112     elements.emplace_back(css_.serialize(inlined_cs_ptr, source_range_tags));
113     ivalues.emplace_back(c10::ivalue::Tuple::create(elements));
114   }
115   std::vector<at::Tensor> table;
116   c10::IValue ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
117   auto result = jit::pickle(ivalue, &table);
118   TORCH_CHECK(table.empty(), "Expected 0 tensors to be written");
119   return result;
120 }
121 
deserialize(const c10::IValue & iv,const ska::flat_hash_map<int64_t,SourceRange> & source_range_map,const std::shared_ptr<CompilationUnit> & cu)122 InlinedCallStackPtr InlinedCallStackDeserializer::deserialize(
123     const c10::IValue& iv,
124     const ska::flat_hash_map<int64_t, SourceRange>& source_range_map,
125     const std::shared_ptr<CompilationUnit>& cu) {
126   if (iv.isNone()) {
127     return c10::intrusive_ptr<InlinedCallStack>();
128   }
129   auto tup = iv.toTuple();
130   auto it = cached_inlined_callstacks_.find(tup);
131   if (it != cached_inlined_callstacks_.end()) {
132     return it->second;
133   }
134 
135   const auto& tup_elems = tup->elements();
136   TORCH_INTERNAL_ASSERT(tup_elems.size() == 4);
137   // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack),
138   // function name}
139   auto module_instance_info =
140       deserialize_module_instance_info(tup_elems[0], cu);
141   int64_t source_range_tag = tup_elems[1].toInt();
142   auto source_range_it = source_range_map.find(source_range_tag);
143   TORCH_CHECK(
144       source_range_tag == kInvalidSourceRangeTag ||
145           source_range_it != source_range_map.end(),
146       "Source range tag must exist in deserialized source range map."
147       " Not found source range tag:",
148       source_range_tag);
149   SourceRange source_range;
150   if (source_range_tag != kInvalidSourceRangeTag) {
151     source_range = source_range_it->second;
152   }
153   auto callee = deserialize(tup_elems[2], source_range_map, cu);
154   auto function_name = tup_elems[3].toStringRef();
155   InlinedCallStackPtr cs_ptr;
156   if (callee) {
157     cs_ptr = c10::make_intrusive<InlinedCallStack>(
158         callee, nullptr, source_range, module_instance_info, function_name);
159   } else {
160     cs_ptr = c10::make_intrusive<InlinedCallStack>(
161         nullptr, source_range, module_instance_info, function_name);
162   }
163   cached_inlined_callstacks_[tup] = cs_ptr;
164   // Invoking move constructor
165   // It is not clear if copy-ellision can happen since
166   // cs_ptr is copied into map above.
167   // This is to help avoid ref count update
168   return cs_ptr;
169 }
170 
171 std::optional<ModuleInstanceInfo> InlinedCallStackDeserializer::
deserialize_module_instance_info(const c10::IValue & iv,const std::shared_ptr<CompilationUnit> & cu)172     deserialize_module_instance_info(
173         const c10::IValue& iv,
174         const std::shared_ptr<CompilationUnit>& cu) {
175   if (iv.isNone()) {
176     return std::nullopt;
177   }
178   auto tup = iv.toTuple();
179   auto it = cached_module_instance_info_.find(tup);
180   if (it != cached_module_instance_info_.end()) {
181     return it->second;
182   }
183   const auto& tup_elems = iv.toTupleRef().elements();
184   TORCH_CHECK(tup_elems.size() == 2);
185   std::string type_name = tup_elems[0].toStringRef();
186   std::string instance_name = tup_elems[1].toStringRef();
187   // type_name might be empty string ""
188   // In that case type_ptr should be just nullptr
189   auto type_ptr = cu->get_class(type_name);
190   if (!type_ptr) {
191     // We may have lost type information. For example in lowered backends
192     // original class type has no relevance.
193     // However, to correlate ops to their original modules
194     // we saved both type name and instance name.
195     // In such cases, when module is absorbed by lowered backend
196     // we augment instance name with type name instead of losing it.
197     auto last_dot_position = type_name.find_last_of('.');
198     size_t substring_pos{0};
199     if (last_dot_position != std::string::npos) {
200       substring_pos = last_dot_position + 1;
201     }
202     type_name = type_name.substr(substring_pos);
203     instance_name = instance_name + "(" + type_name + ")";
204   }
205   cached_module_instance_info_[tup] =
206       ModuleInstanceInfo(type_ptr, instance_name);
207   return cached_module_instance_info_[tup];
208 }
209 
210 ska::flat_hash_map<int64_t, DebugInfoTuple> CallStackDebugInfoUnpickler::
unpickle(const at::DataPtr & data,size_t size,const ska::flat_hash_map<int64_t,SourceRange> & source_range_map,const std::shared_ptr<CompilationUnit> & cu)211     unpickle(
212         const at::DataPtr& data,
213         size_t size,
214         const ska::flat_hash_map<int64_t, SourceRange>& source_range_map,
215         const std::shared_ptr<CompilationUnit>& cu) {
216   auto ival = jit::unpickle(
217       reinterpret_cast<const char*>(data.get()),
218       size,
219       nullptr,
220       {},
221       c10::parseType);
222   ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptrs;
223   const auto& ivalues = ival.toTupleRef().elements();
224   for (auto& val : ivalues) {
225     const auto& tup_elems = val.toTupleRef().elements();
226     TORCH_CHECK(
227         tup_elems.size() == 4,
228         "Pickled map must have four elements: "
229         "debug_handle, source_range_tag, op name, IValue(inlined_call_stack)");
230     int64_t debug_handle = tup_elems[0].toInt();
231     int64_t source_range_tag = tup_elems[1].toInt();
232     const std::string& node_name = tup_elems[2].toStringRef();
233     auto source_range_it = source_range_map.find(source_range_tag);
234     TORCH_CHECK(
235         source_range_it != source_range_map.end(),
236         "Source range tag must exist in deserialized source range map.");
237     auto source_range = source_range_it->second;
238     TORCH_CHECK(
239         callstack_ptrs.count(debug_handle) == 0,
240         "Debug handles should be unique.");
241     callstack_ptrs[debug_handle] = std::make_tuple(
242         source_range,
243         node_name,
244         csds_.deserialize(tup_elems[3], source_range_map, cu));
245   }
246   return callstack_ptrs;
247 }
248 
249 } // namespace torch::jit
250