1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/executor/method.h>
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker #include <cinttypes> // @donotremove
12*523fa7a6SAndroid Build Coastguard Worker #include <cstdint>
13*523fa7a6SAndroid Build Coastguard Worker #include <cstdio>
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/backend/interface.h>
16*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/event_tracer_hooks.h>
17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/span.h>
19*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/executor/memory_manager.h>
20*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/executor/platform_memory_allocator.h>
21*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/executor/program.h>
22*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/executor/tensor_parser.h>
23*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/kernel/kernel_runtime_context.h>
24*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/kernel/operator_registry.h>
25*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/assert.h>
26*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/log.h>
27*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/profiler.h>
28*523fa7a6SAndroid Build Coastguard Worker #include <executorch/schema/program_generated.h>
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
31*523fa7a6SAndroid Build Coastguard Worker namespace runtime {
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker using internal::PlatformMemoryAllocator;
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker /**
36*523fa7a6SAndroid Build Coastguard Worker * Runtime state for a backend delegate.
37*523fa7a6SAndroid Build Coastguard Worker */
38*523fa7a6SAndroid Build Coastguard Worker // This lint wants to wrap the class in an anonymous namespace, but it must be
39*523fa7a6SAndroid Build Coastguard Worker // visible because it's forward-declared and used in Executor.h.
40*523fa7a6SAndroid Build Coastguard Worker // @lint-ignore CLANGTIDY facebook-hte-ShadowingClass
41*523fa7a6SAndroid Build Coastguard Worker class BackendDelegate final {
42*523fa7a6SAndroid Build Coastguard Worker public:
43*523fa7a6SAndroid Build Coastguard Worker /**
44*523fa7a6SAndroid Build Coastguard Worker * Initializes an already-allocated BackendDelegate from its serialized
45*523fa7a6SAndroid Build Coastguard Worker * representation.
46*523fa7a6SAndroid Build Coastguard Worker *
47*523fa7a6SAndroid Build Coastguard Worker * @param[in] delegate The serialized backend delegate to load.
48*523fa7a6SAndroid Build Coastguard Worker * @param[in] program The serialized program to load from.
49*523fa7a6SAndroid Build Coastguard Worker * @param[in] backend_init_context The context pointer to pass to the
50*523fa7a6SAndroid Build Coastguard Worker * backend's init() method.
51*523fa7a6SAndroid Build Coastguard Worker * @param[out] out The BackendDelegate to initialize.
52*523fa7a6SAndroid Build Coastguard Worker *
53*523fa7a6SAndroid Build Coastguard Worker * @returns Error::Ok if the initialization succeeded, or an error otherwise.
54*523fa7a6SAndroid Build Coastguard Worker */
Init(const executorch_flatbuffer::BackendDelegate & delegate,const Program * program,BackendInitContext & backend_init_context,BackendDelegate * out)55*523fa7a6SAndroid Build Coastguard Worker static Error Init(
56*523fa7a6SAndroid Build Coastguard Worker const executorch_flatbuffer::BackendDelegate& delegate,
57*523fa7a6SAndroid Build Coastguard Worker const Program* program,
58*523fa7a6SAndroid Build Coastguard Worker BackendInitContext& backend_init_context,
59*523fa7a6SAndroid Build Coastguard Worker BackendDelegate* out) {
60*523fa7a6SAndroid Build Coastguard Worker // Look up the backend.
61*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
62*523fa7a6SAndroid Build Coastguard Worker delegate.id() != nullptr, InvalidProgram, "Missing backend id");
63*523fa7a6SAndroid Build Coastguard Worker const char* backend_id = delegate.id()->c_str();
64*523fa7a6SAndroid Build Coastguard Worker BackendInterface* backend = get_backend_class(backend_id);
65*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
66*523fa7a6SAndroid Build Coastguard Worker backend != nullptr,
67*523fa7a6SAndroid Build Coastguard Worker NotFound,
68*523fa7a6SAndroid Build Coastguard Worker "Backend %s is not registered.",
69*523fa7a6SAndroid Build Coastguard Worker backend_id);
70*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
71*523fa7a6SAndroid Build Coastguard Worker backend->is_available(),
72*523fa7a6SAndroid Build Coastguard Worker NotFound,
73*523fa7a6SAndroid Build Coastguard Worker "Backend %s is not available.",
74*523fa7a6SAndroid Build Coastguard Worker backend_id);
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker // Get the delegate data.
77*523fa7a6SAndroid Build Coastguard Worker Result<FreeableBuffer> processed_data = GetProcessedData(delegate, program);
78*523fa7a6SAndroid Build Coastguard Worker if (!processed_data.ok()) {
79*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Error, "Failed to load data for backend %s", backend_id);
80*523fa7a6SAndroid Build Coastguard Worker return processed_data.error();
81*523fa7a6SAndroid Build Coastguard Worker }
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Worker // Parse compilation specs from program
84*523fa7a6SAndroid Build Coastguard Worker CompileSpec* compile_specs;
85*523fa7a6SAndroid Build Coastguard Worker Error err = PopulateCompileSpecs(
86*523fa7a6SAndroid Build Coastguard Worker delegate.compile_specs(), backend_init_context, &compile_specs);
87*523fa7a6SAndroid Build Coastguard Worker if (err != Error::Ok) {
88*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Error, "Failed to get compile specs for backend %s", backend_id);
89*523fa7a6SAndroid Build Coastguard Worker return err;
90*523fa7a6SAndroid Build Coastguard Worker }
91*523fa7a6SAndroid Build Coastguard Worker size_t num_compile_specs = delegate.compile_specs()->size();
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Worker out->backend_ = backend;
94*523fa7a6SAndroid Build Coastguard Worker out->handle_ = nullptr;
95*523fa7a6SAndroid Build Coastguard Worker // Pass a pointer to this buffer to the backend. It's safe for the backend
96*523fa7a6SAndroid Build Coastguard Worker // to point its handle to this object, since it will outlive the backend.
97*523fa7a6SAndroid Build Coastguard Worker new (&out->segment_) FreeableBuffer(std::move(processed_data.get()));
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker // Initialize the delegate.
100*523fa7a6SAndroid Build Coastguard Worker Result<DelegateHandle*> handle = backend->init(
101*523fa7a6SAndroid Build Coastguard Worker backend_init_context,
102*523fa7a6SAndroid Build Coastguard Worker &out->segment_,
103*523fa7a6SAndroid Build Coastguard Worker ArrayRef<CompileSpec>(compile_specs, num_compile_specs));
104*523fa7a6SAndroid Build Coastguard Worker if (!handle.ok()) {
105*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
106*523fa7a6SAndroid Build Coastguard Worker Error,
107*523fa7a6SAndroid Build Coastguard Worker "Init failed for backend %s: 0x%" PRIx32,
108*523fa7a6SAndroid Build Coastguard Worker backend_id,
109*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(handle.error()));
110*523fa7a6SAndroid Build Coastguard Worker out->segment_.Free();
111*523fa7a6SAndroid Build Coastguard Worker return handle.error();
112*523fa7a6SAndroid Build Coastguard Worker }
113*523fa7a6SAndroid Build Coastguard Worker out->handle_ = handle.get();
114*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
115*523fa7a6SAndroid Build Coastguard Worker }
116*523fa7a6SAndroid Build Coastguard Worker
~BackendDelegate()117*523fa7a6SAndroid Build Coastguard Worker ~BackendDelegate() {
118*523fa7a6SAndroid Build Coastguard Worker if (backend_ != nullptr) {
119*523fa7a6SAndroid Build Coastguard Worker backend_->destroy(handle_);
120*523fa7a6SAndroid Build Coastguard Worker }
121*523fa7a6SAndroid Build Coastguard Worker }
122*523fa7a6SAndroid Build Coastguard Worker
Execute(BackendExecutionContext & backend_execution_context,EValue ** args) const123*523fa7a6SAndroid Build Coastguard Worker Error Execute(
124*523fa7a6SAndroid Build Coastguard Worker BackendExecutionContext& backend_execution_context,
125*523fa7a6SAndroid Build Coastguard Worker EValue** args) const {
126*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_SCOPE_PROF("delegate_execute");
127*523fa7a6SAndroid Build Coastguard Worker return backend_->execute(backend_execution_context, handle_, args);
128*523fa7a6SAndroid Build Coastguard Worker }
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker private:
131*523fa7a6SAndroid Build Coastguard Worker // Not constructible.
132*523fa7a6SAndroid Build Coastguard Worker BackendDelegate() = delete;
133*523fa7a6SAndroid Build Coastguard Worker
134*523fa7a6SAndroid Build Coastguard Worker // Disallow copy/move.
135*523fa7a6SAndroid Build Coastguard Worker BackendDelegate(const BackendDelegate&) = delete;
136*523fa7a6SAndroid Build Coastguard Worker BackendDelegate& operator=(const BackendDelegate&) = delete;
137*523fa7a6SAndroid Build Coastguard Worker BackendDelegate(BackendDelegate&&) = delete;
138*523fa7a6SAndroid Build Coastguard Worker BackendDelegate& operator=(BackendDelegate&&) = delete;
139*523fa7a6SAndroid Build Coastguard Worker
PopulateCompileSpecs(const flatbuffers::Vector<flatbuffers::Offset<executorch_flatbuffer::CompileSpec>> * compile_specs_in_program,BackendInitContext & backend_init_context,CompileSpec ** out_spec)140*523fa7a6SAndroid Build Coastguard Worker static Error PopulateCompileSpecs(
141*523fa7a6SAndroid Build Coastguard Worker const flatbuffers::Vector<flatbuffers::Offset<
142*523fa7a6SAndroid Build Coastguard Worker executorch_flatbuffer::CompileSpec>>* compile_specs_in_program,
143*523fa7a6SAndroid Build Coastguard Worker BackendInitContext& backend_init_context,
144*523fa7a6SAndroid Build Coastguard Worker CompileSpec** out_spec) {
145*523fa7a6SAndroid Build Coastguard Worker auto number_of_compile_specs = compile_specs_in_program->size();
146*523fa7a6SAndroid Build Coastguard Worker
147*523fa7a6SAndroid Build Coastguard Worker CompileSpec* compile_specs_list =
148*523fa7a6SAndroid Build Coastguard Worker backend_init_context.get_runtime_allocator()->allocateList<CompileSpec>(
149*523fa7a6SAndroid Build Coastguard Worker number_of_compile_specs);
150*523fa7a6SAndroid Build Coastguard Worker if (compile_specs_list == nullptr) {
151*523fa7a6SAndroid Build Coastguard Worker return Error::MemoryAllocationFailed;
152*523fa7a6SAndroid Build Coastguard Worker }
153*523fa7a6SAndroid Build Coastguard Worker
154*523fa7a6SAndroid Build Coastguard Worker // Initialize the spec list for each method spec
155*523fa7a6SAndroid Build Coastguard Worker for (size_t j = 0; j < number_of_compile_specs; j++) {
156*523fa7a6SAndroid Build Coastguard Worker auto compile_spec_in_program = compile_specs_in_program->Get(j);
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker compile_specs_list[j].key = compile_spec_in_program->key()->c_str();
159*523fa7a6SAndroid Build Coastguard Worker compile_specs_list[j].value = {
160*523fa7a6SAndroid Build Coastguard Worker /*buffer*/ static_cast<void*>(
161*523fa7a6SAndroid Build Coastguard Worker const_cast<uint8_t*>(compile_spec_in_program->value()->Data())),
162*523fa7a6SAndroid Build Coastguard Worker /*nbytes*/ compile_spec_in_program->value()->size(),
163*523fa7a6SAndroid Build Coastguard Worker };
164*523fa7a6SAndroid Build Coastguard Worker }
165*523fa7a6SAndroid Build Coastguard Worker
166*523fa7a6SAndroid Build Coastguard Worker *out_spec = compile_specs_list;
167*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
168*523fa7a6SAndroid Build Coastguard Worker }
169*523fa7a6SAndroid Build Coastguard Worker
GetProcessedData(const executorch_flatbuffer::BackendDelegate & delegate,const Program * program)170*523fa7a6SAndroid Build Coastguard Worker static Result<FreeableBuffer> GetProcessedData(
171*523fa7a6SAndroid Build Coastguard Worker const executorch_flatbuffer::BackendDelegate& delegate,
172*523fa7a6SAndroid Build Coastguard Worker const Program* program) {
173*523fa7a6SAndroid Build Coastguard Worker const executorch_flatbuffer::BackendDelegateDataReference* processed =
174*523fa7a6SAndroid Build Coastguard Worker delegate.processed();
175*523fa7a6SAndroid Build Coastguard Worker switch (processed->location()) {
176*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::DataLocation::INLINE: {
177*523fa7a6SAndroid Build Coastguard Worker const void* data;
178*523fa7a6SAndroid Build Coastguard Worker size_t size;
179*523fa7a6SAndroid Build Coastguard Worker Error err = program->get_backend_delegate_data(
180*523fa7a6SAndroid Build Coastguard Worker processed->index(), &data, &size);
181*523fa7a6SAndroid Build Coastguard Worker if (err != Error::Ok) {
182*523fa7a6SAndroid Build Coastguard Worker return err;
183*523fa7a6SAndroid Build Coastguard Worker }
184*523fa7a6SAndroid Build Coastguard Worker return FreeableBuffer(
185*523fa7a6SAndroid Build Coastguard Worker data,
186*523fa7a6SAndroid Build Coastguard Worker size,
187*523fa7a6SAndroid Build Coastguard Worker /*free_fn=*/nullptr);
188*523fa7a6SAndroid Build Coastguard Worker }
189*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::DataLocation::SEGMENT: {
190*523fa7a6SAndroid Build Coastguard Worker const char* backend_id = delegate.id()->c_str();
191*523fa7a6SAndroid Build Coastguard Worker return program->LoadSegment(DataLoader::SegmentInfo(
192*523fa7a6SAndroid Build Coastguard Worker DataLoader::SegmentInfo::Type::Backend,
193*523fa7a6SAndroid Build Coastguard Worker processed->index(),
194*523fa7a6SAndroid Build Coastguard Worker backend_id));
195*523fa7a6SAndroid Build Coastguard Worker }
196*523fa7a6SAndroid Build Coastguard Worker default:
197*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
198*523fa7a6SAndroid Build Coastguard Worker Error,
199*523fa7a6SAndroid Build Coastguard Worker "Unknown data location %u",
200*523fa7a6SAndroid Build Coastguard Worker static_cast<unsigned int>(processed->location()));
201*523fa7a6SAndroid Build Coastguard Worker return Error::Internal;
202*523fa7a6SAndroid Build Coastguard Worker }
203*523fa7a6SAndroid Build Coastguard Worker }
204*523fa7a6SAndroid Build Coastguard Worker
205*523fa7a6SAndroid Build Coastguard Worker FreeableBuffer segment_;
206*523fa7a6SAndroid Build Coastguard Worker const BackendInterface* backend_;
207*523fa7a6SAndroid Build Coastguard Worker DelegateHandle* handle_;
208*523fa7a6SAndroid Build Coastguard Worker };
209*523fa7a6SAndroid Build Coastguard Worker
210*523fa7a6SAndroid Build Coastguard Worker /**
211*523fa7a6SAndroid Build Coastguard Worker * Runtime state for a chain of instructions.
212*523fa7a6SAndroid Build Coastguard Worker */
213*523fa7a6SAndroid Build Coastguard Worker struct Chain {
214*523fa7a6SAndroid Build Coastguard Worker /// Pointer to the associated flatbuffer chain.
215*523fa7a6SAndroid Build Coastguard Worker const executorch_flatbuffer::Chain* s_chain_;
216*523fa7a6SAndroid Build Coastguard Worker
217*523fa7a6SAndroid Build Coastguard Worker /// Each entry is a list of parameters for a kernel or delegate call.
218*523fa7a6SAndroid Build Coastguard Worker Span<InstructionArgs> argument_lists_;
219*523fa7a6SAndroid Build Coastguard Worker /// Each instruction will have one kernel (not for delegate).
220*523fa7a6SAndroid Build Coastguard Worker OpFunction* kernels_;
221*523fa7a6SAndroid Build Coastguard Worker };
222*523fa7a6SAndroid Build Coastguard Worker
223*523fa7a6SAndroid Build Coastguard Worker namespace {
224*523fa7a6SAndroid Build Coastguard Worker
gen_instruction_arguments(MemoryAllocator * method_allocator,size_t num_values,EValue * values,size_t num_args,const int32_t * arg_idxs)225*523fa7a6SAndroid Build Coastguard Worker Result<InstructionArgs> gen_instruction_arguments(
226*523fa7a6SAndroid Build Coastguard Worker MemoryAllocator* method_allocator,
227*523fa7a6SAndroid Build Coastguard Worker size_t num_values,
228*523fa7a6SAndroid Build Coastguard Worker EValue* values,
229*523fa7a6SAndroid Build Coastguard Worker size_t num_args,
230*523fa7a6SAndroid Build Coastguard Worker const int32_t* arg_idxs) {
231*523fa7a6SAndroid Build Coastguard Worker EValue** arg_list = method_allocator->allocateList<EValue*>(num_args);
232*523fa7a6SAndroid Build Coastguard Worker if (arg_list == nullptr) {
233*523fa7a6SAndroid Build Coastguard Worker return Error::MemoryAllocationFailed;
234*523fa7a6SAndroid Build Coastguard Worker }
235*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < num_args; ++i) {
236*523fa7a6SAndroid Build Coastguard Worker int32_t arg_idx = arg_idxs[i];
237*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
238*523fa7a6SAndroid Build Coastguard Worker arg_idx < num_values,
239*523fa7a6SAndroid Build Coastguard Worker InvalidProgram,
240*523fa7a6SAndroid Build Coastguard Worker "Arg index %d >= %zu",
241*523fa7a6SAndroid Build Coastguard Worker arg_idx,
242*523fa7a6SAndroid Build Coastguard Worker num_values);
243*523fa7a6SAndroid Build Coastguard Worker arg_list[i] = &values[arg_idx];
244*523fa7a6SAndroid Build Coastguard Worker }
245*523fa7a6SAndroid Build Coastguard Worker return InstructionArgs(arg_list, num_args);
246*523fa7a6SAndroid Build Coastguard Worker }
247*523fa7a6SAndroid Build Coastguard Worker
parse_cond_value(const EValue & cond_value)248*523fa7a6SAndroid Build Coastguard Worker Result<bool> parse_cond_value(const EValue& cond_value) {
249*523fa7a6SAndroid Build Coastguard Worker // The cond value attached to the JF instruction at the beginning of an
250*523fa7a6SAndroid Build Coastguard Worker // if/else branch is a Tensor which we parse and decide whether to continue
251*523fa7a6SAndroid Build Coastguard Worker // to execute the if branch or jump to the else branch.
252*523fa7a6SAndroid Build Coastguard Worker // The cond value attached to the JF instruction at the end of the if branch
253*523fa7a6SAndroid Build Coastguard Worker // is a Bool Scalar which resolves to false and points us to the instruction
254*523fa7a6SAndroid Build Coastguard Worker // to jump to which will take us to a point that is after the else branch.
255*523fa7a6SAndroid Build Coastguard Worker if (cond_value.isTensor()) {
256*523fa7a6SAndroid Build Coastguard Worker const exec_aten::Tensor& cond_val = cond_value.toTensor();
257*523fa7a6SAndroid Build Coastguard Worker
258*523fa7a6SAndroid Build Coastguard Worker // All the tensors and scalar cond values should be of bool type
259*523fa7a6SAndroid Build Coastguard Worker // currently. If that's not the case then something is wrong in the model
260*523fa7a6SAndroid Build Coastguard Worker // and we should exit.
261*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
262*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType::Bool == cond_val.scalar_type(),
263*523fa7a6SAndroid Build Coastguard Worker InvalidProgram,
264*523fa7a6SAndroid Build Coastguard Worker "Expected dtype of %" PRId8 " got %" PRId8,
265*523fa7a6SAndroid Build Coastguard Worker static_cast<int8_t>(exec_aten::ScalarType::Bool),
266*523fa7a6SAndroid Build Coastguard Worker static_cast<int8_t>(cond_val.scalar_type()));
267*523fa7a6SAndroid Build Coastguard Worker
268*523fa7a6SAndroid Build Coastguard Worker const bool* cond_data = cond_val.const_data_ptr<bool>();
269*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < cond_val.numel(); i++) {
270*523fa7a6SAndroid Build Coastguard Worker if (!cond_data[i]) {
271*523fa7a6SAndroid Build Coastguard Worker return false;
272*523fa7a6SAndroid Build Coastguard Worker }
273*523fa7a6SAndroid Build Coastguard Worker }
274*523fa7a6SAndroid Build Coastguard Worker } else if (cond_value.isBool()) {
275*523fa7a6SAndroid Build Coastguard Worker if (!cond_value.toBool()) {
276*523fa7a6SAndroid Build Coastguard Worker return false;
277*523fa7a6SAndroid Build Coastguard Worker }
278*523fa7a6SAndroid Build Coastguard Worker } else {
279*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
280*523fa7a6SAndroid Build Coastguard Worker Error, "Unsupported JF EValue type %" PRIu32, (uint32_t)cond_value.tag);
281*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidProgram;
282*523fa7a6SAndroid Build Coastguard Worker }
283*523fa7a6SAndroid Build Coastguard Worker
284*523fa7a6SAndroid Build Coastguard Worker return true;
285*523fa7a6SAndroid Build Coastguard Worker }
286*523fa7a6SAndroid Build Coastguard Worker
287*523fa7a6SAndroid Build Coastguard Worker } // namespace
288*523fa7a6SAndroid Build Coastguard Worker
parse_values()289*523fa7a6SAndroid Build Coastguard Worker Error Method::parse_values() {
290*523fa7a6SAndroid Build Coastguard Worker auto flatbuffer_values = serialization_plan_->values();
291*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
292*523fa7a6SAndroid Build Coastguard Worker flatbuffer_values != nullptr, InvalidProgram, "Missing values");
293*523fa7a6SAndroid Build Coastguard Worker size_t n_value = flatbuffer_values->size();
294*523fa7a6SAndroid Build Coastguard Worker values_ = memory_manager_->method_allocator()->allocateList<EValue>(n_value);
295*523fa7a6SAndroid Build Coastguard Worker if (values_ == nullptr) {
296*523fa7a6SAndroid Build Coastguard Worker return Error::MemoryAllocationFailed;
297*523fa7a6SAndroid Build Coastguard Worker }
298*523fa7a6SAndroid Build Coastguard Worker
299*523fa7a6SAndroid Build Coastguard Worker // n_value_ counts the number of successfully-initialized values for ~Method()
300*523fa7a6SAndroid Build Coastguard Worker // to clean up, and is incremented at the bottom of the loop. This makes it
301*523fa7a6SAndroid Build Coastguard Worker // safe for errors to return without updating any state.
302*523fa7a6SAndroid Build Coastguard Worker n_value_ = 0;
303*523fa7a6SAndroid Build Coastguard Worker
304*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < n_value; ++i) {
305*523fa7a6SAndroid Build Coastguard Worker auto serialization_value = flatbuffer_values->Get(i);
306*523fa7a6SAndroid Build Coastguard Worker // Ensure that the `val_as_X()` calls will return non-null pointers.
307*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
308*523fa7a6SAndroid Build Coastguard Worker serialization_value != nullptr &&
309*523fa7a6SAndroid Build Coastguard Worker (serialization_value->val_type() ==
310*523fa7a6SAndroid Build Coastguard Worker executorch_flatbuffer::KernelTypes::Null ||
311*523fa7a6SAndroid Build Coastguard Worker serialization_value->val() != nullptr),
312*523fa7a6SAndroid Build Coastguard Worker InvalidProgram,
313*523fa7a6SAndroid Build Coastguard Worker "Null value at index %zu",
314*523fa7a6SAndroid Build Coastguard Worker i);
315*523fa7a6SAndroid Build Coastguard Worker
316*523fa7a6SAndroid Build Coastguard Worker switch (serialization_value->val_type()) {
317*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::KernelTypes::Null: {
318*523fa7a6SAndroid Build Coastguard Worker // Placement new as the list elements are not initialized, so calling
319*523fa7a6SAndroid Build Coastguard Worker // copy assignment is not defined if its non trivial (Imagine the
320*523fa7a6SAndroid Build Coastguard Worker // garbage in values_[i] thinks its an at::Tensor).
321*523fa7a6SAndroid Build Coastguard Worker new (&values_[i]) EValue();
322*523fa7a6SAndroid Build Coastguard Worker } break;
323*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::KernelTypes::Int: {
324*523fa7a6SAndroid Build Coastguard Worker new (&values_[i]) EValue(serialization_value->val_as_Int()->int_val());
325*523fa7a6SAndroid Build Coastguard Worker } break;
326*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::KernelTypes::Double: {
327*523fa7a6SAndroid Build Coastguard Worker new (&values_[i])
328*523fa7a6SAndroid Build Coastguard Worker EValue(serialization_value->val_as_Double()->double_val());
329*523fa7a6SAndroid Build Coastguard Worker } break;
330*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::KernelTypes::Bool: {
331*523fa7a6SAndroid Build Coastguard Worker new (&values_[i])
332*523fa7a6SAndroid Build Coastguard Worker EValue(serialization_value->val_as_Bool()->bool_val());
333*523fa7a6SAndroid Build Coastguard Worker } break;
334*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::KernelTypes::IntList: {
335*523fa7a6SAndroid Build Coastguard Worker const auto items = serialization_value->val_as_IntList()->items();
336*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
337*523fa7a6SAndroid Build Coastguard Worker items != nullptr, InvalidProgram, "Missing list at index %zu", i);
338*523fa7a6SAndroid Build Coastguard Worker // Allocate space for boxed and unboxed list representations using
339*523fa7a6SAndroid Build Coastguard Worker // values_ as source of truth
340*523fa7a6SAndroid Build Coastguard Worker auto* evalp_list =
341*523fa7a6SAndroid Build Coastguard Worker memory_manager_->method_allocator()->allocateList<EValue*>(
342*523fa7a6SAndroid Build Coastguard Worker items->size());
343*523fa7a6SAndroid Build Coastguard Worker auto* int_list =
344*523fa7a6SAndroid Build Coastguard Worker memory_manager_->method_allocator()->allocateList<int64_t>(
345*523fa7a6SAndroid Build Coastguard Worker items->size());
346*523fa7a6SAndroid Build Coastguard Worker
347*523fa7a6SAndroid Build Coastguard Worker // initialize boxed list
348*523fa7a6SAndroid Build Coastguard Worker for (size_t j = 0; j < items->size(); j++) {
349*523fa7a6SAndroid Build Coastguard Worker evalp_list[j] = &values_[static_cast<size_t>(items->Get(j))];
350*523fa7a6SAndroid Build Coastguard Worker }
351*523fa7a6SAndroid Build Coastguard Worker new (&values_[i]) EValue(
352*523fa7a6SAndroid Build Coastguard Worker BoxedEvalueList<int64_t>(evalp_list, int_list, items->size()));
353*523fa7a6SAndroid Build Coastguard Worker } break;
354*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::KernelTypes::BoolList: {
355*523fa7a6SAndroid Build Coastguard Worker const auto items = serialization_value->val_as_BoolList()->items();
356*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
357*523fa7a6SAndroid Build Coastguard Worker items != nullptr, InvalidProgram, "Missing list at index %zu", i);
358*523fa7a6SAndroid Build Coastguard Worker // NOTE: This is technically not portable. A platform could technically
359*523fa7a6SAndroid Build Coastguard Worker // define boolean as something longer than a byte. This would be an
360*523fa7a6SAndroid Build Coastguard Worker // exceptionally rare case, and this type is currently unused in any
361*523fa7a6SAndroid Build Coastguard Worker // operators in ATen that we would need to support. To be properly
362*523fa7a6SAndroid Build Coastguard Worker // portable here we need to allocate a new array of bool and copy cast
363*523fa7a6SAndroid Build Coastguard Worker // the flatbuffer data into it, but because of how exceptionally rare
364*523fa7a6SAndroid Build Coastguard Worker // this case is its low prio TODO: jakeszwe
365*523fa7a6SAndroid Build Coastguard Worker new (&values_[i]) EValue(exec_aten::ArrayRef<bool>(
366*523fa7a6SAndroid Build Coastguard Worker (const bool*)items->data(), items->size()));
367*523fa7a6SAndroid Build Coastguard Worker } break;
368*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::KernelTypes::DoubleList: {
369*523fa7a6SAndroid Build Coastguard Worker const auto items = serialization_value->val_as_DoubleList()->items();
370*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
371*523fa7a6SAndroid Build Coastguard Worker items != nullptr, InvalidProgram, "Missing list at index %zu", i);
372*523fa7a6SAndroid Build Coastguard Worker new (&values_[i])
373*523fa7a6SAndroid Build Coastguard Worker EValue(exec_aten::ArrayRef<double>(items->data(), items->size()));
374*523fa7a6SAndroid Build Coastguard Worker } break;
375*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::KernelTypes::String: {
376*523fa7a6SAndroid Build Coastguard Worker const auto fb_str = serialization_value->val_as_String()->string_val();
377*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
378*523fa7a6SAndroid Build Coastguard Worker fb_str != nullptr,
379*523fa7a6SAndroid Build Coastguard Worker InvalidProgram,
380*523fa7a6SAndroid Build Coastguard Worker "Missing string at index %zu",
381*523fa7a6SAndroid Build Coastguard Worker i);
382*523fa7a6SAndroid Build Coastguard Worker new (&values_[i]) EValue(fb_str->c_str(), fb_str->size());
383*523fa7a6SAndroid Build Coastguard Worker } break;
384*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::KernelTypes::Tensor: {
385*523fa7a6SAndroid Build Coastguard Worker auto t = deserialization::parseTensor(
386*523fa7a6SAndroid Build Coastguard Worker program_, memory_manager_, serialization_value->val_as_Tensor());
387*523fa7a6SAndroid Build Coastguard Worker if (!t.ok()) {
388*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
389*523fa7a6SAndroid Build Coastguard Worker Error,
390*523fa7a6SAndroid Build Coastguard Worker "Failed parsing tensor at index %zu: 0x%" PRIx32,
391*523fa7a6SAndroid Build Coastguard Worker i,
392*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(t.error()));
393*523fa7a6SAndroid Build Coastguard Worker return t.error();
394*523fa7a6SAndroid Build Coastguard Worker }
395*523fa7a6SAndroid Build Coastguard Worker new (&values_[i]) EValue(t.get());
396*523fa7a6SAndroid Build Coastguard Worker } break;
397*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::KernelTypes::TensorList: {
398*523fa7a6SAndroid Build Coastguard Worker // get list of serialization tensors and allocate storage for executor
399*523fa7a6SAndroid Build Coastguard Worker // tensors
400*523fa7a6SAndroid Build Coastguard Worker auto tensors = deserialization::parseTensorList(
401*523fa7a6SAndroid Build Coastguard Worker serialization_value->val_as_TensorList()->items(),
402*523fa7a6SAndroid Build Coastguard Worker values_,
403*523fa7a6SAndroid Build Coastguard Worker memory_manager_);
404*523fa7a6SAndroid Build Coastguard Worker if (!tensors.ok()) {
405*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
406*523fa7a6SAndroid Build Coastguard Worker Error,
407*523fa7a6SAndroid Build Coastguard Worker "Failed parsing tensor list at index %zu: 0x%" PRIx32,
408*523fa7a6SAndroid Build Coastguard Worker i,
409*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(tensors.error()));
410*523fa7a6SAndroid Build Coastguard Worker return tensors.error();
411*523fa7a6SAndroid Build Coastguard Worker }
412*523fa7a6SAndroid Build Coastguard Worker new (&values_[i]) EValue(tensors.get());
413*523fa7a6SAndroid Build Coastguard Worker } break;
414*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::KernelTypes::OptionalTensorList: {
415*523fa7a6SAndroid Build Coastguard Worker // Same as TensorList but optional<Tensor> instead of Tensor
416*523fa7a6SAndroid Build Coastguard Worker auto tensors =
417*523fa7a6SAndroid Build Coastguard Worker deserialization::parseListOptionalType<exec_aten::Tensor>(
418*523fa7a6SAndroid Build Coastguard Worker serialization_value->val_as_OptionalTensorList()->items(),
419*523fa7a6SAndroid Build Coastguard Worker values_,
420*523fa7a6SAndroid Build Coastguard Worker memory_manager_);
421*523fa7a6SAndroid Build Coastguard Worker if (!tensors.ok()) {
422*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
423*523fa7a6SAndroid Build Coastguard Worker Error,
424*523fa7a6SAndroid Build Coastguard Worker "Failed parsing optional tensor list at index %zu: 0x%" PRIx32,
425*523fa7a6SAndroid Build Coastguard Worker i,
426*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(tensors.error()));
427*523fa7a6SAndroid Build Coastguard Worker return tensors.error();
428*523fa7a6SAndroid Build Coastguard Worker }
429*523fa7a6SAndroid Build Coastguard Worker new (&values_[i]) EValue(tensors.get());
430*523fa7a6SAndroid Build Coastguard Worker } break;
431*523fa7a6SAndroid Build Coastguard Worker default:
432*523fa7a6SAndroid Build Coastguard Worker // flatbuffer enums start at 0, but they generate a hidden NONE enum
433*523fa7a6SAndroid Build Coastguard Worker // and give it that value. schema.fbs doesnt show this type, so I
434*523fa7a6SAndroid Build Coastguard Worker // subtract one to keep the output in 0 based indexing for a
435*523fa7a6SAndroid Build Coastguard Worker // disgruntled debugger seeing this error message and checking
436*523fa7a6SAndroid Build Coastguard Worker // schema.fbs
437*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
438*523fa7a6SAndroid Build Coastguard Worker Error,
439*523fa7a6SAndroid Build Coastguard Worker "Unknown KernelTypes value %" PRIu32 " at index %zu",
440*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(serialization_value->val_type()) - 1,
441*523fa7a6SAndroid Build Coastguard Worker i);
442*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidProgram;
443*523fa7a6SAndroid Build Coastguard Worker }
444*523fa7a6SAndroid Build Coastguard Worker
445*523fa7a6SAndroid Build Coastguard Worker // ~Method() will try to clean up n_value_ entries in the values_ array.
446*523fa7a6SAndroid Build Coastguard Worker // Only increment this once we know the entry is valid, so that we don't try
447*523fa7a6SAndroid Build Coastguard Worker // to clean up an uninitialized entry.
448*523fa7a6SAndroid Build Coastguard Worker n_value_ = i + 1;
449*523fa7a6SAndroid Build Coastguard Worker }
450*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
451*523fa7a6SAndroid Build Coastguard Worker }
452*523fa7a6SAndroid Build Coastguard Worker
453*523fa7a6SAndroid Build Coastguard Worker namespace {
454*523fa7a6SAndroid Build Coastguard Worker /**
455*523fa7a6SAndroid Build Coastguard Worker * Private/helper method for populating operator_name from the Operator.
456*523fa7a6SAndroid Build Coastguard Worker * operator_name is a char pointer that is already allocated. The size of
457*523fa7a6SAndroid Build Coastguard Worker * of this buffer is of size operator_name_size.
458*523fa7a6SAndroid Build Coastguard Worker */
populate_operator_name(const executorch_flatbuffer::Operator * const & op,const size_t operator_name_size,char * operator_name)459*523fa7a6SAndroid Build Coastguard Worker Error populate_operator_name(
460*523fa7a6SAndroid Build Coastguard Worker const executorch_flatbuffer::Operator* const& op,
461*523fa7a6SAndroid Build Coastguard Worker const size_t operator_name_size,
462*523fa7a6SAndroid Build Coastguard Worker char* operator_name) {
463*523fa7a6SAndroid Build Coastguard Worker const bool has_overload =
464*523fa7a6SAndroid Build Coastguard Worker op->overload() != nullptr && op->overload()->size() > 0;
465*523fa7a6SAndroid Build Coastguard Worker
466*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
467*523fa7a6SAndroid Build Coastguard Worker op->name() != nullptr, InvalidProgram, "Missing operator name");
468*523fa7a6SAndroid Build Coastguard Worker int cx = snprintf(
469*523fa7a6SAndroid Build Coastguard Worker operator_name,
470*523fa7a6SAndroid Build Coastguard Worker operator_name_size,
471*523fa7a6SAndroid Build Coastguard Worker "%s%s%s",
472*523fa7a6SAndroid Build Coastguard Worker op->name()->c_str(),
473*523fa7a6SAndroid Build Coastguard Worker // Don't append any overload if the overload string is empty.
474*523fa7a6SAndroid Build Coastguard Worker has_overload ? "." : "",
475*523fa7a6SAndroid Build Coastguard Worker has_overload ? op->overload()->c_str() : "");
476*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(cx >= 0, Internal, "snprintf failed: %d", cx);
477*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
478*523fa7a6SAndroid Build Coastguard Worker cx < operator_name_size,
479*523fa7a6SAndroid Build Coastguard Worker Internal,
480*523fa7a6SAndroid Build Coastguard Worker "Operator name %s%s%s with length %d "
481*523fa7a6SAndroid Build Coastguard Worker "truncated to %zu due to internal buffer limit.",
482*523fa7a6SAndroid Build Coastguard Worker op->name()->c_str(),
483*523fa7a6SAndroid Build Coastguard Worker has_overload ? "." : "",
484*523fa7a6SAndroid Build Coastguard Worker has_overload ? op->overload()->c_str() : "",
485*523fa7a6SAndroid Build Coastguard Worker cx,
486*523fa7a6SAndroid Build Coastguard Worker operator_name_size);
487*523fa7a6SAndroid Build Coastguard Worker
488*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
489*523fa7a6SAndroid Build Coastguard Worker }
490*523fa7a6SAndroid Build Coastguard Worker } // namespace
491*523fa7a6SAndroid Build Coastguard Worker
resolve_operator(int32_t op_index,OpFunction * kernels,size_t kernel_index,InstructionArgs args,size_t n_args)492*523fa7a6SAndroid Build Coastguard Worker Error Method::resolve_operator(
493*523fa7a6SAndroid Build Coastguard Worker int32_t op_index,
494*523fa7a6SAndroid Build Coastguard Worker OpFunction* kernels,
495*523fa7a6SAndroid Build Coastguard Worker size_t kernel_index,
496*523fa7a6SAndroid Build Coastguard Worker InstructionArgs args,
497*523fa7a6SAndroid Build Coastguard Worker size_t n_args) {
498*523fa7a6SAndroid Build Coastguard Worker // TODO(T153505381, T153506819) Investigate optimizing this function for both
499*523fa7a6SAndroid Build Coastguard Worker // space and time.
500*523fa7a6SAndroid Build Coastguard Worker
501*523fa7a6SAndroid Build Coastguard Worker // resolve name
502*523fa7a6SAndroid Build Coastguard Worker constexpr size_t kTempBufferSizeForName = 100;
503*523fa7a6SAndroid Build Coastguard Worker char operator_name[kTempBufferSizeForName];
504*523fa7a6SAndroid Build Coastguard Worker const auto ops = serialization_plan_->operators();
505*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
506*523fa7a6SAndroid Build Coastguard Worker ops != nullptr && op_index < ops->size(),
507*523fa7a6SAndroid Build Coastguard Worker InvalidProgram,
508*523fa7a6SAndroid Build Coastguard Worker "Op index %" PRIu32 " out of range",
509*523fa7a6SAndroid Build Coastguard Worker op_index);
510*523fa7a6SAndroid Build Coastguard Worker const auto& op = ops->Get(op_index);
511*523fa7a6SAndroid Build Coastguard Worker
512*523fa7a6SAndroid Build Coastguard Worker Error err = populate_operator_name(op, kTempBufferSizeForName, operator_name);
513*523fa7a6SAndroid Build Coastguard Worker if (err != Error::Ok) {
514*523fa7a6SAndroid Build Coastguard Worker return err;
515*523fa7a6SAndroid Build Coastguard Worker }
516*523fa7a6SAndroid Build Coastguard Worker
517*523fa7a6SAndroid Build Coastguard Worker // resolve tensor meta
518*523fa7a6SAndroid Build Coastguard Worker auto method_allocator = memory_manager_->method_allocator();
519*523fa7a6SAndroid Build Coastguard Worker TensorMeta* meta = method_allocator->allocateList<TensorMeta>(n_args);
520*523fa7a6SAndroid Build Coastguard Worker if (meta == nullptr) {
521*523fa7a6SAndroid Build Coastguard Worker return Error::MemoryAllocationFailed;
522*523fa7a6SAndroid Build Coastguard Worker }
523*523fa7a6SAndroid Build Coastguard Worker
524*523fa7a6SAndroid Build Coastguard Worker size_t count = 0;
525*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < n_args; i++) {
526*523fa7a6SAndroid Build Coastguard Worker EValue* eval = args[i];
527*523fa7a6SAndroid Build Coastguard Worker // handle tensor list as well
528*523fa7a6SAndroid Build Coastguard Worker if (eval->isTensor()) {
529*523fa7a6SAndroid Build Coastguard Worker auto tensor = eval->toTensor();
530*523fa7a6SAndroid Build Coastguard Worker meta[count].dtype_ = tensor.scalar_type();
531*523fa7a6SAndroid Build Coastguard Worker exec_aten::DimOrderType* dim_order_ptr =
532*523fa7a6SAndroid Build Coastguard Worker method_allocator->allocateList<exec_aten::DimOrderType>(tensor.dim());
533*523fa7a6SAndroid Build Coastguard Worker if (dim_order_ptr == nullptr) {
534*523fa7a6SAndroid Build Coastguard Worker return Error::MemoryAllocationFailed;
535*523fa7a6SAndroid Build Coastguard Worker }
536*523fa7a6SAndroid Build Coastguard Worker size_t size = tensor.dim();
537*523fa7a6SAndroid Build Coastguard Worker err = get_dim_order(tensor, dim_order_ptr, size);
538*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
539*523fa7a6SAndroid Build Coastguard Worker err == Error::Ok,
540*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
541*523fa7a6SAndroid Build Coastguard Worker "Error setting dim_order %zu: 0x%" PRIx32,
542*523fa7a6SAndroid Build Coastguard Worker i,
543*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(err));
544*523fa7a6SAndroid Build Coastguard Worker meta[count].dim_order_ =
545*523fa7a6SAndroid Build Coastguard Worker Span<exec_aten::DimOrderType>(dim_order_ptr, size);
546*523fa7a6SAndroid Build Coastguard Worker count++;
547*523fa7a6SAndroid Build Coastguard Worker }
548*523fa7a6SAndroid Build Coastguard Worker }
549*523fa7a6SAndroid Build Coastguard Worker
550*523fa7a6SAndroid Build Coastguard Worker // Find a kernel with the matching name and tensor meta.
551*523fa7a6SAndroid Build Coastguard Worker Result<OpFunction> op_function =
552*523fa7a6SAndroid Build Coastguard Worker get_op_function_from_registry(operator_name, {meta, count});
553*523fa7a6SAndroid Build Coastguard Worker if (!op_function.ok()) {
554*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Error, "Missing operator: [%d] %s", op_index, operator_name);
555*523fa7a6SAndroid Build Coastguard Worker return op_function.error();
556*523fa7a6SAndroid Build Coastguard Worker }
557*523fa7a6SAndroid Build Coastguard Worker kernels[kernel_index] = op_function.get();
558*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
559*523fa7a6SAndroid Build Coastguard Worker }
560*523fa7a6SAndroid Build Coastguard Worker
load(executorch_flatbuffer::ExecutionPlan * s_plan,const Program * program,MemoryManager * memory_manager,EventTracer * event_tracer)561*523fa7a6SAndroid Build Coastguard Worker Result<Method> Method::load(
562*523fa7a6SAndroid Build Coastguard Worker executorch_flatbuffer::ExecutionPlan* s_plan,
563*523fa7a6SAndroid Build Coastguard Worker const Program* program,
564*523fa7a6SAndroid Build Coastguard Worker MemoryManager* memory_manager,
565*523fa7a6SAndroid Build Coastguard Worker EventTracer* event_tracer) {
566*523fa7a6SAndroid Build Coastguard Worker MemoryAllocator* temp_allocator = memory_manager->temp_allocator();
567*523fa7a6SAndroid Build Coastguard Worker if (temp_allocator == nullptr) {
568*523fa7a6SAndroid Build Coastguard Worker PlatformMemoryAllocator* platform_allocator =
569*523fa7a6SAndroid Build Coastguard Worker memory_manager->method_allocator()
570*523fa7a6SAndroid Build Coastguard Worker ->allocateInstance<PlatformMemoryAllocator>();
571*523fa7a6SAndroid Build Coastguard Worker if (platform_allocator == nullptr) {
572*523fa7a6SAndroid Build Coastguard Worker return Error::MemoryAllocationFailed;
573*523fa7a6SAndroid Build Coastguard Worker }
574*523fa7a6SAndroid Build Coastguard Worker new (platform_allocator) PlatformMemoryAllocator();
575*523fa7a6SAndroid Build Coastguard Worker temp_allocator = platform_allocator;
576*523fa7a6SAndroid Build Coastguard Worker }
577*523fa7a6SAndroid Build Coastguard Worker Method method(program, memory_manager, event_tracer, temp_allocator);
578*523fa7a6SAndroid Build Coastguard Worker
579*523fa7a6SAndroid Build Coastguard Worker Error err = method.init(s_plan);
580*523fa7a6SAndroid Build Coastguard Worker if (err != Error::Ok) {
581*523fa7a6SAndroid Build Coastguard Worker return err;
582*523fa7a6SAndroid Build Coastguard Worker } else {
583*523fa7a6SAndroid Build Coastguard Worker ET_CHECK(method.initialized());
584*523fa7a6SAndroid Build Coastguard Worker return method;
585*523fa7a6SAndroid Build Coastguard Worker }
586*523fa7a6SAndroid Build Coastguard Worker }
587*523fa7a6SAndroid Build Coastguard Worker
init(executorch_flatbuffer::ExecutionPlan * s_plan)588*523fa7a6SAndroid Build Coastguard Worker Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
589*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_SCOPE_PROF("Method::init");
590*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileMethodScope event_tracer_profile_scope =
591*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileMethodScope(event_tracer_, "Method::init");
592*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
593*523fa7a6SAndroid Build Coastguard Worker // Don't use !initialized() here because we also want to fail on the
594*523fa7a6SAndroid Build Coastguard Worker // InitializationFailed state.
595*523fa7a6SAndroid Build Coastguard Worker init_state_ == InitializationState::Uninitialized,
596*523fa7a6SAndroid Build Coastguard Worker InvalidState,
597*523fa7a6SAndroid Build Coastguard Worker "Method already initialized, or previously failed to initialize.");
598*523fa7a6SAndroid Build Coastguard Worker init_state_ =
599*523fa7a6SAndroid Build Coastguard Worker InitializationState::InitializationFailed; // Until proven otherwise
600*523fa7a6SAndroid Build Coastguard Worker serialization_plan_ = s_plan;
601*523fa7a6SAndroid Build Coastguard Worker auto method_allocator = memory_manager_->method_allocator();
602*523fa7a6SAndroid Build Coastguard Worker
603*523fa7a6SAndroid Build Coastguard Worker {
604*523fa7a6SAndroid Build Coastguard Worker // Parse the elements of the values_ array.
605*523fa7a6SAndroid Build Coastguard Worker Error err = parse_values();
606*523fa7a6SAndroid Build Coastguard Worker if (err != Error::Ok) {
607*523fa7a6SAndroid Build Coastguard Worker return err;
608*523fa7a6SAndroid Build Coastguard Worker }
609*523fa7a6SAndroid Build Coastguard Worker }
610*523fa7a6SAndroid Build Coastguard Worker
611*523fa7a6SAndroid Build Coastguard Worker {
612*523fa7a6SAndroid Build Coastguard Worker // Resolve delegates
613*523fa7a6SAndroid Build Coastguard Worker const auto delegates = serialization_plan_->delegates();
614*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
615*523fa7a6SAndroid Build Coastguard Worker delegates != nullptr, InvalidProgram, "Missing delegates field");
616*523fa7a6SAndroid Build Coastguard Worker size_t n_delegate = delegates->size();
617*523fa7a6SAndroid Build Coastguard Worker delegates_ = method_allocator->allocateList<BackendDelegate>(n_delegate);
618*523fa7a6SAndroid Build Coastguard Worker if (delegates_ == nullptr) {
619*523fa7a6SAndroid Build Coastguard Worker return Error::MemoryAllocationFailed;
620*523fa7a6SAndroid Build Coastguard Worker }
621*523fa7a6SAndroid Build Coastguard Worker
622*523fa7a6SAndroid Build Coastguard Worker // n_delegate_ counts the number of successfully-initialized delegates for
623*523fa7a6SAndroid Build Coastguard Worker // ~Method() to clean up, and is incremented at the bottom of the loop. This
624*523fa7a6SAndroid Build Coastguard Worker // makes it safe for errors to return without updating any state.
625*523fa7a6SAndroid Build Coastguard Worker n_delegate_ = 0;
626*523fa7a6SAndroid Build Coastguard Worker
627*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < n_delegate; ++i) {
628*523fa7a6SAndroid Build Coastguard Worker const auto& delegate = *delegates->Get(i);
629*523fa7a6SAndroid Build Coastguard Worker BackendInitContext backend_init_context(
630*523fa7a6SAndroid Build Coastguard Worker method_allocator,
631*523fa7a6SAndroid Build Coastguard Worker /*method_name=*/serialization_plan_->name()->c_str());
632*523fa7a6SAndroid Build Coastguard Worker Error err = BackendDelegate::Init(
633*523fa7a6SAndroid Build Coastguard Worker delegate, program_, backend_init_context, &delegates_[i]);
634*523fa7a6SAndroid Build Coastguard Worker if (err != Error::Ok) {
635*523fa7a6SAndroid Build Coastguard Worker return err;
636*523fa7a6SAndroid Build Coastguard Worker }
637*523fa7a6SAndroid Build Coastguard Worker // ~Method() will try to clean up n_delegate_ entries in the delegates_
638*523fa7a6SAndroid Build Coastguard Worker // array. Only increment this once we know the entry is valid, so that
639*523fa7a6SAndroid Build Coastguard Worker // we don't try to clean up an uninitialized entry.
640*523fa7a6SAndroid Build Coastguard Worker n_delegate_ = i + 1;
641*523fa7a6SAndroid Build Coastguard Worker }
642*523fa7a6SAndroid Build Coastguard Worker }
643*523fa7a6SAndroid Build Coastguard Worker
644*523fa7a6SAndroid Build Coastguard Worker {
645*523fa7a6SAndroid Build Coastguard Worker // Load chains
646*523fa7a6SAndroid Build Coastguard Worker const auto chains = serialization_plan_->chains();
647*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
648*523fa7a6SAndroid Build Coastguard Worker chains != nullptr && chains->size() > 0, InvalidProgram, "No chains");
649*523fa7a6SAndroid Build Coastguard Worker n_chains_ = chains->size();
650*523fa7a6SAndroid Build Coastguard Worker chains_ = method_allocator->allocateList<Chain>(n_chains_);
651*523fa7a6SAndroid Build Coastguard Worker if (chains_ == nullptr) {
652*523fa7a6SAndroid Build Coastguard Worker return Error::MemoryAllocationFailed;
653*523fa7a6SAndroid Build Coastguard Worker }
654*523fa7a6SAndroid Build Coastguard Worker
655*523fa7a6SAndroid Build Coastguard Worker // Try resolving all operators before failing, to make it easier to debug
656*523fa7a6SAndroid Build Coastguard Worker // multiple problems at once.
657*523fa7a6SAndroid Build Coastguard Worker Error delayed_error = Error::Ok;
658*523fa7a6SAndroid Build Coastguard Worker int32_t num_instructions_missing_op = 0;
659*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < n_chains_; ++i) {
660*523fa7a6SAndroid Build Coastguard Worker auto s_chain = chains->Get(i);
661*523fa7a6SAndroid Build Coastguard Worker auto s_instructions = s_chain->instructions();
662*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
663*523fa7a6SAndroid Build Coastguard Worker s_instructions != nullptr,
664*523fa7a6SAndroid Build Coastguard Worker InvalidProgram,
665*523fa7a6SAndroid Build Coastguard Worker "Missing instructions in chain %zu",
666*523fa7a6SAndroid Build Coastguard Worker i);
667*523fa7a6SAndroid Build Coastguard Worker auto num_instructions = s_instructions->size();
668*523fa7a6SAndroid Build Coastguard Worker auto chain_instruction_kernels =
669*523fa7a6SAndroid Build Coastguard Worker method_allocator->allocateList<OpFunction>(num_instructions);
670*523fa7a6SAndroid Build Coastguard Worker if (chain_instruction_kernels == nullptr) {
671*523fa7a6SAndroid Build Coastguard Worker return Error::MemoryAllocationFailed;
672*523fa7a6SAndroid Build Coastguard Worker }
673*523fa7a6SAndroid Build Coastguard Worker auto chain_instruction_arg_lists =
674*523fa7a6SAndroid Build Coastguard Worker method_allocator->allocateList<InstructionArgs>(num_instructions);
675*523fa7a6SAndroid Build Coastguard Worker if (chain_instruction_arg_lists == nullptr) {
676*523fa7a6SAndroid Build Coastguard Worker return Error::MemoryAllocationFailed;
677*523fa7a6SAndroid Build Coastguard Worker }
678*523fa7a6SAndroid Build Coastguard Worker
679*523fa7a6SAndroid Build Coastguard Worker // Set up the argument lists ahead of time and store pointers to them to
680*523fa7a6SAndroid Build Coastguard Worker // use when the instructions are called
681*523fa7a6SAndroid Build Coastguard Worker for (size_t instr_idx = 0; instr_idx < s_instructions->size();
682*523fa7a6SAndroid Build Coastguard Worker ++instr_idx) {
683*523fa7a6SAndroid Build Coastguard Worker const auto instruction = s_instructions->Get(instr_idx);
684*523fa7a6SAndroid Build Coastguard Worker // Ensure that the `instr_args_as_X()` calls will return non-null.
685*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
686*523fa7a6SAndroid Build Coastguard Worker instruction != nullptr && instruction->instr_args() != nullptr,
687*523fa7a6SAndroid Build Coastguard Worker InvalidProgram,
688*523fa7a6SAndroid Build Coastguard Worker "Null instruction at index %zu",
689*523fa7a6SAndroid Build Coastguard Worker instr_idx);
690*523fa7a6SAndroid Build Coastguard Worker
691*523fa7a6SAndroid Build Coastguard Worker switch (instruction->instr_args_type()) {
692*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::InstructionArguments::KernelCall: {
693*523fa7a6SAndroid Build Coastguard Worker const auto arg_idxs =
694*523fa7a6SAndroid Build Coastguard Worker instruction->instr_args_as_KernelCall()->args();
695*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
696*523fa7a6SAndroid Build Coastguard Worker arg_idxs != nullptr, InvalidProgram, "KernelCall args missing");
697*523fa7a6SAndroid Build Coastguard Worker auto res = gen_instruction_arguments(
698*523fa7a6SAndroid Build Coastguard Worker method_allocator,
699*523fa7a6SAndroid Build Coastguard Worker n_value_,
700*523fa7a6SAndroid Build Coastguard Worker values_,
701*523fa7a6SAndroid Build Coastguard Worker arg_idxs->size(),
702*523fa7a6SAndroid Build Coastguard Worker arg_idxs->data());
703*523fa7a6SAndroid Build Coastguard Worker if (!res.ok()) {
704*523fa7a6SAndroid Build Coastguard Worker return res.error();
705*523fa7a6SAndroid Build Coastguard Worker }
706*523fa7a6SAndroid Build Coastguard Worker chain_instruction_arg_lists[instr_idx] = res.get();
707*523fa7a6SAndroid Build Coastguard Worker auto err = resolve_operator(
708*523fa7a6SAndroid Build Coastguard Worker instruction->instr_args_as_KernelCall()->op_index(),
709*523fa7a6SAndroid Build Coastguard Worker chain_instruction_kernels,
710*523fa7a6SAndroid Build Coastguard Worker instr_idx,
711*523fa7a6SAndroid Build Coastguard Worker res.get(),
712*523fa7a6SAndroid Build Coastguard Worker arg_idxs->size());
713*523fa7a6SAndroid Build Coastguard Worker if (err == Error::OperatorMissing) {
714*523fa7a6SAndroid Build Coastguard Worker num_instructions_missing_op++;
715*523fa7a6SAndroid Build Coastguard Worker } else if (err == Error::MemoryAllocationFailed) {
716*523fa7a6SAndroid Build Coastguard Worker return err;
717*523fa7a6SAndroid Build Coastguard Worker } else {
718*523fa7a6SAndroid Build Coastguard Worker delayed_error = err;
719*523fa7a6SAndroid Build Coastguard Worker }
720*523fa7a6SAndroid Build Coastguard Worker } break;
721*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::InstructionArguments::DelegateCall: {
722*523fa7a6SAndroid Build Coastguard Worker const auto arg_idxs =
723*523fa7a6SAndroid Build Coastguard Worker instruction->instr_args_as_DelegateCall()->args();
724*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
725*523fa7a6SAndroid Build Coastguard Worker arg_idxs != nullptr,
726*523fa7a6SAndroid Build Coastguard Worker InvalidProgram,
727*523fa7a6SAndroid Build Coastguard Worker "DelegateCall args missing");
728*523fa7a6SAndroid Build Coastguard Worker auto res = gen_instruction_arguments(
729*523fa7a6SAndroid Build Coastguard Worker method_allocator,
730*523fa7a6SAndroid Build Coastguard Worker n_value_,
731*523fa7a6SAndroid Build Coastguard Worker values_,
732*523fa7a6SAndroid Build Coastguard Worker arg_idxs->size(),
733*523fa7a6SAndroid Build Coastguard Worker arg_idxs->data());
734*523fa7a6SAndroid Build Coastguard Worker if (!res.ok()) {
735*523fa7a6SAndroid Build Coastguard Worker return res.error();
736*523fa7a6SAndroid Build Coastguard Worker }
737*523fa7a6SAndroid Build Coastguard Worker chain_instruction_arg_lists[instr_idx] = res.get();
738*523fa7a6SAndroid Build Coastguard Worker } break;
739*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::InstructionArguments::JumpFalseCall: {
740*523fa7a6SAndroid Build Coastguard Worker // Validate the index at load time so we can trust it during
741*523fa7a6SAndroid Build Coastguard Worker // execution.
742*523fa7a6SAndroid Build Coastguard Worker auto index =
743*523fa7a6SAndroid Build Coastguard Worker instruction->instr_args_as_JumpFalseCall()->cond_value_index();
744*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
745*523fa7a6SAndroid Build Coastguard Worker index >= 0 && index < n_value_,
746*523fa7a6SAndroid Build Coastguard Worker InvalidProgram,
747*523fa7a6SAndroid Build Coastguard Worker "Index %d negative or >= %zu",
748*523fa7a6SAndroid Build Coastguard Worker index,
749*523fa7a6SAndroid Build Coastguard Worker n_value_);
750*523fa7a6SAndroid Build Coastguard Worker chain_instruction_arg_lists[instr_idx] = InstructionArgs();
751*523fa7a6SAndroid Build Coastguard Worker } break;
752*523fa7a6SAndroid Build Coastguard Worker default: {
753*523fa7a6SAndroid Build Coastguard Worker chain_instruction_arg_lists[instr_idx] = InstructionArgs();
754*523fa7a6SAndroid Build Coastguard Worker } break;
755*523fa7a6SAndroid Build Coastguard Worker }
756*523fa7a6SAndroid Build Coastguard Worker }
757*523fa7a6SAndroid Build Coastguard Worker chains_[i] = Chain{
758*523fa7a6SAndroid Build Coastguard Worker s_chain,
759*523fa7a6SAndroid Build Coastguard Worker Span<InstructionArgs>(chain_instruction_arg_lists, num_instructions),
760*523fa7a6SAndroid Build Coastguard Worker chain_instruction_kernels,
761*523fa7a6SAndroid Build Coastguard Worker };
762*523fa7a6SAndroid Build Coastguard Worker }
763*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
764*523fa7a6SAndroid Build Coastguard Worker num_instructions_missing_op == 0,
765*523fa7a6SAndroid Build Coastguard Worker OperatorMissing,
766*523fa7a6SAndroid Build Coastguard Worker "There are %d instructions don't have corresponding operator registered. "
767*523fa7a6SAndroid Build Coastguard Worker "See logs for details",
768*523fa7a6SAndroid Build Coastguard Worker num_instructions_missing_op);
769*523fa7a6SAndroid Build Coastguard Worker if (delayed_error != Error::Ok) {
770*523fa7a6SAndroid Build Coastguard Worker return delayed_error;
771*523fa7a6SAndroid Build Coastguard Worker }
772*523fa7a6SAndroid Build Coastguard Worker }
773*523fa7a6SAndroid Build Coastguard Worker
774*523fa7a6SAndroid Build Coastguard Worker step_state_ = StepState{0, 0};
775*523fa7a6SAndroid Build Coastguard Worker
776*523fa7a6SAndroid Build Coastguard Worker init_state_ = InitializationState::Initialized;
777*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
778*523fa7a6SAndroid Build Coastguard Worker }
779*523fa7a6SAndroid Build Coastguard Worker
780*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD Error
set_input(const EValue & input_evalue,size_t input_idx)781*523fa7a6SAndroid Build Coastguard Worker Method::set_input(const EValue& input_evalue, size_t input_idx) {
782*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
783*523fa7a6SAndroid Build Coastguard Worker initialized(),
784*523fa7a6SAndroid Build Coastguard Worker InvalidState,
785*523fa7a6SAndroid Build Coastguard Worker "Input can not be set until method has been initialized.");
786*523fa7a6SAndroid Build Coastguard Worker
787*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
788*523fa7a6SAndroid Build Coastguard Worker step_state_.instr_idx == 0 && step_state_.chain_idx == 0,
789*523fa7a6SAndroid Build Coastguard Worker InvalidState,
790*523fa7a6SAndroid Build Coastguard Worker "Inputs can not be set mid execution.");
791*523fa7a6SAndroid Build Coastguard Worker
792*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
793*523fa7a6SAndroid Build Coastguard Worker input_idx < inputs_size(),
794*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
795*523fa7a6SAndroid Build Coastguard Worker "Given input index must be less than the number of inputs in method, but got %zu and %zu",
796*523fa7a6SAndroid Build Coastguard Worker input_idx,
797*523fa7a6SAndroid Build Coastguard Worker inputs_size());
798*523fa7a6SAndroid Build Coastguard Worker
799*523fa7a6SAndroid Build Coastguard Worker const auto& e = get_value(get_input_index(input_idx));
800*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
801*523fa7a6SAndroid Build Coastguard Worker e.isTensor() || e.isScalar(),
802*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
803*523fa7a6SAndroid Build Coastguard Worker "The %zu-th input in method is expected Tensor or prim, but received %" PRIu32,
804*523fa7a6SAndroid Build Coastguard Worker input_idx,
805*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(e.tag));
806*523fa7a6SAndroid Build Coastguard Worker
807*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
808*523fa7a6SAndroid Build Coastguard Worker e.tag == input_evalue.tag,
809*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
810*523fa7a6SAndroid Build Coastguard Worker "The %zu-th input of method should have the same type as the input_evalue, but get tag %" PRIu32
811*523fa7a6SAndroid Build Coastguard Worker " and tag %" PRIu32,
812*523fa7a6SAndroid Build Coastguard Worker input_idx,
813*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(e.tag),
814*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(input_evalue.tag));
815*523fa7a6SAndroid Build Coastguard Worker
816*523fa7a6SAndroid Build Coastguard Worker if (e.isTensor()) {
817*523fa7a6SAndroid Build Coastguard Worker const auto& t_dst = e.toTensor();
818*523fa7a6SAndroid Build Coastguard Worker const auto& t_src = input_evalue.toTensor();
819*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
820*523fa7a6SAndroid Build Coastguard Worker t_dst.scalar_type() == t_src.scalar_type(),
821*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
822*523fa7a6SAndroid Build Coastguard Worker "The %zu-th input tensor's scalartype does not meet requirement: found %" PRId8
823*523fa7a6SAndroid Build Coastguard Worker " but expected %" PRId8,
824*523fa7a6SAndroid Build Coastguard Worker input_idx,
825*523fa7a6SAndroid Build Coastguard Worker static_cast<int8_t>(t_src.scalar_type()),
826*523fa7a6SAndroid Build Coastguard Worker static_cast<int8_t>(t_dst.scalar_type()));
827*523fa7a6SAndroid Build Coastguard Worker // Reset the shape for the Method's input as the size of forwarded input
828*523fa7a6SAndroid Build Coastguard Worker // tensor for shape dynamism. Also is a safety check if need memcpy.
829*523fa7a6SAndroid Build Coastguard Worker Error err = resize_tensor(t_dst, t_src.sizes());
830*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
831*523fa7a6SAndroid Build Coastguard Worker err == Error::Ok,
832*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
833*523fa7a6SAndroid Build Coastguard Worker "Error setting input %zu: 0x%" PRIx32,
834*523fa7a6SAndroid Build Coastguard Worker input_idx,
835*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(err));
836*523fa7a6SAndroid Build Coastguard Worker Error error;
837*523fa7a6SAndroid Build Coastguard Worker auto tensor_meta = this->method_meta().input_tensor_meta(input_idx);
838*523fa7a6SAndroid Build Coastguard Worker if (tensor_meta->is_memory_planned()) {
839*523fa7a6SAndroid Build Coastguard Worker error = internal::copy_tensor_data(t_dst, t_src);
840*523fa7a6SAndroid Build Coastguard Worker } else {
841*523fa7a6SAndroid Build Coastguard Worker error = internal::share_tensor_data(t_dst, t_src);
842*523fa7a6SAndroid Build Coastguard Worker }
843*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
844*523fa7a6SAndroid Build Coastguard Worker error == Error::Ok,
845*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
846*523fa7a6SAndroid Build Coastguard Worker "Error setting data_ptr %zu: 0x%" PRIx32,
847*523fa7a6SAndroid Build Coastguard Worker input_idx,
848*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(error));
849*523fa7a6SAndroid Build Coastguard Worker // Prims have to be the same as what was traced
850*523fa7a6SAndroid Build Coastguard Worker } else if (e.isInt()) {
851*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
852*523fa7a6SAndroid Build Coastguard Worker e.toInt() == input_evalue.toInt(),
853*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
854*523fa7a6SAndroid Build Coastguard Worker "The %zu-th input of method should have the same value as the input_evalue, but got %" PRId64
855*523fa7a6SAndroid Build Coastguard Worker " and %" PRId64,
856*523fa7a6SAndroid Build Coastguard Worker input_idx,
857*523fa7a6SAndroid Build Coastguard Worker e.toInt(),
858*523fa7a6SAndroid Build Coastguard Worker input_evalue.toInt());
859*523fa7a6SAndroid Build Coastguard Worker } else if (e.isBool()) {
860*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
861*523fa7a6SAndroid Build Coastguard Worker e.toBool() == input_evalue.toBool(),
862*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
863*523fa7a6SAndroid Build Coastguard Worker "The %zu-th input of method should have the same value as the input_evalue, but got %" PRId64
864*523fa7a6SAndroid Build Coastguard Worker " and %" PRId64,
865*523fa7a6SAndroid Build Coastguard Worker input_idx,
866*523fa7a6SAndroid Build Coastguard Worker (int64_t)e.toBool(),
867*523fa7a6SAndroid Build Coastguard Worker (int64_t)input_evalue.toBool());
868*523fa7a6SAndroid Build Coastguard Worker } else if (e.isDouble()) {
869*523fa7a6SAndroid Build Coastguard Worker double lhs = input_evalue.toDouble();
870*523fa7a6SAndroid Build Coastguard Worker double rhs = e.toDouble();
871*523fa7a6SAndroid Build Coastguard Worker double atol = 1e-4;
872*523fa7a6SAndroid Build Coastguard Worker double rtol = 1e-5;
873*523fa7a6SAndroid Build Coastguard Worker bool is_equal = true;
874*523fa7a6SAndroid Build Coastguard Worker if (std::isnan(lhs) && std::isnan(rhs)) {
875*523fa7a6SAndroid Build Coastguard Worker // NaN == NaN
876*523fa7a6SAndroid Build Coastguard Worker } else if (
877*523fa7a6SAndroid Build Coastguard Worker !std::isfinite(lhs) && !std::isfinite(rhs) &&
878*523fa7a6SAndroid Build Coastguard Worker ((lhs > 0) == (rhs > 0))) {
879*523fa7a6SAndroid Build Coastguard Worker // -Inf == -Inf
880*523fa7a6SAndroid Build Coastguard Worker // +Inf == +Inf
881*523fa7a6SAndroid Build Coastguard Worker } else {
882*523fa7a6SAndroid Build Coastguard Worker auto allowed_error = atol + std::abs(rtol * rhs);
883*523fa7a6SAndroid Build Coastguard Worker auto actual_error = std::abs(lhs - rhs);
884*523fa7a6SAndroid Build Coastguard Worker if (!std::isfinite(actual_error) || actual_error > allowed_error) {
885*523fa7a6SAndroid Build Coastguard Worker is_equal = false;
886*523fa7a6SAndroid Build Coastguard Worker }
887*523fa7a6SAndroid Build Coastguard Worker }
888*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
889*523fa7a6SAndroid Build Coastguard Worker is_equal,
890*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
891*523fa7a6SAndroid Build Coastguard Worker "The %zu-th input of method should have the same value as the input_evalue, but get %f and %f",
892*523fa7a6SAndroid Build Coastguard Worker input_idx,
893*523fa7a6SAndroid Build Coastguard Worker lhs,
894*523fa7a6SAndroid Build Coastguard Worker rhs);
895*523fa7a6SAndroid Build Coastguard Worker } else if (e.isString()) {
896*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
897*523fa7a6SAndroid Build Coastguard Worker e.toString() == input_evalue.toString(),
898*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
899*523fa7a6SAndroid Build Coastguard Worker "The %zu-th input of method should have the same value as the input_evalue, but get %s and %s",
900*523fa7a6SAndroid Build Coastguard Worker input_idx,
901*523fa7a6SAndroid Build Coastguard Worker e.toString().data(),
902*523fa7a6SAndroid Build Coastguard Worker input_evalue.toString().data());
903*523fa7a6SAndroid Build Coastguard Worker } else {
904*523fa7a6SAndroid Build Coastguard Worker ET_LOG(Error, "Unsupported input type: %d", (int32_t)e.tag);
905*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidArgument;
906*523fa7a6SAndroid Build Coastguard Worker }
907*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
908*523fa7a6SAndroid Build Coastguard Worker }
909*523fa7a6SAndroid Build Coastguard Worker
910*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD Error
set_inputs(const exec_aten::ArrayRef<EValue> & input_evalues)911*523fa7a6SAndroid Build Coastguard Worker Method::set_inputs(const exec_aten::ArrayRef<EValue>& input_evalues) {
912*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
913*523fa7a6SAndroid Build Coastguard Worker initialized(),
914*523fa7a6SAndroid Build Coastguard Worker InvalidState,
915*523fa7a6SAndroid Build Coastguard Worker "Inputs can not be set until method has been initialized.");
916*523fa7a6SAndroid Build Coastguard Worker
917*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
918*523fa7a6SAndroid Build Coastguard Worker step_state_.instr_idx == 0 && step_state_.chain_idx == 0,
919*523fa7a6SAndroid Build Coastguard Worker InvalidState,
920*523fa7a6SAndroid Build Coastguard Worker "Inputs can not be set mid execution.");
921*523fa7a6SAndroid Build Coastguard Worker
922*523fa7a6SAndroid Build Coastguard Worker size_t input_size = inputs_size();
923*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
924*523fa7a6SAndroid Build Coastguard Worker input_size == input_evalues.size(),
925*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
926*523fa7a6SAndroid Build Coastguard Worker "The length of given input array (%zu) must be same as the number of inputs in method (%zu).",
927*523fa7a6SAndroid Build Coastguard Worker input_evalues.size(),
928*523fa7a6SAndroid Build Coastguard Worker input_size);
929*523fa7a6SAndroid Build Coastguard Worker
930*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < input_size; i++) {
931*523fa7a6SAndroid Build Coastguard Worker Error status = set_input(input_evalues[i], i);
932*523fa7a6SAndroid Build Coastguard Worker if (status != Error::Ok) {
933*523fa7a6SAndroid Build Coastguard Worker return status;
934*523fa7a6SAndroid Build Coastguard Worker }
935*523fa7a6SAndroid Build Coastguard Worker }
936*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
937*523fa7a6SAndroid Build Coastguard Worker }
938*523fa7a6SAndroid Build Coastguard Worker
939*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD Error
set_output_data_ptr(void * buffer,size_t size,size_t output_idx)940*523fa7a6SAndroid Build Coastguard Worker Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) {
941*523fa7a6SAndroid Build Coastguard Worker // Check method state
942*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
943*523fa7a6SAndroid Build Coastguard Worker initialized(),
944*523fa7a6SAndroid Build Coastguard Worker InvalidState,
945*523fa7a6SAndroid Build Coastguard Worker "Outputs can not be retrieved until method has been initialized.");
946*523fa7a6SAndroid Build Coastguard Worker
947*523fa7a6SAndroid Build Coastguard Worker // Check the args
948*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
949*523fa7a6SAndroid Build Coastguard Worker output_idx < outputs_size(),
950*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
951*523fa7a6SAndroid Build Coastguard Worker "output_idx: %zu > num_outputs: %zu",
952*523fa7a6SAndroid Build Coastguard Worker output_idx,
953*523fa7a6SAndroid Build Coastguard Worker outputs_size());
954*523fa7a6SAndroid Build Coastguard Worker
955*523fa7a6SAndroid Build Coastguard Worker auto& output = mutable_value(get_output_index(output_idx));
956*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
957*523fa7a6SAndroid Build Coastguard Worker output.isTensor(),
958*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
959*523fa7a6SAndroid Build Coastguard Worker "output type: %zu is not tensor",
960*523fa7a6SAndroid Build Coastguard Worker (size_t)output.tag);
961*523fa7a6SAndroid Build Coastguard Worker
962*523fa7a6SAndroid Build Coastguard Worker auto tensor_meta = this->method_meta().output_tensor_meta(output_idx);
963*523fa7a6SAndroid Build Coastguard Worker if (tensor_meta->is_memory_planned()) {
964*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
965*523fa7a6SAndroid Build Coastguard Worker Error,
966*523fa7a6SAndroid Build Coastguard Worker "Output %zu is memory planned, or is a constant. Cannot override "
967*523fa7a6SAndroid Build Coastguard Worker "the existing data pointer.",
968*523fa7a6SAndroid Build Coastguard Worker output_idx);
969*523fa7a6SAndroid Build Coastguard Worker return Error::InvalidState;
970*523fa7a6SAndroid Build Coastguard Worker }
971*523fa7a6SAndroid Build Coastguard Worker
972*523fa7a6SAndroid Build Coastguard Worker auto& t = output.toTensor();
973*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
974*523fa7a6SAndroid Build Coastguard Worker output.isTensor(),
975*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
976*523fa7a6SAndroid Build Coastguard Worker "output type: %zu is not tensor",
977*523fa7a6SAndroid Build Coastguard Worker (size_t)output.tag);
978*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
979*523fa7a6SAndroid Build Coastguard Worker t.nbytes() <= size,
980*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
981*523fa7a6SAndroid Build Coastguard Worker "buffer size: %zu is smaller then expected tensor size: %zu",
982*523fa7a6SAndroid Build Coastguard Worker size,
983*523fa7a6SAndroid Build Coastguard Worker t.nbytes());
984*523fa7a6SAndroid Build Coastguard Worker
985*523fa7a6SAndroid Build Coastguard Worker // Set data
986*523fa7a6SAndroid Build Coastguard Worker return internal::set_tensor_data(t, buffer, size);
987*523fa7a6SAndroid Build Coastguard Worker }
988*523fa7a6SAndroid Build Coastguard Worker
get_outputs(EValue * output_evalues,size_t length)989*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD Error Method::get_outputs(EValue* output_evalues, size_t length) {
990*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
991*523fa7a6SAndroid Build Coastguard Worker initialized(),
992*523fa7a6SAndroid Build Coastguard Worker InvalidState,
993*523fa7a6SAndroid Build Coastguard Worker "Outputs can not be retrieved until method has been initialized.");
994*523fa7a6SAndroid Build Coastguard Worker
995*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
996*523fa7a6SAndroid Build Coastguard Worker length >= outputs_size(),
997*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
998*523fa7a6SAndroid Build Coastguard Worker "The given array is not large enough to hold all outputs.");
999*523fa7a6SAndroid Build Coastguard Worker
1000*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < outputs_size(); i++) {
1001*523fa7a6SAndroid Build Coastguard Worker output_evalues[i] = values_[get_output_index(i)];
1002*523fa7a6SAndroid Build Coastguard Worker }
1003*523fa7a6SAndroid Build Coastguard Worker
1004*523fa7a6SAndroid Build Coastguard Worker for (size_t i = outputs_size(); i < length; i++) {
1005*523fa7a6SAndroid Build Coastguard Worker output_evalues[i] = EValue();
1006*523fa7a6SAndroid Build Coastguard Worker }
1007*523fa7a6SAndroid Build Coastguard Worker
1008*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
1009*523fa7a6SAndroid Build Coastguard Worker }
1010*523fa7a6SAndroid Build Coastguard Worker
get_inputs(EValue * input_evalues,size_t length)1011*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD Error Method::get_inputs(EValue* input_evalues, size_t length) {
1012*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
1013*523fa7a6SAndroid Build Coastguard Worker initialized(),
1014*523fa7a6SAndroid Build Coastguard Worker InvalidState,
1015*523fa7a6SAndroid Build Coastguard Worker "Inputs can not be retrieved until method has been initialized.");
1016*523fa7a6SAndroid Build Coastguard Worker
1017*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
1018*523fa7a6SAndroid Build Coastguard Worker length >= inputs_size(),
1019*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
1020*523fa7a6SAndroid Build Coastguard Worker "The given array is not large enough to hold all inputs.");
1021*523fa7a6SAndroid Build Coastguard Worker
1022*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < inputs_size(); i++) {
1023*523fa7a6SAndroid Build Coastguard Worker input_evalues[i] = values_[get_input_index(i)];
1024*523fa7a6SAndroid Build Coastguard Worker }
1025*523fa7a6SAndroid Build Coastguard Worker
1026*523fa7a6SAndroid Build Coastguard Worker for (size_t i = inputs_size(); i < length; i++) {
1027*523fa7a6SAndroid Build Coastguard Worker input_evalues[i] = EValue();
1028*523fa7a6SAndroid Build Coastguard Worker }
1029*523fa7a6SAndroid Build Coastguard Worker
1030*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
1031*523fa7a6SAndroid Build Coastguard Worker }
1032*523fa7a6SAndroid Build Coastguard Worker
execute_instruction()1033*523fa7a6SAndroid Build Coastguard Worker Error Method::execute_instruction() {
1034*523fa7a6SAndroid Build Coastguard Worker auto& chain = chains_[step_state_.chain_idx];
1035*523fa7a6SAndroid Build Coastguard Worker auto instructions = chain.s_chain_->instructions();
1036*523fa7a6SAndroid Build Coastguard Worker
1037*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
1038*523fa7a6SAndroid Build Coastguard Worker step_state_.instr_idx < instructions->size(),
1039*523fa7a6SAndroid Build Coastguard Worker Internal,
1040*523fa7a6SAndroid Build Coastguard Worker "Instr index %zu >= chain[%zu] instr count %zu",
1041*523fa7a6SAndroid Build Coastguard Worker step_state_.instr_idx,
1042*523fa7a6SAndroid Build Coastguard Worker step_state_.chain_idx,
1043*523fa7a6SAndroid Build Coastguard Worker (size_t)instructions->size());
1044*523fa7a6SAndroid Build Coastguard Worker
1045*523fa7a6SAndroid Build Coastguard Worker auto instruction = instructions->Get(step_state_.instr_idx);
1046*523fa7a6SAndroid Build Coastguard Worker size_t next_instr_idx = step_state_.instr_idx + 1;
1047*523fa7a6SAndroid Build Coastguard Worker Error err = Error::Ok;
1048*523fa7a6SAndroid Build Coastguard Worker
1049*523fa7a6SAndroid Build Coastguard Worker switch (instruction->instr_args_type()) {
1050*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::InstructionArguments::KernelCall: {
1051*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_SCOPE_PROF("OPERATOR_CALL");
1052*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileOpScope event_tracer_op_scope =
1053*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileOpScope(event_tracer_, "OPERATOR_CALL");
1054*523fa7a6SAndroid Build Coastguard Worker // TODO(T147221312): Also expose tensor resizer via the context.
1055*523fa7a6SAndroid Build Coastguard Worker KernelRuntimeContext context(event_tracer_, temp_allocator_);
1056*523fa7a6SAndroid Build Coastguard Worker auto args = chain.argument_lists_[step_state_.instr_idx];
1057*523fa7a6SAndroid Build Coastguard Worker chain.kernels_[step_state_.instr_idx](context, args.data());
1058*523fa7a6SAndroid Build Coastguard Worker // We reset the temp_allocator after the switch statement
1059*523fa7a6SAndroid Build Coastguard Worker err = context.failure_state();
1060*523fa7a6SAndroid Build Coastguard Worker if (err != Error::Ok) {
1061*523fa7a6SAndroid Build Coastguard Worker // We know that instr_args_as_KernelCall is non-null because it was
1062*523fa7a6SAndroid Build Coastguard Worker // checked at init time.
1063*523fa7a6SAndroid Build Coastguard Worker auto op_index = instruction->instr_args_as_KernelCall()->op_index();
1064*523fa7a6SAndroid Build Coastguard Worker auto op = serialization_plan_->operators()->Get(op_index);
1065*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
1066*523fa7a6SAndroid Build Coastguard Worker Error,
1067*523fa7a6SAndroid Build Coastguard Worker "KernelCall failed at instruction %zu:%zu in operator %s.%s: 0x%x",
1068*523fa7a6SAndroid Build Coastguard Worker step_state_.chain_idx,
1069*523fa7a6SAndroid Build Coastguard Worker step_state_.instr_idx,
1070*523fa7a6SAndroid Build Coastguard Worker op->name()->c_str(),
1071*523fa7a6SAndroid Build Coastguard Worker op->overload()->c_str(),
1072*523fa7a6SAndroid Build Coastguard Worker (unsigned int)err);
1073*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < args.size(); ++i) {
1074*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
1075*523fa7a6SAndroid Build Coastguard Worker Error,
1076*523fa7a6SAndroid Build Coastguard Worker "arg %u with type id %u",
1077*523fa7a6SAndroid Build Coastguard Worker (unsigned int)i,
1078*523fa7a6SAndroid Build Coastguard Worker (unsigned int)args[i]->tag);
1079*523fa7a6SAndroid Build Coastguard Worker }
1080*523fa7a6SAndroid Build Coastguard Worker // TODO(T153804650): Consider logging the EValues to help with
1081*523fa7a6SAndroid Build Coastguard Worker // debugging. This is a failure path, and it doesn't matter if it's a
1082*523fa7a6SAndroid Build Coastguard Worker // little slow. Do the same for DelegateCall errors.
1083*523fa7a6SAndroid Build Coastguard Worker }
1084*523fa7a6SAndroid Build Coastguard Worker } break;
1085*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::InstructionArguments::DelegateCall: {
1086*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_SCOPE_PROF("DELEGATE_CALL");
1087*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileOpScope event_tracer_op_scope =
1088*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileOpScope(event_tracer_, "DELEGATE_CALL");
1089*523fa7a6SAndroid Build Coastguard Worker // We know that instr_args_as_DelegateCall is non-null because it was
1090*523fa7a6SAndroid Build Coastguard Worker // checked at init time.
1091*523fa7a6SAndroid Build Coastguard Worker auto delegate_idx =
1092*523fa7a6SAndroid Build Coastguard Worker instruction->instr_args_as_DelegateCall()->delegate_index();
1093*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
1094*523fa7a6SAndroid Build Coastguard Worker delegate_idx < n_delegate_,
1095*523fa7a6SAndroid Build Coastguard Worker Internal,
1096*523fa7a6SAndroid Build Coastguard Worker "DELEGATE_CALL index %" PRIu32
1097*523fa7a6SAndroid Build Coastguard Worker " >= num delegates %zu at instruction %zu",
1098*523fa7a6SAndroid Build Coastguard Worker delegate_idx,
1099*523fa7a6SAndroid Build Coastguard Worker n_delegate_,
1100*523fa7a6SAndroid Build Coastguard Worker step_state_.instr_idx);
1101*523fa7a6SAndroid Build Coastguard Worker BackendExecutionContext backend_execution_context(
1102*523fa7a6SAndroid Build Coastguard Worker /*event_tracer=*/event_tracer_,
1103*523fa7a6SAndroid Build Coastguard Worker /*temp_allocator=*/temp_allocator_,
1104*523fa7a6SAndroid Build Coastguard Worker /*method_name=*/serialization_plan_->name()->c_str());
1105*523fa7a6SAndroid Build Coastguard Worker err = delegates_[delegate_idx].Execute(
1106*523fa7a6SAndroid Build Coastguard Worker backend_execution_context,
1107*523fa7a6SAndroid Build Coastguard Worker chain.argument_lists_[step_state_.instr_idx].data());
1108*523fa7a6SAndroid Build Coastguard Worker if (err != Error::Ok) {
1109*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
1110*523fa7a6SAndroid Build Coastguard Worker Error,
1111*523fa7a6SAndroid Build Coastguard Worker "CALL_DELEGATE execute failed at instruction %zu: 0x%" PRIx32,
1112*523fa7a6SAndroid Build Coastguard Worker step_state_.instr_idx,
1113*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(err));
1114*523fa7a6SAndroid Build Coastguard Worker }
1115*523fa7a6SAndroid Build Coastguard Worker
1116*523fa7a6SAndroid Build Coastguard Worker // Log all the arguments of the delegate call. Ideally we'd only like to
1117*523fa7a6SAndroid Build Coastguard Worker // log the outputs of the delegate, but currently we cannot know from the
1118*523fa7a6SAndroid Build Coastguard Worker // arguments which are the inputs and which are the outputs, so we just
1119*523fa7a6SAndroid Build Coastguard Worker // log everything. This will be changed in the future when the inputs and
1120*523fa7a6SAndroid Build Coastguard Worker // ouputs are separate lists.
1121*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_EVENT_TRACER_ENABLED
1122*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0;
1123*523fa7a6SAndroid Build Coastguard Worker i < chain.argument_lists_[step_state_.instr_idx].size();
1124*523fa7a6SAndroid Build Coastguard Worker i++) {
1125*523fa7a6SAndroid Build Coastguard Worker EValue* arg = chain.argument_lists_[step_state_.instr_idx].data()[i];
1126*523fa7a6SAndroid Build Coastguard Worker internal::event_tracer_log_evalue(event_tracer_, *arg);
1127*523fa7a6SAndroid Build Coastguard Worker }
1128*523fa7a6SAndroid Build Coastguard Worker #endif
1129*523fa7a6SAndroid Build Coastguard Worker } break;
1130*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::InstructionArguments::JumpFalseCall: {
1131*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_SCOPE_PROF("JF_CALL");
1132*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileOpScope event_tracer_op_scope =
1133*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileOpScope(event_tracer_, "JF_CALL");
1134*523fa7a6SAndroid Build Coastguard Worker // We know that instr_args_as_JumpFalseCall is non-null because it was
1135*523fa7a6SAndroid Build Coastguard Worker // checked at init time.
1136*523fa7a6SAndroid Build Coastguard Worker auto jf_call = instruction->instr_args_as_JumpFalseCall();
1137*523fa7a6SAndroid Build Coastguard Worker // We know that index is a valid values_ index because it was checked at
1138*523fa7a6SAndroid Build Coastguard Worker // init time.
1139*523fa7a6SAndroid Build Coastguard Worker auto index = jf_call->cond_value_index();
1140*523fa7a6SAndroid Build Coastguard Worker Result<bool> jf_result = parse_cond_value(values_[index]);
1141*523fa7a6SAndroid Build Coastguard Worker if (jf_result.ok()) {
1142*523fa7a6SAndroid Build Coastguard Worker if (!jf_result.get()) {
1143*523fa7a6SAndroid Build Coastguard Worker next_instr_idx = jf_call->destination_instruction();
1144*523fa7a6SAndroid Build Coastguard Worker }
1145*523fa7a6SAndroid Build Coastguard Worker } else {
1146*523fa7a6SAndroid Build Coastguard Worker err = jf_result.error();
1147*523fa7a6SAndroid Build Coastguard Worker }
1148*523fa7a6SAndroid Build Coastguard Worker } break;
1149*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::InstructionArguments::MoveCall: {
1150*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_SCOPE_PROF("MOVE_CALL");
1151*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileOpScope event_tracer_op_scope =
1152*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileOpScope(event_tracer_, "MOVE_CALL");
1153*523fa7a6SAndroid Build Coastguard Worker // We know that instr_args_as_MoveCall is non-null because it was checked
1154*523fa7a6SAndroid Build Coastguard Worker // at init time.
1155*523fa7a6SAndroid Build Coastguard Worker auto move_call = instruction->instr_args_as_MoveCall();
1156*523fa7a6SAndroid Build Coastguard Worker mutable_value(move_call->move_to()) = get_value(move_call->move_from());
1157*523fa7a6SAndroid Build Coastguard Worker } break;
1158*523fa7a6SAndroid Build Coastguard Worker case executorch_flatbuffer::InstructionArguments::FreeCall: {
1159*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_SCOPE_PROF("FREE_CALL");
1160*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileOpScope event_tracer_op_scope =
1161*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileOpScope(event_tracer_, "FREE_CALL");
1162*523fa7a6SAndroid Build Coastguard Worker // We know that instr_args_as_FreeCall is non-null because it was checked
1163*523fa7a6SAndroid Build Coastguard Worker // at init time.
1164*523fa7a6SAndroid Build Coastguard Worker auto free_call = instruction->instr_args_as_FreeCall();
1165*523fa7a6SAndroid Build Coastguard Worker auto t = values_[free_call->value_index()].toTensor();
1166*523fa7a6SAndroid Build Coastguard Worker internal::reset_data_ptr(t);
1167*523fa7a6SAndroid Build Coastguard Worker } break;
1168*523fa7a6SAndroid Build Coastguard Worker default:
1169*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
1170*523fa7a6SAndroid Build Coastguard Worker Error,
1171*523fa7a6SAndroid Build Coastguard Worker "Unknown instruction: %hhu",
1172*523fa7a6SAndroid Build Coastguard Worker static_cast<uint8_t>(instruction->instr_args_type()));
1173*523fa7a6SAndroid Build Coastguard Worker err = Error::InvalidProgram;
1174*523fa7a6SAndroid Build Coastguard Worker }
1175*523fa7a6SAndroid Build Coastguard Worker // Reset the temp allocator for every instruction.
1176*523fa7a6SAndroid Build Coastguard Worker if (temp_allocator_ != nullptr) {
1177*523fa7a6SAndroid Build Coastguard Worker temp_allocator_->reset();
1178*523fa7a6SAndroid Build Coastguard Worker }
1179*523fa7a6SAndroid Build Coastguard Worker if (err == Error::Ok) {
1180*523fa7a6SAndroid Build Coastguard Worker step_state_.instr_idx = next_instr_idx;
1181*523fa7a6SAndroid Build Coastguard Worker }
1182*523fa7a6SAndroid Build Coastguard Worker return err;
1183*523fa7a6SAndroid Build Coastguard Worker }
1184*523fa7a6SAndroid Build Coastguard Worker
reset_execution()1185*523fa7a6SAndroid Build Coastguard Worker Error Method::reset_execution() {
1186*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
1187*523fa7a6SAndroid Build Coastguard Worker step_state_.chain_idx == n_chains_,
1188*523fa7a6SAndroid Build Coastguard Worker InvalidState,
1189*523fa7a6SAndroid Build Coastguard Worker "Cannot reset until EndOfMethod has been reached.");
1190*523fa7a6SAndroid Build Coastguard Worker step_state_ = StepState{0, 0};
1191*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
1192*523fa7a6SAndroid Build Coastguard Worker }
1193*523fa7a6SAndroid Build Coastguard Worker
experimental_reset_execution()1194*523fa7a6SAndroid Build Coastguard Worker Error Method::experimental_reset_execution() {
1195*523fa7a6SAndroid Build Coastguard Worker return reset_execution(); // @lint-ignore CLANGTIDY facebook-hte-Deprecated
1196*523fa7a6SAndroid Build Coastguard Worker }
1197*523fa7a6SAndroid Build Coastguard Worker
1198*523fa7a6SAndroid Build Coastguard Worker // Log all the outputs of this method to the event tracer.
log_outputs()1199*523fa7a6SAndroid Build Coastguard Worker void Method::log_outputs() {
1200*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_EVENT_TRACER_ENABLED
1201*523fa7a6SAndroid Build Coastguard Worker if (event_tracer_ != nullptr) {
1202*523fa7a6SAndroid Build Coastguard Worker if (event_tracer_->event_tracer_debug_level() >=
1203*523fa7a6SAndroid Build Coastguard Worker EventTracerDebugLogLevel::kProgramOutputs) {
1204*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < outputs_size(); i++) {
1205*523fa7a6SAndroid Build Coastguard Worker internal::event_tracer_log_evalue_output(event_tracer_, get_output(i));
1206*523fa7a6SAndroid Build Coastguard Worker }
1207*523fa7a6SAndroid Build Coastguard Worker }
1208*523fa7a6SAndroid Build Coastguard Worker }
1209*523fa7a6SAndroid Build Coastguard Worker #endif
1210*523fa7a6SAndroid Build Coastguard Worker }
1211*523fa7a6SAndroid Build Coastguard Worker
step()1212*523fa7a6SAndroid Build Coastguard Worker Error Method::step() {
1213*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_PROFILE_INSTRUCTION_SCOPE(
1214*523fa7a6SAndroid Build Coastguard Worker static_cast<int32_t>(step_state_.chain_idx),
1215*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(step_state_.instr_idx));
1216*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileInstructionScope event_tracer_instr_scope =
1217*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileInstructionScope(
1218*523fa7a6SAndroid Build Coastguard Worker event_tracer_,
1219*523fa7a6SAndroid Build Coastguard Worker static_cast<int32_t>(step_state_.chain_idx),
1220*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(step_state_.instr_idx));
1221*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_SCOPE_PROF("Method::step");
1222*523fa7a6SAndroid Build Coastguard Worker EventTracerEntry event_tracer_entry =
1223*523fa7a6SAndroid Build Coastguard Worker internal::event_tracer_begin_profiling_event(
1224*523fa7a6SAndroid Build Coastguard Worker event_tracer_, "Method::step");
1225*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
1226*523fa7a6SAndroid Build Coastguard Worker initialized(),
1227*523fa7a6SAndroid Build Coastguard Worker InvalidState,
1228*523fa7a6SAndroid Build Coastguard Worker "Cannot execute until method has been initialized.");
1229*523fa7a6SAndroid Build Coastguard Worker
1230*523fa7a6SAndroid Build Coastguard Worker // If chain_step_ is on n_chains_, then we have no instructions run.
1231*523fa7a6SAndroid Build Coastguard Worker if (step_state_.chain_idx == n_chains_) {
1232*523fa7a6SAndroid Build Coastguard Worker return Error::EndOfMethod;
1233*523fa7a6SAndroid Build Coastguard Worker }
1234*523fa7a6SAndroid Build Coastguard Worker
1235*523fa7a6SAndroid Build Coastguard Worker auto num_instructions =
1236*523fa7a6SAndroid Build Coastguard Worker chains_[step_state_.chain_idx].s_chain_->instructions()->size();
1237*523fa7a6SAndroid Build Coastguard Worker
1238*523fa7a6SAndroid Build Coastguard Worker // Special case chains with no instructions. These appear for example in a
1239*523fa7a6SAndroid Build Coastguard Worker // model that just returns the input/a constant.
1240*523fa7a6SAndroid Build Coastguard Worker if (num_instructions == 0) {
1241*523fa7a6SAndroid Build Coastguard Worker step_state_.chain_idx += 1;
1242*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
1243*523fa7a6SAndroid Build Coastguard Worker }
1244*523fa7a6SAndroid Build Coastguard Worker
1245*523fa7a6SAndroid Build Coastguard Worker auto status = execute_instruction();
1246*523fa7a6SAndroid Build Coastguard Worker if (status != Error::Ok) {
1247*523fa7a6SAndroid Build Coastguard Worker return status;
1248*523fa7a6SAndroid Build Coastguard Worker }
1249*523fa7a6SAndroid Build Coastguard Worker
1250*523fa7a6SAndroid Build Coastguard Worker internal::event_tracer_end_profiling_event(event_tracer_, event_tracer_entry);
1251*523fa7a6SAndroid Build Coastguard Worker // end of the current chain, advance to the next chain
1252*523fa7a6SAndroid Build Coastguard Worker if (step_state_.instr_idx == num_instructions) {
1253*523fa7a6SAndroid Build Coastguard Worker step_state_.instr_idx = 0;
1254*523fa7a6SAndroid Build Coastguard Worker step_state_.chain_idx += 1;
1255*523fa7a6SAndroid Build Coastguard Worker log_outputs();
1256*523fa7a6SAndroid Build Coastguard Worker }
1257*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
1258*523fa7a6SAndroid Build Coastguard Worker }
1259*523fa7a6SAndroid Build Coastguard Worker
experimental_step()1260*523fa7a6SAndroid Build Coastguard Worker Error Method::experimental_step() {
1261*523fa7a6SAndroid Build Coastguard Worker return step();
1262*523fa7a6SAndroid Build Coastguard Worker }
1263*523fa7a6SAndroid Build Coastguard Worker
execute()1264*523fa7a6SAndroid Build Coastguard Worker Error Method::execute() {
1265*523fa7a6SAndroid Build Coastguard Worker internal::event_tracer_create_event_block(event_tracer_, "Execute");
1266*523fa7a6SAndroid Build Coastguard Worker EventTracerEntry event_tracer_entry =
1267*523fa7a6SAndroid Build Coastguard Worker internal::event_tracer_begin_profiling_event(
1268*523fa7a6SAndroid Build Coastguard Worker event_tracer_, "Method::execute");
1269*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_SCOPE_PROF("Method::execute");
1270*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
1271*523fa7a6SAndroid Build Coastguard Worker initialized(),
1272*523fa7a6SAndroid Build Coastguard Worker NotSupported,
1273*523fa7a6SAndroid Build Coastguard Worker "Cannot execute until method has been initialized.");
1274*523fa7a6SAndroid Build Coastguard Worker
1275*523fa7a6SAndroid Build Coastguard Worker // Chains are executed sequentially today, but future async designs may
1276*523fa7a6SAndroid Build Coastguard Worker // branch and run many in parallel or out of order.
1277*523fa7a6SAndroid Build Coastguard Worker for (step_state_.chain_idx = 0; step_state_.chain_idx < n_chains_;
1278*523fa7a6SAndroid Build Coastguard Worker ++step_state_.chain_idx) {
1279*523fa7a6SAndroid Build Coastguard Worker Chain& chain = chains_[step_state_.chain_idx];
1280*523fa7a6SAndroid Build Coastguard Worker auto instructions = chain.s_chain_->instructions();
1281*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
1282*523fa7a6SAndroid Build Coastguard Worker instructions != nullptr,
1283*523fa7a6SAndroid Build Coastguard Worker Internal,
1284*523fa7a6SAndroid Build Coastguard Worker "chain %zu has no instructions field",
1285*523fa7a6SAndroid Build Coastguard Worker step_state_.chain_idx);
1286*523fa7a6SAndroid Build Coastguard Worker
1287*523fa7a6SAndroid Build Coastguard Worker // Loop over instructions
1288*523fa7a6SAndroid Build Coastguard Worker step_state_.instr_idx = 0;
1289*523fa7a6SAndroid Build Coastguard Worker while (step_state_.instr_idx < chain.s_chain_->instructions()->size()) {
1290*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_PROFILE_INSTRUCTION_SCOPE(
1291*523fa7a6SAndroid Build Coastguard Worker static_cast<int32_t>(step_state_.chain_idx),
1292*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(step_state_.instr_idx));
1293*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileInstructionScope event_tracer_instr_scope =
1294*523fa7a6SAndroid Build Coastguard Worker internal::EventTracerProfileInstructionScope(
1295*523fa7a6SAndroid Build Coastguard Worker event_tracer_,
1296*523fa7a6SAndroid Build Coastguard Worker static_cast<ChainID>(step_state_.chain_idx),
1297*523fa7a6SAndroid Build Coastguard Worker static_cast<DebugHandle>(step_state_.instr_idx));
1298*523fa7a6SAndroid Build Coastguard Worker auto status = execute_instruction();
1299*523fa7a6SAndroid Build Coastguard Worker if (status != Error::Ok) {
1300*523fa7a6SAndroid Build Coastguard Worker return status;
1301*523fa7a6SAndroid Build Coastguard Worker }
1302*523fa7a6SAndroid Build Coastguard Worker }
1303*523fa7a6SAndroid Build Coastguard Worker }
1304*523fa7a6SAndroid Build Coastguard Worker internal::event_tracer_end_profiling_event(event_tracer_, event_tracer_entry);
1305*523fa7a6SAndroid Build Coastguard Worker log_outputs();
1306*523fa7a6SAndroid Build Coastguard Worker
1307*523fa7a6SAndroid Build Coastguard Worker // TODO(jakeszwe, dbort): Decide on calling execute back to back without
1308*523fa7a6SAndroid Build Coastguard Worker // going through the reset api first.
1309*523fa7a6SAndroid Build Coastguard Worker return reset_execution(); // @lint-ignore CLANGTIDY facebook-hte-Deprecated
1310*523fa7a6SAndroid Build Coastguard Worker }
1311*523fa7a6SAndroid Build Coastguard Worker
method_meta() const1312*523fa7a6SAndroid Build Coastguard Worker MethodMeta Method::method_meta() const {
1313*523fa7a6SAndroid Build Coastguard Worker auto name = serialization_plan_->name()->c_str();
1314*523fa7a6SAndroid Build Coastguard Worker auto method_meta = program_->method_meta(name);
1315*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
1316*523fa7a6SAndroid Build Coastguard Worker method_meta.ok(),
1317*523fa7a6SAndroid Build Coastguard Worker "Internal error: method_meta(%s) returned 0x%" PRIx32,
1318*523fa7a6SAndroid Build Coastguard Worker name,
1319*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(method_meta.error()));
1320*523fa7a6SAndroid Build Coastguard Worker return method_meta.get();
1321*523fa7a6SAndroid Build Coastguard Worker }
1322*523fa7a6SAndroid Build Coastguard Worker
get_value(size_t i) const1323*523fa7a6SAndroid Build Coastguard Worker const EValue& Method::get_value(size_t i) const {
1324*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(i < n_value_, "%zu >= %zu", i, n_value_);
1325*523fa7a6SAndroid Build Coastguard Worker return values_[i];
1326*523fa7a6SAndroid Build Coastguard Worker }
1327*523fa7a6SAndroid Build Coastguard Worker
mutable_value(size_t i)1328*523fa7a6SAndroid Build Coastguard Worker EValue& Method::mutable_value(size_t i) {
1329*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(i < n_value_, "%zu >= %zu", i, n_value_);
1330*523fa7a6SAndroid Build Coastguard Worker return values_[i];
1331*523fa7a6SAndroid Build Coastguard Worker }
1332*523fa7a6SAndroid Build Coastguard Worker
inputs_size() const1333*523fa7a6SAndroid Build Coastguard Worker size_t Method::inputs_size() const {
1334*523fa7a6SAndroid Build Coastguard Worker const auto* inputs = serialization_plan_->inputs();
1335*523fa7a6SAndroid Build Coastguard Worker return inputs == nullptr ? 0 : inputs->size();
1336*523fa7a6SAndroid Build Coastguard Worker }
1337*523fa7a6SAndroid Build Coastguard Worker
get_input_index(size_t i) const1338*523fa7a6SAndroid Build Coastguard Worker size_t Method::get_input_index(size_t i) const {
1339*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(i < inputs_size(), "%zu >= %zu", i, inputs_size());
1340*523fa7a6SAndroid Build Coastguard Worker return static_cast<size_t>(serialization_plan_->inputs()->Get(i));
1341*523fa7a6SAndroid Build Coastguard Worker }
1342*523fa7a6SAndroid Build Coastguard Worker
get_input(size_t i) const1343*523fa7a6SAndroid Build Coastguard Worker const EValue& Method::get_input(size_t i) const {
1344*523fa7a6SAndroid Build Coastguard Worker return get_value(get_input_index(i));
1345*523fa7a6SAndroid Build Coastguard Worker }
1346*523fa7a6SAndroid Build Coastguard Worker
mutable_input(size_t i)1347*523fa7a6SAndroid Build Coastguard Worker EValue& Method::mutable_input(size_t i) {
1348*523fa7a6SAndroid Build Coastguard Worker return mutable_value(get_input_index(i));
1349*523fa7a6SAndroid Build Coastguard Worker }
1350*523fa7a6SAndroid Build Coastguard Worker
outputs_size() const1351*523fa7a6SAndroid Build Coastguard Worker size_t Method::outputs_size() const {
1352*523fa7a6SAndroid Build Coastguard Worker const auto* outputs = serialization_plan_->outputs();
1353*523fa7a6SAndroid Build Coastguard Worker return outputs == nullptr ? 0 : outputs->size();
1354*523fa7a6SAndroid Build Coastguard Worker }
1355*523fa7a6SAndroid Build Coastguard Worker
get_output_index(size_t i) const1356*523fa7a6SAndroid Build Coastguard Worker size_t Method::get_output_index(size_t i) const {
1357*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(i < outputs_size(), "%zu >= %zu", i, outputs_size());
1358*523fa7a6SAndroid Build Coastguard Worker return static_cast<size_t>(serialization_plan_->outputs()->Get(i));
1359*523fa7a6SAndroid Build Coastguard Worker }
1360*523fa7a6SAndroid Build Coastguard Worker
get_output(size_t i) const1361*523fa7a6SAndroid Build Coastguard Worker const EValue& Method::get_output(size_t i) const {
1362*523fa7a6SAndroid Build Coastguard Worker return get_value(get_output_index(i));
1363*523fa7a6SAndroid Build Coastguard Worker }
1364*523fa7a6SAndroid Build Coastguard Worker
mutable_output(size_t i)1365*523fa7a6SAndroid Build Coastguard Worker EValue& Method::mutable_output(size_t i) {
1366*523fa7a6SAndroid Build Coastguard Worker return mutable_value(get_output_index(i));
1367*523fa7a6SAndroid Build Coastguard Worker }
1368*523fa7a6SAndroid Build Coastguard Worker
get_event_tracer()1369*523fa7a6SAndroid Build Coastguard Worker EventTracer* Method::get_event_tracer() {
1370*523fa7a6SAndroid Build Coastguard Worker return event_tracer_;
1371*523fa7a6SAndroid Build Coastguard Worker }
1372*523fa7a6SAndroid Build Coastguard Worker
~Method()1373*523fa7a6SAndroid Build Coastguard Worker Method::~Method() {
1374*523fa7a6SAndroid Build Coastguard Worker // Destroy the values. It's necessary in ATen mode, where the refcount of
1375*523fa7a6SAndroid Build Coastguard Worker // Tensors needs to be decremented properly.
1376*523fa7a6SAndroid Build Coastguard Worker if (values_ != nullptr) {
1377*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < n_value_; ++i) {
1378*523fa7a6SAndroid Build Coastguard Worker values_[i].~EValue();
1379*523fa7a6SAndroid Build Coastguard Worker }
1380*523fa7a6SAndroid Build Coastguard Worker }
1381*523fa7a6SAndroid Build Coastguard Worker // Free any resources associated with delegate backends.
1382*523fa7a6SAndroid Build Coastguard Worker if (delegates_ != nullptr) {
1383*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < n_delegate_; i++) {
1384*523fa7a6SAndroid Build Coastguard Worker delegates_[i].~BackendDelegate();
1385*523fa7a6SAndroid Build Coastguard Worker }
1386*523fa7a6SAndroid Build Coastguard Worker }
1387*523fa7a6SAndroid Build Coastguard Worker // All other fields are trivially destructible.
1388*523fa7a6SAndroid Build Coastguard Worker }
1389*523fa7a6SAndroid Build Coastguard Worker } // namespace runtime
1390*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
1391