xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_backend_compiler_preprocess.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/backends/backend.h>
2 #include <torch/csrc/jit/backends/backend_preprocess.h>
3 #include <torch/csrc/jit/passes/dead_code_elimination.h>
4 #include <torch/csrc/jit/passes/inliner.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace {
9 // For this backend, the actual compilation happens in preprocess function AOT.
10 // Put here for demonstration of backend
11 // as a whole piece. It's used when compilation is required. A dummy function
12 // can be passed when there's no usage of compilation in runtime backend lib.
preprocess(const Module & mod,const c10::Dict<IValue,IValue> & method_compile_spec,const BackendDebugHandleGenerator & generate_debug_handles)13 c10::IValue preprocess(
14     const Module& mod,
15     const c10::Dict<IValue, IValue>& method_compile_spec,
16     const BackendDebugHandleGenerator& generate_debug_handles) {
17   // The output of this process would produce a dictionary
18   // Key: method name.
19   // Val: compiled blob (represented by a string).
20   c10::Dict<IValue, IValue> compiled(StringType::get(), StringType::get());
21 
22   for (const auto& method : mod.get_methods()) {
23     auto graph = toGraphFunction(method.function()).graph()->copy();
24     // Must inline the graph for debug info map.
25     Inline(*graph);
26     // This is here because to test module hierarchy we will have
27     // getattr nodes which after inlining dont serve any purpose.
28     // Without removing them we will run into compilation errors.
29     // So eliminate deadcode just remove those getattr nodes.
30     EliminateDeadCode(graph);
31     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
32     auto key = method.name();
33     auto node_debug_handles = generate_debug_handles(graph);
34     std::stringstream ss;
35     for (const auto& node : graph->nodes()) {
36       switch (node->kind()) {
37         case prim::Constant:
38           ss << node->kind().toDisplayString() << "#"
39              << toIValue(node->output()).value();
40           ss << "<debug_handle>" << node_debug_handles[node];
41           break;
42         // NOLINTNEXTLINE(bugprone-branch-clone)
43         case aten::add:
44           ss << node->kind().toQualString();
45           ss << "<debug_handle>" << node_debug_handles[node];
46           break;
47         case aten::sub:
48           ss << node->kind().toQualString();
49           ss << "<debug_handle>" << node_debug_handles[node];
50           break;
51         default:
52           TORCH_CHECK(
53               false,
54               "The node of ",
55               node->kind().toQualString(),
56               " is not supported in this compiler. Source code: ",
57               node->sourceRange().str());
58           break;
59       }
60       ss << ",";
61     }
62     std::string blob = ss.str();
63     if (!blob.empty()) {
64       blob.pop_back();
65     }
66     compiled.insert(method.name(), blob);
67   }
68   return compiled;
69 }
70 
71 constexpr auto backend_name = "backend_with_compiler_demo";
72 static auto pre_reg = backend_preprocess_register(backend_name, preprocess);
73 } // namespace
74 
75 } // namespace jit
76 } // namespace torch
77