xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_cs_debug_info_serialization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <test/cpp/jit/test_utils.h>
2 
3 #include <gtest/gtest.h>
4 
5 #include <c10/core/TensorOptions.h>
6 #include <torch/csrc/autograd/generated/variable_factories.h>
7 #include <torch/csrc/jit/api/module.h>
8 #include <torch/csrc/jit/backends/backend_debug_handler.h>
9 #include <torch/csrc/jit/frontend/resolver.h>
10 #include <torch/csrc/jit/mobile/import.h>
11 #include <torch/csrc/jit/mobile/module.h>
12 #include <torch/csrc/jit/passes/inliner.h>
13 #include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
14 #include <torch/csrc/jit/serialization/export.h>
15 #include <torch/csrc/jit/serialization/import.h>
16 #include <torch/custom_class.h>
17 #include <torch/torch.h>
18 
19 #include <stack>
20 #include <unordered_set>
21 
22 // Tests go in torch::jit
23 namespace torch {
24 namespace jit {
25 
26 namespace {
validate_debug_info(const DebugInfoTuple & pre_serialize,const DebugInfoTuple & post_serialize)27 bool validate_debug_info(
28     const DebugInfoTuple& pre_serialize,
29     const DebugInfoTuple& post_serialize) {
30   auto sr1 = std::get<kDebugInfoTupleSourceRangeIndex>(pre_serialize);
31   auto sr2 = std::get<kDebugInfoTupleSourceRangeIndex>(post_serialize);
32   if (sr1 != sr2) {
33     return false;
34   }
35   auto csptr1 = std::get<kDebugInfoTupleInlinedCSIndex>(pre_serialize);
36   auto csptr2 = std::get<kDebugInfoTupleInlinedCSIndex>(post_serialize);
37   if (!csptr1.defined()) {
38     return !csptr2.defined();
39   }
40   if (!csptr2.defined()) {
41     return false;
42   }
43   auto vec1 = csptr1->vec();
44   auto vec2 = csptr2->vec();
45   if (vec1.size() != vec2.size()) {
46     return false;
47   }
48   while (csptr1) {
49     auto rhs_sr = csptr1->source_range();
50     auto lhs_sr = csptr2->source_range();
51     auto rhs_module = csptr1->module_instance();
52     auto lhs_module = csptr2->module_instance();
53     std::string rhs_fn_name, lhs_fn_name;
54     if (csptr1->function()) {
55       rhs_fn_name = csptr1->function()->name();
56     } else {
57       rhs_fn_name = csptr1->function_name();
58     }
59     if (csptr2->function()) {
60       lhs_fn_name = csptr2->function()->name();
61     } else {
62       lhs_fn_name = csptr2->function_name();
63     }
64     if (!((rhs_module.has_value() == lhs_module.has_value()) &&
65           (rhs_module.has_value() &&
66            (rhs_module.value().class_type()->name().value() ==
67             lhs_module.value().class_type()->name().value()) &&
68            (rhs_module.value().instance_name() ==
69             lhs_module.value().instance_name())) &&
70           (rhs_fn_name == lhs_fn_name) && (rhs_sr == lhs_sr))) {
71       return false;
72     }
73     if (csptr1->callee()) {
74       csptr1 = csptr1->callee().value();
75       csptr2 = csptr2->callee().value();
76     } else {
77       csptr1 = c10::intrusive_ptr<InlinedCallStack>();
78     }
79   }
80   return true;
81 }
82 
TEST(CSDebugInfoSerializaitionTest,TwoSubmodules)83 TEST(CSDebugInfoSerializaitionTest, TwoSubmodules) {
84   std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
85   Module a("A", cu);
86   a.define(R"JIT(
87     def forward(self, x):
88       return x + 1
89   )JIT");
90   Module b("B", cu);
91   b.define(R"JIT(
92     def forward(self, x):
93       return x + 2
94   )JIT");
95   Module c("C", cu);
96   c.register_module("A0", a);
97   c.register_module("B0", b);
98   c.define(R"JIT(
99     def forward(self, x):
100       return self.A0.forward(x) + self.B0.forward(x)
101   )JIT");
102 
103   BackendDebugInfoRecorder debug_info_recorder;
104   auto graph = c.get_method("forward").graph();
105   Inline(*graph);
106   std::stack<Block*> blocks_to_visit;
107 
108   // maps from source range to debug handle
109   SourceRangeTagMap source_range_tags;
110   // Maps from debug handle to source range
111   ska::flat_hash_map<int64_t, SourceRange> source_range_map;
112   int64_t source_range_tag{0};
113 
114   blocks_to_visit.push(graph->block());
115   while (!blocks_to_visit.empty()) {
116     Block* b = blocks_to_visit.top();
117     blocks_to_visit.pop();
118     for (Node* n : b->nodes()) {
119       source_range_tags[n->sourceRange()] = source_range_tag;
120       source_range_map[source_range_tag] = n->sourceRange();
121       source_range_tag++;
122       debug_info_recorder.getNextDebugHandle(n);
123       if (n->callstack().has_value()) {
124         for (const auto& e : n->callstack().value()->vec()) {
125           auto sr = std::get<1>(e);
126           source_range_tags[sr] = source_range_tag;
127           source_range_map[source_range_tag] = sr;
128           source_range_tag++;
129         }
130       }
131     }
132   }
133   auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
134   CallStackDebugInfoPickler cs_debug_info_pickler;
135   auto cs_data =
136       cs_debug_info_pickler.pickle(debug_handle_cs_ptr_map, source_range_tags);
137   at::DataPtr data_ptr(cs_data.data(), DeviceType::CPU);
138   CallStackDebugInfoUnpickler unpickler;
139   auto deserialized_cs_map = unpickler.unpickle(
140       std::move(data_ptr), cs_data.size(), source_range_map, cu);
141   for (const auto& it : debug_handle_cs_ptr_map) {
142     auto handle = it.first;
143     auto debug_info_one = it.second;
144     TORCH_CHECK(
145         deserialized_cs_map.count(handle),
146         "Serialized debug handle must be in deserialized map.");
147     auto debug_info_two = deserialized_cs_map[handle];
148     ASSERT_TRUE(validate_debug_info(debug_info_one, debug_info_two));
149   }
150 }
151 
152 } // namespace
153 
154 } // namespace jit
155 } // namespace torch
156