1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/executor/method.h>
10
11 #include <cinttypes> // @donotremove
12 #include <cstdint>
13 #include <cstdio>
14
15 #include <executorch/runtime/backend/interface.h>
16 #include <executorch/runtime/core/event_tracer_hooks.h>
17 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
18 #include <executorch/runtime/core/span.h>
19 #include <executorch/runtime/executor/memory_manager.h>
20 #include <executorch/runtime/executor/platform_memory_allocator.h>
21 #include <executorch/runtime/executor/program.h>
22 #include <executorch/runtime/executor/tensor_parser.h>
23 #include <executorch/runtime/kernel/kernel_runtime_context.h>
24 #include <executorch/runtime/kernel/operator_registry.h>
25 #include <executorch/runtime/platform/assert.h>
26 #include <executorch/runtime/platform/log.h>
27 #include <executorch/runtime/platform/profiler.h>
28 #include <executorch/schema/program_generated.h>
29
30 namespace executorch {
31 namespace runtime {
32
33 using internal::PlatformMemoryAllocator;
34
35 /**
36 * Runtime state for a backend delegate.
37 */
38 // This lint wants to wrap the class in an anonymous namespace, but it must be
39 // visible because it's forward-declared and used in Executor.h.
40 // @lint-ignore CLANGTIDY facebook-hte-ShadowingClass
41 class BackendDelegate final {
42 public:
43 /**
44 * Initializes an already-allocated BackendDelegate from its serialized
45 * representation.
46 *
47 * @param[in] delegate The serialized backend delegate to load.
48 * @param[in] program The serialized program to load from.
49 * @param[in] backend_init_context The context pointer to pass to the
50 * backend's init() method.
51 * @param[out] out The BackendDelegate to initialize.
52 *
53 * @returns Error::Ok if the initialization succeeded, or an error otherwise.
54 */
Init(const executorch_flatbuffer::BackendDelegate & delegate,const Program * program,BackendInitContext & backend_init_context,BackendDelegate * out)55 static Error Init(
56 const executorch_flatbuffer::BackendDelegate& delegate,
57 const Program* program,
58 BackendInitContext& backend_init_context,
59 BackendDelegate* out) {
60 // Look up the backend.
61 ET_CHECK_OR_RETURN_ERROR(
62 delegate.id() != nullptr, InvalidProgram, "Missing backend id");
63 const char* backend_id = delegate.id()->c_str();
64 BackendInterface* backend = get_backend_class(backend_id);
65 ET_CHECK_OR_RETURN_ERROR(
66 backend != nullptr,
67 NotFound,
68 "Backend %s is not registered.",
69 backend_id);
70 ET_CHECK_OR_RETURN_ERROR(
71 backend->is_available(),
72 NotFound,
73 "Backend %s is not available.",
74 backend_id);
75
76 // Get the delegate data.
77 Result<FreeableBuffer> processed_data = GetProcessedData(delegate, program);
78 if (!processed_data.ok()) {
79 ET_LOG(Error, "Failed to load data for backend %s", backend_id);
80 return processed_data.error();
81 }
82
83 // Parse compilation specs from program
84 CompileSpec* compile_specs;
85 Error err = PopulateCompileSpecs(
86 delegate.compile_specs(), backend_init_context, &compile_specs);
87 if (err != Error::Ok) {
88 ET_LOG(Error, "Failed to get compile specs for backend %s", backend_id);
89 return err;
90 }
91 size_t num_compile_specs = delegate.compile_specs()->size();
92
93 out->backend_ = backend;
94 out->handle_ = nullptr;
95 // Pass a pointer to this buffer to the backend. It's safe for the backend
96 // to point its handle to this object, since it will outlive the backend.
97 new (&out->segment_) FreeableBuffer(std::move(processed_data.get()));
98
99 // Initialize the delegate.
100 Result<DelegateHandle*> handle = backend->init(
101 backend_init_context,
102 &out->segment_,
103 ArrayRef<CompileSpec>(compile_specs, num_compile_specs));
104 if (!handle.ok()) {
105 ET_LOG(
106 Error,
107 "Init failed for backend %s: 0x%" PRIx32,
108 backend_id,
109 static_cast<uint32_t>(handle.error()));
110 out->segment_.Free();
111 return handle.error();
112 }
113 out->handle_ = handle.get();
114 return Error::Ok;
115 }
116
~BackendDelegate()117 ~BackendDelegate() {
118 if (backend_ != nullptr) {
119 backend_->destroy(handle_);
120 }
121 }
122
Execute(BackendExecutionContext & backend_execution_context,EValue ** args) const123 Error Execute(
124 BackendExecutionContext& backend_execution_context,
125 EValue** args) const {
126 EXECUTORCH_SCOPE_PROF("delegate_execute");
127 return backend_->execute(backend_execution_context, handle_, args);
128 }
129
130 private:
131 // Not constructible.
132 BackendDelegate() = delete;
133
134 // Disallow copy/move.
135 BackendDelegate(const BackendDelegate&) = delete;
136 BackendDelegate& operator=(const BackendDelegate&) = delete;
137 BackendDelegate(BackendDelegate&&) = delete;
138 BackendDelegate& operator=(BackendDelegate&&) = delete;
139
PopulateCompileSpecs(const flatbuffers::Vector<flatbuffers::Offset<executorch_flatbuffer::CompileSpec>> * compile_specs_in_program,BackendInitContext & backend_init_context,CompileSpec ** out_spec)140 static Error PopulateCompileSpecs(
141 const flatbuffers::Vector<flatbuffers::Offset<
142 executorch_flatbuffer::CompileSpec>>* compile_specs_in_program,
143 BackendInitContext& backend_init_context,
144 CompileSpec** out_spec) {
145 auto number_of_compile_specs = compile_specs_in_program->size();
146
147 CompileSpec* compile_specs_list =
148 backend_init_context.get_runtime_allocator()->allocateList<CompileSpec>(
149 number_of_compile_specs);
150 if (compile_specs_list == nullptr) {
151 return Error::MemoryAllocationFailed;
152 }
153
154 // Initialize the spec list for each method spec
155 for (size_t j = 0; j < number_of_compile_specs; j++) {
156 auto compile_spec_in_program = compile_specs_in_program->Get(j);
157
158 compile_specs_list[j].key = compile_spec_in_program->key()->c_str();
159 compile_specs_list[j].value = {
160 /*buffer*/ static_cast<void*>(
161 const_cast<uint8_t*>(compile_spec_in_program->value()->Data())),
162 /*nbytes*/ compile_spec_in_program->value()->size(),
163 };
164 }
165
166 *out_spec = compile_specs_list;
167 return Error::Ok;
168 }
169
GetProcessedData(const executorch_flatbuffer::BackendDelegate & delegate,const Program * program)170 static Result<FreeableBuffer> GetProcessedData(
171 const executorch_flatbuffer::BackendDelegate& delegate,
172 const Program* program) {
173 const executorch_flatbuffer::BackendDelegateDataReference* processed =
174 delegate.processed();
175 switch (processed->location()) {
176 case executorch_flatbuffer::DataLocation::INLINE: {
177 const void* data;
178 size_t size;
179 Error err = program->get_backend_delegate_data(
180 processed->index(), &data, &size);
181 if (err != Error::Ok) {
182 return err;
183 }
184 return FreeableBuffer(
185 data,
186 size,
187 /*free_fn=*/nullptr);
188 }
189 case executorch_flatbuffer::DataLocation::SEGMENT: {
190 const char* backend_id = delegate.id()->c_str();
191 return program->LoadSegment(DataLoader::SegmentInfo(
192 DataLoader::SegmentInfo::Type::Backend,
193 processed->index(),
194 backend_id));
195 }
196 default:
197 ET_LOG(
198 Error,
199 "Unknown data location %u",
200 static_cast<unsigned int>(processed->location()));
201 return Error::Internal;
202 }
203 }
204
205 FreeableBuffer segment_;
206 const BackendInterface* backend_;
207 DelegateHandle* handle_;
208 };
209
210 /**
211 * Runtime state for a chain of instructions.
212 */
213 struct Chain {
214 /// Pointer to the associated flatbuffer chain.
215 const executorch_flatbuffer::Chain* s_chain_;
216
217 /// Each entry is a list of parameters for a kernel or delegate call.
218 Span<InstructionArgs> argument_lists_;
219 /// Each instruction will have one kernel (not for delegate).
220 OpFunction* kernels_;
221 };
222
223 namespace {
224
gen_instruction_arguments(MemoryAllocator * method_allocator,size_t num_values,EValue * values,size_t num_args,const int32_t * arg_idxs)225 Result<InstructionArgs> gen_instruction_arguments(
226 MemoryAllocator* method_allocator,
227 size_t num_values,
228 EValue* values,
229 size_t num_args,
230 const int32_t* arg_idxs) {
231 EValue** arg_list = method_allocator->allocateList<EValue*>(num_args);
232 if (arg_list == nullptr) {
233 return Error::MemoryAllocationFailed;
234 }
235 for (size_t i = 0; i < num_args; ++i) {
236 int32_t arg_idx = arg_idxs[i];
237 ET_CHECK_OR_RETURN_ERROR(
238 arg_idx < num_values,
239 InvalidProgram,
240 "Arg index %d >= %zu",
241 arg_idx,
242 num_values);
243 arg_list[i] = &values[arg_idx];
244 }
245 return InstructionArgs(arg_list, num_args);
246 }
247
parse_cond_value(const EValue & cond_value)248 Result<bool> parse_cond_value(const EValue& cond_value) {
249 // The cond value attached to the JF instruction at the beginning of an
250 // if/else branch is a Tensor which we parse and decide whether to continue
251 // to execute the if branch or jump to the else branch.
252 // The cond value attached to the JF instruction at the end of the if branch
253 // is a Bool Scalar which resolves to false and points us to the instruction
254 // to jump to which will take us to a point that is after the else branch.
255 if (cond_value.isTensor()) {
256 const exec_aten::Tensor& cond_val = cond_value.toTensor();
257
258 // All the tensors and scalar cond values should be of bool type
259 // currently. If that's not the case then something is wrong in the model
260 // and we should exit.
261 ET_CHECK_OR_RETURN_ERROR(
262 exec_aten::ScalarType::Bool == cond_val.scalar_type(),
263 InvalidProgram,
264 "Expected dtype of %" PRId8 " got %" PRId8,
265 static_cast<int8_t>(exec_aten::ScalarType::Bool),
266 static_cast<int8_t>(cond_val.scalar_type()));
267
268 const bool* cond_data = cond_val.const_data_ptr<bool>();
269 for (size_t i = 0; i < cond_val.numel(); i++) {
270 if (!cond_data[i]) {
271 return false;
272 }
273 }
274 } else if (cond_value.isBool()) {
275 if (!cond_value.toBool()) {
276 return false;
277 }
278 } else {
279 ET_LOG(
280 Error, "Unsupported JF EValue type %" PRIu32, (uint32_t)cond_value.tag);
281 return Error::InvalidProgram;
282 }
283
284 return true;
285 }
286
287 } // namespace
288
parse_values()289 Error Method::parse_values() {
290 auto flatbuffer_values = serialization_plan_->values();
291 ET_CHECK_OR_RETURN_ERROR(
292 flatbuffer_values != nullptr, InvalidProgram, "Missing values");
293 size_t n_value = flatbuffer_values->size();
294 values_ = memory_manager_->method_allocator()->allocateList<EValue>(n_value);
295 if (values_ == nullptr) {
296 return Error::MemoryAllocationFailed;
297 }
298
299 // n_value_ counts the number of successfully-initialized values for ~Method()
300 // to clean up, and is incremented at the bottom of the loop. This makes it
301 // safe for errors to return without updating any state.
302 n_value_ = 0;
303
304 for (size_t i = 0; i < n_value; ++i) {
305 auto serialization_value = flatbuffer_values->Get(i);
306 // Ensure that the `val_as_X()` calls will return non-null pointers.
307 ET_CHECK_OR_RETURN_ERROR(
308 serialization_value != nullptr &&
309 (serialization_value->val_type() ==
310 executorch_flatbuffer::KernelTypes::Null ||
311 serialization_value->val() != nullptr),
312 InvalidProgram,
313 "Null value at index %zu",
314 i);
315
316 switch (serialization_value->val_type()) {
317 case executorch_flatbuffer::KernelTypes::Null: {
318 // Placement new as the list elements are not initialized, so calling
319 // copy assignment is not defined if its non trivial (Imagine the
320 // garbage in values_[i] thinks its an at::Tensor).
321 new (&values_[i]) EValue();
322 } break;
323 case executorch_flatbuffer::KernelTypes::Int: {
324 new (&values_[i]) EValue(serialization_value->val_as_Int()->int_val());
325 } break;
326 case executorch_flatbuffer::KernelTypes::Double: {
327 new (&values_[i])
328 EValue(serialization_value->val_as_Double()->double_val());
329 } break;
330 case executorch_flatbuffer::KernelTypes::Bool: {
331 new (&values_[i])
332 EValue(serialization_value->val_as_Bool()->bool_val());
333 } break;
334 case executorch_flatbuffer::KernelTypes::IntList: {
335 const auto items = serialization_value->val_as_IntList()->items();
336 ET_CHECK_OR_RETURN_ERROR(
337 items != nullptr, InvalidProgram, "Missing list at index %zu", i);
338 // Allocate space for boxed and unboxed list representations using
339 // values_ as source of truth
340 auto* evalp_list =
341 memory_manager_->method_allocator()->allocateList<EValue*>(
342 items->size());
343 auto* int_list =
344 memory_manager_->method_allocator()->allocateList<int64_t>(
345 items->size());
346
347 // initialize boxed list
348 for (size_t j = 0; j < items->size(); j++) {
349 evalp_list[j] = &values_[static_cast<size_t>(items->Get(j))];
350 }
351 new (&values_[i]) EValue(
352 BoxedEvalueList<int64_t>(evalp_list, int_list, items->size()));
353 } break;
354 case executorch_flatbuffer::KernelTypes::BoolList: {
355 const auto items = serialization_value->val_as_BoolList()->items();
356 ET_CHECK_OR_RETURN_ERROR(
357 items != nullptr, InvalidProgram, "Missing list at index %zu", i);
358 // NOTE: This is technically not portable. A platform could technically
359 // define boolean as something longer than a byte. This would be an
360 // exceptionally rare case, and this type is currently unused in any
361 // operators in ATen that we would need to support. To be properly
362 // portable here we need to allocate a new array of bool and copy cast
363 // the flatbuffer data into it, but because of how exceptionally rare
364 // this case is its low prio TODO: jakeszwe
365 new (&values_[i]) EValue(exec_aten::ArrayRef<bool>(
366 (const bool*)items->data(), items->size()));
367 } break;
368 case executorch_flatbuffer::KernelTypes::DoubleList: {
369 const auto items = serialization_value->val_as_DoubleList()->items();
370 ET_CHECK_OR_RETURN_ERROR(
371 items != nullptr, InvalidProgram, "Missing list at index %zu", i);
372 new (&values_[i])
373 EValue(exec_aten::ArrayRef<double>(items->data(), items->size()));
374 } break;
375 case executorch_flatbuffer::KernelTypes::String: {
376 const auto fb_str = serialization_value->val_as_String()->string_val();
377 ET_CHECK_OR_RETURN_ERROR(
378 fb_str != nullptr,
379 InvalidProgram,
380 "Missing string at index %zu",
381 i);
382 new (&values_[i]) EValue(fb_str->c_str(), fb_str->size());
383 } break;
384 case executorch_flatbuffer::KernelTypes::Tensor: {
385 auto t = deserialization::parseTensor(
386 program_, memory_manager_, serialization_value->val_as_Tensor());
387 if (!t.ok()) {
388 ET_LOG(
389 Error,
390 "Failed parsing tensor at index %zu: 0x%" PRIx32,
391 i,
392 static_cast<uint32_t>(t.error()));
393 return t.error();
394 }
395 new (&values_[i]) EValue(t.get());
396 } break;
397 case executorch_flatbuffer::KernelTypes::TensorList: {
398 // get list of serialization tensors and allocate storage for executor
399 // tensors
400 auto tensors = deserialization::parseTensorList(
401 serialization_value->val_as_TensorList()->items(),
402 values_,
403 memory_manager_);
404 if (!tensors.ok()) {
405 ET_LOG(
406 Error,
407 "Failed parsing tensor list at index %zu: 0x%" PRIx32,
408 i,
409 static_cast<uint32_t>(tensors.error()));
410 return tensors.error();
411 }
412 new (&values_[i]) EValue(tensors.get());
413 } break;
414 case executorch_flatbuffer::KernelTypes::OptionalTensorList: {
415 // Same as TensorList but optional<Tensor> instead of Tensor
416 auto tensors =
417 deserialization::parseListOptionalType<exec_aten::Tensor>(
418 serialization_value->val_as_OptionalTensorList()->items(),
419 values_,
420 memory_manager_);
421 if (!tensors.ok()) {
422 ET_LOG(
423 Error,
424 "Failed parsing optional tensor list at index %zu: 0x%" PRIx32,
425 i,
426 static_cast<uint32_t>(tensors.error()));
427 return tensors.error();
428 }
429 new (&values_[i]) EValue(tensors.get());
430 } break;
431 default:
432 // flatbuffer enums start at 0, but they generate a hidden NONE enum
433 // and give it that value. schema.fbs doesnt show this type, so I
434 // subtract one to keep the output in 0 based indexing for a
435 // disgruntled debugger seeing this error message and checking
436 // schema.fbs
437 ET_LOG(
438 Error,
439 "Unknown KernelTypes value %" PRIu32 " at index %zu",
440 static_cast<uint32_t>(serialization_value->val_type()) - 1,
441 i);
442 return Error::InvalidProgram;
443 }
444
445 // ~Method() will try to clean up n_value_ entries in the values_ array.
446 // Only increment this once we know the entry is valid, so that we don't try
447 // to clean up an uninitialized entry.
448 n_value_ = i + 1;
449 }
450 return Error::Ok;
451 }
452
453 namespace {
454 /**
455 * Private/helper method for populating operator_name from the Operator.
456 * operator_name is a char pointer that is already allocated. The size of
457 * of this buffer is of size operator_name_size.
458 */
populate_operator_name(const executorch_flatbuffer::Operator * const & op,const size_t operator_name_size,char * operator_name)459 Error populate_operator_name(
460 const executorch_flatbuffer::Operator* const& op,
461 const size_t operator_name_size,
462 char* operator_name) {
463 const bool has_overload =
464 op->overload() != nullptr && op->overload()->size() > 0;
465
466 ET_CHECK_OR_RETURN_ERROR(
467 op->name() != nullptr, InvalidProgram, "Missing operator name");
468 int cx = snprintf(
469 operator_name,
470 operator_name_size,
471 "%s%s%s",
472 op->name()->c_str(),
473 // Don't append any overload if the overload string is empty.
474 has_overload ? "." : "",
475 has_overload ? op->overload()->c_str() : "");
476 ET_CHECK_OR_RETURN_ERROR(cx >= 0, Internal, "snprintf failed: %d", cx);
477 ET_CHECK_OR_RETURN_ERROR(
478 cx < operator_name_size,
479 Internal,
480 "Operator name %s%s%s with length %d "
481 "truncated to %zu due to internal buffer limit.",
482 op->name()->c_str(),
483 has_overload ? "." : "",
484 has_overload ? op->overload()->c_str() : "",
485 cx,
486 operator_name_size);
487
488 return Error::Ok;
489 }
490 } // namespace
491
resolve_operator(int32_t op_index,OpFunction * kernels,size_t kernel_index,InstructionArgs args,size_t n_args)492 Error Method::resolve_operator(
493 int32_t op_index,
494 OpFunction* kernels,
495 size_t kernel_index,
496 InstructionArgs args,
497 size_t n_args) {
498 // TODO(T153505381, T153506819) Investigate optimizing this function for both
499 // space and time.
500
501 // resolve name
502 constexpr size_t kTempBufferSizeForName = 100;
503 char operator_name[kTempBufferSizeForName];
504 const auto ops = serialization_plan_->operators();
505 ET_CHECK_OR_RETURN_ERROR(
506 ops != nullptr && op_index < ops->size(),
507 InvalidProgram,
508 "Op index %" PRIu32 " out of range",
509 op_index);
510 const auto& op = ops->Get(op_index);
511
512 Error err = populate_operator_name(op, kTempBufferSizeForName, operator_name);
513 if (err != Error::Ok) {
514 return err;
515 }
516
517 // resolve tensor meta
518 auto method_allocator = memory_manager_->method_allocator();
519 TensorMeta* meta = method_allocator->allocateList<TensorMeta>(n_args);
520 if (meta == nullptr) {
521 return Error::MemoryAllocationFailed;
522 }
523
524 size_t count = 0;
525 for (size_t i = 0; i < n_args; i++) {
526 EValue* eval = args[i];
527 // handle tensor list as well
528 if (eval->isTensor()) {
529 auto tensor = eval->toTensor();
530 meta[count].dtype_ = tensor.scalar_type();
531 exec_aten::DimOrderType* dim_order_ptr =
532 method_allocator->allocateList<exec_aten::DimOrderType>(tensor.dim());
533 if (dim_order_ptr == nullptr) {
534 return Error::MemoryAllocationFailed;
535 }
536 size_t size = tensor.dim();
537 err = get_dim_order(tensor, dim_order_ptr, size);
538 ET_CHECK_OR_RETURN_ERROR(
539 err == Error::Ok,
540 InvalidArgument,
541 "Error setting dim_order %zu: 0x%" PRIx32,
542 i,
543 static_cast<uint32_t>(err));
544 meta[count].dim_order_ =
545 Span<exec_aten::DimOrderType>(dim_order_ptr, size);
546 count++;
547 }
548 }
549
550 // Find a kernel with the matching name and tensor meta.
551 Result<OpFunction> op_function =
552 get_op_function_from_registry(operator_name, {meta, count});
553 if (!op_function.ok()) {
554 ET_LOG(Error, "Missing operator: [%d] %s", op_index, operator_name);
555 return op_function.error();
556 }
557 kernels[kernel_index] = op_function.get();
558 return Error::Ok;
559 }
560
load(executorch_flatbuffer::ExecutionPlan * s_plan,const Program * program,MemoryManager * memory_manager,EventTracer * event_tracer)561 Result<Method> Method::load(
562 executorch_flatbuffer::ExecutionPlan* s_plan,
563 const Program* program,
564 MemoryManager* memory_manager,
565 EventTracer* event_tracer) {
566 MemoryAllocator* temp_allocator = memory_manager->temp_allocator();
567 if (temp_allocator == nullptr) {
568 PlatformMemoryAllocator* platform_allocator =
569 memory_manager->method_allocator()
570 ->allocateInstance<PlatformMemoryAllocator>();
571 if (platform_allocator == nullptr) {
572 return Error::MemoryAllocationFailed;
573 }
574 new (platform_allocator) PlatformMemoryAllocator();
575 temp_allocator = platform_allocator;
576 }
577 Method method(program, memory_manager, event_tracer, temp_allocator);
578
579 Error err = method.init(s_plan);
580 if (err != Error::Ok) {
581 return err;
582 } else {
583 ET_CHECK(method.initialized());
584 return method;
585 }
586 }
587
init(executorch_flatbuffer::ExecutionPlan * s_plan)588 Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
589 EXECUTORCH_SCOPE_PROF("Method::init");
590 internal::EventTracerProfileMethodScope event_tracer_profile_scope =
591 internal::EventTracerProfileMethodScope(event_tracer_, "Method::init");
592 ET_CHECK_OR_RETURN_ERROR(
593 // Don't use !initialized() here because we also want to fail on the
594 // InitializationFailed state.
595 init_state_ == InitializationState::Uninitialized,
596 InvalidState,
597 "Method already initialized, or previously failed to initialize.");
598 init_state_ =
599 InitializationState::InitializationFailed; // Until proven otherwise
600 serialization_plan_ = s_plan;
601 auto method_allocator = memory_manager_->method_allocator();
602
603 {
604 // Parse the elements of the values_ array.
605 Error err = parse_values();
606 if (err != Error::Ok) {
607 return err;
608 }
609 }
610
611 {
612 // Resolve delegates
613 const auto delegates = serialization_plan_->delegates();
614 ET_CHECK_OR_RETURN_ERROR(
615 delegates != nullptr, InvalidProgram, "Missing delegates field");
616 size_t n_delegate = delegates->size();
617 delegates_ = method_allocator->allocateList<BackendDelegate>(n_delegate);
618 if (delegates_ == nullptr) {
619 return Error::MemoryAllocationFailed;
620 }
621
622 // n_delegate_ counts the number of successfully-initialized delegates for
623 // ~Method() to clean up, and is incremented at the bottom of the loop. This
624 // makes it safe for errors to return without updating any state.
625 n_delegate_ = 0;
626
627 for (size_t i = 0; i < n_delegate; ++i) {
628 const auto& delegate = *delegates->Get(i);
629 BackendInitContext backend_init_context(
630 method_allocator,
631 /*method_name=*/serialization_plan_->name()->c_str());
632 Error err = BackendDelegate::Init(
633 delegate, program_, backend_init_context, &delegates_[i]);
634 if (err != Error::Ok) {
635 return err;
636 }
637 // ~Method() will try to clean up n_delegate_ entries in the delegates_
638 // array. Only increment this once we know the entry is valid, so that
639 // we don't try to clean up an uninitialized entry.
640 n_delegate_ = i + 1;
641 }
642 }
643
644 {
645 // Load chains
646 const auto chains = serialization_plan_->chains();
647 ET_CHECK_OR_RETURN_ERROR(
648 chains != nullptr && chains->size() > 0, InvalidProgram, "No chains");
649 n_chains_ = chains->size();
650 chains_ = method_allocator->allocateList<Chain>(n_chains_);
651 if (chains_ == nullptr) {
652 return Error::MemoryAllocationFailed;
653 }
654
655 // Try resolving all operators before failing, to make it easier to debug
656 // multiple problems at once.
657 Error delayed_error = Error::Ok;
658 int32_t num_instructions_missing_op = 0;
659 for (size_t i = 0; i < n_chains_; ++i) {
660 auto s_chain = chains->Get(i);
661 auto s_instructions = s_chain->instructions();
662 ET_CHECK_OR_RETURN_ERROR(
663 s_instructions != nullptr,
664 InvalidProgram,
665 "Missing instructions in chain %zu",
666 i);
667 auto num_instructions = s_instructions->size();
668 auto chain_instruction_kernels =
669 method_allocator->allocateList<OpFunction>(num_instructions);
670 if (chain_instruction_kernels == nullptr) {
671 return Error::MemoryAllocationFailed;
672 }
673 auto chain_instruction_arg_lists =
674 method_allocator->allocateList<InstructionArgs>(num_instructions);
675 if (chain_instruction_arg_lists == nullptr) {
676 return Error::MemoryAllocationFailed;
677 }
678
679 // Set up the argument lists ahead of time and store pointers to them to
680 // use when the instructions are called
681 for (size_t instr_idx = 0; instr_idx < s_instructions->size();
682 ++instr_idx) {
683 const auto instruction = s_instructions->Get(instr_idx);
684 // Ensure that the `instr_args_as_X()` calls will return non-null.
685 ET_CHECK_OR_RETURN_ERROR(
686 instruction != nullptr && instruction->instr_args() != nullptr,
687 InvalidProgram,
688 "Null instruction at index %zu",
689 instr_idx);
690
691 switch (instruction->instr_args_type()) {
692 case executorch_flatbuffer::InstructionArguments::KernelCall: {
693 const auto arg_idxs =
694 instruction->instr_args_as_KernelCall()->args();
695 ET_CHECK_OR_RETURN_ERROR(
696 arg_idxs != nullptr, InvalidProgram, "KernelCall args missing");
697 auto res = gen_instruction_arguments(
698 method_allocator,
699 n_value_,
700 values_,
701 arg_idxs->size(),
702 arg_idxs->data());
703 if (!res.ok()) {
704 return res.error();
705 }
706 chain_instruction_arg_lists[instr_idx] = res.get();
707 auto err = resolve_operator(
708 instruction->instr_args_as_KernelCall()->op_index(),
709 chain_instruction_kernels,
710 instr_idx,
711 res.get(),
712 arg_idxs->size());
713 if (err == Error::OperatorMissing) {
714 num_instructions_missing_op++;
715 } else if (err == Error::MemoryAllocationFailed) {
716 return err;
717 } else {
718 delayed_error = err;
719 }
720 } break;
721 case executorch_flatbuffer::InstructionArguments::DelegateCall: {
722 const auto arg_idxs =
723 instruction->instr_args_as_DelegateCall()->args();
724 ET_CHECK_OR_RETURN_ERROR(
725 arg_idxs != nullptr,
726 InvalidProgram,
727 "DelegateCall args missing");
728 auto res = gen_instruction_arguments(
729 method_allocator,
730 n_value_,
731 values_,
732 arg_idxs->size(),
733 arg_idxs->data());
734 if (!res.ok()) {
735 return res.error();
736 }
737 chain_instruction_arg_lists[instr_idx] = res.get();
738 } break;
739 case executorch_flatbuffer::InstructionArguments::JumpFalseCall: {
740 // Validate the index at load time so we can trust it during
741 // execution.
742 auto index =
743 instruction->instr_args_as_JumpFalseCall()->cond_value_index();
744 ET_CHECK_OR_RETURN_ERROR(
745 index >= 0 && index < n_value_,
746 InvalidProgram,
747 "Index %d negative or >= %zu",
748 index,
749 n_value_);
750 chain_instruction_arg_lists[instr_idx] = InstructionArgs();
751 } break;
752 default: {
753 chain_instruction_arg_lists[instr_idx] = InstructionArgs();
754 } break;
755 }
756 }
757 chains_[i] = Chain{
758 s_chain,
759 Span<InstructionArgs>(chain_instruction_arg_lists, num_instructions),
760 chain_instruction_kernels,
761 };
762 }
763 ET_CHECK_OR_RETURN_ERROR(
764 num_instructions_missing_op == 0,
765 OperatorMissing,
766 "There are %d instructions don't have corresponding operator registered. "
767 "See logs for details",
768 num_instructions_missing_op);
769 if (delayed_error != Error::Ok) {
770 return delayed_error;
771 }
772 }
773
774 step_state_ = StepState{0, 0};
775
776 init_state_ = InitializationState::Initialized;
777 return Error::Ok;
778 }
779
780 ET_NODISCARD Error
set_input(const EValue & input_evalue,size_t input_idx)781 Method::set_input(const EValue& input_evalue, size_t input_idx) {
782 ET_CHECK_OR_RETURN_ERROR(
783 initialized(),
784 InvalidState,
785 "Input can not be set until method has been initialized.");
786
787 ET_CHECK_OR_RETURN_ERROR(
788 step_state_.instr_idx == 0 && step_state_.chain_idx == 0,
789 InvalidState,
790 "Inputs can not be set mid execution.");
791
792 ET_CHECK_OR_RETURN_ERROR(
793 input_idx < inputs_size(),
794 InvalidArgument,
795 "Given input index must be less than the number of inputs in method, but got %zu and %zu",
796 input_idx,
797 inputs_size());
798
799 const auto& e = get_value(get_input_index(input_idx));
800 ET_CHECK_OR_RETURN_ERROR(
801 e.isTensor() || e.isScalar(),
802 InvalidArgument,
803 "The %zu-th input in method is expected Tensor or prim, but received %" PRIu32,
804 input_idx,
805 static_cast<uint32_t>(e.tag));
806
807 ET_CHECK_OR_RETURN_ERROR(
808 e.tag == input_evalue.tag,
809 InvalidArgument,
810 "The %zu-th input of method should have the same type as the input_evalue, but get tag %" PRIu32
811 " and tag %" PRIu32,
812 input_idx,
813 static_cast<uint32_t>(e.tag),
814 static_cast<uint32_t>(input_evalue.tag));
815
816 if (e.isTensor()) {
817 const auto& t_dst = e.toTensor();
818 const auto& t_src = input_evalue.toTensor();
819 ET_CHECK_OR_RETURN_ERROR(
820 t_dst.scalar_type() == t_src.scalar_type(),
821 InvalidArgument,
822 "The %zu-th input tensor's scalartype does not meet requirement: found %" PRId8
823 " but expected %" PRId8,
824 input_idx,
825 static_cast<int8_t>(t_src.scalar_type()),
826 static_cast<int8_t>(t_dst.scalar_type()));
827 // Reset the shape for the Method's input as the size of forwarded input
828 // tensor for shape dynamism. Also is a safety check if need memcpy.
829 Error err = resize_tensor(t_dst, t_src.sizes());
830 ET_CHECK_OR_RETURN_ERROR(
831 err == Error::Ok,
832 InvalidArgument,
833 "Error setting input %zu: 0x%" PRIx32,
834 input_idx,
835 static_cast<uint32_t>(err));
836 Error error;
837 auto tensor_meta = this->method_meta().input_tensor_meta(input_idx);
838 if (tensor_meta->is_memory_planned()) {
839 error = internal::copy_tensor_data(t_dst, t_src);
840 } else {
841 error = internal::share_tensor_data(t_dst, t_src);
842 }
843 ET_CHECK_OR_RETURN_ERROR(
844 error == Error::Ok,
845 InvalidArgument,
846 "Error setting data_ptr %zu: 0x%" PRIx32,
847 input_idx,
848 static_cast<uint32_t>(error));
849 // Prims have to be the same as what was traced
850 } else if (e.isInt()) {
851 ET_CHECK_OR_RETURN_ERROR(
852 e.toInt() == input_evalue.toInt(),
853 InvalidArgument,
854 "The %zu-th input of method should have the same value as the input_evalue, but got %" PRId64
855 " and %" PRId64,
856 input_idx,
857 e.toInt(),
858 input_evalue.toInt());
859 } else if (e.isBool()) {
860 ET_CHECK_OR_RETURN_ERROR(
861 e.toBool() == input_evalue.toBool(),
862 InvalidArgument,
863 "The %zu-th input of method should have the same value as the input_evalue, but got %" PRId64
864 " and %" PRId64,
865 input_idx,
866 (int64_t)e.toBool(),
867 (int64_t)input_evalue.toBool());
868 } else if (e.isDouble()) {
869 double lhs = input_evalue.toDouble();
870 double rhs = e.toDouble();
871 double atol = 1e-4;
872 double rtol = 1e-5;
873 bool is_equal = true;
874 if (std::isnan(lhs) && std::isnan(rhs)) {
875 // NaN == NaN
876 } else if (
877 !std::isfinite(lhs) && !std::isfinite(rhs) &&
878 ((lhs > 0) == (rhs > 0))) {
879 // -Inf == -Inf
880 // +Inf == +Inf
881 } else {
882 auto allowed_error = atol + std::abs(rtol * rhs);
883 auto actual_error = std::abs(lhs - rhs);
884 if (!std::isfinite(actual_error) || actual_error > allowed_error) {
885 is_equal = false;
886 }
887 }
888 ET_CHECK_OR_RETURN_ERROR(
889 is_equal,
890 InvalidArgument,
891 "The %zu-th input of method should have the same value as the input_evalue, but get %f and %f",
892 input_idx,
893 lhs,
894 rhs);
895 } else if (e.isString()) {
896 ET_CHECK_OR_RETURN_ERROR(
897 e.toString() == input_evalue.toString(),
898 InvalidArgument,
899 "The %zu-th input of method should have the same value as the input_evalue, but get %s and %s",
900 input_idx,
901 e.toString().data(),
902 input_evalue.toString().data());
903 } else {
904 ET_LOG(Error, "Unsupported input type: %d", (int32_t)e.tag);
905 return Error::InvalidArgument;
906 }
907 return Error::Ok;
908 }
909
910 ET_NODISCARD Error
set_inputs(const exec_aten::ArrayRef<EValue> & input_evalues)911 Method::set_inputs(const exec_aten::ArrayRef<EValue>& input_evalues) {
912 ET_CHECK_OR_RETURN_ERROR(
913 initialized(),
914 InvalidState,
915 "Inputs can not be set until method has been initialized.");
916
917 ET_CHECK_OR_RETURN_ERROR(
918 step_state_.instr_idx == 0 && step_state_.chain_idx == 0,
919 InvalidState,
920 "Inputs can not be set mid execution.");
921
922 size_t input_size = inputs_size();
923 ET_CHECK_OR_RETURN_ERROR(
924 input_size == input_evalues.size(),
925 InvalidArgument,
926 "The length of given input array (%zu) must be same as the number of inputs in method (%zu).",
927 input_evalues.size(),
928 input_size);
929
930 for (size_t i = 0; i < input_size; i++) {
931 Error status = set_input(input_evalues[i], i);
932 if (status != Error::Ok) {
933 return status;
934 }
935 }
936 return Error::Ok;
937 }
938
939 ET_NODISCARD Error
set_output_data_ptr(void * buffer,size_t size,size_t output_idx)940 Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) {
941 // Check method state
942 ET_CHECK_OR_RETURN_ERROR(
943 initialized(),
944 InvalidState,
945 "Outputs can not be retrieved until method has been initialized.");
946
947 // Check the args
948 ET_CHECK_OR_RETURN_ERROR(
949 output_idx < outputs_size(),
950 InvalidArgument,
951 "output_idx: %zu > num_outputs: %zu",
952 output_idx,
953 outputs_size());
954
955 auto& output = mutable_value(get_output_index(output_idx));
956 ET_CHECK_OR_RETURN_ERROR(
957 output.isTensor(),
958 InvalidArgument,
959 "output type: %zu is not tensor",
960 (size_t)output.tag);
961
962 auto tensor_meta = this->method_meta().output_tensor_meta(output_idx);
963 if (tensor_meta->is_memory_planned()) {
964 ET_LOG(
965 Error,
966 "Output %zu is memory planned, or is a constant. Cannot override "
967 "the existing data pointer.",
968 output_idx);
969 return Error::InvalidState;
970 }
971
972 auto& t = output.toTensor();
973 ET_CHECK_OR_RETURN_ERROR(
974 output.isTensor(),
975 InvalidArgument,
976 "output type: %zu is not tensor",
977 (size_t)output.tag);
978 ET_CHECK_OR_RETURN_ERROR(
979 t.nbytes() <= size,
980 InvalidArgument,
981 "buffer size: %zu is smaller then expected tensor size: %zu",
982 size,
983 t.nbytes());
984
985 // Set data
986 return internal::set_tensor_data(t, buffer, size);
987 }
988
get_outputs(EValue * output_evalues,size_t length)989 ET_NODISCARD Error Method::get_outputs(EValue* output_evalues, size_t length) {
990 ET_CHECK_OR_RETURN_ERROR(
991 initialized(),
992 InvalidState,
993 "Outputs can not be retrieved until method has been initialized.");
994
995 ET_CHECK_OR_RETURN_ERROR(
996 length >= outputs_size(),
997 InvalidArgument,
998 "The given array is not large enough to hold all outputs.");
999
1000 for (size_t i = 0; i < outputs_size(); i++) {
1001 output_evalues[i] = values_[get_output_index(i)];
1002 }
1003
1004 for (size_t i = outputs_size(); i < length; i++) {
1005 output_evalues[i] = EValue();
1006 }
1007
1008 return Error::Ok;
1009 }
1010
get_inputs(EValue * input_evalues,size_t length)1011 ET_NODISCARD Error Method::get_inputs(EValue* input_evalues, size_t length) {
1012 ET_CHECK_OR_RETURN_ERROR(
1013 initialized(),
1014 InvalidState,
1015 "Inputs can not be retrieved until method has been initialized.");
1016
1017 ET_CHECK_OR_RETURN_ERROR(
1018 length >= inputs_size(),
1019 InvalidArgument,
1020 "The given array is not large enough to hold all inputs.");
1021
1022 for (size_t i = 0; i < inputs_size(); i++) {
1023 input_evalues[i] = values_[get_input_index(i)];
1024 }
1025
1026 for (size_t i = inputs_size(); i < length; i++) {
1027 input_evalues[i] = EValue();
1028 }
1029
1030 return Error::Ok;
1031 }
1032
execute_instruction()1033 Error Method::execute_instruction() {
1034 auto& chain = chains_[step_state_.chain_idx];
1035 auto instructions = chain.s_chain_->instructions();
1036
1037 ET_CHECK_OR_RETURN_ERROR(
1038 step_state_.instr_idx < instructions->size(),
1039 Internal,
1040 "Instr index %zu >= chain[%zu] instr count %zu",
1041 step_state_.instr_idx,
1042 step_state_.chain_idx,
1043 (size_t)instructions->size());
1044
1045 auto instruction = instructions->Get(step_state_.instr_idx);
1046 size_t next_instr_idx = step_state_.instr_idx + 1;
1047 Error err = Error::Ok;
1048
1049 switch (instruction->instr_args_type()) {
1050 case executorch_flatbuffer::InstructionArguments::KernelCall: {
1051 EXECUTORCH_SCOPE_PROF("OPERATOR_CALL");
1052 internal::EventTracerProfileOpScope event_tracer_op_scope =
1053 internal::EventTracerProfileOpScope(event_tracer_, "OPERATOR_CALL");
1054 // TODO(T147221312): Also expose tensor resizer via the context.
1055 KernelRuntimeContext context(event_tracer_, temp_allocator_);
1056 auto args = chain.argument_lists_[step_state_.instr_idx];
1057 chain.kernels_[step_state_.instr_idx](context, args.data());
1058 // We reset the temp_allocator after the switch statement
1059 err = context.failure_state();
1060 if (err != Error::Ok) {
1061 // We know that instr_args_as_KernelCall is non-null because it was
1062 // checked at init time.
1063 auto op_index = instruction->instr_args_as_KernelCall()->op_index();
1064 auto op = serialization_plan_->operators()->Get(op_index);
1065 ET_LOG(
1066 Error,
1067 "KernelCall failed at instruction %zu:%zu in operator %s.%s: 0x%x",
1068 step_state_.chain_idx,
1069 step_state_.instr_idx,
1070 op->name()->c_str(),
1071 op->overload()->c_str(),
1072 (unsigned int)err);
1073 for (size_t i = 0; i < args.size(); ++i) {
1074 ET_LOG(
1075 Error,
1076 "arg %u with type id %u",
1077 (unsigned int)i,
1078 (unsigned int)args[i]->tag);
1079 }
1080 // TODO(T153804650): Consider logging the EValues to help with
1081 // debugging. This is a failure path, and it doesn't matter if it's a
1082 // little slow. Do the same for DelegateCall errors.
1083 }
1084 } break;
1085 case executorch_flatbuffer::InstructionArguments::DelegateCall: {
1086 EXECUTORCH_SCOPE_PROF("DELEGATE_CALL");
1087 internal::EventTracerProfileOpScope event_tracer_op_scope =
1088 internal::EventTracerProfileOpScope(event_tracer_, "DELEGATE_CALL");
1089 // We know that instr_args_as_DelegateCall is non-null because it was
1090 // checked at init time.
1091 auto delegate_idx =
1092 instruction->instr_args_as_DelegateCall()->delegate_index();
1093 ET_CHECK_OR_RETURN_ERROR(
1094 delegate_idx < n_delegate_,
1095 Internal,
1096 "DELEGATE_CALL index %" PRIu32
1097 " >= num delegates %zu at instruction %zu",
1098 delegate_idx,
1099 n_delegate_,
1100 step_state_.instr_idx);
1101 BackendExecutionContext backend_execution_context(
1102 /*event_tracer=*/event_tracer_,
1103 /*temp_allocator=*/temp_allocator_,
1104 /*method_name=*/serialization_plan_->name()->c_str());
1105 err = delegates_[delegate_idx].Execute(
1106 backend_execution_context,
1107 chain.argument_lists_[step_state_.instr_idx].data());
1108 if (err != Error::Ok) {
1109 ET_LOG(
1110 Error,
1111 "CALL_DELEGATE execute failed at instruction %zu: 0x%" PRIx32,
1112 step_state_.instr_idx,
1113 static_cast<uint32_t>(err));
1114 }
1115
1116 // Log all the arguments of the delegate call. Ideally we'd only like to
1117 // log the outputs of the delegate, but currently we cannot know from the
1118 // arguments which are the inputs and which are the outputs, so we just
1119 // log everything. This will be changed in the future when the inputs and
1120 // ouputs are separate lists.
1121 #ifdef ET_EVENT_TRACER_ENABLED
1122 for (size_t i = 0;
1123 i < chain.argument_lists_[step_state_.instr_idx].size();
1124 i++) {
1125 EValue* arg = chain.argument_lists_[step_state_.instr_idx].data()[i];
1126 internal::event_tracer_log_evalue(event_tracer_, *arg);
1127 }
1128 #endif
1129 } break;
1130 case executorch_flatbuffer::InstructionArguments::JumpFalseCall: {
1131 EXECUTORCH_SCOPE_PROF("JF_CALL");
1132 internal::EventTracerProfileOpScope event_tracer_op_scope =
1133 internal::EventTracerProfileOpScope(event_tracer_, "JF_CALL");
1134 // We know that instr_args_as_JumpFalseCall is non-null because it was
1135 // checked at init time.
1136 auto jf_call = instruction->instr_args_as_JumpFalseCall();
1137 // We know that index is a valid values_ index because it was checked at
1138 // init time.
1139 auto index = jf_call->cond_value_index();
1140 Result<bool> jf_result = parse_cond_value(values_[index]);
1141 if (jf_result.ok()) {
1142 if (!jf_result.get()) {
1143 next_instr_idx = jf_call->destination_instruction();
1144 }
1145 } else {
1146 err = jf_result.error();
1147 }
1148 } break;
1149 case executorch_flatbuffer::InstructionArguments::MoveCall: {
1150 EXECUTORCH_SCOPE_PROF("MOVE_CALL");
1151 internal::EventTracerProfileOpScope event_tracer_op_scope =
1152 internal::EventTracerProfileOpScope(event_tracer_, "MOVE_CALL");
1153 // We know that instr_args_as_MoveCall is non-null because it was checked
1154 // at init time.
1155 auto move_call = instruction->instr_args_as_MoveCall();
1156 mutable_value(move_call->move_to()) = get_value(move_call->move_from());
1157 } break;
1158 case executorch_flatbuffer::InstructionArguments::FreeCall: {
1159 EXECUTORCH_SCOPE_PROF("FREE_CALL");
1160 internal::EventTracerProfileOpScope event_tracer_op_scope =
1161 internal::EventTracerProfileOpScope(event_tracer_, "FREE_CALL");
1162 // We know that instr_args_as_FreeCall is non-null because it was checked
1163 // at init time.
1164 auto free_call = instruction->instr_args_as_FreeCall();
1165 auto t = values_[free_call->value_index()].toTensor();
1166 internal::reset_data_ptr(t);
1167 } break;
1168 default:
1169 ET_LOG(
1170 Error,
1171 "Unknown instruction: %hhu",
1172 static_cast<uint8_t>(instruction->instr_args_type()));
1173 err = Error::InvalidProgram;
1174 }
1175 // Reset the temp allocator for every instruction.
1176 if (temp_allocator_ != nullptr) {
1177 temp_allocator_->reset();
1178 }
1179 if (err == Error::Ok) {
1180 step_state_.instr_idx = next_instr_idx;
1181 }
1182 return err;
1183 }
1184
reset_execution()1185 Error Method::reset_execution() {
1186 ET_CHECK_OR_RETURN_ERROR(
1187 step_state_.chain_idx == n_chains_,
1188 InvalidState,
1189 "Cannot reset until EndOfMethod has been reached.");
1190 step_state_ = StepState{0, 0};
1191 return Error::Ok;
1192 }
1193
experimental_reset_execution()1194 Error Method::experimental_reset_execution() {
1195 return reset_execution(); // @lint-ignore CLANGTIDY facebook-hte-Deprecated
1196 }
1197
1198 // Log all the outputs of this method to the event tracer.
log_outputs()1199 void Method::log_outputs() {
1200 #ifdef ET_EVENT_TRACER_ENABLED
1201 if (event_tracer_ != nullptr) {
1202 if (event_tracer_->event_tracer_debug_level() >=
1203 EventTracerDebugLogLevel::kProgramOutputs) {
1204 for (size_t i = 0; i < outputs_size(); i++) {
1205 internal::event_tracer_log_evalue_output(event_tracer_, get_output(i));
1206 }
1207 }
1208 }
1209 #endif
1210 }
1211
step()1212 Error Method::step() {
1213 EXECUTORCH_PROFILE_INSTRUCTION_SCOPE(
1214 static_cast<int32_t>(step_state_.chain_idx),
1215 static_cast<uint32_t>(step_state_.instr_idx));
1216 internal::EventTracerProfileInstructionScope event_tracer_instr_scope =
1217 internal::EventTracerProfileInstructionScope(
1218 event_tracer_,
1219 static_cast<int32_t>(step_state_.chain_idx),
1220 static_cast<uint32_t>(step_state_.instr_idx));
1221 EXECUTORCH_SCOPE_PROF("Method::step");
1222 EventTracerEntry event_tracer_entry =
1223 internal::event_tracer_begin_profiling_event(
1224 event_tracer_, "Method::step");
1225 ET_CHECK_OR_RETURN_ERROR(
1226 initialized(),
1227 InvalidState,
1228 "Cannot execute until method has been initialized.");
1229
1230 // If chain_step_ is on n_chains_, then we have no instructions run.
1231 if (step_state_.chain_idx == n_chains_) {
1232 return Error::EndOfMethod;
1233 }
1234
1235 auto num_instructions =
1236 chains_[step_state_.chain_idx].s_chain_->instructions()->size();
1237
1238 // Special case chains with no instructions. These appear for example in a
1239 // model that just returns the input/a constant.
1240 if (num_instructions == 0) {
1241 step_state_.chain_idx += 1;
1242 return Error::Ok;
1243 }
1244
1245 auto status = execute_instruction();
1246 if (status != Error::Ok) {
1247 return status;
1248 }
1249
1250 internal::event_tracer_end_profiling_event(event_tracer_, event_tracer_entry);
1251 // end of the current chain, advance to the next chain
1252 if (step_state_.instr_idx == num_instructions) {
1253 step_state_.instr_idx = 0;
1254 step_state_.chain_idx += 1;
1255 log_outputs();
1256 }
1257 return Error::Ok;
1258 }
1259
experimental_step()1260 Error Method::experimental_step() {
1261 return step();
1262 }
1263
execute()1264 Error Method::execute() {
1265 internal::event_tracer_create_event_block(event_tracer_, "Execute");
1266 EventTracerEntry event_tracer_entry =
1267 internal::event_tracer_begin_profiling_event(
1268 event_tracer_, "Method::execute");
1269 EXECUTORCH_SCOPE_PROF("Method::execute");
1270 ET_CHECK_OR_RETURN_ERROR(
1271 initialized(),
1272 NotSupported,
1273 "Cannot execute until method has been initialized.");
1274
1275 // Chains are executed sequentially today, but future async designs may
1276 // branch and run many in parallel or out of order.
1277 for (step_state_.chain_idx = 0; step_state_.chain_idx < n_chains_;
1278 ++step_state_.chain_idx) {
1279 Chain& chain = chains_[step_state_.chain_idx];
1280 auto instructions = chain.s_chain_->instructions();
1281 ET_CHECK_OR_RETURN_ERROR(
1282 instructions != nullptr,
1283 Internal,
1284 "chain %zu has no instructions field",
1285 step_state_.chain_idx);
1286
1287 // Loop over instructions
1288 step_state_.instr_idx = 0;
1289 while (step_state_.instr_idx < chain.s_chain_->instructions()->size()) {
1290 EXECUTORCH_PROFILE_INSTRUCTION_SCOPE(
1291 static_cast<int32_t>(step_state_.chain_idx),
1292 static_cast<uint32_t>(step_state_.instr_idx));
1293 internal::EventTracerProfileInstructionScope event_tracer_instr_scope =
1294 internal::EventTracerProfileInstructionScope(
1295 event_tracer_,
1296 static_cast<ChainID>(step_state_.chain_idx),
1297 static_cast<DebugHandle>(step_state_.instr_idx));
1298 auto status = execute_instruction();
1299 if (status != Error::Ok) {
1300 return status;
1301 }
1302 }
1303 }
1304 internal::event_tracer_end_profiling_event(event_tracer_, event_tracer_entry);
1305 log_outputs();
1306
1307 // TODO(jakeszwe, dbort): Decide on calling execute back to back without
1308 // going through the reset api first.
1309 return reset_execution(); // @lint-ignore CLANGTIDY facebook-hte-Deprecated
1310 }
1311
method_meta() const1312 MethodMeta Method::method_meta() const {
1313 auto name = serialization_plan_->name()->c_str();
1314 auto method_meta = program_->method_meta(name);
1315 ET_CHECK_MSG(
1316 method_meta.ok(),
1317 "Internal error: method_meta(%s) returned 0x%" PRIx32,
1318 name,
1319 static_cast<uint32_t>(method_meta.error()));
1320 return method_meta.get();
1321 }
1322
get_value(size_t i) const1323 const EValue& Method::get_value(size_t i) const {
1324 ET_CHECK_MSG(i < n_value_, "%zu >= %zu", i, n_value_);
1325 return values_[i];
1326 }
1327
mutable_value(size_t i)1328 EValue& Method::mutable_value(size_t i) {
1329 ET_CHECK_MSG(i < n_value_, "%zu >= %zu", i, n_value_);
1330 return values_[i];
1331 }
1332
inputs_size() const1333 size_t Method::inputs_size() const {
1334 const auto* inputs = serialization_plan_->inputs();
1335 return inputs == nullptr ? 0 : inputs->size();
1336 }
1337
get_input_index(size_t i) const1338 size_t Method::get_input_index(size_t i) const {
1339 ET_CHECK_MSG(i < inputs_size(), "%zu >= %zu", i, inputs_size());
1340 return static_cast<size_t>(serialization_plan_->inputs()->Get(i));
1341 }
1342
get_input(size_t i) const1343 const EValue& Method::get_input(size_t i) const {
1344 return get_value(get_input_index(i));
1345 }
1346
mutable_input(size_t i)1347 EValue& Method::mutable_input(size_t i) {
1348 return mutable_value(get_input_index(i));
1349 }
1350
outputs_size() const1351 size_t Method::outputs_size() const {
1352 const auto* outputs = serialization_plan_->outputs();
1353 return outputs == nullptr ? 0 : outputs->size();
1354 }
1355
get_output_index(size_t i) const1356 size_t Method::get_output_index(size_t i) const {
1357 ET_CHECK_MSG(i < outputs_size(), "%zu >= %zu", i, outputs_size());
1358 return static_cast<size_t>(serialization_plan_->outputs()->Get(i));
1359 }
1360
get_output(size_t i) const1361 const EValue& Method::get_output(size_t i) const {
1362 return get_value(get_output_index(i));
1363 }
1364
mutable_output(size_t i)1365 EValue& Method::mutable_output(size_t i) {
1366 return mutable_value(get_output_index(i));
1367 }
1368
get_event_tracer()1369 EventTracer* Method::get_event_tracer() {
1370 return event_tracer_;
1371 }
1372
~Method()1373 Method::~Method() {
1374 // Destroy the values. It's necessary in ATen mode, where the refcount of
1375 // Tensors needs to be decremented properly.
1376 if (values_ != nullptr) {
1377 for (int i = 0; i < n_value_; ++i) {
1378 values_[i].~EValue();
1379 }
1380 }
1381 // Free any resources associated with delegate backends.
1382 if (delegates_ != nullptr) {
1383 for (int i = 0; i < n_delegate_; i++) {
1384 delegates_[i].~BackendDelegate();
1385 }
1386 }
1387 // All other fields are trivially destructible.
1388 }
1389 } // namespace runtime
1390 } // namespace executorch
1391