xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <memory>
2 
3 #include <ATen/nnapi/nnapi_bind.h>
4 #include <torch/csrc/jit/backends/backend.h>
5 #include <torch/csrc/jit/backends/backend_exception.h>
6 #include <torch/csrc/jit/mobile/import.h>
7 #include <torch/csrc/jit/mobile/module.h>
8 
9 namespace torch {
10 namespace jit {
11 
12 // Implementation of Android NNAPI Backend delegate
13 
14 // The Android Neural Networks API (NNAPI) is an Android C API designed
15 // for running computationally intensive operations for machine learning on
16 // Android devices. The API is available on all Android devices running
17 // Android 8.1 (API level 27) or higher.
18 
19 // Implementation is reflective of caffe2/torch/backends/_nnapi/prepare.py
20 // NnapiModule.forward()
21 class NnapiBackend : public PyTorchBackendInterface {
22  public:
23   // Constructor.
24   explicit NnapiBackend() = default;
25   ~NnapiBackend() override = default;
26 
is_available()27   bool is_available() override {
28     return true;
29   }
30 
compile(c10::IValue processed,c10::impl::GenericDict method_compile_spec)31   c10::impl::GenericDict compile(
32       c10::IValue processed,
33       c10::impl::GenericDict method_compile_spec) override {
34     // Wrap processed in dictionary: {"forward": processed}
35     auto dict = processed.toGenericDict();
36     c10::Dict<c10::IValue, c10::IValue> handles(
37         c10::StringType::get(), c10::AnyType::get());
38     handles.insert("forward", dict);
39     return c10::impl::toGenericDict(handles);
40   }
41 
execute(c10::IValue handle,c10::impl::GenericList inputs)42   c10::impl::GenericList execute(
43       c10::IValue handle,
44       c10::impl::GenericList inputs) override {
45     // Convert inputs to Tensors
46     c10::List<at::Tensor> tensorInp;
47     for (c10::IValue element : inputs) {
48       tensorInp.push_back(element.toTensor());
49     }
50 
51     // Lazily call init()
52     if (comp_ == nullptr) {
53       init(handle, tensorInp);
54     }
55     TORCH_CHECK(comp_ != nullptr)
56 
57     c10::List<at::Tensor> outputs;
58     for (at::Tensor out : out_templates_) {
59       outputs.push_back(at::empty_like(out));
60     }
61 
62     // Adjust input memory formats
63     auto dict = handle.toGenericDict();
64     auto inp_mem_fmts = dict.at("inp_mem_fmts").toIntList();
65     TORCH_CHECK(tensorInp.size() == inp_mem_fmts.size());
66     std::vector<at::Tensor> fixed_inputs;
67     for (auto i = 0U; i < tensorInp.size(); i++) {
68       int fmt = inp_mem_fmts[i];
69       // These constants match the values in DimOrder in serializer.py
70       // 0: NCHW, 1: NHWC
71       // TODO: See if it's possible to use those directly.
72       if (fmt == 0) {
73         fixed_inputs.push_back(tensorInp.get(i).contiguous());
74       } else if (fmt == 1) {
75         fixed_inputs.push_back(
76             tensorInp.get(i).permute({0, 2, 3, 1}).contiguous());
77       } else {
78         TORCH_CHECK(false, "Invalid mem_fmt");
79       }
80     }
81 
82     comp_->run(fixed_inputs, outputs.vec());
83 
84     // Adjust output memory formats
85     auto out_mem_fmts = dict.at("out_mem_fmts").toIntList();
86     TORCH_CHECK(outputs.size() == out_mem_fmts.size());
87     for (auto i = 0U; i < outputs.size(); i++) {
88       int fmt = out_mem_fmts[i];
89       // These constants match the values in DimOrder in serializer.py
90       // 0: NCHW, 1: NHWC
91       // TODO: See if it's possible to use those directly.
92       if (fmt == 1) {
93         outputs.set(i, outputs.get(i).permute({0, 3, 1, 2}));
94       } else {
95         TORCH_CHECK(fmt == 0, "Invalid mem_fmt");
96       }
97     }
98 
99     return c10::impl::toList(outputs);
100   }
101 
102  private:
103   // The following variables are modified by init() during execution,
104   // and cannot be passed through the handles dictionary
105   std::unique_ptr<torch::nnapi::bind::NnapiCompilation> comp_;
106   c10::List<at::Tensor> out_templates_;
107 
108   // Runs once per model initialization
109   // Cannot be moved to compile(), because init() requires actual inputs
init(c10::IValue handle,c10::List<at::Tensor> inputs)110   void init(c10::IValue handle, c10::List<at::Tensor> inputs) {
111     TORCH_CHECK(comp_ == nullptr);
112     auto dict = handle.toGenericDict();
113 
114     // Get ser_model
115     auto ser_model = dict.at("ser_model").toTensor();
116     // Load shape computation module
117     std::stringstream ss;
118     auto shape_ptr = dict.at("shape_compute_module").toString();
119     ss.str(*shape_ptr);
120     auto shape_compute_module = _load_for_mobile(ss);
121     out_templates_ =
122         shape_compute_module.run_method("prepare", ser_model, inputs)
123             .toTensorList();
124 
125     // Create and initialize NnapiComilation object
126     comp_ = std::make_unique<torch::nnapi::bind::NnapiCompilation>();
127     auto weights = dict.at("weights").toTensorVector();
128     comp_->init(ser_model, weights);
129   }
130 };
131 
132 namespace {
133 constexpr auto backend_name = "nnapi";
134 static auto cls = torch::jit::backend<NnapiBackend>(backend_name);
135 } // namespace
136 
137 } // namespace jit
138 } // namespace torch
139