1 #ifndef NNAPI_BIND_H_ 2 #define NNAPI_BIND_H_ 3 4 #include <vector> 5 6 #include <ATen/ATen.h> 7 #include <torch/custom_class.h> 8 9 #include <ATen/nnapi/nnapi_wrapper.h> 10 11 namespace torch { 12 namespace nnapi { 13 namespace bind { 14 15 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) 16 TORCH_API extern nnapi_wrapper* nnapi; 17 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) 18 TORCH_API extern nnapi_wrapper* check_nnapi; 19 20 #define MAKE_SMART_PTR(type) \ 21 struct type ## Freer { \ 22 void operator()(ANeuralNetworks ## type * obj) { \ 23 if (!nnapi) { /* obj must be null. */ return; } \ 24 nnapi-> type ## _free(obj); \ 25 } \ 26 }; \ 27 typedef std::unique_ptr<ANeuralNetworks ## type, type ## Freer> type ## Ptr; 28 29 MAKE_SMART_PTR(Model) 30 MAKE_SMART_PTR(Compilation) 31 MAKE_SMART_PTR(Execution) 32 33 #undef MAKE_SMART_PTR 34 35 struct NnapiCompilation : torch::jit::CustomClassHolder { 36 NnapiCompilation() = default; 37 ~NnapiCompilation() override = default; 38 39 // only necessary for older models that still call init() 40 TORCH_API void init( 41 at::Tensor serialized_model_tensor, 42 std::vector<at::Tensor> parameter_buffers 43 ); 44 45 TORCH_API void init2( 46 at::Tensor serialized_model_tensor, 47 const std::vector<at::Tensor>& parameter_buffers, 48 int64_t compilation_preference, 49 bool relax_f32_to_f16 50 ); 51 52 53 TORCH_API void run(std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs); 54 static void get_operand_type(const at::Tensor& t, ANeuralNetworksOperandType* operand, std::vector<uint32_t>* dims); 55 56 ModelPtr model_; 57 CompilationPtr compilation_; 58 int32_t num_inputs_ {}; 59 int32_t num_outputs_ {}; 60 }; 61 62 } // namespace bind 63 } // namespace nnapi 64 } // namespace torch 65 66 #endif // NNAPI_BIND_H_ 67