xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/callstack_debug_info_serialization.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Allocator.h>
4 #include <torch/csrc/jit/frontend/source_range.h>
5 #include <torch/csrc/jit/ir/scope.h>
6 
7 #include <ATen/core/ivalue.h>
8 
9 #include <vector>
10 
11 #include <c10/util/flat_hash_map.h>
12 
13 namespace c10 {
14 struct IValue;
15 }
16 
17 namespace torch::jit {
18 
19 class Pickler;
20 class InlinedCallStackSerializer {
21  public:
22   // Serialize InlinedCallStack as
23   // SerializedInlinedCallStack =
24   // [module_info, source range tag, SerializedInlinedCallStack]
25   // module_info = [ClassType.qualifiedName, instance_name]
26   // source_range_tag = unique source range id
27   c10::IValue serialize(
28       const InlinedCallStackPtr& cs_ptr,
29       const SourceRangeTagMap& source_range_tags);
30 
31  private:
32   // module_info = [ClassType.qualifiedName, instance_name]
33   c10::IValue serialize_module_instance_info(
34       const std::optional<ModuleInstanceInfo>& m);
35 
36   // This caches serialized inlined callstack ptr, since many
37   // InlinedCallStackPtr can refer to the same one.
38   ska::flat_hash_map<InlinedCallStackPtr, c10::IValue>
39       serialized_inlined_callstack_;
40   // This caches serialized module instance info.
41   // There might be many nodes that are part of the same
42   // parent, grandparent etc. module.
43   ska::flat_hash_map<std::string, c10::IValue> serialized_module_instance_info_;
44 };
45 
46 class TORCH_API CallStackDebugInfoPickler {
47  public:
48   CallStackDebugInfoPickler() = default;
49 
50   std::vector<char> pickle(
51       const std::unordered_map<int64_t, DebugInfoTuple>& callstack_ptrs,
52       const SourceRangeTagMap& source_range_tags);
53 
54  private:
55   InlinedCallStackSerializer css_;
56 };
57 
58 class InlinedCallStackDeserializer {
59  public:
60   InlinedCallStackPtr deserialize(
61       const c10::IValue& iv,
62       const ska::flat_hash_map<int64_t, SourceRange>& source_range_map,
63       const std::shared_ptr<CompilationUnit>& cu);
64 
65  private:
66   std::optional<ModuleInstanceInfo> deserialize_module_instance_info(
67       const c10::IValue& iv,
68       const std::shared_ptr<CompilationUnit>& cu);
69 
70   ska::
71       flat_hash_map<c10::intrusive_ptr<c10::ivalue::Tuple>, InlinedCallStackPtr>
72           cached_inlined_callstacks_;
73   ska::flat_hash_map<c10::intrusive_ptr<c10::ivalue::Tuple>, ModuleInstanceInfo>
74       cached_module_instance_info_;
75 };
76 
77 class TORCH_API CallStackDebugInfoUnpickler {
78  public:
79   ska::flat_hash_map<int64_t, DebugInfoTuple> unpickle(
80       const at::DataPtr& data,
81       size_t size,
82       const ska::flat_hash_map<int64_t, SourceRange>& source_range_map,
83       const std::shared_ptr<CompilationUnit>& cu);
84 
85  private:
86   InlinedCallStackDeserializer csds_;
87 };
88 
89 } // namespace torch::jit
90