xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/backends/backend.h>
2 #include <torch/csrc/jit/backends/backend_preprocess.h>
3 
4 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
5 #include <torch/torch.h>
6 #include <xnnpack.h>
7 
8 #include <ATen/core/List.h>
9 #include <torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h>
10 
11 namespace torch {
12 namespace jit {
13 namespace xnnpack {
14 namespace delegate {
15 
16 // Expected method_compile_spec should look something like this:
17 // {
18 //     "forward" : {"inputs" : at::Tensor}
19 // }
20 // or
21 // {
22 //     "forward" : {
23 //                  "inputs" : c10::List<at::Tensor>,
24 //                  "outputs" : c10::List<at::Tensor>
25 //                  }
26 // }
27 // in which the value for "inputs" is the input shape to the module.
28 // The module fed to the xnnpack backend must first be traced in order
29 // to propagate input shapes through the module. This is important
30 // for building the xnnpack_subgraph_t object.
preprocess(const Module & mod,const c10::Dict<c10::IValue,c10::IValue> & method_compile_spec,const BackendDebugHandleGenerator & generate_debug_handles)31 c10::IValue preprocess(
32     const Module& mod,
33     const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
34     const BackendDebugHandleGenerator& generate_debug_handles) {
35   auto eval_mod = mod.clone();
36   eval_mod.eval();
37   eval_mod = torch::jit::freeze(eval_mod);
38 
39   c10::Dict<IValue, IValue> compiled(StringType::get(), TensorType::get());
40 
41   c10::IValue inp;
42   c10::IValue out;
43 
44   TORCH_CHECK(
45       method_compile_spec.contains("forward"),
46       "method_compile_spec does not contain the \"forward\" key.");
47   auto innerDict = method_compile_spec.at("forward");
48 
49   TORCH_CHECK(
50       innerDict.isGenericDict() &&
51           innerDict.toGenericDict().contains("inputs") &&
52           innerDict.toGenericDict().contains("outputs"),
53       "method_compile_spec does not contain a dictionary with an \"inputs\" key, under \"forward\" key.");
54 
55   inp = innerDict.toGenericDict().at("inputs");
56   out = innerDict.toGenericDict().at("outputs");
57 
58   TORCH_CHECK(
59       inp.isTensor() || inp.isTensorList(),
60       "method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key.");
61   TORCH_CHECK(
62       out.isTensor() || out.isTensorList(),
63       "method_compile_spec does not contain either a Tensor or TensorList, under it's \"outputs\" key.");
64 
65   // Graph preprocessing
66   const auto& forward_method = eval_mod.get_method("forward");
67 
68   auto graph = toGraphFunction(forward_method.function()).graph()->copy();
69   graph = tensorexpr::removeUnusedSelfArgument(graph);
70   std::vector<c10::IValue> example_inputs;
71   if (inp.isTensorList()) {
72     c10::List<at::Tensor> inp_list = inp.toTensorList();
73     TORCH_CHECK(
74         graph->inputs().size() == inp_list.size(),
75         "method_compile_spec inputs do not match expected number of forward inputs");
76 
77     example_inputs.reserve(inp_list.size());
78     for (const auto i : c10::irange(inp_list.size())) {
79       example_inputs.emplace_back(inp_list[i]);
80     }
81   } else {
82     TORCH_CHECK(
83         graph->inputs().size() == 1,
84         "method_compile_spec inputs do not match expected number of forward inputs");
85 
86     example_inputs.emplace_back(inp.toTensor());
87   }
88 
89   // inp above has been confirmed to be either Tensor or TensorList
90   XNNGraph graph_builder;
91   graph_builder.buildXNNGraph(graph, example_inputs);
92   // at this point graph is complete, for the sake of testing preprocess at this
93   // point we will do runtime setup and run with some default values
94 
95   // grabbing the inputs from compile spec for testing
96 
97   // gather sample inputs from compile spec
98   std::vector<at::Tensor> inputs;
99   auto input_list = inp.toList();
100 
101   for (int i = 0; i < input_list.size(); i++) {
102     inputs.push_back(input_list.get(i).toTensor());
103   }
104   std::vector<at::Tensor> outputs;
105   auto output_list = out.toList();
106   std::vector<c10::IntList> output_shapes;
107 
108   // gather sample outputs from compile spec
109   for (int i = 0; i < output_list.size(); i++) {
110     auto sample_output = output_list.get(i).toTensor();
111     outputs.push_back(sample_output);
112     // also gather output shapes to forward along to device
113     output_shapes.push_back(sample_output.sizes());
114   }
115 
116   // sample run on sample inputs
117   graph_builder.runGraphOnInputs(inputs, outputs);
118   c10::List<c10::IntList> shapes_list(output_shapes);
119 
120   compiled.insert("ser_model", graph_builder.serializedXNNGraph());
121   compiled.insert("outputs", shapes_list);
122   compiled.insert("Answer", outputs);
123 
124   return compiled;
125 }
126 constexpr auto backend_name = "xnnpack";
127 static auto pre_reg = backend_preprocess_register(backend_name, preprocess);
128 
129 } // namespace delegate
130 } // namespace xnnpack
131 } // namespace jit
132 } // namespace torch
133