1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright 2023-2024 Arm Limited and/or its affiliates.
3*523fa7a6SAndroid Build Coastguard Worker *
4*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker */
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Worker /*
9*523fa7a6SAndroid Build Coastguard Worker * Arm backend for Ethos-U baremetal driver stack, this relies on the
10*523fa7a6SAndroid Build Coastguard Worker * ethos-u-core-driver for hardware interaction.
11*523fa7a6SAndroid Build Coastguard Worker */
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Worker #include <cstring>
14*523fa7a6SAndroid Build Coastguard Worker #include <memory>
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker #include <ethosu_driver.h>
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/backends/arm/runtime/VelaBinStream.h>
19*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/backend/interface.h>
20*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/error.h>
21*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/evalue.h>
22*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
23*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker using namespace std;
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Worker using executorch::aten::ScalarType;
28*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::ArrayRef;
29*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Backend;
30*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::BackendExecutionContext;
31*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::BackendInitContext;
32*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::CompileSpec;
33*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::DelegateHandle;
34*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Error;
35*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::EValue;
36*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::FreeableBuffer;
37*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::MemoryAllocator;
38*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Result;
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
41*523fa7a6SAndroid Build Coastguard Worker namespace backends {
42*523fa7a6SAndroid Build Coastguard Worker namespace arm {
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker typedef struct {
45*523fa7a6SAndroid Build Coastguard Worker FreeableBuffer* processed;
46*523fa7a6SAndroid Build Coastguard Worker bool permuted_io_flag;
47*523fa7a6SAndroid Build Coastguard Worker } ExecutionHandle;
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker extern "C" {
ArmBackend_execute_begin()50*523fa7a6SAndroid Build Coastguard Worker void __attribute__((weak)) ArmBackend_execute_begin() {}
ArmBackend_execute_end()51*523fa7a6SAndroid Build Coastguard Worker void __attribute__((weak)) ArmBackend_execute_end() {}
52*523fa7a6SAndroid Build Coastguard Worker }
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker class ArmBackendExecuteCallbacks {
55*523fa7a6SAndroid Build Coastguard Worker public:
ArmBackendExecuteCallbacks()56*523fa7a6SAndroid Build Coastguard Worker ArmBackendExecuteCallbacks() {
57*523fa7a6SAndroid Build Coastguard Worker ArmBackend_execute_begin();
58*523fa7a6SAndroid Build Coastguard Worker }
~ArmBackendExecuteCallbacks()59*523fa7a6SAndroid Build Coastguard Worker ~ArmBackendExecuteCallbacks() {
60*523fa7a6SAndroid Build Coastguard Worker ArmBackend_execute_end();
61*523fa7a6SAndroid Build Coastguard Worker }
62*523fa7a6SAndroid Build Coastguard Worker };
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker class ArmBackend final : public ::executorch::runtime::BackendInterface {
65*523fa7a6SAndroid Build Coastguard Worker public:
ArmBackend()66*523fa7a6SAndroid Build Coastguard Worker ArmBackend() {}
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Worker ~ArmBackend() = default;
69*523fa7a6SAndroid Build Coastguard Worker
is_available() const70*523fa7a6SAndroid Build Coastguard Worker virtual bool is_available() const override {
71*523fa7a6SAndroid Build Coastguard Worker // TODO: revise to use a register check/init function
72*523fa7a6SAndroid Build Coastguard Worker return 1;
73*523fa7a6SAndroid Build Coastguard Worker }
74*523fa7a6SAndroid Build Coastguard Worker
init(BackendInitContext & context,FreeableBuffer * processed,ArrayRef<CompileSpec> compile_specs) const75*523fa7a6SAndroid Build Coastguard Worker Result<DelegateHandle*> init(
76*523fa7a6SAndroid Build Coastguard Worker BackendInitContext& context,
77*523fa7a6SAndroid Build Coastguard Worker FreeableBuffer* processed,
78*523fa7a6SAndroid Build Coastguard Worker ArrayRef<CompileSpec> compile_specs) const override {
79*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Info, "ArmBackend::init %p", processed->data());
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker char* data = (char*)processed->data();
82*523fa7a6SAndroid Build Coastguard Worker size_t size = processed->size();
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker // Verify format of vela_bin
85*523fa7a6SAndroid Build Coastguard Worker if (vela_bin_validate(data, size) == false) {
86*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Error, "Malformed vela_bin_stream found");
87*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidProgram;
88*523fa7a6SAndroid Build Coastguard Worker }
89*523fa7a6SAndroid Build Coastguard Worker
90*523fa7a6SAndroid Build Coastguard Worker MemoryAllocator* allocator = context.get_runtime_allocator();
91*523fa7a6SAndroid Build Coastguard Worker ExecutionHandle* handle =
92*523fa7a6SAndroid Build Coastguard Worker ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(allocator, ExecutionHandle);
93*523fa7a6SAndroid Build Coastguard Worker handle->processed = processed;
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Worker handle->permuted_io_flag = false;
96*523fa7a6SAndroid Build Coastguard Worker for (auto& compile_spec : compile_specs) {
97*523fa7a6SAndroid Build Coastguard Worker if (0 == std::strcmp(compile_spec.key, "permute_memory_format") &&
98*523fa7a6SAndroid Build Coastguard Worker 0 == std::memcmp(compile_spec.value.buffer, "nhwc", 4)) {
99*523fa7a6SAndroid Build Coastguard Worker handle->permuted_io_flag = true;
100*523fa7a6SAndroid Build Coastguard Worker }
101*523fa7a6SAndroid Build Coastguard Worker }
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker // Return the same buffer we were passed - this data will be
104*523fa7a6SAndroid Build Coastguard Worker // executed directly
105*523fa7a6SAndroid Build Coastguard Worker return handle;
106*523fa7a6SAndroid Build Coastguard Worker }
107*523fa7a6SAndroid Build Coastguard Worker
execute(BackendExecutionContext & context,DelegateHandle * input_handle,EValue ** args) const108*523fa7a6SAndroid Build Coastguard Worker Error execute(
109*523fa7a6SAndroid Build Coastguard Worker BackendExecutionContext& context,
110*523fa7a6SAndroid Build Coastguard Worker DelegateHandle* input_handle,
111*523fa7a6SAndroid Build Coastguard Worker EValue** args) const override {
112*523fa7a6SAndroid Build Coastguard Worker ExecutionHandle* execution_handle = (ExecutionHandle*)input_handle;
113*523fa7a6SAndroid Build Coastguard Worker VelaHandles handles;
114*523fa7a6SAndroid Build Coastguard Worker
115*523fa7a6SAndroid Build Coastguard Worker ArmBackendExecuteCallbacks ArmBackend_execute_callbacks;
116*523fa7a6SAndroid Build Coastguard Worker // Command stream - we know at this point it's aligned
117*523fa7a6SAndroid Build Coastguard Worker char* data = (char*)execution_handle->processed->data();
118*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Debug, "ArmBackend::execute %p", data);
119*523fa7a6SAndroid Build Coastguard Worker
120*523fa7a6SAndroid Build Coastguard Worker // Read key sections from the vela_bin_stream
121*523fa7a6SAndroid Build Coastguard Worker if (vela_bin_read(data, &handles, execution_handle->processed->size()) ==
122*523fa7a6SAndroid Build Coastguard Worker false) {
123*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Error, "ArmBackend::vela_read: error, invalid binary layout");
124*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidProgram;
125*523fa7a6SAndroid Build Coastguard Worker }
126*523fa7a6SAndroid Build Coastguard Worker
127*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
128*523fa7a6SAndroid Build Coastguard Worker Debug,
129*523fa7a6SAndroid Build Coastguard Worker "ArmBackend::execute: Running program data:\n cmd %p %zu\n weight %p %zu\n scratch %p %zu\n",
130*523fa7a6SAndroid Build Coastguard Worker handles.cmd_data,
131*523fa7a6SAndroid Build Coastguard Worker handles.cmd_data_size,
132*523fa7a6SAndroid Build Coastguard Worker handles.weight_data,
133*523fa7a6SAndroid Build Coastguard Worker handles.weight_data_size,
134*523fa7a6SAndroid Build Coastguard Worker handles.scratch_data,
135*523fa7a6SAndroid Build Coastguard Worker handles.scratch_data_size);
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Worker // Write argument values (from EValue tensor) into Ethos-U scratch
138*523fa7a6SAndroid Build Coastguard Worker // TODO(MLETORCH-123): Optimise into direct write from Vela into the SRAM
139*523fa7a6SAndroid Build Coastguard Worker // or DRAM output for compatible data layouts.
140*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < handles.inputs->count; i++) {
141*523fa7a6SAndroid Build Coastguard Worker auto tensor_in = args[i]->toTensor();
142*523fa7a6SAndroid Build Coastguard Worker char* scratch_addr = handles.scratch_data + handles.inputs->io[i].offset;
143*523fa7a6SAndroid Build Coastguard Worker
144*523fa7a6SAndroid Build Coastguard Worker // We accept:
145*523fa7a6SAndroid Build Coastguard Worker bool supported = 0;
146*523fa7a6SAndroid Build Coastguard Worker // 32 bit int (simple non-quantised test cases)
147*523fa7a6SAndroid Build Coastguard Worker supported |=
148*523fa7a6SAndroid Build Coastguard Worker (tensor_in.scalar_type() == ScalarType::Int and
149*523fa7a6SAndroid Build Coastguard Worker handles.inputs->io[i].elem_size == 4);
150*523fa7a6SAndroid Build Coastguard Worker // 8 bit int (IOQDQ pass prepared networks)
151*523fa7a6SAndroid Build Coastguard Worker supported |=
152*523fa7a6SAndroid Build Coastguard Worker (tensor_in.scalar_type() == ScalarType::Char and
153*523fa7a6SAndroid Build Coastguard Worker handles.inputs->io[i].elem_size == 1);
154*523fa7a6SAndroid Build Coastguard Worker if (!supported) {
155*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
156*523fa7a6SAndroid Build Coastguard Worker Error,
157*523fa7a6SAndroid Build Coastguard Worker "Input %d expected Integer (4 byte) or Char (1 byte) integer inputs, got ScalarType id %s",
158*523fa7a6SAndroid Build Coastguard Worker i,
159*523fa7a6SAndroid Build Coastguard Worker executorch::runtime::toString(tensor_in.scalar_type()));
160*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidProgram;
161*523fa7a6SAndroid Build Coastguard Worker }
162*523fa7a6SAndroid Build Coastguard Worker supported = executorch::runtime::is_contiguous_dim_order(
163*523fa7a6SAndroid Build Coastguard Worker tensor_in.dim_order().data(), tensor_in.dim());
164*523fa7a6SAndroid Build Coastguard Worker if (!supported) {
165*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
166*523fa7a6SAndroid Build Coastguard Worker Error,
167*523fa7a6SAndroid Build Coastguard Worker "Input %d expected contiguous dim_order, but got non-contiguous dim_order",
168*523fa7a6SAndroid Build Coastguard Worker i);
169*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidProgram;
170*523fa7a6SAndroid Build Coastguard Worker }
171*523fa7a6SAndroid Build Coastguard Worker
172*523fa7a6SAndroid Build Coastguard Worker // Select a compatible copy routine including checking for input layouts
173*523fa7a6SAndroid Build Coastguard Worker // which require permutation.
174*523fa7a6SAndroid Build Coastguard Worker bool permuted_input_shape;
175*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
176*523fa7a6SAndroid Build Coastguard Worker i,
177*523fa7a6SAndroid Build Coastguard Worker tensor_in,
178*523fa7a6SAndroid Build Coastguard Worker &handles.inputs->io[i],
179*523fa7a6SAndroid Build Coastguard Worker execution_handle->permuted_io_flag,
180*523fa7a6SAndroid Build Coastguard Worker &permuted_input_shape));
181*523fa7a6SAndroid Build Coastguard Worker bool both_char = tensor_in.scalar_type() == ScalarType::Char and
182*523fa7a6SAndroid Build Coastguard Worker handles.inputs->io[i].elem_size == 1;
183*523fa7a6SAndroid Build Coastguard Worker bool both_int = tensor_in.scalar_type() == ScalarType::Int and
184*523fa7a6SAndroid Build Coastguard Worker handles.inputs->io[i].elem_size == 4;
185*523fa7a6SAndroid Build Coastguard Worker
186*523fa7a6SAndroid Build Coastguard Worker // Select a compatible copy routine
187*523fa7a6SAndroid Build Coastguard Worker if (both_char and permuted_input_shape) {
188*523fa7a6SAndroid Build Coastguard Worker // permuted byte copy CHW to HWC
189*523fa7a6SAndroid Build Coastguard Worker permute_CHW_to_HWC(
190*523fa7a6SAndroid Build Coastguard Worker tensor_in.mutable_data_ptr<char>(),
191*523fa7a6SAndroid Build Coastguard Worker scratch_addr,
192*523fa7a6SAndroid Build Coastguard Worker tensor_in.size(1),
193*523fa7a6SAndroid Build Coastguard Worker tensor_in.size(2),
194*523fa7a6SAndroid Build Coastguard Worker tensor_in.size(3));
195*523fa7a6SAndroid Build Coastguard Worker } else if (both_char or both_int) {
196*523fa7a6SAndroid Build Coastguard Worker // Sizes match and elt size matches so memcpy
197*523fa7a6SAndroid Build Coastguard Worker memcpy(
198*523fa7a6SAndroid Build Coastguard Worker scratch_addr,
199*523fa7a6SAndroid Build Coastguard Worker tensor_in.mutable_data_ptr<char>(),
200*523fa7a6SAndroid Build Coastguard Worker tensor_in.nbytes());
201*523fa7a6SAndroid Build Coastguard Worker } else {
202*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Error, "No matching input copy routine");
203*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidProgram;
204*523fa7a6SAndroid Build Coastguard Worker }
205*523fa7a6SAndroid Build Coastguard Worker }
206*523fa7a6SAndroid Build Coastguard Worker
207*523fa7a6SAndroid Build Coastguard Worker // Allocate driver handle and synchronously invoke driver
208*523fa7a6SAndroid Build Coastguard Worker auto driver =
209*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<ethosu_driver, decltype(ðosu_release_driver)>(
210*523fa7a6SAndroid Build Coastguard Worker ethosu_reserve_driver(), ethosu_release_driver);
211*523fa7a6SAndroid Build Coastguard Worker if (driver == NULL) {
212*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Error, "ArmBackend::execute: ethosu_reserve_driver failed");
213*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidState;
214*523fa7a6SAndroid Build Coastguard Worker }
215*523fa7a6SAndroid Build Coastguard Worker
216*523fa7a6SAndroid Build Coastguard Worker // Ethos-U low level driver expected order for Ethos U-55, we have
217*523fa7a6SAndroid Build Coastguard Worker // constant weight data, then scratch (which contains input and output)
218*523fa7a6SAndroid Build Coastguard Worker // scratch is written above in this function.
219*523fa7a6SAndroid Build Coastguard Worker uint64_t bases[2] = {
220*523fa7a6SAndroid Build Coastguard Worker (uint64_t)handles.weight_data, (uint64_t)handles.scratch_data};
221*523fa7a6SAndroid Build Coastguard Worker size_t bases_size[2] = {
222*523fa7a6SAndroid Build Coastguard Worker handles.weight_data_size, handles.scratch_data_size};
223*523fa7a6SAndroid Build Coastguard Worker int result = ethosu_invoke_v3(
224*523fa7a6SAndroid Build Coastguard Worker driver.get(),
225*523fa7a6SAndroid Build Coastguard Worker (void*)handles.cmd_data,
226*523fa7a6SAndroid Build Coastguard Worker handles.cmd_data_size,
227*523fa7a6SAndroid Build Coastguard Worker bases,
228*523fa7a6SAndroid Build Coastguard Worker bases_size,
229*523fa7a6SAndroid Build Coastguard Worker 2, /* fixed array of pointers to binary interface*/
230*523fa7a6SAndroid Build Coastguard Worker nullptr);
231*523fa7a6SAndroid Build Coastguard Worker
232*523fa7a6SAndroid Build Coastguard Worker if (result != 0) {
233*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
234*523fa7a6SAndroid Build Coastguard Worker Error,
235*523fa7a6SAndroid Build Coastguard Worker "ArmBackend::execute: Ethos-U invocation failed error (%d)",
236*523fa7a6SAndroid Build Coastguard Worker result);
237*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidProgram;
238*523fa7a6SAndroid Build Coastguard Worker }
239*523fa7a6SAndroid Build Coastguard Worker
240*523fa7a6SAndroid Build Coastguard Worker // Write outputs from scratch into EValue pointers
241*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < handles.outputs->count; i++) {
242*523fa7a6SAndroid Build Coastguard Worker const char* output_addr =
243*523fa7a6SAndroid Build Coastguard Worker handles.scratch_data + handles.outputs->io[i].offset;
244*523fa7a6SAndroid Build Coastguard Worker // Process input EValue into scratch
245*523fa7a6SAndroid Build Coastguard Worker // Outputs are in the index immediately after inputs
246*523fa7a6SAndroid Build Coastguard Worker auto tensor_out = args[handles.inputs->count + i]->toTensor();
247*523fa7a6SAndroid Build Coastguard Worker bool permuted_output_shape;
248*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
249*523fa7a6SAndroid Build Coastguard Worker i,
250*523fa7a6SAndroid Build Coastguard Worker tensor_out,
251*523fa7a6SAndroid Build Coastguard Worker &handles.outputs->io[i],
252*523fa7a6SAndroid Build Coastguard Worker execution_handle->permuted_io_flag,
253*523fa7a6SAndroid Build Coastguard Worker &permuted_output_shape));
254*523fa7a6SAndroid Build Coastguard Worker if (tensor_out.scalar_type() == ScalarType::Char and
255*523fa7a6SAndroid Build Coastguard Worker permuted_output_shape) {
256*523fa7a6SAndroid Build Coastguard Worker char* output_address = (char*)output_addr;
257*523fa7a6SAndroid Build Coastguard Worker permute_HWC_to_CHW(
258*523fa7a6SAndroid Build Coastguard Worker output_address,
259*523fa7a6SAndroid Build Coastguard Worker tensor_out.mutable_data_ptr<char>(),
260*523fa7a6SAndroid Build Coastguard Worker tensor_out.size(1),
261*523fa7a6SAndroid Build Coastguard Worker tensor_out.size(2),
262*523fa7a6SAndroid Build Coastguard Worker tensor_out.size(3));
263*523fa7a6SAndroid Build Coastguard Worker } else {
264*523fa7a6SAndroid Build Coastguard Worker for (int j = 0; j < tensor_out.numel(); j++) {
265*523fa7a6SAndroid Build Coastguard Worker if (tensor_out.scalar_type() == ScalarType::Char) {
266*523fa7a6SAndroid Build Coastguard Worker char* output_address = (char*)output_addr;
267*523fa7a6SAndroid Build Coastguard Worker tensor_out.mutable_data_ptr<char>()[j] = output_address[j];
268*523fa7a6SAndroid Build Coastguard Worker } else {
269*523fa7a6SAndroid Build Coastguard Worker int* output_address = (int*)output_addr;
270*523fa7a6SAndroid Build Coastguard Worker tensor_out.mutable_data_ptr<int>()[j] = output_address[j];
271*523fa7a6SAndroid Build Coastguard Worker }
272*523fa7a6SAndroid Build Coastguard Worker }
273*523fa7a6SAndroid Build Coastguard Worker }
274*523fa7a6SAndroid Build Coastguard Worker }
275*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
276*523fa7a6SAndroid Build Coastguard Worker }
277*523fa7a6SAndroid Build Coastguard Worker
destroy(DelegateHandle * handle) const278*523fa7a6SAndroid Build Coastguard Worker void destroy(DelegateHandle* handle) const override {
279*523fa7a6SAndroid Build Coastguard Worker return;
280*523fa7a6SAndroid Build Coastguard Worker }
281*523fa7a6SAndroid Build Coastguard Worker
282*523fa7a6SAndroid Build Coastguard Worker private:
check_requires_permute(int index,const executorch::aten::Tensor tensor,VelaIO * io,bool permuted_io_flag,bool * is_permuted) const283*523fa7a6SAndroid Build Coastguard Worker Error check_requires_permute(
284*523fa7a6SAndroid Build Coastguard Worker int index,
285*523fa7a6SAndroid Build Coastguard Worker const executorch::aten::Tensor tensor,
286*523fa7a6SAndroid Build Coastguard Worker VelaIO* io,
287*523fa7a6SAndroid Build Coastguard Worker bool permuted_io_flag,
288*523fa7a6SAndroid Build Coastguard Worker bool* is_permuted) const {
289*523fa7a6SAndroid Build Coastguard Worker bool permuted_shape = false;
290*523fa7a6SAndroid Build Coastguard Worker if (tensor.dim() == 4) {
291*523fa7a6SAndroid Build Coastguard Worker // special case for NHWC workaround in AOT; as the compilation has
292*523fa7a6SAndroid Build Coastguard Worker // permuted to channel last in an undetectable way, we assume here
293*523fa7a6SAndroid Build Coastguard Worker // that the application has similarly permuted any input/output tensors.
294*523fa7a6SAndroid Build Coastguard Worker permuted_shape = tensor.size(0) == io->shape[0] &&
295*523fa7a6SAndroid Build Coastguard Worker tensor.size(1) == io->shape[3] && tensor.size(2) == io->shape[1] &&
296*523fa7a6SAndroid Build Coastguard Worker tensor.size(3) == io->shape[2];
297*523fa7a6SAndroid Build Coastguard Worker if (permuted_shape) {
298*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
299*523fa7a6SAndroid Build Coastguard Worker }
300*523fa7a6SAndroid Build Coastguard Worker if (permuted_io_flag != permuted_shape) {
301*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
302*523fa7a6SAndroid Build Coastguard Worker Error,
303*523fa7a6SAndroid Build Coastguard Worker "Permute compile flag and permuted input/output don't agree");
304*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidProgram;
305*523fa7a6SAndroid Build Coastguard Worker }
306*523fa7a6SAndroid Build Coastguard Worker }
307*523fa7a6SAndroid Build Coastguard Worker if (!permuted_shape) {
308*523fa7a6SAndroid Build Coastguard Worker // Check the number of elements in each tensor match
309*523fa7a6SAndroid Build Coastguard Worker int tensor_count = 1;
310*523fa7a6SAndroid Build Coastguard Worker int io_count = 1;
311*523fa7a6SAndroid Build Coastguard Worker
312*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < tensor.dim(); i++) {
313*523fa7a6SAndroid Build Coastguard Worker tensor_count = tensor_count * tensor.size(i);
314*523fa7a6SAndroid Build Coastguard Worker }
315*523fa7a6SAndroid Build Coastguard Worker
316*523fa7a6SAndroid Build Coastguard Worker // The VelaIO type has a shape of fixed size 4
317*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < 4; i++) {
318*523fa7a6SAndroid Build Coastguard Worker io_count = io_count * io->shape[i];
319*523fa7a6SAndroid Build Coastguard Worker }
320*523fa7a6SAndroid Build Coastguard Worker
321*523fa7a6SAndroid Build Coastguard Worker if (tensor_count != io_count) {
322*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Error, "Input tensor sizes do not match");
323*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
324*523fa7a6SAndroid Build Coastguard Worker Error,
325*523fa7a6SAndroid Build Coastguard Worker "Program expects %d elements but got %d",
326*523fa7a6SAndroid Build Coastguard Worker io_count,
327*523fa7a6SAndroid Build Coastguard Worker tensor_count);
328*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidProgram;
329*523fa7a6SAndroid Build Coastguard Worker }
330*523fa7a6SAndroid Build Coastguard Worker }
331*523fa7a6SAndroid Build Coastguard Worker *is_permuted = permuted_shape;
332*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
333*523fa7a6SAndroid Build Coastguard Worker }
334*523fa7a6SAndroid Build Coastguard Worker
permute_CHW_to_HWC(char * input,char * output,int C,int H,int W) const335*523fa7a6SAndroid Build Coastguard Worker void permute_CHW_to_HWC(char* input, char* output, int C, int H, int W)
336*523fa7a6SAndroid Build Coastguard Worker const {
337*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i != H * W; ++i) {
338*523fa7a6SAndroid Build Coastguard Worker for (int j = 0; j < C; ++j) {
339*523fa7a6SAndroid Build Coastguard Worker output[i * C + j] = input[i + j * W * H];
340*523fa7a6SAndroid Build Coastguard Worker }
341*523fa7a6SAndroid Build Coastguard Worker }
342*523fa7a6SAndroid Build Coastguard Worker }
343*523fa7a6SAndroid Build Coastguard Worker
permute_HWC_to_CHW(char * input,char * output,int C,int H,int W) const344*523fa7a6SAndroid Build Coastguard Worker void permute_HWC_to_CHW(char* input, char* output, int C, int H, int W)
345*523fa7a6SAndroid Build Coastguard Worker const {
346*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i != H * W; ++i) {
347*523fa7a6SAndroid Build Coastguard Worker for (int j = 0; j < C; ++j) {
348*523fa7a6SAndroid Build Coastguard Worker output[i + j * W * H] = input[i * C + j];
349*523fa7a6SAndroid Build Coastguard Worker }
350*523fa7a6SAndroid Build Coastguard Worker }
351*523fa7a6SAndroid Build Coastguard Worker }
352*523fa7a6SAndroid Build Coastguard Worker };
353*523fa7a6SAndroid Build Coastguard Worker
354*523fa7a6SAndroid Build Coastguard Worker namespace {
355*523fa7a6SAndroid Build Coastguard Worker auto backend = ArmBackend();
356*523fa7a6SAndroid Build Coastguard Worker Backend backend_id{"ArmBackend", &backend};
357*523fa7a6SAndroid Build Coastguard Worker static auto registered = register_backend(backend_id);
358*523fa7a6SAndroid Build Coastguard Worker } // namespace
359*523fa7a6SAndroid Build Coastguard Worker
360*523fa7a6SAndroid Build Coastguard Worker } // namespace arm
361*523fa7a6SAndroid Build Coastguard Worker } // namespace backends
362*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
363