1 #include <torch/csrc/jit/backends/backend.h>
2 #include <torch/csrc/jit/backends/backend_preprocess.h>
3
4 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
5 #include <torch/torch.h>
6 #include <xnnpack.h>
7
8 #include <ATen/core/List.h>
9 #include <torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h>
10
11 namespace torch {
12 namespace jit {
13 namespace xnnpack {
14 namespace delegate {
15
16 // Expected method_compile_spec should look something like this:
17 // {
18 // "forward" : {"inputs" : at::Tensor}
19 // }
20 // or
21 // {
22 // "forward" : {
23 // "inputs" : c10::List<at::Tensor>,
24 // "outputs" : c10::List<at::Tensor>
25 // }
26 // }
27 // in which the value for "inputs" is the input shape to the module.
28 // The module fed to the xnnpack backend must first be traced in order
29 // to propagate input shapes through the module. This is important
30 // for building the xnnpack_subgraph_t object.
preprocess(const Module & mod,const c10::Dict<c10::IValue,c10::IValue> & method_compile_spec,const BackendDebugHandleGenerator & generate_debug_handles)31 c10::IValue preprocess(
32 const Module& mod,
33 const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
34 const BackendDebugHandleGenerator& generate_debug_handles) {
35 auto eval_mod = mod.clone();
36 eval_mod.eval();
37 eval_mod = torch::jit::freeze(eval_mod);
38
39 c10::Dict<IValue, IValue> compiled(StringType::get(), TensorType::get());
40
41 c10::IValue inp;
42 c10::IValue out;
43
44 TORCH_CHECK(
45 method_compile_spec.contains("forward"),
46 "method_compile_spec does not contain the \"forward\" key.");
47 auto innerDict = method_compile_spec.at("forward");
48
49 TORCH_CHECK(
50 innerDict.isGenericDict() &&
51 innerDict.toGenericDict().contains("inputs") &&
52 innerDict.toGenericDict().contains("outputs"),
53 "method_compile_spec does not contain a dictionary with an \"inputs\" key, under \"forward\" key.");
54
55 inp = innerDict.toGenericDict().at("inputs");
56 out = innerDict.toGenericDict().at("outputs");
57
58 TORCH_CHECK(
59 inp.isTensor() || inp.isTensorList(),
60 "method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key.");
61 TORCH_CHECK(
62 out.isTensor() || out.isTensorList(),
63 "method_compile_spec does not contain either a Tensor or TensorList, under it's \"outputs\" key.");
64
65 // Graph preprocessing
66 const auto& forward_method = eval_mod.get_method("forward");
67
68 auto graph = toGraphFunction(forward_method.function()).graph()->copy();
69 graph = tensorexpr::removeUnusedSelfArgument(graph);
70 std::vector<c10::IValue> example_inputs;
71 if (inp.isTensorList()) {
72 c10::List<at::Tensor> inp_list = inp.toTensorList();
73 TORCH_CHECK(
74 graph->inputs().size() == inp_list.size(),
75 "method_compile_spec inputs do not match expected number of forward inputs");
76
77 example_inputs.reserve(inp_list.size());
78 for (const auto i : c10::irange(inp_list.size())) {
79 example_inputs.emplace_back(inp_list[i]);
80 }
81 } else {
82 TORCH_CHECK(
83 graph->inputs().size() == 1,
84 "method_compile_spec inputs do not match expected number of forward inputs");
85
86 example_inputs.emplace_back(inp.toTensor());
87 }
88
89 // inp above has been confirmed to be either Tensor or TensorList
90 XNNGraph graph_builder;
91 graph_builder.buildXNNGraph(graph, example_inputs);
92 // at this point graph is complete, for the sake of testing preprocess at this
93 // point we will do runtime setup and run with some default values
94
95 // grabbing the inputs from compile spec for testing
96
97 // gather sample inputs from compile spec
98 std::vector<at::Tensor> inputs;
99 auto input_list = inp.toList();
100
101 for (int i = 0; i < input_list.size(); i++) {
102 inputs.push_back(input_list.get(i).toTensor());
103 }
104 std::vector<at::Tensor> outputs;
105 auto output_list = out.toList();
106 std::vector<c10::IntList> output_shapes;
107
108 // gather sample outputs from compile spec
109 for (int i = 0; i < output_list.size(); i++) {
110 auto sample_output = output_list.get(i).toTensor();
111 outputs.push_back(sample_output);
112 // also gather output shapes to forward along to device
113 output_shapes.push_back(sample_output.sizes());
114 }
115
116 // sample run on sample inputs
117 graph_builder.runGraphOnInputs(inputs, outputs);
118 c10::List<c10::IntList> shapes_list(output_shapes);
119
120 compiled.insert("ser_model", graph_builder.serializedXNNGraph());
121 compiled.insert("outputs", shapes_list);
122 compiled.insert("Answer", outputs);
123
124 return compiled;
125 }
126 constexpr auto backend_name = "xnnpack";
127 static auto pre_reg = backend_preprocess_register(backend_name, preprocess);
128
129 } // namespace delegate
130 } // namespace xnnpack
131 } // namespace jit
132 } // namespace torch
133