xref: /aosp_15_r20/external/executorch/runtime/executor/method.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/executor/method.h>
10 
11 #include <cinttypes> // @donotremove
12 #include <cstdint>
13 #include <cstdio>
14 
15 #include <executorch/runtime/backend/interface.h>
16 #include <executorch/runtime/core/event_tracer_hooks.h>
17 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
18 #include <executorch/runtime/core/span.h>
19 #include <executorch/runtime/executor/memory_manager.h>
20 #include <executorch/runtime/executor/platform_memory_allocator.h>
21 #include <executorch/runtime/executor/program.h>
22 #include <executorch/runtime/executor/tensor_parser.h>
23 #include <executorch/runtime/kernel/kernel_runtime_context.h>
24 #include <executorch/runtime/kernel/operator_registry.h>
25 #include <executorch/runtime/platform/assert.h>
26 #include <executorch/runtime/platform/log.h>
27 #include <executorch/runtime/platform/profiler.h>
28 #include <executorch/schema/program_generated.h>
29 
30 namespace executorch {
31 namespace runtime {
32 
33 using internal::PlatformMemoryAllocator;
34 
35 /**
36  * Runtime state for a backend delegate.
37  */
38 // This lint wants to wrap the class in an anonymous namespace, but it must be
39 // visible because it's forward-declared and used in Executor.h.
40 // @lint-ignore CLANGTIDY facebook-hte-ShadowingClass
41 class BackendDelegate final {
42  public:
43   /**
44    * Initializes an already-allocated BackendDelegate from its serialized
45    * representation.
46    *
47    * @param[in] delegate The serialized backend delegate to load.
48    * @param[in] program The serialized program to load from.
49    * @param[in] backend_init_context The context pointer to pass to the
50    *     backend's init() method.
51    * @param[out] out The BackendDelegate to initialize.
52    *
53    * @returns Error::Ok if the initialization succeeded, or an error otherwise.
54    */
Init(const executorch_flatbuffer::BackendDelegate & delegate,const Program * program,BackendInitContext & backend_init_context,BackendDelegate * out)55   static Error Init(
56       const executorch_flatbuffer::BackendDelegate& delegate,
57       const Program* program,
58       BackendInitContext& backend_init_context,
59       BackendDelegate* out) {
60     // Look up the backend.
61     ET_CHECK_OR_RETURN_ERROR(
62         delegate.id() != nullptr, InvalidProgram, "Missing backend id");
63     const char* backend_id = delegate.id()->c_str();
64     BackendInterface* backend = get_backend_class(backend_id);
65     ET_CHECK_OR_RETURN_ERROR(
66         backend != nullptr,
67         NotFound,
68         "Backend %s is not registered.",
69         backend_id);
70     ET_CHECK_OR_RETURN_ERROR(
71         backend->is_available(),
72         NotFound,
73         "Backend %s is not available.",
74         backend_id);
75 
76     // Get the delegate data.
77     Result<FreeableBuffer> processed_data = GetProcessedData(delegate, program);
78     if (!processed_data.ok()) {
79       ET_LOG(Error, "Failed to load data for backend %s", backend_id);
80       return processed_data.error();
81     }
82 
83     // Parse compilation specs from program
84     CompileSpec* compile_specs;
85     Error err = PopulateCompileSpecs(
86         delegate.compile_specs(), backend_init_context, &compile_specs);
87     if (err != Error::Ok) {
88       ET_LOG(Error, "Failed to get compile specs for backend %s", backend_id);
89       return err;
90     }
91     size_t num_compile_specs = delegate.compile_specs()->size();
92 
93     out->backend_ = backend;
94     out->handle_ = nullptr;
95     // Pass a pointer to this buffer to the backend. It's safe for the backend
96     // to point its handle to this object, since it will outlive the backend.
97     new (&out->segment_) FreeableBuffer(std::move(processed_data.get()));
98 
99     // Initialize the delegate.
100     Result<DelegateHandle*> handle = backend->init(
101         backend_init_context,
102         &out->segment_,
103         ArrayRef<CompileSpec>(compile_specs, num_compile_specs));
104     if (!handle.ok()) {
105       ET_LOG(
106           Error,
107           "Init failed for backend %s: 0x%" PRIx32,
108           backend_id,
109           static_cast<uint32_t>(handle.error()));
110       out->segment_.Free();
111       return handle.error();
112     }
113     out->handle_ = handle.get();
114     return Error::Ok;
115   }
116 
~BackendDelegate()117   ~BackendDelegate() {
118     if (backend_ != nullptr) {
119       backend_->destroy(handle_);
120     }
121   }
122 
Execute(BackendExecutionContext & backend_execution_context,EValue ** args) const123   Error Execute(
124       BackendExecutionContext& backend_execution_context,
125       EValue** args) const {
126     EXECUTORCH_SCOPE_PROF("delegate_execute");
127     return backend_->execute(backend_execution_context, handle_, args);
128   }
129 
130  private:
131   // Not constructible.
132   BackendDelegate() = delete;
133 
134   // Disallow copy/move.
135   BackendDelegate(const BackendDelegate&) = delete;
136   BackendDelegate& operator=(const BackendDelegate&) = delete;
137   BackendDelegate(BackendDelegate&&) = delete;
138   BackendDelegate& operator=(BackendDelegate&&) = delete;
139 
PopulateCompileSpecs(const flatbuffers::Vector<flatbuffers::Offset<executorch_flatbuffer::CompileSpec>> * compile_specs_in_program,BackendInitContext & backend_init_context,CompileSpec ** out_spec)140   static Error PopulateCompileSpecs(
141       const flatbuffers::Vector<flatbuffers::Offset<
142           executorch_flatbuffer::CompileSpec>>* compile_specs_in_program,
143       BackendInitContext& backend_init_context,
144       CompileSpec** out_spec) {
145     auto number_of_compile_specs = compile_specs_in_program->size();
146 
147     CompileSpec* compile_specs_list =
148         backend_init_context.get_runtime_allocator()->allocateList<CompileSpec>(
149             number_of_compile_specs);
150     if (compile_specs_list == nullptr) {
151       return Error::MemoryAllocationFailed;
152     }
153 
154     // Initialize the spec list for each method spec
155     for (size_t j = 0; j < number_of_compile_specs; j++) {
156       auto compile_spec_in_program = compile_specs_in_program->Get(j);
157 
158       compile_specs_list[j].key = compile_spec_in_program->key()->c_str();
159       compile_specs_list[j].value = {
160           /*buffer*/ static_cast<void*>(
161               const_cast<uint8_t*>(compile_spec_in_program->value()->Data())),
162           /*nbytes*/ compile_spec_in_program->value()->size(),
163       };
164     }
165 
166     *out_spec = compile_specs_list;
167     return Error::Ok;
168   }
169 
GetProcessedData(const executorch_flatbuffer::BackendDelegate & delegate,const Program * program)170   static Result<FreeableBuffer> GetProcessedData(
171       const executorch_flatbuffer::BackendDelegate& delegate,
172       const Program* program) {
173     const executorch_flatbuffer::BackendDelegateDataReference* processed =
174         delegate.processed();
175     switch (processed->location()) {
176       case executorch_flatbuffer::DataLocation::INLINE: {
177         const void* data;
178         size_t size;
179         Error err = program->get_backend_delegate_data(
180             processed->index(), &data, &size);
181         if (err != Error::Ok) {
182           return err;
183         }
184         return FreeableBuffer(
185             data,
186             size,
187             /*free_fn=*/nullptr);
188       }
189       case executorch_flatbuffer::DataLocation::SEGMENT: {
190         const char* backend_id = delegate.id()->c_str();
191         return program->LoadSegment(DataLoader::SegmentInfo(
192             DataLoader::SegmentInfo::Type::Backend,
193             processed->index(),
194             backend_id));
195       }
196       default:
197         ET_LOG(
198             Error,
199             "Unknown data location %u",
200             static_cast<unsigned int>(processed->location()));
201         return Error::Internal;
202     }
203   }
204 
205   FreeableBuffer segment_;
206   const BackendInterface* backend_;
207   DelegateHandle* handle_;
208 };
209 
210 /**
211  * Runtime state for a chain of instructions.
212  */
213 struct Chain {
214   /// Pointer to the associated flatbuffer chain.
215   const executorch_flatbuffer::Chain* s_chain_;
216 
217   /// Each entry is a list of parameters for a kernel or delegate call.
218   Span<InstructionArgs> argument_lists_;
219   /// Each instruction will have one kernel (not for delegate).
220   OpFunction* kernels_;
221 };
222 
223 namespace {
224 
gen_instruction_arguments(MemoryAllocator * method_allocator,size_t num_values,EValue * values,size_t num_args,const int32_t * arg_idxs)225 Result<InstructionArgs> gen_instruction_arguments(
226     MemoryAllocator* method_allocator,
227     size_t num_values,
228     EValue* values,
229     size_t num_args,
230     const int32_t* arg_idxs) {
231   EValue** arg_list = method_allocator->allocateList<EValue*>(num_args);
232   if (arg_list == nullptr) {
233     return Error::MemoryAllocationFailed;
234   }
235   for (size_t i = 0; i < num_args; ++i) {
236     int32_t arg_idx = arg_idxs[i];
237     ET_CHECK_OR_RETURN_ERROR(
238         arg_idx < num_values,
239         InvalidProgram,
240         "Arg index %d >= %zu",
241         arg_idx,
242         num_values);
243     arg_list[i] = &values[arg_idx];
244   }
245   return InstructionArgs(arg_list, num_args);
246 }
247 
parse_cond_value(const EValue & cond_value)248 Result<bool> parse_cond_value(const EValue& cond_value) {
249   // The cond value attached to the JF instruction at the beginning of an
250   // if/else branch is a Tensor which we parse and decide whether to continue
251   // to execute the if branch or jump to the else branch.
252   // The cond value attached to the JF instruction at the end of the if branch
253   // is a Bool Scalar which resolves to false and points us to the instruction
254   // to jump to which will take us to a point that is after the else branch.
255   if (cond_value.isTensor()) {
256     const exec_aten::Tensor& cond_val = cond_value.toTensor();
257 
258     // All the tensors and scalar cond values should be of bool type
259     // currently. If that's not the case then something is wrong in the model
260     // and we should exit.
261     ET_CHECK_OR_RETURN_ERROR(
262         exec_aten::ScalarType::Bool == cond_val.scalar_type(),
263         InvalidProgram,
264         "Expected dtype of %" PRId8 " got %" PRId8,
265         static_cast<int8_t>(exec_aten::ScalarType::Bool),
266         static_cast<int8_t>(cond_val.scalar_type()));
267 
268     const bool* cond_data = cond_val.const_data_ptr<bool>();
269     for (size_t i = 0; i < cond_val.numel(); i++) {
270       if (!cond_data[i]) {
271         return false;
272       }
273     }
274   } else if (cond_value.isBool()) {
275     if (!cond_value.toBool()) {
276       return false;
277     }
278   } else {
279     ET_LOG(
280         Error, "Unsupported JF EValue type %" PRIu32, (uint32_t)cond_value.tag);
281     return Error::InvalidProgram;
282   }
283 
284   return true;
285 }
286 
287 } // namespace
288 
parse_values()289 Error Method::parse_values() {
290   auto flatbuffer_values = serialization_plan_->values();
291   ET_CHECK_OR_RETURN_ERROR(
292       flatbuffer_values != nullptr, InvalidProgram, "Missing values");
293   size_t n_value = flatbuffer_values->size();
294   values_ = memory_manager_->method_allocator()->allocateList<EValue>(n_value);
295   if (values_ == nullptr) {
296     return Error::MemoryAllocationFailed;
297   }
298 
299   // n_value_ counts the number of successfully-initialized values for ~Method()
300   // to clean up, and is incremented at the bottom of the loop. This makes it
301   // safe for errors to return without updating any state.
302   n_value_ = 0;
303 
304   for (size_t i = 0; i < n_value; ++i) {
305     auto serialization_value = flatbuffer_values->Get(i);
306     // Ensure that the `val_as_X()` calls will return non-null pointers.
307     ET_CHECK_OR_RETURN_ERROR(
308         serialization_value != nullptr &&
309             (serialization_value->val_type() ==
310                  executorch_flatbuffer::KernelTypes::Null ||
311              serialization_value->val() != nullptr),
312         InvalidProgram,
313         "Null value at index %zu",
314         i);
315 
316     switch (serialization_value->val_type()) {
317       case executorch_flatbuffer::KernelTypes::Null: {
318         // Placement new as the list elements are not initialized, so calling
319         // copy assignment is not defined if its non trivial (Imagine the
320         // garbage in values_[i] thinks its an at::Tensor).
321         new (&values_[i]) EValue();
322       } break;
323       case executorch_flatbuffer::KernelTypes::Int: {
324         new (&values_[i]) EValue(serialization_value->val_as_Int()->int_val());
325       } break;
326       case executorch_flatbuffer::KernelTypes::Double: {
327         new (&values_[i])
328             EValue(serialization_value->val_as_Double()->double_val());
329       } break;
330       case executorch_flatbuffer::KernelTypes::Bool: {
331         new (&values_[i])
332             EValue(serialization_value->val_as_Bool()->bool_val());
333       } break;
334       case executorch_flatbuffer::KernelTypes::IntList: {
335         const auto items = serialization_value->val_as_IntList()->items();
336         ET_CHECK_OR_RETURN_ERROR(
337             items != nullptr, InvalidProgram, "Missing list at index %zu", i);
338         // Allocate space for boxed and unboxed list representations using
339         // values_ as source of truth
340         auto* evalp_list =
341             memory_manager_->method_allocator()->allocateList<EValue*>(
342                 items->size());
343         auto* int_list =
344             memory_manager_->method_allocator()->allocateList<int64_t>(
345                 items->size());
346 
347         // initialize boxed list
348         for (size_t j = 0; j < items->size(); j++) {
349           evalp_list[j] = &values_[static_cast<size_t>(items->Get(j))];
350         }
351         new (&values_[i]) EValue(
352             BoxedEvalueList<int64_t>(evalp_list, int_list, items->size()));
353       } break;
354       case executorch_flatbuffer::KernelTypes::BoolList: {
355         const auto items = serialization_value->val_as_BoolList()->items();
356         ET_CHECK_OR_RETURN_ERROR(
357             items != nullptr, InvalidProgram, "Missing list at index %zu", i);
358         // NOTE: This is technically not portable. A platform could technically
359         // define boolean as something longer than a byte. This would be an
360         // exceptionally rare case, and this type is currently unused in any
361         // operators in ATen that we would need to support. To be properly
362         // portable here we need to allocate a new array of bool and copy cast
363         // the flatbuffer data into it, but because of how exceptionally rare
364         // this case is its low prio TODO: jakeszwe
365         new (&values_[i]) EValue(exec_aten::ArrayRef<bool>(
366             (const bool*)items->data(), items->size()));
367       } break;
368       case executorch_flatbuffer::KernelTypes::DoubleList: {
369         const auto items = serialization_value->val_as_DoubleList()->items();
370         ET_CHECK_OR_RETURN_ERROR(
371             items != nullptr, InvalidProgram, "Missing list at index %zu", i);
372         new (&values_[i])
373             EValue(exec_aten::ArrayRef<double>(items->data(), items->size()));
374       } break;
375       case executorch_flatbuffer::KernelTypes::String: {
376         const auto fb_str = serialization_value->val_as_String()->string_val();
377         ET_CHECK_OR_RETURN_ERROR(
378             fb_str != nullptr,
379             InvalidProgram,
380             "Missing string at index %zu",
381             i);
382         new (&values_[i]) EValue(fb_str->c_str(), fb_str->size());
383       } break;
384       case executorch_flatbuffer::KernelTypes::Tensor: {
385         auto t = deserialization::parseTensor(
386             program_, memory_manager_, serialization_value->val_as_Tensor());
387         if (!t.ok()) {
388           ET_LOG(
389               Error,
390               "Failed parsing tensor at index %zu: 0x%" PRIx32,
391               i,
392               static_cast<uint32_t>(t.error()));
393           return t.error();
394         }
395         new (&values_[i]) EValue(t.get());
396       } break;
397       case executorch_flatbuffer::KernelTypes::TensorList: {
398         // get list of serialization tensors and allocate storage for executor
399         // tensors
400         auto tensors = deserialization::parseTensorList(
401             serialization_value->val_as_TensorList()->items(),
402             values_,
403             memory_manager_);
404         if (!tensors.ok()) {
405           ET_LOG(
406               Error,
407               "Failed parsing tensor list at index %zu: 0x%" PRIx32,
408               i,
409               static_cast<uint32_t>(tensors.error()));
410           return tensors.error();
411         }
412         new (&values_[i]) EValue(tensors.get());
413       } break;
414       case executorch_flatbuffer::KernelTypes::OptionalTensorList: {
415         // Same as TensorList but optional<Tensor> instead of Tensor
416         auto tensors =
417             deserialization::parseListOptionalType<exec_aten::Tensor>(
418                 serialization_value->val_as_OptionalTensorList()->items(),
419                 values_,
420                 memory_manager_);
421         if (!tensors.ok()) {
422           ET_LOG(
423               Error,
424               "Failed parsing optional tensor list at index %zu: 0x%" PRIx32,
425               i,
426               static_cast<uint32_t>(tensors.error()));
427           return tensors.error();
428         }
429         new (&values_[i]) EValue(tensors.get());
430       } break;
431       default:
432         // flatbuffer enums start at 0, but they generate a hidden NONE enum
433         // and give it that value. schema.fbs doesnt show this type, so I
434         // subtract one to keep the output in 0 based indexing for a
435         // disgruntled debugger seeing this error message and checking
436         // schema.fbs
437         ET_LOG(
438             Error,
439             "Unknown KernelTypes value %" PRIu32 " at index %zu",
440             static_cast<uint32_t>(serialization_value->val_type()) - 1,
441             i);
442         return Error::InvalidProgram;
443     }
444 
445     // ~Method() will try to clean up n_value_ entries in the values_ array.
446     // Only increment this once we know the entry is valid, so that we don't try
447     // to clean up an uninitialized entry.
448     n_value_ = i + 1;
449   }
450   return Error::Ok;
451 }
452 
453 namespace {
454 /**
455  * Private/helper method for populating operator_name from the Operator.
456  * operator_name is a char pointer that is already allocated. The size of
457  * of this buffer is of size operator_name_size.
458  */
populate_operator_name(const executorch_flatbuffer::Operator * const & op,const size_t operator_name_size,char * operator_name)459 Error populate_operator_name(
460     const executorch_flatbuffer::Operator* const& op,
461     const size_t operator_name_size,
462     char* operator_name) {
463   const bool has_overload =
464       op->overload() != nullptr && op->overload()->size() > 0;
465 
466   ET_CHECK_OR_RETURN_ERROR(
467       op->name() != nullptr, InvalidProgram, "Missing operator name");
468   int cx = snprintf(
469       operator_name,
470       operator_name_size,
471       "%s%s%s",
472       op->name()->c_str(),
473       // Don't append any overload if the overload string is empty.
474       has_overload ? "." : "",
475       has_overload ? op->overload()->c_str() : "");
476   ET_CHECK_OR_RETURN_ERROR(cx >= 0, Internal, "snprintf failed: %d", cx);
477   ET_CHECK_OR_RETURN_ERROR(
478       cx < operator_name_size,
479       Internal,
480       "Operator name %s%s%s with length %d "
481       "truncated to %zu due to internal buffer limit.",
482       op->name()->c_str(),
483       has_overload ? "." : "",
484       has_overload ? op->overload()->c_str() : "",
485       cx,
486       operator_name_size);
487 
488   return Error::Ok;
489 }
490 } // namespace
491 
resolve_operator(int32_t op_index,OpFunction * kernels,size_t kernel_index,InstructionArgs args,size_t n_args)492 Error Method::resolve_operator(
493     int32_t op_index,
494     OpFunction* kernels,
495     size_t kernel_index,
496     InstructionArgs args,
497     size_t n_args) {
498   // TODO(T153505381, T153506819) Investigate optimizing this function for both
499   // space and time.
500 
501   // resolve name
502   constexpr size_t kTempBufferSizeForName = 100;
503   char operator_name[kTempBufferSizeForName];
504   const auto ops = serialization_plan_->operators();
505   ET_CHECK_OR_RETURN_ERROR(
506       ops != nullptr && op_index < ops->size(),
507       InvalidProgram,
508       "Op index %" PRIu32 " out of range",
509       op_index);
510   const auto& op = ops->Get(op_index);
511 
512   Error err = populate_operator_name(op, kTempBufferSizeForName, operator_name);
513   if (err != Error::Ok) {
514     return err;
515   }
516 
517   // resolve tensor meta
518   auto method_allocator = memory_manager_->method_allocator();
519   TensorMeta* meta = method_allocator->allocateList<TensorMeta>(n_args);
520   if (meta == nullptr) {
521     return Error::MemoryAllocationFailed;
522   }
523 
524   size_t count = 0;
525   for (size_t i = 0; i < n_args; i++) {
526     EValue* eval = args[i];
527     // handle tensor list as well
528     if (eval->isTensor()) {
529       auto tensor = eval->toTensor();
530       meta[count].dtype_ = tensor.scalar_type();
531       exec_aten::DimOrderType* dim_order_ptr =
532           method_allocator->allocateList<exec_aten::DimOrderType>(tensor.dim());
533       if (dim_order_ptr == nullptr) {
534         return Error::MemoryAllocationFailed;
535       }
536       size_t size = tensor.dim();
537       err = get_dim_order(tensor, dim_order_ptr, size);
538       ET_CHECK_OR_RETURN_ERROR(
539           err == Error::Ok,
540           InvalidArgument,
541           "Error setting dim_order %zu: 0x%" PRIx32,
542           i,
543           static_cast<uint32_t>(err));
544       meta[count].dim_order_ =
545           Span<exec_aten::DimOrderType>(dim_order_ptr, size);
546       count++;
547     }
548   }
549 
550   // Find a kernel with the matching name and tensor meta.
551   Result<OpFunction> op_function =
552       get_op_function_from_registry(operator_name, {meta, count});
553   if (!op_function.ok()) {
554     ET_LOG(Error, "Missing operator: [%d] %s", op_index, operator_name);
555     return op_function.error();
556   }
557   kernels[kernel_index] = op_function.get();
558   return Error::Ok;
559 }
560 
load(executorch_flatbuffer::ExecutionPlan * s_plan,const Program * program,MemoryManager * memory_manager,EventTracer * event_tracer)561 Result<Method> Method::load(
562     executorch_flatbuffer::ExecutionPlan* s_plan,
563     const Program* program,
564     MemoryManager* memory_manager,
565     EventTracer* event_tracer) {
566   MemoryAllocator* temp_allocator = memory_manager->temp_allocator();
567   if (temp_allocator == nullptr) {
568     PlatformMemoryAllocator* platform_allocator =
569         memory_manager->method_allocator()
570             ->allocateInstance<PlatformMemoryAllocator>();
571     if (platform_allocator == nullptr) {
572       return Error::MemoryAllocationFailed;
573     }
574     new (platform_allocator) PlatformMemoryAllocator();
575     temp_allocator = platform_allocator;
576   }
577   Method method(program, memory_manager, event_tracer, temp_allocator);
578 
579   Error err = method.init(s_plan);
580   if (err != Error::Ok) {
581     return err;
582   } else {
583     ET_CHECK(method.initialized());
584     return method;
585   }
586 }
587 
init(executorch_flatbuffer::ExecutionPlan * s_plan)588 Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
589   EXECUTORCH_SCOPE_PROF("Method::init");
590   internal::EventTracerProfileMethodScope event_tracer_profile_scope =
591       internal::EventTracerProfileMethodScope(event_tracer_, "Method::init");
592   ET_CHECK_OR_RETURN_ERROR(
593       // Don't use !initialized() here because we also want to fail on the
594       // InitializationFailed state.
595       init_state_ == InitializationState::Uninitialized,
596       InvalidState,
597       "Method already initialized, or previously failed to initialize.");
598   init_state_ =
599       InitializationState::InitializationFailed; // Until proven otherwise
600   serialization_plan_ = s_plan;
601   auto method_allocator = memory_manager_->method_allocator();
602 
603   {
604     // Parse the elements of the values_ array.
605     Error err = parse_values();
606     if (err != Error::Ok) {
607       return err;
608     }
609   }
610 
611   {
612     // Resolve delegates
613     const auto delegates = serialization_plan_->delegates();
614     ET_CHECK_OR_RETURN_ERROR(
615         delegates != nullptr, InvalidProgram, "Missing delegates field");
616     size_t n_delegate = delegates->size();
617     delegates_ = method_allocator->allocateList<BackendDelegate>(n_delegate);
618     if (delegates_ == nullptr) {
619       return Error::MemoryAllocationFailed;
620     }
621 
622     // n_delegate_ counts the number of successfully-initialized delegates for
623     // ~Method() to clean up, and is incremented at the bottom of the loop. This
624     // makes it safe for errors to return without updating any state.
625     n_delegate_ = 0;
626 
627     for (size_t i = 0; i < n_delegate; ++i) {
628       const auto& delegate = *delegates->Get(i);
629       BackendInitContext backend_init_context(
630           method_allocator,
631           /*method_name=*/serialization_plan_->name()->c_str());
632       Error err = BackendDelegate::Init(
633           delegate, program_, backend_init_context, &delegates_[i]);
634       if (err != Error::Ok) {
635         return err;
636       }
637       // ~Method() will try to clean up n_delegate_ entries in the delegates_
638       // array. Only increment this once we know the entry is valid, so that
639       // we don't try to clean up an uninitialized entry.
640       n_delegate_ = i + 1;
641     }
642   }
643 
644   {
645     // Load chains
646     const auto chains = serialization_plan_->chains();
647     ET_CHECK_OR_RETURN_ERROR(
648         chains != nullptr && chains->size() > 0, InvalidProgram, "No chains");
649     n_chains_ = chains->size();
650     chains_ = method_allocator->allocateList<Chain>(n_chains_);
651     if (chains_ == nullptr) {
652       return Error::MemoryAllocationFailed;
653     }
654 
655     // Try resolving all operators before failing, to make it easier to debug
656     // multiple problems at once.
657     Error delayed_error = Error::Ok;
658     int32_t num_instructions_missing_op = 0;
659     for (size_t i = 0; i < n_chains_; ++i) {
660       auto s_chain = chains->Get(i);
661       auto s_instructions = s_chain->instructions();
662       ET_CHECK_OR_RETURN_ERROR(
663           s_instructions != nullptr,
664           InvalidProgram,
665           "Missing instructions in chain %zu",
666           i);
667       auto num_instructions = s_instructions->size();
668       auto chain_instruction_kernels =
669           method_allocator->allocateList<OpFunction>(num_instructions);
670       if (chain_instruction_kernels == nullptr) {
671         return Error::MemoryAllocationFailed;
672       }
673       auto chain_instruction_arg_lists =
674           method_allocator->allocateList<InstructionArgs>(num_instructions);
675       if (chain_instruction_arg_lists == nullptr) {
676         return Error::MemoryAllocationFailed;
677       }
678 
679       // Set up the argument lists ahead of time and store pointers to them to
680       // use when the instructions are called
681       for (size_t instr_idx = 0; instr_idx < s_instructions->size();
682            ++instr_idx) {
683         const auto instruction = s_instructions->Get(instr_idx);
684         // Ensure that the `instr_args_as_X()` calls will return non-null.
685         ET_CHECK_OR_RETURN_ERROR(
686             instruction != nullptr && instruction->instr_args() != nullptr,
687             InvalidProgram,
688             "Null instruction at index %zu",
689             instr_idx);
690 
691         switch (instruction->instr_args_type()) {
692           case executorch_flatbuffer::InstructionArguments::KernelCall: {
693             const auto arg_idxs =
694                 instruction->instr_args_as_KernelCall()->args();
695             ET_CHECK_OR_RETURN_ERROR(
696                 arg_idxs != nullptr, InvalidProgram, "KernelCall args missing");
697             auto res = gen_instruction_arguments(
698                 method_allocator,
699                 n_value_,
700                 values_,
701                 arg_idxs->size(),
702                 arg_idxs->data());
703             if (!res.ok()) {
704               return res.error();
705             }
706             chain_instruction_arg_lists[instr_idx] = res.get();
707             auto err = resolve_operator(
708                 instruction->instr_args_as_KernelCall()->op_index(),
709                 chain_instruction_kernels,
710                 instr_idx,
711                 res.get(),
712                 arg_idxs->size());
713             if (err == Error::OperatorMissing) {
714               num_instructions_missing_op++;
715             } else if (err == Error::MemoryAllocationFailed) {
716               return err;
717             } else {
718               delayed_error = err;
719             }
720           } break;
721           case executorch_flatbuffer::InstructionArguments::DelegateCall: {
722             const auto arg_idxs =
723                 instruction->instr_args_as_DelegateCall()->args();
724             ET_CHECK_OR_RETURN_ERROR(
725                 arg_idxs != nullptr,
726                 InvalidProgram,
727                 "DelegateCall args missing");
728             auto res = gen_instruction_arguments(
729                 method_allocator,
730                 n_value_,
731                 values_,
732                 arg_idxs->size(),
733                 arg_idxs->data());
734             if (!res.ok()) {
735               return res.error();
736             }
737             chain_instruction_arg_lists[instr_idx] = res.get();
738           } break;
739           case executorch_flatbuffer::InstructionArguments::JumpFalseCall: {
740             // Validate the index at load time so we can trust it during
741             // execution.
742             auto index =
743                 instruction->instr_args_as_JumpFalseCall()->cond_value_index();
744             ET_CHECK_OR_RETURN_ERROR(
745                 index >= 0 && index < n_value_,
746                 InvalidProgram,
747                 "Index %d negative or >= %zu",
748                 index,
749                 n_value_);
750             chain_instruction_arg_lists[instr_idx] = InstructionArgs();
751           } break;
752           default: {
753             chain_instruction_arg_lists[instr_idx] = InstructionArgs();
754           } break;
755         }
756       }
757       chains_[i] = Chain{
758           s_chain,
759           Span<InstructionArgs>(chain_instruction_arg_lists, num_instructions),
760           chain_instruction_kernels,
761       };
762     }
763     ET_CHECK_OR_RETURN_ERROR(
764         num_instructions_missing_op == 0,
765         OperatorMissing,
766         "There are %d instructions don't have corresponding operator registered. "
767         "See logs for details",
768         num_instructions_missing_op);
769     if (delayed_error != Error::Ok) {
770       return delayed_error;
771     }
772   }
773 
774   step_state_ = StepState{0, 0};
775 
776   init_state_ = InitializationState::Initialized;
777   return Error::Ok;
778 }
779 
780 ET_NODISCARD Error
set_input(const EValue & input_evalue,size_t input_idx)781 Method::set_input(const EValue& input_evalue, size_t input_idx) {
782   ET_CHECK_OR_RETURN_ERROR(
783       initialized(),
784       InvalidState,
785       "Input can not be set until method has been initialized.");
786 
787   ET_CHECK_OR_RETURN_ERROR(
788       step_state_.instr_idx == 0 && step_state_.chain_idx == 0,
789       InvalidState,
790       "Inputs can not be set mid execution.");
791 
792   ET_CHECK_OR_RETURN_ERROR(
793       input_idx < inputs_size(),
794       InvalidArgument,
795       "Given input index must be less than the number of inputs in method, but got %zu and %zu",
796       input_idx,
797       inputs_size());
798 
799   const auto& e = get_value(get_input_index(input_idx));
800   ET_CHECK_OR_RETURN_ERROR(
801       e.isTensor() || e.isScalar(),
802       InvalidArgument,
803       "The %zu-th input in method is expected Tensor or prim, but received %" PRIu32,
804       input_idx,
805       static_cast<uint32_t>(e.tag));
806 
807   ET_CHECK_OR_RETURN_ERROR(
808       e.tag == input_evalue.tag,
809       InvalidArgument,
810       "The %zu-th input of method should have the same type as the input_evalue, but get tag %" PRIu32
811       " and tag %" PRIu32,
812       input_idx,
813       static_cast<uint32_t>(e.tag),
814       static_cast<uint32_t>(input_evalue.tag));
815 
816   if (e.isTensor()) {
817     const auto& t_dst = e.toTensor();
818     const auto& t_src = input_evalue.toTensor();
819     ET_CHECK_OR_RETURN_ERROR(
820         t_dst.scalar_type() == t_src.scalar_type(),
821         InvalidArgument,
822         "The %zu-th input tensor's scalartype does not meet requirement: found %" PRId8
823         " but expected %" PRId8,
824         input_idx,
825         static_cast<int8_t>(t_src.scalar_type()),
826         static_cast<int8_t>(t_dst.scalar_type()));
827     // Reset the shape for the Method's input as the size of forwarded input
828     // tensor for shape dynamism. Also is a safety check if need memcpy.
829     Error err = resize_tensor(t_dst, t_src.sizes());
830     ET_CHECK_OR_RETURN_ERROR(
831         err == Error::Ok,
832         InvalidArgument,
833         "Error setting input %zu: 0x%" PRIx32,
834         input_idx,
835         static_cast<uint32_t>(err));
836     Error error;
837     auto tensor_meta = this->method_meta().input_tensor_meta(input_idx);
838     if (tensor_meta->is_memory_planned()) {
839       error = internal::copy_tensor_data(t_dst, t_src);
840     } else {
841       error = internal::share_tensor_data(t_dst, t_src);
842     }
843     ET_CHECK_OR_RETURN_ERROR(
844         error == Error::Ok,
845         InvalidArgument,
846         "Error setting data_ptr %zu: 0x%" PRIx32,
847         input_idx,
848         static_cast<uint32_t>(error));
849     // Prims have to be the same as what was traced
850   } else if (e.isInt()) {
851     ET_CHECK_OR_RETURN_ERROR(
852         e.toInt() == input_evalue.toInt(),
853         InvalidArgument,
854         "The %zu-th input of method should have the same value as the input_evalue, but got %" PRId64
855         " and %" PRId64,
856         input_idx,
857         e.toInt(),
858         input_evalue.toInt());
859   } else if (e.isBool()) {
860     ET_CHECK_OR_RETURN_ERROR(
861         e.toBool() == input_evalue.toBool(),
862         InvalidArgument,
863         "The %zu-th input of method should have the same value as the input_evalue, but got %" PRId64
864         " and %" PRId64,
865         input_idx,
866         (int64_t)e.toBool(),
867         (int64_t)input_evalue.toBool());
868   } else if (e.isDouble()) {
869     double lhs = input_evalue.toDouble();
870     double rhs = e.toDouble();
871     double atol = 1e-4;
872     double rtol = 1e-5;
873     bool is_equal = true;
874     if (std::isnan(lhs) && std::isnan(rhs)) {
875       // NaN == NaN
876     } else if (
877         !std::isfinite(lhs) && !std::isfinite(rhs) &&
878         ((lhs > 0) == (rhs > 0))) {
879       // -Inf == -Inf
880       // +Inf == +Inf
881     } else {
882       auto allowed_error = atol + std::abs(rtol * rhs);
883       auto actual_error = std::abs(lhs - rhs);
884       if (!std::isfinite(actual_error) || actual_error > allowed_error) {
885         is_equal = false;
886       }
887     }
888     ET_CHECK_OR_RETURN_ERROR(
889         is_equal,
890         InvalidArgument,
891         "The %zu-th input of method should have the same value as the input_evalue, but get %f and %f",
892         input_idx,
893         lhs,
894         rhs);
895   } else if (e.isString()) {
896     ET_CHECK_OR_RETURN_ERROR(
897         e.toString() == input_evalue.toString(),
898         InvalidArgument,
899         "The %zu-th input of method should have the same value as the input_evalue, but get %s and %s",
900         input_idx,
901         e.toString().data(),
902         input_evalue.toString().data());
903   } else {
904     ET_LOG(Error, "Unsupported input type: %d", (int32_t)e.tag);
905     return Error::InvalidArgument;
906   }
907   return Error::Ok;
908 }
909 
910 ET_NODISCARD Error
set_inputs(const exec_aten::ArrayRef<EValue> & input_evalues)911 Method::set_inputs(const exec_aten::ArrayRef<EValue>& input_evalues) {
912   ET_CHECK_OR_RETURN_ERROR(
913       initialized(),
914       InvalidState,
915       "Inputs can not be set until method has been initialized.");
916 
917   ET_CHECK_OR_RETURN_ERROR(
918       step_state_.instr_idx == 0 && step_state_.chain_idx == 0,
919       InvalidState,
920       "Inputs can not be set mid execution.");
921 
922   size_t input_size = inputs_size();
923   ET_CHECK_OR_RETURN_ERROR(
924       input_size == input_evalues.size(),
925       InvalidArgument,
926       "The length of given input array (%zu) must be same as the number of inputs in method (%zu).",
927       input_evalues.size(),
928       input_size);
929 
930   for (size_t i = 0; i < input_size; i++) {
931     Error status = set_input(input_evalues[i], i);
932     if (status != Error::Ok) {
933       return status;
934     }
935   }
936   return Error::Ok;
937 }
938 
939 ET_NODISCARD Error
set_output_data_ptr(void * buffer,size_t size,size_t output_idx)940 Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) {
941   // Check method state
942   ET_CHECK_OR_RETURN_ERROR(
943       initialized(),
944       InvalidState,
945       "Outputs can not be retrieved until method has been initialized.");
946 
947   // Check the args
948   ET_CHECK_OR_RETURN_ERROR(
949       output_idx < outputs_size(),
950       InvalidArgument,
951       "output_idx: %zu > num_outputs: %zu",
952       output_idx,
953       outputs_size());
954 
955   auto& output = mutable_value(get_output_index(output_idx));
956   ET_CHECK_OR_RETURN_ERROR(
957       output.isTensor(),
958       InvalidArgument,
959       "output type: %zu is not tensor",
960       (size_t)output.tag);
961 
962   auto tensor_meta = this->method_meta().output_tensor_meta(output_idx);
963   if (tensor_meta->is_memory_planned()) {
964     ET_LOG(
965         Error,
966         "Output %zu is memory planned, or is a constant. Cannot override "
967         "the existing data pointer.",
968         output_idx);
969     return Error::InvalidState;
970   }
971 
972   auto& t = output.toTensor();
973   ET_CHECK_OR_RETURN_ERROR(
974       output.isTensor(),
975       InvalidArgument,
976       "output type: %zu is not tensor",
977       (size_t)output.tag);
978   ET_CHECK_OR_RETURN_ERROR(
979       t.nbytes() <= size,
980       InvalidArgument,
981       "buffer size: %zu is smaller then expected tensor size: %zu",
982       size,
983       t.nbytes());
984 
985   // Set data
986   return internal::set_tensor_data(t, buffer, size);
987 }
988 
get_outputs(EValue * output_evalues,size_t length)989 ET_NODISCARD Error Method::get_outputs(EValue* output_evalues, size_t length) {
990   ET_CHECK_OR_RETURN_ERROR(
991       initialized(),
992       InvalidState,
993       "Outputs can not be retrieved until method has been initialized.");
994 
995   ET_CHECK_OR_RETURN_ERROR(
996       length >= outputs_size(),
997       InvalidArgument,
998       "The given array is not large enough to hold all outputs.");
999 
1000   for (size_t i = 0; i < outputs_size(); i++) {
1001     output_evalues[i] = values_[get_output_index(i)];
1002   }
1003 
1004   for (size_t i = outputs_size(); i < length; i++) {
1005     output_evalues[i] = EValue();
1006   }
1007 
1008   return Error::Ok;
1009 }
1010 
get_inputs(EValue * input_evalues,size_t length)1011 ET_NODISCARD Error Method::get_inputs(EValue* input_evalues, size_t length) {
1012   ET_CHECK_OR_RETURN_ERROR(
1013       initialized(),
1014       InvalidState,
1015       "Inputs can not be retrieved until method has been initialized.");
1016 
1017   ET_CHECK_OR_RETURN_ERROR(
1018       length >= inputs_size(),
1019       InvalidArgument,
1020       "The given array is not large enough to hold all inputs.");
1021 
1022   for (size_t i = 0; i < inputs_size(); i++) {
1023     input_evalues[i] = values_[get_input_index(i)];
1024   }
1025 
1026   for (size_t i = inputs_size(); i < length; i++) {
1027     input_evalues[i] = EValue();
1028   }
1029 
1030   return Error::Ok;
1031 }
1032 
execute_instruction()1033 Error Method::execute_instruction() {
1034   auto& chain = chains_[step_state_.chain_idx];
1035   auto instructions = chain.s_chain_->instructions();
1036 
1037   ET_CHECK_OR_RETURN_ERROR(
1038       step_state_.instr_idx < instructions->size(),
1039       Internal,
1040       "Instr index %zu >= chain[%zu] instr count %zu",
1041       step_state_.instr_idx,
1042       step_state_.chain_idx,
1043       (size_t)instructions->size());
1044 
1045   auto instruction = instructions->Get(step_state_.instr_idx);
1046   size_t next_instr_idx = step_state_.instr_idx + 1;
1047   Error err = Error::Ok;
1048 
1049   switch (instruction->instr_args_type()) {
1050     case executorch_flatbuffer::InstructionArguments::KernelCall: {
1051       EXECUTORCH_SCOPE_PROF("OPERATOR_CALL");
1052       internal::EventTracerProfileOpScope event_tracer_op_scope =
1053           internal::EventTracerProfileOpScope(event_tracer_, "OPERATOR_CALL");
1054       // TODO(T147221312): Also expose tensor resizer via the context.
1055       KernelRuntimeContext context(event_tracer_, temp_allocator_);
1056       auto args = chain.argument_lists_[step_state_.instr_idx];
1057       chain.kernels_[step_state_.instr_idx](context, args.data());
1058       // We reset the temp_allocator after the switch statement
1059       err = context.failure_state();
1060       if (err != Error::Ok) {
1061         // We know that instr_args_as_KernelCall is non-null because it was
1062         // checked at init time.
1063         auto op_index = instruction->instr_args_as_KernelCall()->op_index();
1064         auto op = serialization_plan_->operators()->Get(op_index);
1065         ET_LOG(
1066             Error,
1067             "KernelCall failed at instruction %zu:%zu in operator %s.%s: 0x%x",
1068             step_state_.chain_idx,
1069             step_state_.instr_idx,
1070             op->name()->c_str(),
1071             op->overload()->c_str(),
1072             (unsigned int)err);
1073         for (size_t i = 0; i < args.size(); ++i) {
1074           ET_LOG(
1075               Error,
1076               "arg %u with type id %u",
1077               (unsigned int)i,
1078               (unsigned int)args[i]->tag);
1079         }
1080         // TODO(T153804650): Consider logging the EValues to help with
1081         // debugging. This is a failure path, and it doesn't matter if it's a
1082         // little slow. Do the same for DelegateCall errors.
1083       }
1084     } break;
1085     case executorch_flatbuffer::InstructionArguments::DelegateCall: {
1086       EXECUTORCH_SCOPE_PROF("DELEGATE_CALL");
1087       internal::EventTracerProfileOpScope event_tracer_op_scope =
1088           internal::EventTracerProfileOpScope(event_tracer_, "DELEGATE_CALL");
1089       // We know that instr_args_as_DelegateCall is non-null because it was
1090       // checked at init time.
1091       auto delegate_idx =
1092           instruction->instr_args_as_DelegateCall()->delegate_index();
1093       ET_CHECK_OR_RETURN_ERROR(
1094           delegate_idx < n_delegate_,
1095           Internal,
1096           "DELEGATE_CALL index %" PRIu32
1097           " >= num delegates %zu at instruction %zu",
1098           delegate_idx,
1099           n_delegate_,
1100           step_state_.instr_idx);
1101       BackendExecutionContext backend_execution_context(
1102           /*event_tracer=*/event_tracer_,
1103           /*temp_allocator=*/temp_allocator_,
1104           /*method_name=*/serialization_plan_->name()->c_str());
1105       err = delegates_[delegate_idx].Execute(
1106           backend_execution_context,
1107           chain.argument_lists_[step_state_.instr_idx].data());
1108       if (err != Error::Ok) {
1109         ET_LOG(
1110             Error,
1111             "CALL_DELEGATE execute failed at instruction %zu: 0x%" PRIx32,
1112             step_state_.instr_idx,
1113             static_cast<uint32_t>(err));
1114       }
1115 
1116       // Log all the arguments of the delegate call. Ideally we'd only like to
1117       // log the outputs of the delegate, but currently we cannot know from the
1118       // arguments which are the inputs and which are the outputs, so we just
1119       // log everything. This will be changed in the future when the inputs and
1120       // ouputs are separate lists.
1121 #ifdef ET_EVENT_TRACER_ENABLED
1122       for (size_t i = 0;
1123            i < chain.argument_lists_[step_state_.instr_idx].size();
1124            i++) {
1125         EValue* arg = chain.argument_lists_[step_state_.instr_idx].data()[i];
1126         internal::event_tracer_log_evalue(event_tracer_, *arg);
1127       }
1128 #endif
1129     } break;
1130     case executorch_flatbuffer::InstructionArguments::JumpFalseCall: {
1131       EXECUTORCH_SCOPE_PROF("JF_CALL");
1132       internal::EventTracerProfileOpScope event_tracer_op_scope =
1133           internal::EventTracerProfileOpScope(event_tracer_, "JF_CALL");
1134       // We know that instr_args_as_JumpFalseCall is non-null because it was
1135       // checked at init time.
1136       auto jf_call = instruction->instr_args_as_JumpFalseCall();
1137       // We know that index is a valid values_ index because it was checked at
1138       // init time.
1139       auto index = jf_call->cond_value_index();
1140       Result<bool> jf_result = parse_cond_value(values_[index]);
1141       if (jf_result.ok()) {
1142         if (!jf_result.get()) {
1143           next_instr_idx = jf_call->destination_instruction();
1144         }
1145       } else {
1146         err = jf_result.error();
1147       }
1148     } break;
1149     case executorch_flatbuffer::InstructionArguments::MoveCall: {
1150       EXECUTORCH_SCOPE_PROF("MOVE_CALL");
1151       internal::EventTracerProfileOpScope event_tracer_op_scope =
1152           internal::EventTracerProfileOpScope(event_tracer_, "MOVE_CALL");
1153       // We know that instr_args_as_MoveCall is non-null because it was checked
1154       // at init time.
1155       auto move_call = instruction->instr_args_as_MoveCall();
1156       mutable_value(move_call->move_to()) = get_value(move_call->move_from());
1157     } break;
1158     case executorch_flatbuffer::InstructionArguments::FreeCall: {
1159       EXECUTORCH_SCOPE_PROF("FREE_CALL");
1160       internal::EventTracerProfileOpScope event_tracer_op_scope =
1161           internal::EventTracerProfileOpScope(event_tracer_, "FREE_CALL");
1162       // We know that instr_args_as_FreeCall is non-null because it was checked
1163       // at init time.
1164       auto free_call = instruction->instr_args_as_FreeCall();
1165       auto t = values_[free_call->value_index()].toTensor();
1166       internal::reset_data_ptr(t);
1167     } break;
1168     default:
1169       ET_LOG(
1170           Error,
1171           "Unknown instruction: %hhu",
1172           static_cast<uint8_t>(instruction->instr_args_type()));
1173       err = Error::InvalidProgram;
1174   }
1175   // Reset the temp allocator for every instruction.
1176   if (temp_allocator_ != nullptr) {
1177     temp_allocator_->reset();
1178   }
1179   if (err == Error::Ok) {
1180     step_state_.instr_idx = next_instr_idx;
1181   }
1182   return err;
1183 }
1184 
reset_execution()1185 Error Method::reset_execution() {
1186   ET_CHECK_OR_RETURN_ERROR(
1187       step_state_.chain_idx == n_chains_,
1188       InvalidState,
1189       "Cannot reset until EndOfMethod has been reached.");
1190   step_state_ = StepState{0, 0};
1191   return Error::Ok;
1192 }
1193 
experimental_reset_execution()1194 Error Method::experimental_reset_execution() {
1195   return reset_execution(); // @lint-ignore CLANGTIDY facebook-hte-Deprecated
1196 }
1197 
1198 // Log all the outputs of this method to the event tracer.
log_outputs()1199 void Method::log_outputs() {
1200 #ifdef ET_EVENT_TRACER_ENABLED
1201   if (event_tracer_ != nullptr) {
1202     if (event_tracer_->event_tracer_debug_level() >=
1203         EventTracerDebugLogLevel::kProgramOutputs) {
1204       for (size_t i = 0; i < outputs_size(); i++) {
1205         internal::event_tracer_log_evalue_output(event_tracer_, get_output(i));
1206       }
1207     }
1208   }
1209 #endif
1210 }
1211 
step()1212 Error Method::step() {
1213   EXECUTORCH_PROFILE_INSTRUCTION_SCOPE(
1214       static_cast<int32_t>(step_state_.chain_idx),
1215       static_cast<uint32_t>(step_state_.instr_idx));
1216   internal::EventTracerProfileInstructionScope event_tracer_instr_scope =
1217       internal::EventTracerProfileInstructionScope(
1218           event_tracer_,
1219           static_cast<int32_t>(step_state_.chain_idx),
1220           static_cast<uint32_t>(step_state_.instr_idx));
1221   EXECUTORCH_SCOPE_PROF("Method::step");
1222   EventTracerEntry event_tracer_entry =
1223       internal::event_tracer_begin_profiling_event(
1224           event_tracer_, "Method::step");
1225   ET_CHECK_OR_RETURN_ERROR(
1226       initialized(),
1227       InvalidState,
1228       "Cannot execute until method has been initialized.");
1229 
1230   // If chain_step_ is on n_chains_, then we have no instructions run.
1231   if (step_state_.chain_idx == n_chains_) {
1232     return Error::EndOfMethod;
1233   }
1234 
1235   auto num_instructions =
1236       chains_[step_state_.chain_idx].s_chain_->instructions()->size();
1237 
1238   // Special case chains with no instructions. These appear for example in a
1239   // model that just returns the input/a constant.
1240   if (num_instructions == 0) {
1241     step_state_.chain_idx += 1;
1242     return Error::Ok;
1243   }
1244 
1245   auto status = execute_instruction();
1246   if (status != Error::Ok) {
1247     return status;
1248   }
1249 
1250   internal::event_tracer_end_profiling_event(event_tracer_, event_tracer_entry);
1251   // end of the current chain, advance to the next chain
1252   if (step_state_.instr_idx == num_instructions) {
1253     step_state_.instr_idx = 0;
1254     step_state_.chain_idx += 1;
1255     log_outputs();
1256   }
1257   return Error::Ok;
1258 }
1259 
experimental_step()1260 Error Method::experimental_step() {
1261   return step();
1262 }
1263 
execute()1264 Error Method::execute() {
1265   internal::event_tracer_create_event_block(event_tracer_, "Execute");
1266   EventTracerEntry event_tracer_entry =
1267       internal::event_tracer_begin_profiling_event(
1268           event_tracer_, "Method::execute");
1269   EXECUTORCH_SCOPE_PROF("Method::execute");
1270   ET_CHECK_OR_RETURN_ERROR(
1271       initialized(),
1272       NotSupported,
1273       "Cannot execute until method has been initialized.");
1274 
1275   // Chains are executed sequentially today, but future async designs may
1276   // branch and run many in parallel or out of order.
1277   for (step_state_.chain_idx = 0; step_state_.chain_idx < n_chains_;
1278        ++step_state_.chain_idx) {
1279     Chain& chain = chains_[step_state_.chain_idx];
1280     auto instructions = chain.s_chain_->instructions();
1281     ET_CHECK_OR_RETURN_ERROR(
1282         instructions != nullptr,
1283         Internal,
1284         "chain %zu has no instructions field",
1285         step_state_.chain_idx);
1286 
1287     // Loop over instructions
1288     step_state_.instr_idx = 0;
1289     while (step_state_.instr_idx < chain.s_chain_->instructions()->size()) {
1290       EXECUTORCH_PROFILE_INSTRUCTION_SCOPE(
1291           static_cast<int32_t>(step_state_.chain_idx),
1292           static_cast<uint32_t>(step_state_.instr_idx));
1293       internal::EventTracerProfileInstructionScope event_tracer_instr_scope =
1294           internal::EventTracerProfileInstructionScope(
1295               event_tracer_,
1296               static_cast<ChainID>(step_state_.chain_idx),
1297               static_cast<DebugHandle>(step_state_.instr_idx));
1298       auto status = execute_instruction();
1299       if (status != Error::Ok) {
1300         return status;
1301       }
1302     }
1303   }
1304   internal::event_tracer_end_profiling_event(event_tracer_, event_tracer_entry);
1305   log_outputs();
1306 
1307   // TODO(jakeszwe, dbort): Decide on calling execute back to back without
1308   // going through the reset api first.
1309   return reset_execution(); // @lint-ignore CLANGTIDY facebook-hte-Deprecated
1310 }
1311 
method_meta() const1312 MethodMeta Method::method_meta() const {
1313   auto name = serialization_plan_->name()->c_str();
1314   auto method_meta = program_->method_meta(name);
1315   ET_CHECK_MSG(
1316       method_meta.ok(),
1317       "Internal error: method_meta(%s) returned 0x%" PRIx32,
1318       name,
1319       static_cast<uint32_t>(method_meta.error()));
1320   return method_meta.get();
1321 }
1322 
get_value(size_t i) const1323 const EValue& Method::get_value(size_t i) const {
1324   ET_CHECK_MSG(i < n_value_, "%zu >= %zu", i, n_value_);
1325   return values_[i];
1326 }
1327 
mutable_value(size_t i)1328 EValue& Method::mutable_value(size_t i) {
1329   ET_CHECK_MSG(i < n_value_, "%zu >= %zu", i, n_value_);
1330   return values_[i];
1331 }
1332 
inputs_size() const1333 size_t Method::inputs_size() const {
1334   const auto* inputs = serialization_plan_->inputs();
1335   return inputs == nullptr ? 0 : inputs->size();
1336 }
1337 
get_input_index(size_t i) const1338 size_t Method::get_input_index(size_t i) const {
1339   ET_CHECK_MSG(i < inputs_size(), "%zu >= %zu", i, inputs_size());
1340   return static_cast<size_t>(serialization_plan_->inputs()->Get(i));
1341 }
1342 
get_input(size_t i) const1343 const EValue& Method::get_input(size_t i) const {
1344   return get_value(get_input_index(i));
1345 }
1346 
mutable_input(size_t i)1347 EValue& Method::mutable_input(size_t i) {
1348   return mutable_value(get_input_index(i));
1349 }
1350 
outputs_size() const1351 size_t Method::outputs_size() const {
1352   const auto* outputs = serialization_plan_->outputs();
1353   return outputs == nullptr ? 0 : outputs->size();
1354 }
1355 
get_output_index(size_t i) const1356 size_t Method::get_output_index(size_t i) const {
1357   ET_CHECK_MSG(i < outputs_size(), "%zu >= %zu", i, outputs_size());
1358   return static_cast<size_t>(serialization_plan_->outputs()->Get(i));
1359 }
1360 
get_output(size_t i) const1361 const EValue& Method::get_output(size_t i) const {
1362   return get_value(get_output_index(i));
1363 }
1364 
mutable_output(size_t i)1365 EValue& Method::mutable_output(size_t i) {
1366   return mutable_value(get_output_index(i));
1367 }
1368 
get_event_tracer()1369 EventTracer* Method::get_event_tracer() {
1370   return event_tracer_;
1371 }
1372 
~Method()1373 Method::~Method() {
1374   // Destroy the values. It's necessary in ATen mode, where the refcount of
1375   // Tensors needs to be decremented properly.
1376   if (values_ != nullptr) {
1377     for (int i = 0; i < n_value_; ++i) {
1378       values_[i].~EValue();
1379     }
1380   }
1381   // Free any resources associated with delegate backends.
1382   if (delegates_ != nullptr) {
1383     for (int i = 0; i < n_delegate_; i++) {
1384       delegates_[i].~BackendDelegate();
1385     }
1386   }
1387   // All other fields are trivially destructible.
1388 }
1389 } // namespace runtime
1390 } // namespace executorch
1391