1 #include <ATen/Functions.h> 2 #include <ATen/Utils.h> 3 #include <c10/core/TensorImpl.h> 4 #include <torch/csrc/jit/backends/backend.h> 5 #include <torch/csrc/jit/backends/backend_exception.h> 6 7 #include <caffe2/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h> 8 #include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h> 9 10 namespace torch { 11 namespace jit { 12 namespace xnnpack { 13 namespace delegate { 14 15 class XNNModelWrapper : public CustomClassHolder { 16 public: 17 XNNExecutor executor_; XNNModelWrapper(XNNExecutor executor)18 XNNModelWrapper(XNNExecutor executor) : executor_(std::move(executor)){}; 19 20 XNNModelWrapper() = delete; 21 22 XNNModelWrapper(const XNNModelWrapper& oldObject) = delete; 23 }; 24 25 class XNNPackBackend : public PyTorchBackendInterface { 26 public: 27 // Constructor. 28 // NOLINTNEXTLINE(modernize-use-equals-default) XNNPackBackend()29 explicit XNNPackBackend() {} 30 virtual ~XNNPackBackend() override = default; 31 is_available()32 bool is_available() override { 33 return xnn_status_success == xnn_initialize(/*allocator=*/nullptr); 34 } 35 compile(c10::IValue processed,c10::impl::GenericDict method_compile_spec)36 c10::impl::GenericDict compile( 37 c10::IValue processed, 38 c10::impl::GenericDict method_compile_spec) override { 39 auto dict = processed.toGenericDict(); 40 41 // Compiling and wrapping exeuction object 42 const std::string& ser_model = dict.at("ser_model").toStringRef(); 43 XNNExecutor executor; 44 XNNCompiler::compileModel(ser_model.data(), ser_model.length(), &executor); 45 46 auto model_ptr = c10::make_intrusive<XNNModelWrapper>(std::move(executor)); 47 auto runtime_handle = IValue::make_capsule(model_ptr); 48 auto wrapper = c10::static_intrusive_pointer_cast<XNNModelWrapper>( 49 runtime_handle.toCapsule()); 50 51 // Packing outputs into generic dict 52 c10::Dict<c10::IValue, c10::IValue> handles( 53 c10::StringType::get(), c10::AnyType::get()); 54 55 c10::Dict<c10::IValue, c10::IValue> ret( 56 c10::StringType::get(), c10::AnyType::get()); 57 58 ret.insert("runtime", runtime_handle); 59 ret.insert("output_shapes", dict.at("outputs")); 60 61 handles.insert("forward", ret); 62 63 return handles; 64 } 65 66 // Currently this is not implemented, and everything is computed a head of 67 // time the current implementation just takes the computed results from ahead 68 // of time and grabs them. The inputs are fed in through the compile spec for 69 // the sake of testing. In reality, the inputs will be fed in at this stage 70 // and ran here. execute(c10::IValue handle,c10::impl::GenericList inputs)71 c10::impl::GenericList execute( 72 c10::IValue handle, 73 c10::impl::GenericList inputs) override { 74 auto dict = handle.toGenericDict(); 75 auto output_shapes = dict.at("output_shapes").toList(); 76 77 auto capsule = dict.at("runtime").toCapsule(); 78 auto model_wrapper = 79 c10::static_intrusive_pointer_cast<XNNModelWrapper>(capsule); 80 81 XNNExecutor& executor = model_wrapper->executor_; 82 83 std::vector<float*> input_pointers; 84 for (int i = 0; i < inputs.size(); ++i) { 85 at::IValue val = inputs.get(i); 86 TORCH_CHECK(val.isTensor(), "Non-tensor inputs not supported"); 87 input_pointers.push_back(val.toTensor().data_ptr<float>()); 88 } 89 90 std::vector<at::Tensor> output_tensors; 91 std::vector<float*> output_pointers; 92 output_tensors.reserve(output_shapes.size()); 93 for (int i = 0; i < output_shapes.size(); i++) { 94 auto o_shape = output_shapes.get(i).toIntVector(); 95 auto output = at::empty(o_shape, c10::ScalarType::Float); 96 output_tensors.push_back(output); 97 output_pointers.push_back(output.data_ptr<float>()); 98 } 99 100 TORCH_CHECK( 101 executor.set_inputs(input_pointers, output_pointers), 102 "Number of inputs/outputs does not match expected number of inputs/outputs"); 103 TORCH_CHECK(executor.forward(), "Failed to invoke XNNPack runtime"); 104 105 c10::List<at::Tensor> output_list(output_tensors); 106 return c10::impl::toList(output_list); 107 } 108 }; 109 110 namespace { 111 constexpr auto backend_name = "xnnpack"; 112 static auto cls = torch::jit::backend<XNNPackBackend>(backend_name); 113 } // namespace 114 115 } // namespace delegate 116 } // namespace xnnpack 117 } // namespace jit 118 } // namespace torch 119