xref: /aosp_15_r20/external/executorch/backends/arm/runtime/ArmBackendEthosU.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright 2023-2024 Arm Limited and/or its affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  *
4*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker  */
7*523fa7a6SAndroid Build Coastguard Worker 
8*523fa7a6SAndroid Build Coastguard Worker /*
9*523fa7a6SAndroid Build Coastguard Worker  * Arm backend for Ethos-U baremetal driver stack, this relies on the
10*523fa7a6SAndroid Build Coastguard Worker  * ethos-u-core-driver for hardware interaction.
11*523fa7a6SAndroid Build Coastguard Worker  */
12*523fa7a6SAndroid Build Coastguard Worker 
13*523fa7a6SAndroid Build Coastguard Worker #include <cstring>
14*523fa7a6SAndroid Build Coastguard Worker #include <memory>
15*523fa7a6SAndroid Build Coastguard Worker 
16*523fa7a6SAndroid Build Coastguard Worker #include <ethosu_driver.h>
17*523fa7a6SAndroid Build Coastguard Worker 
18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/backends/arm/runtime/VelaBinStream.h>
19*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/backend/interface.h>
20*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/error.h>
21*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/evalue.h>
22*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
23*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
24*523fa7a6SAndroid Build Coastguard Worker 
25*523fa7a6SAndroid Build Coastguard Worker using namespace std;
26*523fa7a6SAndroid Build Coastguard Worker 
27*523fa7a6SAndroid Build Coastguard Worker using executorch::aten::ScalarType;
28*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::ArrayRef;
29*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Backend;
30*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::BackendExecutionContext;
31*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::BackendInitContext;
32*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::CompileSpec;
33*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::DelegateHandle;
34*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Error;
35*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::EValue;
36*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::FreeableBuffer;
37*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::MemoryAllocator;
38*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Result;
39*523fa7a6SAndroid Build Coastguard Worker 
40*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
41*523fa7a6SAndroid Build Coastguard Worker namespace backends {
42*523fa7a6SAndroid Build Coastguard Worker namespace arm {
43*523fa7a6SAndroid Build Coastguard Worker 
44*523fa7a6SAndroid Build Coastguard Worker typedef struct {
45*523fa7a6SAndroid Build Coastguard Worker   FreeableBuffer* processed;
46*523fa7a6SAndroid Build Coastguard Worker   bool permuted_io_flag;
47*523fa7a6SAndroid Build Coastguard Worker } ExecutionHandle;
48*523fa7a6SAndroid Build Coastguard Worker 
49*523fa7a6SAndroid Build Coastguard Worker extern "C" {
ArmBackend_execute_begin()50*523fa7a6SAndroid Build Coastguard Worker void __attribute__((weak)) ArmBackend_execute_begin() {}
ArmBackend_execute_end()51*523fa7a6SAndroid Build Coastguard Worker void __attribute__((weak)) ArmBackend_execute_end() {}
52*523fa7a6SAndroid Build Coastguard Worker }
53*523fa7a6SAndroid Build Coastguard Worker 
54*523fa7a6SAndroid Build Coastguard Worker class ArmBackendExecuteCallbacks {
55*523fa7a6SAndroid Build Coastguard Worker  public:
ArmBackendExecuteCallbacks()56*523fa7a6SAndroid Build Coastguard Worker   ArmBackendExecuteCallbacks() {
57*523fa7a6SAndroid Build Coastguard Worker     ArmBackend_execute_begin();
58*523fa7a6SAndroid Build Coastguard Worker   }
~ArmBackendExecuteCallbacks()59*523fa7a6SAndroid Build Coastguard Worker   ~ArmBackendExecuteCallbacks() {
60*523fa7a6SAndroid Build Coastguard Worker     ArmBackend_execute_end();
61*523fa7a6SAndroid Build Coastguard Worker   }
62*523fa7a6SAndroid Build Coastguard Worker };
63*523fa7a6SAndroid Build Coastguard Worker 
64*523fa7a6SAndroid Build Coastguard Worker class ArmBackend final : public ::executorch::runtime::BackendInterface {
65*523fa7a6SAndroid Build Coastguard Worker  public:
ArmBackend()66*523fa7a6SAndroid Build Coastguard Worker   ArmBackend() {}
67*523fa7a6SAndroid Build Coastguard Worker 
68*523fa7a6SAndroid Build Coastguard Worker   ~ArmBackend() = default;
69*523fa7a6SAndroid Build Coastguard Worker 
is_available() const70*523fa7a6SAndroid Build Coastguard Worker   virtual bool is_available() const override {
71*523fa7a6SAndroid Build Coastguard Worker     // TODO: revise to use a register check/init function
72*523fa7a6SAndroid Build Coastguard Worker     return 1;
73*523fa7a6SAndroid Build Coastguard Worker   }
74*523fa7a6SAndroid Build Coastguard Worker 
init(BackendInitContext & context,FreeableBuffer * processed,ArrayRef<CompileSpec> compile_specs) const75*523fa7a6SAndroid Build Coastguard Worker   Result<DelegateHandle*> init(
76*523fa7a6SAndroid Build Coastguard Worker       BackendInitContext& context,
77*523fa7a6SAndroid Build Coastguard Worker       FreeableBuffer* processed,
78*523fa7a6SAndroid Build Coastguard Worker       ArrayRef<CompileSpec> compile_specs) const override {
79*523fa7a6SAndroid Build Coastguard Worker     ET_LOG(Info, "ArmBackend::init %p", processed->data());
80*523fa7a6SAndroid Build Coastguard Worker 
81*523fa7a6SAndroid Build Coastguard Worker     char* data = (char*)processed->data();
82*523fa7a6SAndroid Build Coastguard Worker     size_t size = processed->size();
83*523fa7a6SAndroid Build Coastguard Worker 
84*523fa7a6SAndroid Build Coastguard Worker     // Verify format of vela_bin
85*523fa7a6SAndroid Build Coastguard Worker     if (vela_bin_validate(data, size) == false) {
86*523fa7a6SAndroid Build Coastguard Worker       ET_LOG(Error, "Malformed vela_bin_stream found");
87*523fa7a6SAndroid Build Coastguard Worker       return Error::InvalidProgram;
88*523fa7a6SAndroid Build Coastguard Worker     }
89*523fa7a6SAndroid Build Coastguard Worker 
90*523fa7a6SAndroid Build Coastguard Worker     MemoryAllocator* allocator = context.get_runtime_allocator();
91*523fa7a6SAndroid Build Coastguard Worker     ExecutionHandle* handle =
92*523fa7a6SAndroid Build Coastguard Worker         ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(allocator, ExecutionHandle);
93*523fa7a6SAndroid Build Coastguard Worker     handle->processed = processed;
94*523fa7a6SAndroid Build Coastguard Worker 
95*523fa7a6SAndroid Build Coastguard Worker     handle->permuted_io_flag = false;
96*523fa7a6SAndroid Build Coastguard Worker     for (auto& compile_spec : compile_specs) {
97*523fa7a6SAndroid Build Coastguard Worker       if (0 == std::strcmp(compile_spec.key, "permute_memory_format") &&
98*523fa7a6SAndroid Build Coastguard Worker           0 == std::memcmp(compile_spec.value.buffer, "nhwc", 4)) {
99*523fa7a6SAndroid Build Coastguard Worker         handle->permuted_io_flag = true;
100*523fa7a6SAndroid Build Coastguard Worker       }
101*523fa7a6SAndroid Build Coastguard Worker     }
102*523fa7a6SAndroid Build Coastguard Worker 
103*523fa7a6SAndroid Build Coastguard Worker     // Return the same buffer we were passed - this data will be
104*523fa7a6SAndroid Build Coastguard Worker     // executed directly
105*523fa7a6SAndroid Build Coastguard Worker     return handle;
106*523fa7a6SAndroid Build Coastguard Worker   }
107*523fa7a6SAndroid Build Coastguard Worker 
execute(BackendExecutionContext & context,DelegateHandle * input_handle,EValue ** args) const108*523fa7a6SAndroid Build Coastguard Worker   Error execute(
109*523fa7a6SAndroid Build Coastguard Worker       BackendExecutionContext& context,
110*523fa7a6SAndroid Build Coastguard Worker       DelegateHandle* input_handle,
111*523fa7a6SAndroid Build Coastguard Worker       EValue** args) const override {
112*523fa7a6SAndroid Build Coastguard Worker     ExecutionHandle* execution_handle = (ExecutionHandle*)input_handle;
113*523fa7a6SAndroid Build Coastguard Worker     VelaHandles handles;
114*523fa7a6SAndroid Build Coastguard Worker 
115*523fa7a6SAndroid Build Coastguard Worker     ArmBackendExecuteCallbacks ArmBackend_execute_callbacks;
116*523fa7a6SAndroid Build Coastguard Worker     // Command stream - we know at this point it's aligned
117*523fa7a6SAndroid Build Coastguard Worker     char* data = (char*)execution_handle->processed->data();
118*523fa7a6SAndroid Build Coastguard Worker     ET_LOG(Debug, "ArmBackend::execute %p", data);
119*523fa7a6SAndroid Build Coastguard Worker 
120*523fa7a6SAndroid Build Coastguard Worker     // Read key sections from the vela_bin_stream
121*523fa7a6SAndroid Build Coastguard Worker     if (vela_bin_read(data, &handles, execution_handle->processed->size()) ==
122*523fa7a6SAndroid Build Coastguard Worker         false) {
123*523fa7a6SAndroid Build Coastguard Worker       ET_LOG(Error, "ArmBackend::vela_read: error, invalid binary layout");
124*523fa7a6SAndroid Build Coastguard Worker       return Error::InvalidProgram;
125*523fa7a6SAndroid Build Coastguard Worker     }
126*523fa7a6SAndroid Build Coastguard Worker 
127*523fa7a6SAndroid Build Coastguard Worker     ET_LOG(
128*523fa7a6SAndroid Build Coastguard Worker         Debug,
129*523fa7a6SAndroid Build Coastguard Worker         "ArmBackend::execute: Running program data:\n  cmd %p %zu\n  weight %p %zu\n  scratch %p %zu\n",
130*523fa7a6SAndroid Build Coastguard Worker         handles.cmd_data,
131*523fa7a6SAndroid Build Coastguard Worker         handles.cmd_data_size,
132*523fa7a6SAndroid Build Coastguard Worker         handles.weight_data,
133*523fa7a6SAndroid Build Coastguard Worker         handles.weight_data_size,
134*523fa7a6SAndroid Build Coastguard Worker         handles.scratch_data,
135*523fa7a6SAndroid Build Coastguard Worker         handles.scratch_data_size);
136*523fa7a6SAndroid Build Coastguard Worker 
137*523fa7a6SAndroid Build Coastguard Worker     // Write argument values (from EValue tensor) into Ethos-U scratch
138*523fa7a6SAndroid Build Coastguard Worker     // TODO(MLETORCH-123): Optimise into direct write from Vela into the SRAM
139*523fa7a6SAndroid Build Coastguard Worker     //                     or DRAM output for compatible data layouts.
140*523fa7a6SAndroid Build Coastguard Worker     for (int i = 0; i < handles.inputs->count; i++) {
141*523fa7a6SAndroid Build Coastguard Worker       auto tensor_in = args[i]->toTensor();
142*523fa7a6SAndroid Build Coastguard Worker       char* scratch_addr = handles.scratch_data + handles.inputs->io[i].offset;
143*523fa7a6SAndroid Build Coastguard Worker 
144*523fa7a6SAndroid Build Coastguard Worker       // We accept:
145*523fa7a6SAndroid Build Coastguard Worker       bool supported = 0;
146*523fa7a6SAndroid Build Coastguard Worker       // 32 bit int (simple non-quantised test cases)
147*523fa7a6SAndroid Build Coastguard Worker       supported |=
148*523fa7a6SAndroid Build Coastguard Worker           (tensor_in.scalar_type() == ScalarType::Int and
149*523fa7a6SAndroid Build Coastguard Worker            handles.inputs->io[i].elem_size == 4);
150*523fa7a6SAndroid Build Coastguard Worker       // 8 bit int (IOQDQ pass prepared networks)
151*523fa7a6SAndroid Build Coastguard Worker       supported |=
152*523fa7a6SAndroid Build Coastguard Worker           (tensor_in.scalar_type() == ScalarType::Char and
153*523fa7a6SAndroid Build Coastguard Worker            handles.inputs->io[i].elem_size == 1);
154*523fa7a6SAndroid Build Coastguard Worker       if (!supported) {
155*523fa7a6SAndroid Build Coastguard Worker         ET_LOG(
156*523fa7a6SAndroid Build Coastguard Worker             Error,
157*523fa7a6SAndroid Build Coastguard Worker             "Input %d expected Integer (4 byte) or Char (1 byte) integer inputs, got ScalarType id %s",
158*523fa7a6SAndroid Build Coastguard Worker             i,
159*523fa7a6SAndroid Build Coastguard Worker             executorch::runtime::toString(tensor_in.scalar_type()));
160*523fa7a6SAndroid Build Coastguard Worker         return Error::InvalidProgram;
161*523fa7a6SAndroid Build Coastguard Worker       }
162*523fa7a6SAndroid Build Coastguard Worker       supported = executorch::runtime::is_contiguous_dim_order(
163*523fa7a6SAndroid Build Coastguard Worker           tensor_in.dim_order().data(), tensor_in.dim());
164*523fa7a6SAndroid Build Coastguard Worker       if (!supported) {
165*523fa7a6SAndroid Build Coastguard Worker         ET_LOG(
166*523fa7a6SAndroid Build Coastguard Worker             Error,
167*523fa7a6SAndroid Build Coastguard Worker             "Input %d expected contiguous dim_order, but got non-contiguous dim_order",
168*523fa7a6SAndroid Build Coastguard Worker             i);
169*523fa7a6SAndroid Build Coastguard Worker         return Error::InvalidProgram;
170*523fa7a6SAndroid Build Coastguard Worker       }
171*523fa7a6SAndroid Build Coastguard Worker 
172*523fa7a6SAndroid Build Coastguard Worker       // Select a compatible copy routine including checking for input layouts
173*523fa7a6SAndroid Build Coastguard Worker       // which require permutation.
174*523fa7a6SAndroid Build Coastguard Worker       bool permuted_input_shape;
175*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
176*523fa7a6SAndroid Build Coastguard Worker           i,
177*523fa7a6SAndroid Build Coastguard Worker           tensor_in,
178*523fa7a6SAndroid Build Coastguard Worker           &handles.inputs->io[i],
179*523fa7a6SAndroid Build Coastguard Worker           execution_handle->permuted_io_flag,
180*523fa7a6SAndroid Build Coastguard Worker           &permuted_input_shape));
181*523fa7a6SAndroid Build Coastguard Worker       bool both_char = tensor_in.scalar_type() == ScalarType::Char and
182*523fa7a6SAndroid Build Coastguard Worker           handles.inputs->io[i].elem_size == 1;
183*523fa7a6SAndroid Build Coastguard Worker       bool both_int = tensor_in.scalar_type() == ScalarType::Int and
184*523fa7a6SAndroid Build Coastguard Worker           handles.inputs->io[i].elem_size == 4;
185*523fa7a6SAndroid Build Coastguard Worker 
186*523fa7a6SAndroid Build Coastguard Worker       // Select a compatible copy routine
187*523fa7a6SAndroid Build Coastguard Worker       if (both_char and permuted_input_shape) {
188*523fa7a6SAndroid Build Coastguard Worker         // permuted byte copy CHW to HWC
189*523fa7a6SAndroid Build Coastguard Worker         permute_CHW_to_HWC(
190*523fa7a6SAndroid Build Coastguard Worker             tensor_in.mutable_data_ptr<char>(),
191*523fa7a6SAndroid Build Coastguard Worker             scratch_addr,
192*523fa7a6SAndroid Build Coastguard Worker             tensor_in.size(1),
193*523fa7a6SAndroid Build Coastguard Worker             tensor_in.size(2),
194*523fa7a6SAndroid Build Coastguard Worker             tensor_in.size(3));
195*523fa7a6SAndroid Build Coastguard Worker       } else if (both_char or both_int) {
196*523fa7a6SAndroid Build Coastguard Worker         // Sizes match and elt size matches so memcpy
197*523fa7a6SAndroid Build Coastguard Worker         memcpy(
198*523fa7a6SAndroid Build Coastguard Worker             scratch_addr,
199*523fa7a6SAndroid Build Coastguard Worker             tensor_in.mutable_data_ptr<char>(),
200*523fa7a6SAndroid Build Coastguard Worker             tensor_in.nbytes());
201*523fa7a6SAndroid Build Coastguard Worker       } else {
202*523fa7a6SAndroid Build Coastguard Worker         ET_LOG(Error, "No matching input copy routine");
203*523fa7a6SAndroid Build Coastguard Worker         return Error::InvalidProgram;
204*523fa7a6SAndroid Build Coastguard Worker       }
205*523fa7a6SAndroid Build Coastguard Worker     }
206*523fa7a6SAndroid Build Coastguard Worker 
207*523fa7a6SAndroid Build Coastguard Worker     // Allocate driver handle and synchronously invoke driver
208*523fa7a6SAndroid Build Coastguard Worker     auto driver =
209*523fa7a6SAndroid Build Coastguard Worker         std::unique_ptr<ethosu_driver, decltype(&ethosu_release_driver)>(
210*523fa7a6SAndroid Build Coastguard Worker             ethosu_reserve_driver(), ethosu_release_driver);
211*523fa7a6SAndroid Build Coastguard Worker     if (driver == NULL) {
212*523fa7a6SAndroid Build Coastguard Worker       ET_LOG(Error, "ArmBackend::execute: ethosu_reserve_driver failed");
213*523fa7a6SAndroid Build Coastguard Worker       return Error::InvalidState;
214*523fa7a6SAndroid Build Coastguard Worker     }
215*523fa7a6SAndroid Build Coastguard Worker 
216*523fa7a6SAndroid Build Coastguard Worker     // Ethos-U low level driver expected order for Ethos U-55, we have
217*523fa7a6SAndroid Build Coastguard Worker     // constant weight data, then scratch (which contains input and output)
218*523fa7a6SAndroid Build Coastguard Worker     // scratch is written above in this function.
219*523fa7a6SAndroid Build Coastguard Worker     uint64_t bases[2] = {
220*523fa7a6SAndroid Build Coastguard Worker         (uint64_t)handles.weight_data, (uint64_t)handles.scratch_data};
221*523fa7a6SAndroid Build Coastguard Worker     size_t bases_size[2] = {
222*523fa7a6SAndroid Build Coastguard Worker         handles.weight_data_size, handles.scratch_data_size};
223*523fa7a6SAndroid Build Coastguard Worker     int result = ethosu_invoke_v3(
224*523fa7a6SAndroid Build Coastguard Worker         driver.get(),
225*523fa7a6SAndroid Build Coastguard Worker         (void*)handles.cmd_data,
226*523fa7a6SAndroid Build Coastguard Worker         handles.cmd_data_size,
227*523fa7a6SAndroid Build Coastguard Worker         bases,
228*523fa7a6SAndroid Build Coastguard Worker         bases_size,
229*523fa7a6SAndroid Build Coastguard Worker         2, /* fixed array of pointers to binary interface*/
230*523fa7a6SAndroid Build Coastguard Worker         nullptr);
231*523fa7a6SAndroid Build Coastguard Worker 
232*523fa7a6SAndroid Build Coastguard Worker     if (result != 0) {
233*523fa7a6SAndroid Build Coastguard Worker       ET_LOG(
234*523fa7a6SAndroid Build Coastguard Worker           Error,
235*523fa7a6SAndroid Build Coastguard Worker           "ArmBackend::execute: Ethos-U invocation failed error (%d)",
236*523fa7a6SAndroid Build Coastguard Worker           result);
237*523fa7a6SAndroid Build Coastguard Worker       return Error::InvalidProgram;
238*523fa7a6SAndroid Build Coastguard Worker     }
239*523fa7a6SAndroid Build Coastguard Worker 
240*523fa7a6SAndroid Build Coastguard Worker     // Write outputs from scratch into EValue pointers
241*523fa7a6SAndroid Build Coastguard Worker     for (int i = 0; i < handles.outputs->count; i++) {
242*523fa7a6SAndroid Build Coastguard Worker       const char* output_addr =
243*523fa7a6SAndroid Build Coastguard Worker           handles.scratch_data + handles.outputs->io[i].offset;
244*523fa7a6SAndroid Build Coastguard Worker       // Process input EValue into scratch
245*523fa7a6SAndroid Build Coastguard Worker       // Outputs are in the index immediately after inputs
246*523fa7a6SAndroid Build Coastguard Worker       auto tensor_out = args[handles.inputs->count + i]->toTensor();
247*523fa7a6SAndroid Build Coastguard Worker       bool permuted_output_shape;
248*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
249*523fa7a6SAndroid Build Coastguard Worker           i,
250*523fa7a6SAndroid Build Coastguard Worker           tensor_out,
251*523fa7a6SAndroid Build Coastguard Worker           &handles.outputs->io[i],
252*523fa7a6SAndroid Build Coastguard Worker           execution_handle->permuted_io_flag,
253*523fa7a6SAndroid Build Coastguard Worker           &permuted_output_shape));
254*523fa7a6SAndroid Build Coastguard Worker       if (tensor_out.scalar_type() == ScalarType::Char and
255*523fa7a6SAndroid Build Coastguard Worker           permuted_output_shape) {
256*523fa7a6SAndroid Build Coastguard Worker         char* output_address = (char*)output_addr;
257*523fa7a6SAndroid Build Coastguard Worker         permute_HWC_to_CHW(
258*523fa7a6SAndroid Build Coastguard Worker             output_address,
259*523fa7a6SAndroid Build Coastguard Worker             tensor_out.mutable_data_ptr<char>(),
260*523fa7a6SAndroid Build Coastguard Worker             tensor_out.size(1),
261*523fa7a6SAndroid Build Coastguard Worker             tensor_out.size(2),
262*523fa7a6SAndroid Build Coastguard Worker             tensor_out.size(3));
263*523fa7a6SAndroid Build Coastguard Worker       } else {
264*523fa7a6SAndroid Build Coastguard Worker         for (int j = 0; j < tensor_out.numel(); j++) {
265*523fa7a6SAndroid Build Coastguard Worker           if (tensor_out.scalar_type() == ScalarType::Char) {
266*523fa7a6SAndroid Build Coastguard Worker             char* output_address = (char*)output_addr;
267*523fa7a6SAndroid Build Coastguard Worker             tensor_out.mutable_data_ptr<char>()[j] = output_address[j];
268*523fa7a6SAndroid Build Coastguard Worker           } else {
269*523fa7a6SAndroid Build Coastguard Worker             int* output_address = (int*)output_addr;
270*523fa7a6SAndroid Build Coastguard Worker             tensor_out.mutable_data_ptr<int>()[j] = output_address[j];
271*523fa7a6SAndroid Build Coastguard Worker           }
272*523fa7a6SAndroid Build Coastguard Worker         }
273*523fa7a6SAndroid Build Coastguard Worker       }
274*523fa7a6SAndroid Build Coastguard Worker     }
275*523fa7a6SAndroid Build Coastguard Worker     return Error::Ok;
276*523fa7a6SAndroid Build Coastguard Worker   }
277*523fa7a6SAndroid Build Coastguard Worker 
destroy(DelegateHandle * handle) const278*523fa7a6SAndroid Build Coastguard Worker   void destroy(DelegateHandle* handle) const override {
279*523fa7a6SAndroid Build Coastguard Worker     return;
280*523fa7a6SAndroid Build Coastguard Worker   }
281*523fa7a6SAndroid Build Coastguard Worker 
282*523fa7a6SAndroid Build Coastguard Worker  private:
check_requires_permute(int index,const executorch::aten::Tensor tensor,VelaIO * io,bool permuted_io_flag,bool * is_permuted) const283*523fa7a6SAndroid Build Coastguard Worker   Error check_requires_permute(
284*523fa7a6SAndroid Build Coastguard Worker       int index,
285*523fa7a6SAndroid Build Coastguard Worker       const executorch::aten::Tensor tensor,
286*523fa7a6SAndroid Build Coastguard Worker       VelaIO* io,
287*523fa7a6SAndroid Build Coastguard Worker       bool permuted_io_flag,
288*523fa7a6SAndroid Build Coastguard Worker       bool* is_permuted) const {
289*523fa7a6SAndroid Build Coastguard Worker     bool permuted_shape = false;
290*523fa7a6SAndroid Build Coastguard Worker     if (tensor.dim() == 4) {
291*523fa7a6SAndroid Build Coastguard Worker       // special case for NHWC workaround in AOT; as the compilation has
292*523fa7a6SAndroid Build Coastguard Worker       // permuted to channel last in an undetectable way, we assume here
293*523fa7a6SAndroid Build Coastguard Worker       // that the application has similarly permuted any input/output tensors.
294*523fa7a6SAndroid Build Coastguard Worker       permuted_shape = tensor.size(0) == io->shape[0] &&
295*523fa7a6SAndroid Build Coastguard Worker           tensor.size(1) == io->shape[3] && tensor.size(2) == io->shape[1] &&
296*523fa7a6SAndroid Build Coastguard Worker           tensor.size(3) == io->shape[2];
297*523fa7a6SAndroid Build Coastguard Worker       if (permuted_shape) {
298*523fa7a6SAndroid Build Coastguard Worker         ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
299*523fa7a6SAndroid Build Coastguard Worker       }
300*523fa7a6SAndroid Build Coastguard Worker       if (permuted_io_flag != permuted_shape) {
301*523fa7a6SAndroid Build Coastguard Worker         ET_LOG(
302*523fa7a6SAndroid Build Coastguard Worker             Error,
303*523fa7a6SAndroid Build Coastguard Worker             "Permute compile flag and permuted input/output don't agree");
304*523fa7a6SAndroid Build Coastguard Worker         return Error::InvalidProgram;
305*523fa7a6SAndroid Build Coastguard Worker       }
306*523fa7a6SAndroid Build Coastguard Worker     }
307*523fa7a6SAndroid Build Coastguard Worker     if (!permuted_shape) {
308*523fa7a6SAndroid Build Coastguard Worker       // Check the number of elements in each tensor match
309*523fa7a6SAndroid Build Coastguard Worker       int tensor_count = 1;
310*523fa7a6SAndroid Build Coastguard Worker       int io_count = 1;
311*523fa7a6SAndroid Build Coastguard Worker 
312*523fa7a6SAndroid Build Coastguard Worker       for (int i = 0; i < tensor.dim(); i++) {
313*523fa7a6SAndroid Build Coastguard Worker         tensor_count = tensor_count * tensor.size(i);
314*523fa7a6SAndroid Build Coastguard Worker       }
315*523fa7a6SAndroid Build Coastguard Worker 
316*523fa7a6SAndroid Build Coastguard Worker       // The VelaIO type has a shape of fixed size 4
317*523fa7a6SAndroid Build Coastguard Worker       for (int i = 0; i < 4; i++) {
318*523fa7a6SAndroid Build Coastguard Worker         io_count = io_count * io->shape[i];
319*523fa7a6SAndroid Build Coastguard Worker       }
320*523fa7a6SAndroid Build Coastguard Worker 
321*523fa7a6SAndroid Build Coastguard Worker       if (tensor_count != io_count) {
322*523fa7a6SAndroid Build Coastguard Worker         ET_LOG(Error, "Input tensor sizes do not match");
323*523fa7a6SAndroid Build Coastguard Worker         ET_LOG(
324*523fa7a6SAndroid Build Coastguard Worker             Error,
325*523fa7a6SAndroid Build Coastguard Worker             "Program expects %d elements but got %d",
326*523fa7a6SAndroid Build Coastguard Worker             io_count,
327*523fa7a6SAndroid Build Coastguard Worker             tensor_count);
328*523fa7a6SAndroid Build Coastguard Worker         return Error::InvalidProgram;
329*523fa7a6SAndroid Build Coastguard Worker       }
330*523fa7a6SAndroid Build Coastguard Worker     }
331*523fa7a6SAndroid Build Coastguard Worker     *is_permuted = permuted_shape;
332*523fa7a6SAndroid Build Coastguard Worker     return Error::Ok;
333*523fa7a6SAndroid Build Coastguard Worker   }
334*523fa7a6SAndroid Build Coastguard Worker 
permute_CHW_to_HWC(char * input,char * output,int C,int H,int W) const335*523fa7a6SAndroid Build Coastguard Worker   void permute_CHW_to_HWC(char* input, char* output, int C, int H, int W)
336*523fa7a6SAndroid Build Coastguard Worker       const {
337*523fa7a6SAndroid Build Coastguard Worker     for (int i = 0; i != H * W; ++i) {
338*523fa7a6SAndroid Build Coastguard Worker       for (int j = 0; j < C; ++j) {
339*523fa7a6SAndroid Build Coastguard Worker         output[i * C + j] = input[i + j * W * H];
340*523fa7a6SAndroid Build Coastguard Worker       }
341*523fa7a6SAndroid Build Coastguard Worker     }
342*523fa7a6SAndroid Build Coastguard Worker   }
343*523fa7a6SAndroid Build Coastguard Worker 
permute_HWC_to_CHW(char * input,char * output,int C,int H,int W) const344*523fa7a6SAndroid Build Coastguard Worker   void permute_HWC_to_CHW(char* input, char* output, int C, int H, int W)
345*523fa7a6SAndroid Build Coastguard Worker       const {
346*523fa7a6SAndroid Build Coastguard Worker     for (int i = 0; i != H * W; ++i) {
347*523fa7a6SAndroid Build Coastguard Worker       for (int j = 0; j < C; ++j) {
348*523fa7a6SAndroid Build Coastguard Worker         output[i + j * W * H] = input[i * C + j];
349*523fa7a6SAndroid Build Coastguard Worker       }
350*523fa7a6SAndroid Build Coastguard Worker     }
351*523fa7a6SAndroid Build Coastguard Worker   }
352*523fa7a6SAndroid Build Coastguard Worker };
353*523fa7a6SAndroid Build Coastguard Worker 
354*523fa7a6SAndroid Build Coastguard Worker namespace {
355*523fa7a6SAndroid Build Coastguard Worker auto backend = ArmBackend();
356*523fa7a6SAndroid Build Coastguard Worker Backend backend_id{"ArmBackend", &backend};
357*523fa7a6SAndroid Build Coastguard Worker static auto registered = register_backend(backend_id);
358*523fa7a6SAndroid Build Coastguard Worker } // namespace
359*523fa7a6SAndroid Build Coastguard Worker 
360*523fa7a6SAndroid Build Coastguard Worker } // namespace arm
361*523fa7a6SAndroid Build Coastguard Worker } // namespace backends
362*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
363