1 #include <torch/csrc/jit/api/compilation_unit.h>
2 #include <torch/csrc/jit/mobile/type_parser.h>
3 #include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
4 #include <torch/csrc/jit/serialization/pickle.h>
5
6 namespace torch::jit {
7
8 namespace {
9 constexpr int64_t kInvalidSourceRangeTag = -1;
10 } // namespace
11
serialize(const InlinedCallStackPtr & cs_ptr,const SourceRangeTagMap & source_range_tags)12 c10::IValue InlinedCallStackSerializer::serialize(
13 const InlinedCallStackPtr& cs_ptr,
14 const SourceRangeTagMap& source_range_tags) {
15 if (!cs_ptr) {
16 return c10::IValue();
17 }
18 auto cs_it = serialized_inlined_callstack_.find(cs_ptr);
19 if (cs_it != serialized_inlined_callstack_.end()) {
20 return cs_it->second;
21 }
22 // Inlined callstack pointer is serialized as tuple of 4 elements
23 // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack),
24 // function name} Note function name is serialized separately because Function
25 // is only in memory structure. It gets constructed by JIT from serialized
26 // Code at runtime. As such even InlinedCallStack get constructed by JIT at
27 // runtime during graph inlining. However, we introduce
28 // serialization/deserialization of it in order to generate callstack debug
29 // information, _when_ equivalent InlinedCallStack cannot be constructed at
30 // runtime. For example, in lite interpreter or delegated backend.
31 std::vector<c10::IValue> elements;
32 elements.reserve(4);
33 elements.emplace_back(
34 serialize_module_instance_info(cs_ptr->module_instance()));
35 int64_t source_range_tag{kInvalidSourceRangeTag};
36 const SourceRange& sr = cs_ptr->source_range().findSourceRangeThatGenerated()
37 ? cs_ptr->source_range().findSourceRangeThatGenerated().value()
38 : cs_ptr->source_range();
39 auto sr_it = source_range_tags.find(sr);
40 if (sr_it != source_range_tags.end()) {
41 source_range_tag = sr_it->second;
42 }
43 elements.emplace_back(source_range_tag);
44 if (cs_ptr->callee()) {
45 elements.emplace_back(
46 serialize(cs_ptr->callee().value(), source_range_tags));
47 } else {
48 elements.emplace_back();
49 }
50 auto fn_name = cs_ptr->function_name();
51 if (!fn_name.empty()) {
52 elements.emplace_back(fn_name);
53 } else {
54 elements.emplace_back("FunctionName_UNKNOWN");
55 }
56 c10::IValue serialized_cs = c10::ivalue::Tuple::create(elements);
57 serialized_inlined_callstack_[cs_ptr] = serialized_cs;
58 return serialized_cs;
59 }
60
serialize_module_instance_info(const std::optional<ModuleInstanceInfo> & m)61 c10::IValue InlinedCallStackSerializer::serialize_module_instance_info(
62 const std::optional<ModuleInstanceInfo>& m) {
63 if (!m) {
64 return c10::IValue();
65 }
66 const auto& m_val = m.value();
67 std::string module_type_name = m_val.class_type()->name()->qualifiedName();
68 auto module_instance_name = m_val.instance_name();
69 if (m_val.class_type()) {
70 module_type_name = m_val.class_type()->name()->qualifiedName();
71 }
72 auto key_val = module_type_name + module_instance_name;
73 auto m_inst_it = serialized_module_instance_info_.find(key_val);
74 if (m_inst_it != serialized_module_instance_info_.end()) {
75 return m_inst_it->second;
76 }
77 // Module instance info is serialized as
78 // {type name, instance name}
79 serialized_module_instance_info_[key_val] =
80 c10::ivalue::Tuple::create({module_type_name, module_instance_name});
81 return serialized_module_instance_info_[key_val];
82 }
83
pickle(const std::unordered_map<int64_t,DebugInfoTuple> & callstack_ptrs,const SourceRangeTagMap & source_range_tags)84 std::vector<char> CallStackDebugInfoPickler::pickle(
85 const std::unordered_map<int64_t, DebugInfoTuple>& callstack_ptrs,
86 const SourceRangeTagMap& source_range_tags) {
87 std::vector<c10::IValue> ivalues;
88 for (const auto& it : callstack_ptrs) {
89 int64_t debug_handle = it.first;
90 std::vector<c10::IValue> elements;
91 /*
92 * Debug handles and debug info (source range + inlinded callstack)
93 * are serialized as a tuple of 3 elements
94 * {debug_handle, source_range_tag, serialized_callstack}
95 */
96 elements.reserve(4);
97 elements.emplace_back(debug_handle);
98 int64_t source_range_tag{kInvalidSourceRangeTag};
99 const auto& source_range =
100 std::get<kDebugInfoTupleSourceRangeIndex>(it.second);
101 const SourceRange& sr = source_range.findSourceRangeThatGenerated()
102 ? source_range.findSourceRangeThatGenerated().value()
103 : source_range;
104 auto sr_it = source_range_tags.find(sr);
105 if (sr_it != source_range_tags.end()) {
106 source_range_tag = sr_it->second;
107 }
108 elements.emplace_back(source_range_tag);
109 elements.emplace_back(std::get<kDebugInfoTupleNodeNameIndex>(it.second));
110 const auto& inlined_cs_ptr =
111 std::get<kDebugInfoTupleInlinedCSIndex>(it.second);
112 elements.emplace_back(css_.serialize(inlined_cs_ptr, source_range_tags));
113 ivalues.emplace_back(c10::ivalue::Tuple::create(elements));
114 }
115 std::vector<at::Tensor> table;
116 c10::IValue ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
117 auto result = jit::pickle(ivalue, &table);
118 TORCH_CHECK(table.empty(), "Expected 0 tensors to be written");
119 return result;
120 }
121
deserialize(const c10::IValue & iv,const ska::flat_hash_map<int64_t,SourceRange> & source_range_map,const std::shared_ptr<CompilationUnit> & cu)122 InlinedCallStackPtr InlinedCallStackDeserializer::deserialize(
123 const c10::IValue& iv,
124 const ska::flat_hash_map<int64_t, SourceRange>& source_range_map,
125 const std::shared_ptr<CompilationUnit>& cu) {
126 if (iv.isNone()) {
127 return c10::intrusive_ptr<InlinedCallStack>();
128 }
129 auto tup = iv.toTuple();
130 auto it = cached_inlined_callstacks_.find(tup);
131 if (it != cached_inlined_callstacks_.end()) {
132 return it->second;
133 }
134
135 const auto& tup_elems = tup->elements();
136 TORCH_INTERNAL_ASSERT(tup_elems.size() == 4);
137 // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack),
138 // function name}
139 auto module_instance_info =
140 deserialize_module_instance_info(tup_elems[0], cu);
141 int64_t source_range_tag = tup_elems[1].toInt();
142 auto source_range_it = source_range_map.find(source_range_tag);
143 TORCH_CHECK(
144 source_range_tag == kInvalidSourceRangeTag ||
145 source_range_it != source_range_map.end(),
146 "Source range tag must exist in deserialized source range map."
147 " Not found source range tag:",
148 source_range_tag);
149 SourceRange source_range;
150 if (source_range_tag != kInvalidSourceRangeTag) {
151 source_range = source_range_it->second;
152 }
153 auto callee = deserialize(tup_elems[2], source_range_map, cu);
154 auto function_name = tup_elems[3].toStringRef();
155 InlinedCallStackPtr cs_ptr;
156 if (callee) {
157 cs_ptr = c10::make_intrusive<InlinedCallStack>(
158 callee, nullptr, source_range, module_instance_info, function_name);
159 } else {
160 cs_ptr = c10::make_intrusive<InlinedCallStack>(
161 nullptr, source_range, module_instance_info, function_name);
162 }
163 cached_inlined_callstacks_[tup] = cs_ptr;
164 // Invoking move constructor
165 // It is not clear if copy-ellision can happen since
166 // cs_ptr is copied into map above.
167 // This is to help avoid ref count update
168 return cs_ptr;
169 }
170
171 std::optional<ModuleInstanceInfo> InlinedCallStackDeserializer::
deserialize_module_instance_info(const c10::IValue & iv,const std::shared_ptr<CompilationUnit> & cu)172 deserialize_module_instance_info(
173 const c10::IValue& iv,
174 const std::shared_ptr<CompilationUnit>& cu) {
175 if (iv.isNone()) {
176 return std::nullopt;
177 }
178 auto tup = iv.toTuple();
179 auto it = cached_module_instance_info_.find(tup);
180 if (it != cached_module_instance_info_.end()) {
181 return it->second;
182 }
183 const auto& tup_elems = iv.toTupleRef().elements();
184 TORCH_CHECK(tup_elems.size() == 2);
185 std::string type_name = tup_elems[0].toStringRef();
186 std::string instance_name = tup_elems[1].toStringRef();
187 // type_name might be empty string ""
188 // In that case type_ptr should be just nullptr
189 auto type_ptr = cu->get_class(type_name);
190 if (!type_ptr) {
191 // We may have lost type information. For example in lowered backends
192 // original class type has no relevance.
193 // However, to correlate ops to their original modules
194 // we saved both type name and instance name.
195 // In such cases, when module is absorbed by lowered backend
196 // we augment instance name with type name instead of losing it.
197 auto last_dot_position = type_name.find_last_of('.');
198 size_t substring_pos{0};
199 if (last_dot_position != std::string::npos) {
200 substring_pos = last_dot_position + 1;
201 }
202 type_name = type_name.substr(substring_pos);
203 instance_name = instance_name + "(" + type_name + ")";
204 }
205 cached_module_instance_info_[tup] =
206 ModuleInstanceInfo(type_ptr, instance_name);
207 return cached_module_instance_info_[tup];
208 }
209
210 ska::flat_hash_map<int64_t, DebugInfoTuple> CallStackDebugInfoUnpickler::
unpickle(const at::DataPtr & data,size_t size,const ska::flat_hash_map<int64_t,SourceRange> & source_range_map,const std::shared_ptr<CompilationUnit> & cu)211 unpickle(
212 const at::DataPtr& data,
213 size_t size,
214 const ska::flat_hash_map<int64_t, SourceRange>& source_range_map,
215 const std::shared_ptr<CompilationUnit>& cu) {
216 auto ival = jit::unpickle(
217 reinterpret_cast<const char*>(data.get()),
218 size,
219 nullptr,
220 {},
221 c10::parseType);
222 ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptrs;
223 const auto& ivalues = ival.toTupleRef().elements();
224 for (auto& val : ivalues) {
225 const auto& tup_elems = val.toTupleRef().elements();
226 TORCH_CHECK(
227 tup_elems.size() == 4,
228 "Pickled map must have four elements: "
229 "debug_handle, source_range_tag, op name, IValue(inlined_call_stack)");
230 int64_t debug_handle = tup_elems[0].toInt();
231 int64_t source_range_tag = tup_elems[1].toInt();
232 const std::string& node_name = tup_elems[2].toStringRef();
233 auto source_range_it = source_range_map.find(source_range_tag);
234 TORCH_CHECK(
235 source_range_it != source_range_map.end(),
236 "Source range tag must exist in deserialized source range map.");
237 auto source_range = source_range_it->second;
238 TORCH_CHECK(
239 callstack_ptrs.count(debug_handle) == 0,
240 "Debug handles should be unique.");
241 callstack_ptrs[debug_handle] = std::make_tuple(
242 source_range,
243 node_name,
244 csds_.deserialize(tup_elems[3], source_range_map, cu));
245 }
246 return callstack_ptrs;
247 }
248
249 } // namespace torch::jit
250