xref: /aosp_15_r20/external/pytorch/aten/src/ATen/nnapi/nnapi_bind.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef NNAPI_BIND_H_
2 #define NNAPI_BIND_H_
3 
4 #include <vector>
5 
6 #include <ATen/ATen.h>
7 #include <torch/custom_class.h>
8 
9 #include <ATen/nnapi/nnapi_wrapper.h>
10 
11 namespace torch {
12 namespace nnapi {
13 namespace bind {
14 
15 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
16 TORCH_API extern nnapi_wrapper* nnapi;
17 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
18 TORCH_API extern nnapi_wrapper* check_nnapi;
19 
20 #define MAKE_SMART_PTR(type) \
21   struct type ## Freer { \
22     void operator()(ANeuralNetworks ## type * obj) { \
23       if (!nnapi) { /* obj must be null. */ return; } \
24       nnapi-> type ## _free(obj); \
25     } \
26   }; \
27   typedef std::unique_ptr<ANeuralNetworks ## type, type ## Freer> type ## Ptr;
28 
29 MAKE_SMART_PTR(Model)
30 MAKE_SMART_PTR(Compilation)
31 MAKE_SMART_PTR(Execution)
32 
33 #undef MAKE_SMART_PTR
34 
35 struct NnapiCompilation : torch::jit::CustomClassHolder {
36     NnapiCompilation() = default;
37     ~NnapiCompilation() override = default;
38 
39     // only necessary for older models that still call init()
40     TORCH_API void init(
41       at::Tensor serialized_model_tensor,
42       std::vector<at::Tensor> parameter_buffers
43     );
44 
45     TORCH_API void init2(
46       at::Tensor serialized_model_tensor,
47       const std::vector<at::Tensor>& parameter_buffers,
48       int64_t compilation_preference,
49       bool relax_f32_to_f16
50     );
51 
52 
53     TORCH_API void run(std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs);
54     static void get_operand_type(const at::Tensor& t, ANeuralNetworksOperandType* operand, std::vector<uint32_t>* dims);
55 
56     ModelPtr model_;
57     CompilationPtr compilation_;
58     int32_t num_inputs_ {};
59     int32_t num_outputs_ {};
60 };
61 
62 } // namespace bind
63 } // namespace nnapi
64 } // namespace torch
65 
66 #endif // NNAPI_BIND_H_
67