1 #include <memory> 2 3 #include <ATen/nnapi/nnapi_bind.h> 4 #include <torch/csrc/jit/backends/backend.h> 5 #include <torch/csrc/jit/backends/backend_exception.h> 6 #include <torch/csrc/jit/mobile/import.h> 7 #include <torch/csrc/jit/mobile/module.h> 8 9 namespace torch { 10 namespace jit { 11 12 // Implementation of Android NNAPI Backend delegate 13 14 // The Android Neural Networks API (NNAPI) is an Android C API designed 15 // for running computationally intensive operations for machine learning on 16 // Android devices. The API is available on all Android devices running 17 // Android 8.1 (API level 27) or higher. 18 19 // Implementation is reflective of caffe2/torch/backends/_nnapi/prepare.py 20 // NnapiModule.forward() 21 class NnapiBackend : public PyTorchBackendInterface { 22 public: 23 // Constructor. 24 explicit NnapiBackend() = default; 25 ~NnapiBackend() override = default; 26 is_available()27 bool is_available() override { 28 return true; 29 } 30 compile(c10::IValue processed,c10::impl::GenericDict method_compile_spec)31 c10::impl::GenericDict compile( 32 c10::IValue processed, 33 c10::impl::GenericDict method_compile_spec) override { 34 // Wrap processed in dictionary: {"forward": processed} 35 auto dict = processed.toGenericDict(); 36 c10::Dict<c10::IValue, c10::IValue> handles( 37 c10::StringType::get(), c10::AnyType::get()); 38 handles.insert("forward", dict); 39 return c10::impl::toGenericDict(handles); 40 } 41 execute(c10::IValue handle,c10::impl::GenericList inputs)42 c10::impl::GenericList execute( 43 c10::IValue handle, 44 c10::impl::GenericList inputs) override { 45 // Convert inputs to Tensors 46 c10::List<at::Tensor> tensorInp; 47 for (c10::IValue element : inputs) { 48 tensorInp.push_back(element.toTensor()); 49 } 50 51 // Lazily call init() 52 if (comp_ == nullptr) { 53 init(handle, tensorInp); 54 } 55 TORCH_CHECK(comp_ != nullptr) 56 57 c10::List<at::Tensor> outputs; 58 for (at::Tensor out : out_templates_) { 59 outputs.push_back(at::empty_like(out)); 60 } 61 62 // Adjust input memory formats 63 auto dict = handle.toGenericDict(); 64 auto inp_mem_fmts = dict.at("inp_mem_fmts").toIntList(); 65 TORCH_CHECK(tensorInp.size() == inp_mem_fmts.size()); 66 std::vector<at::Tensor> fixed_inputs; 67 for (auto i = 0U; i < tensorInp.size(); i++) { 68 int fmt = inp_mem_fmts[i]; 69 // These constants match the values in DimOrder in serializer.py 70 // 0: NCHW, 1: NHWC 71 // TODO: See if it's possible to use those directly. 72 if (fmt == 0) { 73 fixed_inputs.push_back(tensorInp.get(i).contiguous()); 74 } else if (fmt == 1) { 75 fixed_inputs.push_back( 76 tensorInp.get(i).permute({0, 2, 3, 1}).contiguous()); 77 } else { 78 TORCH_CHECK(false, "Invalid mem_fmt"); 79 } 80 } 81 82 comp_->run(fixed_inputs, outputs.vec()); 83 84 // Adjust output memory formats 85 auto out_mem_fmts = dict.at("out_mem_fmts").toIntList(); 86 TORCH_CHECK(outputs.size() == out_mem_fmts.size()); 87 for (auto i = 0U; i < outputs.size(); i++) { 88 int fmt = out_mem_fmts[i]; 89 // These constants match the values in DimOrder in serializer.py 90 // 0: NCHW, 1: NHWC 91 // TODO: See if it's possible to use those directly. 92 if (fmt == 1) { 93 outputs.set(i, outputs.get(i).permute({0, 3, 1, 2})); 94 } else { 95 TORCH_CHECK(fmt == 0, "Invalid mem_fmt"); 96 } 97 } 98 99 return c10::impl::toList(outputs); 100 } 101 102 private: 103 // The following variables are modified by init() during execution, 104 // and cannot be passed through the handles dictionary 105 std::unique_ptr<torch::nnapi::bind::NnapiCompilation> comp_; 106 c10::List<at::Tensor> out_templates_; 107 108 // Runs once per model initialization 109 // Cannot be moved to compile(), because init() requires actual inputs init(c10::IValue handle,c10::List<at::Tensor> inputs)110 void init(c10::IValue handle, c10::List<at::Tensor> inputs) { 111 TORCH_CHECK(comp_ == nullptr); 112 auto dict = handle.toGenericDict(); 113 114 // Get ser_model 115 auto ser_model = dict.at("ser_model").toTensor(); 116 // Load shape computation module 117 std::stringstream ss; 118 auto shape_ptr = dict.at("shape_compute_module").toString(); 119 ss.str(*shape_ptr); 120 auto shape_compute_module = _load_for_mobile(ss); 121 out_templates_ = 122 shape_compute_module.run_method("prepare", ser_model, inputs) 123 .toTensorList(); 124 125 // Create and initialize NnapiComilation object 126 comp_ = std::make_unique<torch::nnapi::bind::NnapiCompilation>(); 127 auto weights = dict.at("weights").toTensorVector(); 128 comp_->init(ser_model, weights); 129 } 130 }; 131 132 namespace { 133 constexpr auto backend_name = "nnapi"; 134 static auto cls = torch::jit::backend<NnapiBackend>(backend_name); 135 } // namespace 136 137 } // namespace jit 138 } // namespace torch 139