xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Functions.h>
2 #include <ATen/Utils.h>
3 #include <c10/core/TensorImpl.h>
4 #include <torch/csrc/jit/backends/backend.h>
5 #include <torch/csrc/jit/backends/backend_exception.h>
6 
7 #include <caffe2/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h>
8 #include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h>
9 
10 namespace torch {
11 namespace jit {
12 namespace xnnpack {
13 namespace delegate {
14 
15 class XNNModelWrapper : public CustomClassHolder {
16  public:
17   XNNExecutor executor_;
XNNModelWrapper(XNNExecutor executor)18   XNNModelWrapper(XNNExecutor executor) : executor_(std::move(executor)){};
19 
20   XNNModelWrapper() = delete;
21 
22   XNNModelWrapper(const XNNModelWrapper& oldObject) = delete;
23 };
24 
25 class XNNPackBackend : public PyTorchBackendInterface {
26  public:
27   // Constructor.
28   // NOLINTNEXTLINE(modernize-use-equals-default)
XNNPackBackend()29   explicit XNNPackBackend() {}
30   virtual ~XNNPackBackend() override = default;
31 
is_available()32   bool is_available() override {
33     return xnn_status_success == xnn_initialize(/*allocator=*/nullptr);
34   }
35 
compile(c10::IValue processed,c10::impl::GenericDict method_compile_spec)36   c10::impl::GenericDict compile(
37       c10::IValue processed,
38       c10::impl::GenericDict method_compile_spec) override {
39     auto dict = processed.toGenericDict();
40 
41     // Compiling and wrapping exeuction object
42     const std::string& ser_model = dict.at("ser_model").toStringRef();
43     XNNExecutor executor;
44     XNNCompiler::compileModel(ser_model.data(), ser_model.length(), &executor);
45 
46     auto model_ptr = c10::make_intrusive<XNNModelWrapper>(std::move(executor));
47     auto runtime_handle = IValue::make_capsule(model_ptr);
48     auto wrapper = c10::static_intrusive_pointer_cast<XNNModelWrapper>(
49         runtime_handle.toCapsule());
50 
51     // Packing outputs into generic dict
52     c10::Dict<c10::IValue, c10::IValue> handles(
53         c10::StringType::get(), c10::AnyType::get());
54 
55     c10::Dict<c10::IValue, c10::IValue> ret(
56         c10::StringType::get(), c10::AnyType::get());
57 
58     ret.insert("runtime", runtime_handle);
59     ret.insert("output_shapes", dict.at("outputs"));
60 
61     handles.insert("forward", ret);
62 
63     return handles;
64   }
65 
66   // Currently this is not implemented, and everything is computed a head of
67   // time the current implementation just takes the computed results from ahead
68   // of time and grabs them. The inputs are fed in through the compile spec for
69   // the sake of testing. In reality, the inputs will be fed in at this stage
70   // and ran here.
execute(c10::IValue handle,c10::impl::GenericList inputs)71   c10::impl::GenericList execute(
72       c10::IValue handle,
73       c10::impl::GenericList inputs) override {
74     auto dict = handle.toGenericDict();
75     auto output_shapes = dict.at("output_shapes").toList();
76 
77     auto capsule = dict.at("runtime").toCapsule();
78     auto model_wrapper =
79         c10::static_intrusive_pointer_cast<XNNModelWrapper>(capsule);
80 
81     XNNExecutor& executor = model_wrapper->executor_;
82 
83     std::vector<float*> input_pointers;
84     for (int i = 0; i < inputs.size(); ++i) {
85       at::IValue val = inputs.get(i);
86       TORCH_CHECK(val.isTensor(), "Non-tensor inputs not supported");
87       input_pointers.push_back(val.toTensor().data_ptr<float>());
88     }
89 
90     std::vector<at::Tensor> output_tensors;
91     std::vector<float*> output_pointers;
92     output_tensors.reserve(output_shapes.size());
93     for (int i = 0; i < output_shapes.size(); i++) {
94       auto o_shape = output_shapes.get(i).toIntVector();
95       auto output = at::empty(o_shape, c10::ScalarType::Float);
96       output_tensors.push_back(output);
97       output_pointers.push_back(output.data_ptr<float>());
98     }
99 
100     TORCH_CHECK(
101         executor.set_inputs(input_pointers, output_pointers),
102         "Number of inputs/outputs does not match expected number of inputs/outputs");
103     TORCH_CHECK(executor.forward(), "Failed to invoke XNNPack runtime");
104 
105     c10::List<at::Tensor> output_list(output_tensors);
106     return c10::impl::toList(output_list);
107   }
108 };
109 
110 namespace {
111 constexpr auto backend_name = "xnnpack";
112 static auto cls = torch::jit::backend<XNNPackBackend>(backend_name);
113 } // namespace
114 
115 } // namespace delegate
116 } // namespace xnnpack
117 } // namespace jit
118 } // namespace torch
119