1 /*
2 * Copyright 2023-2024 Arm Limited and/or its affiliates.
3 *
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 /*
9 * Arm backend for Ethos-U baremetal driver stack, this relies on the
10 * ethos-u-core-driver for hardware interaction.
11 */
12
13 #include <cstring>
14 #include <memory>
15
16 #include <ethosu_driver.h>
17
18 #include <executorch/backends/arm/runtime/VelaBinStream.h>
19 #include <executorch/runtime/backend/interface.h>
20 #include <executorch/runtime/core/error.h>
21 #include <executorch/runtime/core/evalue.h>
22 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
23 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
24
25 using namespace std;
26
27 using executorch::aten::ScalarType;
28 using executorch::runtime::ArrayRef;
29 using executorch::runtime::Backend;
30 using executorch::runtime::BackendExecutionContext;
31 using executorch::runtime::BackendInitContext;
32 using executorch::runtime::CompileSpec;
33 using executorch::runtime::DelegateHandle;
34 using executorch::runtime::Error;
35 using executorch::runtime::EValue;
36 using executorch::runtime::FreeableBuffer;
37 using executorch::runtime::MemoryAllocator;
38 using executorch::runtime::Result;
39
40 namespace executorch {
41 namespace backends {
42 namespace arm {
43
44 typedef struct {
45 FreeableBuffer* processed;
46 bool permuted_io_flag;
47 } ExecutionHandle;
48
49 extern "C" {
ArmBackend_execute_begin()50 void __attribute__((weak)) ArmBackend_execute_begin() {}
ArmBackend_execute_end()51 void __attribute__((weak)) ArmBackend_execute_end() {}
52 }
53
54 class ArmBackendExecuteCallbacks {
55 public:
ArmBackendExecuteCallbacks()56 ArmBackendExecuteCallbacks() {
57 ArmBackend_execute_begin();
58 }
~ArmBackendExecuteCallbacks()59 ~ArmBackendExecuteCallbacks() {
60 ArmBackend_execute_end();
61 }
62 };
63
64 class ArmBackend final : public ::executorch::runtime::BackendInterface {
65 public:
ArmBackend()66 ArmBackend() {}
67
68 ~ArmBackend() = default;
69
is_available() const70 virtual bool is_available() const override {
71 // TODO: revise to use a register check/init function
72 return 1;
73 }
74
init(BackendInitContext & context,FreeableBuffer * processed,ArrayRef<CompileSpec> compile_specs) const75 Result<DelegateHandle*> init(
76 BackendInitContext& context,
77 FreeableBuffer* processed,
78 ArrayRef<CompileSpec> compile_specs) const override {
79 ET_LOG(Info, "ArmBackend::init %p", processed->data());
80
81 char* data = (char*)processed->data();
82 size_t size = processed->size();
83
84 // Verify format of vela_bin
85 if (vela_bin_validate(data, size) == false) {
86 ET_LOG(Error, "Malformed vela_bin_stream found");
87 return Error::InvalidProgram;
88 }
89
90 MemoryAllocator* allocator = context.get_runtime_allocator();
91 ExecutionHandle* handle =
92 ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(allocator, ExecutionHandle);
93 handle->processed = processed;
94
95 handle->permuted_io_flag = false;
96 for (auto& compile_spec : compile_specs) {
97 if (0 == std::strcmp(compile_spec.key, "permute_memory_format") &&
98 0 == std::memcmp(compile_spec.value.buffer, "nhwc", 4)) {
99 handle->permuted_io_flag = true;
100 }
101 }
102
103 // Return the same buffer we were passed - this data will be
104 // executed directly
105 return handle;
106 }
107
execute(BackendExecutionContext & context,DelegateHandle * input_handle,EValue ** args) const108 Error execute(
109 BackendExecutionContext& context,
110 DelegateHandle* input_handle,
111 EValue** args) const override {
112 ExecutionHandle* execution_handle = (ExecutionHandle*)input_handle;
113 VelaHandles handles;
114
115 ArmBackendExecuteCallbacks ArmBackend_execute_callbacks;
116 // Command stream - we know at this point it's aligned
117 char* data = (char*)execution_handle->processed->data();
118 ET_LOG(Debug, "ArmBackend::execute %p", data);
119
120 // Read key sections from the vela_bin_stream
121 if (vela_bin_read(data, &handles, execution_handle->processed->size()) ==
122 false) {
123 ET_LOG(Error, "ArmBackend::vela_read: error, invalid binary layout");
124 return Error::InvalidProgram;
125 }
126
127 ET_LOG(
128 Debug,
129 "ArmBackend::execute: Running program data:\n cmd %p %zu\n weight %p %zu\n scratch %p %zu\n",
130 handles.cmd_data,
131 handles.cmd_data_size,
132 handles.weight_data,
133 handles.weight_data_size,
134 handles.scratch_data,
135 handles.scratch_data_size);
136
137 // Write argument values (from EValue tensor) into Ethos-U scratch
138 // TODO(MLETORCH-123): Optimise into direct write from Vela into the SRAM
139 // or DRAM output for compatible data layouts.
140 for (int i = 0; i < handles.inputs->count; i++) {
141 auto tensor_in = args[i]->toTensor();
142 char* scratch_addr = handles.scratch_data + handles.inputs->io[i].offset;
143
144 // We accept:
145 bool supported = 0;
146 // 32 bit int (simple non-quantised test cases)
147 supported |=
148 (tensor_in.scalar_type() == ScalarType::Int and
149 handles.inputs->io[i].elem_size == 4);
150 // 8 bit int (IOQDQ pass prepared networks)
151 supported |=
152 (tensor_in.scalar_type() == ScalarType::Char and
153 handles.inputs->io[i].elem_size == 1);
154 if (!supported) {
155 ET_LOG(
156 Error,
157 "Input %d expected Integer (4 byte) or Char (1 byte) integer inputs, got ScalarType id %s",
158 i,
159 executorch::runtime::toString(tensor_in.scalar_type()));
160 return Error::InvalidProgram;
161 }
162 supported = executorch::runtime::is_contiguous_dim_order(
163 tensor_in.dim_order().data(), tensor_in.dim());
164 if (!supported) {
165 ET_LOG(
166 Error,
167 "Input %d expected contiguous dim_order, but got non-contiguous dim_order",
168 i);
169 return Error::InvalidProgram;
170 }
171
172 // Select a compatible copy routine including checking for input layouts
173 // which require permutation.
174 bool permuted_input_shape;
175 ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
176 i,
177 tensor_in,
178 &handles.inputs->io[i],
179 execution_handle->permuted_io_flag,
180 &permuted_input_shape));
181 bool both_char = tensor_in.scalar_type() == ScalarType::Char and
182 handles.inputs->io[i].elem_size == 1;
183 bool both_int = tensor_in.scalar_type() == ScalarType::Int and
184 handles.inputs->io[i].elem_size == 4;
185
186 // Select a compatible copy routine
187 if (both_char and permuted_input_shape) {
188 // permuted byte copy CHW to HWC
189 permute_CHW_to_HWC(
190 tensor_in.mutable_data_ptr<char>(),
191 scratch_addr,
192 tensor_in.size(1),
193 tensor_in.size(2),
194 tensor_in.size(3));
195 } else if (both_char or both_int) {
196 // Sizes match and elt size matches so memcpy
197 memcpy(
198 scratch_addr,
199 tensor_in.mutable_data_ptr<char>(),
200 tensor_in.nbytes());
201 } else {
202 ET_LOG(Error, "No matching input copy routine");
203 return Error::InvalidProgram;
204 }
205 }
206
207 // Allocate driver handle and synchronously invoke driver
208 auto driver =
209 std::unique_ptr<ethosu_driver, decltype(ðosu_release_driver)>(
210 ethosu_reserve_driver(), ethosu_release_driver);
211 if (driver == NULL) {
212 ET_LOG(Error, "ArmBackend::execute: ethosu_reserve_driver failed");
213 return Error::InvalidState;
214 }
215
216 // Ethos-U low level driver expected order for Ethos U-55, we have
217 // constant weight data, then scratch (which contains input and output)
218 // scratch is written above in this function.
219 uint64_t bases[2] = {
220 (uint64_t)handles.weight_data, (uint64_t)handles.scratch_data};
221 size_t bases_size[2] = {
222 handles.weight_data_size, handles.scratch_data_size};
223 int result = ethosu_invoke_v3(
224 driver.get(),
225 (void*)handles.cmd_data,
226 handles.cmd_data_size,
227 bases,
228 bases_size,
229 2, /* fixed array of pointers to binary interface*/
230 nullptr);
231
232 if (result != 0) {
233 ET_LOG(
234 Error,
235 "ArmBackend::execute: Ethos-U invocation failed error (%d)",
236 result);
237 return Error::InvalidProgram;
238 }
239
240 // Write outputs from scratch into EValue pointers
241 for (int i = 0; i < handles.outputs->count; i++) {
242 const char* output_addr =
243 handles.scratch_data + handles.outputs->io[i].offset;
244 // Process input EValue into scratch
245 // Outputs are in the index immediately after inputs
246 auto tensor_out = args[handles.inputs->count + i]->toTensor();
247 bool permuted_output_shape;
248 ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
249 i,
250 tensor_out,
251 &handles.outputs->io[i],
252 execution_handle->permuted_io_flag,
253 &permuted_output_shape));
254 if (tensor_out.scalar_type() == ScalarType::Char and
255 permuted_output_shape) {
256 char* output_address = (char*)output_addr;
257 permute_HWC_to_CHW(
258 output_address,
259 tensor_out.mutable_data_ptr<char>(),
260 tensor_out.size(1),
261 tensor_out.size(2),
262 tensor_out.size(3));
263 } else {
264 for (int j = 0; j < tensor_out.numel(); j++) {
265 if (tensor_out.scalar_type() == ScalarType::Char) {
266 char* output_address = (char*)output_addr;
267 tensor_out.mutable_data_ptr<char>()[j] = output_address[j];
268 } else {
269 int* output_address = (int*)output_addr;
270 tensor_out.mutable_data_ptr<int>()[j] = output_address[j];
271 }
272 }
273 }
274 }
275 return Error::Ok;
276 }
277
destroy(DelegateHandle * handle) const278 void destroy(DelegateHandle* handle) const override {
279 return;
280 }
281
282 private:
check_requires_permute(int index,const executorch::aten::Tensor tensor,VelaIO * io,bool permuted_io_flag,bool * is_permuted) const283 Error check_requires_permute(
284 int index,
285 const executorch::aten::Tensor tensor,
286 VelaIO* io,
287 bool permuted_io_flag,
288 bool* is_permuted) const {
289 bool permuted_shape = false;
290 if (tensor.dim() == 4) {
291 // special case for NHWC workaround in AOT; as the compilation has
292 // permuted to channel last in an undetectable way, we assume here
293 // that the application has similarly permuted any input/output tensors.
294 permuted_shape = tensor.size(0) == io->shape[0] &&
295 tensor.size(1) == io->shape[3] && tensor.size(2) == io->shape[1] &&
296 tensor.size(3) == io->shape[2];
297 if (permuted_shape) {
298 ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
299 }
300 if (permuted_io_flag != permuted_shape) {
301 ET_LOG(
302 Error,
303 "Permute compile flag and permuted input/output don't agree");
304 return Error::InvalidProgram;
305 }
306 }
307 if (!permuted_shape) {
308 // Check the number of elements in each tensor match
309 int tensor_count = 1;
310 int io_count = 1;
311
312 for (int i = 0; i < tensor.dim(); i++) {
313 tensor_count = tensor_count * tensor.size(i);
314 }
315
316 // The VelaIO type has a shape of fixed size 4
317 for (int i = 0; i < 4; i++) {
318 io_count = io_count * io->shape[i];
319 }
320
321 if (tensor_count != io_count) {
322 ET_LOG(Error, "Input tensor sizes do not match");
323 ET_LOG(
324 Error,
325 "Program expects %d elements but got %d",
326 io_count,
327 tensor_count);
328 return Error::InvalidProgram;
329 }
330 }
331 *is_permuted = permuted_shape;
332 return Error::Ok;
333 }
334
permute_CHW_to_HWC(char * input,char * output,int C,int H,int W) const335 void permute_CHW_to_HWC(char* input, char* output, int C, int H, int W)
336 const {
337 for (int i = 0; i != H * W; ++i) {
338 for (int j = 0; j < C; ++j) {
339 output[i * C + j] = input[i + j * W * H];
340 }
341 }
342 }
343
permute_HWC_to_CHW(char * input,char * output,int C,int H,int W) const344 void permute_HWC_to_CHW(char* input, char* output, int C, int H, int W)
345 const {
346 for (int i = 0; i != H * W; ++i) {
347 for (int j = 0; j < C; ++j) {
348 output[i + j * W * H] = input[i * C + j];
349 }
350 }
351 }
352 };
353
354 namespace {
355 auto backend = ArmBackend();
356 Backend backend_id{"ArmBackend", &backend};
357 static auto registered = register_backend(backend_id);
358 } // namespace
359
360 } // namespace arm
361 } // namespace backends
362 } // namespace executorch
363