1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <algorithm>
10*523fa7a6SAndroid Build Coastguard Worker #include <cstdio>
11*523fa7a6SAndroid Build Coastguard Worker #include <iostream>
12*523fa7a6SAndroid Build Coastguard Worker #include <memory>
13*523fa7a6SAndroid Build Coastguard Worker #include <stdexcept>
14*523fa7a6SAndroid Build Coastguard Worker #include <unordered_map>
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/iostream.h>
17*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/pybind11.h>
18*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/stl.h>
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker #include <executorch/devtools/bundled_program/bundled_program.h>
21*523fa7a6SAndroid Build Coastguard Worker #include <executorch/devtools/bundled_program/schema/bundled_program_schema_generated.h>
22*523fa7a6SAndroid Build Coastguard Worker #include <executorch/devtools/etdump/etdump_flatcc.h>
23*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/data_loader/buffer_data_loader.h>
24*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/data_loader/mmap_data_loader.h>
25*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
26*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/data_loader.h>
27*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
28*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/executor/method.h>
29*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/executor/program.h>
30*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/kernel/operator_registry.h>
31*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/assert.h>
32*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/platform.h>
33*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/profiler.h>
34*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/runtime.h>
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker #include <ATen/Functions.h>
37*523fa7a6SAndroid Build Coastguard Worker #include <ATen/Tensor.h>
38*523fa7a6SAndroid Build Coastguard Worker #include <ATen/core/functional.h>
39*523fa7a6SAndroid Build Coastguard Worker #include <c10/core/ScalarTypeToTypeMeta.h>
40*523fa7a6SAndroid Build Coastguard Worker #include <torch/csrc/utils/pybind.h>
41*523fa7a6SAndroid Build Coastguard Worker #include <torch/python.h>
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker #ifndef USE_ATEN_LIB
44*523fa7a6SAndroid Build Coastguard Worker #include <c10/core/impl/LocalDispatchKeySet.h>
45*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/aten_util/aten_bridge.h>
46*523fa7a6SAndroid Build Coastguard Worker #endif
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker /// Throws a runtime_error with the provided message if `error` is not `Ok`.
49*523fa7a6SAndroid Build Coastguard Worker #define THROW_IF_ERROR(error, message, ...) \
50*523fa7a6SAndroid Build Coastguard Worker ({ \
51*523fa7a6SAndroid Build Coastguard Worker if ((error) != Error::Ok) { \
52*523fa7a6SAndroid Build Coastguard Worker char msg_buf[128]; \
53*523fa7a6SAndroid Build Coastguard Worker snprintf(msg_buf, sizeof(msg_buf), message, ##__VA_ARGS__); \
54*523fa7a6SAndroid Build Coastguard Worker /* pybind will convert this to a python exception. */ \
55*523fa7a6SAndroid Build Coastguard Worker throw std::runtime_error(msg_buf); \
56*523fa7a6SAndroid Build Coastguard Worker } \
57*523fa7a6SAndroid Build Coastguard Worker })
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Worker #define THROW_INDEX_IF_ERROR(error, message, ...) \
60*523fa7a6SAndroid Build Coastguard Worker ({ \
61*523fa7a6SAndroid Build Coastguard Worker if ((error) != Error::Ok) { \
62*523fa7a6SAndroid Build Coastguard Worker char msg_buf[128]; \
63*523fa7a6SAndroid Build Coastguard Worker snprintf(msg_buf, sizeof(msg_buf), message, ##__VA_ARGS__); \
64*523fa7a6SAndroid Build Coastguard Worker /* pybind will convert this to a python exception. */ \
65*523fa7a6SAndroid Build Coastguard Worker throw std::out_of_range(msg_buf); \
66*523fa7a6SAndroid Build Coastguard Worker } \
67*523fa7a6SAndroid Build Coastguard Worker })
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker // Our logs work by writing to stderr. By default this is done through fprintf
70*523fa7a6SAndroid Build Coastguard Worker // (as defined in posix.cpp) which then does not show up in python environments.
71*523fa7a6SAndroid Build Coastguard Worker // Here we override the pal to use std::cerr which can be properly redirected by
72*523fa7a6SAndroid Build Coastguard Worker // scoped_estream_redirect.
et_pal_emit_log_message(et_timestamp_t timestamp,et_pal_log_level_t level,const char * filename,ET_UNUSED const char * function,size_t line,const char * message,ET_UNUSED size_t length)73*523fa7a6SAndroid Build Coastguard Worker void et_pal_emit_log_message(
74*523fa7a6SAndroid Build Coastguard Worker et_timestamp_t timestamp,
75*523fa7a6SAndroid Build Coastguard Worker et_pal_log_level_t level,
76*523fa7a6SAndroid Build Coastguard Worker const char* filename,
77*523fa7a6SAndroid Build Coastguard Worker ET_UNUSED const char* function,
78*523fa7a6SAndroid Build Coastguard Worker size_t line,
79*523fa7a6SAndroid Build Coastguard Worker const char* message,
80*523fa7a6SAndroid Build Coastguard Worker ET_UNUSED size_t length) {
81*523fa7a6SAndroid Build Coastguard Worker std::cerr << "[" << filename << ":" << line << "] " << message << std::endl;
82*523fa7a6SAndroid Build Coastguard Worker }
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker namespace py = pybind11;
85*523fa7a6SAndroid Build Coastguard Worker using executorch::bundled_program::verify_method_outputs;
86*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::BufferDataLoader;
87*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::MallocMemoryAllocator;
88*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::MmapDataLoader;
89*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::ArrayRef;
90*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::DataLoader;
91*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Error;
92*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::EValue;
93*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::EventTracerDebugLogLevel;
94*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::get_registered_kernels;
95*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::HierarchicalAllocator;
96*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Kernel;
97*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::MemoryAllocator;
98*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::MemoryManager;
99*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Method;
100*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::prof_result_t;
101*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Program;
102*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Result;
103*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Span;
104*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Tag;
105*523fa7a6SAndroid Build Coastguard Worker using torch::executor::etdump_result;
106*523fa7a6SAndroid Build Coastguard Worker using torch::executor::ETDumpGen;
107*523fa7a6SAndroid Build Coastguard Worker
108*523fa7a6SAndroid Build Coastguard Worker #ifndef USE_ATEN_LIB
109*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::alias_attensor_to_etensor;
110*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::alias_etensor_to_attensor;
111*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::torch_to_executorch_scalar_type;
112*523fa7a6SAndroid Build Coastguard Worker #endif // !USE_ATEN_LIB
113*523fa7a6SAndroid Build Coastguard Worker
114*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
115*523fa7a6SAndroid Build Coastguard Worker namespace extension {
116*523fa7a6SAndroid Build Coastguard Worker namespace pybindings {
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Worker namespace {
119*523fa7a6SAndroid Build Coastguard Worker
write_data_to_file(const std::string & path,void * buf,size_t size)120*523fa7a6SAndroid Build Coastguard Worker void write_data_to_file(const std::string& path, void* buf, size_t size) {
121*523fa7a6SAndroid Build Coastguard Worker FILE* f = fopen(path.c_str(), "w+");
122*523fa7a6SAndroid Build Coastguard Worker if (!f) {
123*523fa7a6SAndroid Build Coastguard Worker throw std::runtime_error(
124*523fa7a6SAndroid Build Coastguard Worker "Failed to open file " + path + ": " + strerror(errno));
125*523fa7a6SAndroid Build Coastguard Worker }
126*523fa7a6SAndroid Build Coastguard Worker size_t num_written = fwrite(buf, 1, size, f);
127*523fa7a6SAndroid Build Coastguard Worker if (num_written != size) {
128*523fa7a6SAndroid Build Coastguard Worker fclose(f);
129*523fa7a6SAndroid Build Coastguard Worker throw std::runtime_error("Failed to write etdump to file " + path);
130*523fa7a6SAndroid Build Coastguard Worker }
131*523fa7a6SAndroid Build Coastguard Worker int err = fclose(f);
132*523fa7a6SAndroid Build Coastguard Worker if (err) {
133*523fa7a6SAndroid Build Coastguard Worker throw std::runtime_error(
134*523fa7a6SAndroid Build Coastguard Worker "Failed to close etdump file " + path + ": " + strerror(err));
135*523fa7a6SAndroid Build Coastguard Worker }
136*523fa7a6SAndroid Build Coastguard Worker }
137*523fa7a6SAndroid Build Coastguard Worker
setup_output_storage(Method & method,const std::vector<Span<uint8_t>> & output_storages)138*523fa7a6SAndroid Build Coastguard Worker void setup_output_storage(
139*523fa7a6SAndroid Build Coastguard Worker Method& method,
140*523fa7a6SAndroid Build Coastguard Worker const std::vector<Span<uint8_t>>& output_storages) {
141*523fa7a6SAndroid Build Coastguard Worker if (output_storages.size() != method.outputs_size()) {
142*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
143*523fa7a6SAndroid Build Coastguard Worker Error::InvalidArgument,
144*523fa7a6SAndroid Build Coastguard Worker "number of output storages %zu does not match number of outputs %zu",
145*523fa7a6SAndroid Build Coastguard Worker output_storages.size(),
146*523fa7a6SAndroid Build Coastguard Worker method.outputs_size());
147*523fa7a6SAndroid Build Coastguard Worker }
148*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < output_storages.size(); ++i) {
149*523fa7a6SAndroid Build Coastguard Worker if (output_storages[i].size() == 0) {
150*523fa7a6SAndroid Build Coastguard Worker // Skip empty output storages, this would happen for non-tensor outputs
151*523fa7a6SAndroid Build Coastguard Worker // and memory planned outputs.
152*523fa7a6SAndroid Build Coastguard Worker continue;
153*523fa7a6SAndroid Build Coastguard Worker }
154*523fa7a6SAndroid Build Coastguard Worker Error output_status = method.set_output_data_ptr(
155*523fa7a6SAndroid Build Coastguard Worker output_storages[i].data(), output_storages[i].size(), i);
156*523fa7a6SAndroid Build Coastguard Worker // We already should be skipping non-tensor outputs, and memory planned
157*523fa7a6SAndroid Build Coastguard Worker // outputs so any error is real.
158*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
159*523fa7a6SAndroid Build Coastguard Worker output_status,
160*523fa7a6SAndroid Build Coastguard Worker "set_output_data_ptr failed for output %zu with error 0x%" PRIx32,
161*523fa7a6SAndroid Build Coastguard Worker i,
162*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(output_status));
163*523fa7a6SAndroid Build Coastguard Worker }
164*523fa7a6SAndroid Build Coastguard Worker }
165*523fa7a6SAndroid Build Coastguard Worker
166*523fa7a6SAndroid Build Coastguard Worker class Module final {
167*523fa7a6SAndroid Build Coastguard Worker public:
Module(std::unique_ptr<DataLoader> loader,std::unique_ptr<ETDumpGen> tracer=nullptr,size_t debug_buffer_size=0,Program::Verification program_verification=Program::Verification::InternalConsistency)168*523fa7a6SAndroid Build Coastguard Worker explicit Module(
169*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<DataLoader> loader,
170*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<ETDumpGen> tracer = nullptr,
171*523fa7a6SAndroid Build Coastguard Worker size_t debug_buffer_size = 0,
172*523fa7a6SAndroid Build Coastguard Worker Program::Verification program_verification =
173*523fa7a6SAndroid Build Coastguard Worker Program::Verification::InternalConsistency)
174*523fa7a6SAndroid Build Coastguard Worker : loader_(std::move(loader)),
175*523fa7a6SAndroid Build Coastguard Worker event_tracer_(std::move(tracer)),
176*523fa7a6SAndroid Build Coastguard Worker debug_buffer_size_(debug_buffer_size) {
177*523fa7a6SAndroid Build Coastguard Worker ::executorch::runtime::runtime_init();
178*523fa7a6SAndroid Build Coastguard Worker Result<Program> program =
179*523fa7a6SAndroid Build Coastguard Worker Program::load(loader_.get(), program_verification);
180*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
181*523fa7a6SAndroid Build Coastguard Worker program.error(),
182*523fa7a6SAndroid Build Coastguard Worker "loading program failed with error: 0x%" PRIx32,
183*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(program.error()));
184*523fa7a6SAndroid Build Coastguard Worker program_ = std::make_unique<Program>(std::move(program.get()));
185*523fa7a6SAndroid Build Coastguard Worker
186*523fa7a6SAndroid Build Coastguard Worker // Figure out the size of each non_const layer we need to support every
187*523fa7a6SAndroid Build Coastguard Worker // method in the program. Map will be easier to use than a list because we
188*523fa7a6SAndroid Build Coastguard Worker // dont know how many non_const arenas there will be
189*523fa7a6SAndroid Build Coastguard Worker std::map<size_t, int64_t> non_const_buffer_sizes;
190*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < program_->num_methods(); ++i) {
191*523fa7a6SAndroid Build Coastguard Worker auto name = program_->get_method_name(i).get();
192*523fa7a6SAndroid Build Coastguard Worker auto method_meta = program_->method_meta(name).get();
193*523fa7a6SAndroid Build Coastguard Worker for (size_t j = 0; j < method_meta.num_non_const_buffers(); j++) {
194*523fa7a6SAndroid Build Coastguard Worker int64_t buffer_size = method_meta.non_const_buffer_size(j).get();
195*523fa7a6SAndroid Build Coastguard Worker if (non_const_buffer_sizes.find(j) == non_const_buffer_sizes.end()) {
196*523fa7a6SAndroid Build Coastguard Worker non_const_buffer_sizes.insert({j, buffer_size});
197*523fa7a6SAndroid Build Coastguard Worker } else {
198*523fa7a6SAndroid Build Coastguard Worker non_const_buffer_sizes[j] =
199*523fa7a6SAndroid Build Coastguard Worker std::max(non_const_buffer_sizes[j], buffer_size);
200*523fa7a6SAndroid Build Coastguard Worker }
201*523fa7a6SAndroid Build Coastguard Worker }
202*523fa7a6SAndroid Build Coastguard Worker }
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker // Allocate the arenas. Using vector because we need to remember the size as
205*523fa7a6SAndroid Build Coastguard Worker // well, so vector is easier then unique_ptr.
206*523fa7a6SAndroid Build Coastguard Worker std::vector<std::vector<uint8_t>> non_const_buffers_;
207*523fa7a6SAndroid Build Coastguard Worker for (std::map<size_t, int64_t>::iterator i = non_const_buffer_sizes.begin();
208*523fa7a6SAndroid Build Coastguard Worker i != non_const_buffer_sizes.end();
209*523fa7a6SAndroid Build Coastguard Worker i++) {
210*523fa7a6SAndroid Build Coastguard Worker non_const_buffers_.push_back(std::vector<uint8_t>(i->second));
211*523fa7a6SAndroid Build Coastguard Worker }
212*523fa7a6SAndroid Build Coastguard Worker
213*523fa7a6SAndroid Build Coastguard Worker memory_ = std::make_unique<Memory>(std::move(non_const_buffers_));
214*523fa7a6SAndroid Build Coastguard Worker if (event_tracer_ && debug_buffer_size > 0) {
215*523fa7a6SAndroid Build Coastguard Worker // If a debug buffer was requested for the ETDump, allocate it and make
216*523fa7a6SAndroid Build Coastguard Worker // sure its lifetime is as long as the event_tracer.
217*523fa7a6SAndroid Build Coastguard Worker debug_buffer_ = std::make_unique<uint8_t[]>(debug_buffer_size);
218*523fa7a6SAndroid Build Coastguard Worker event_tracer_->set_debug_buffer(get_etdump_debug_buffer());
219*523fa7a6SAndroid Build Coastguard Worker event_tracer_->set_event_tracer_debug_level(
220*523fa7a6SAndroid Build Coastguard Worker EventTracerDebugLogLevel::kIntermediateOutputs);
221*523fa7a6SAndroid Build Coastguard Worker }
222*523fa7a6SAndroid Build Coastguard Worker
223*523fa7a6SAndroid Build Coastguard Worker // Load methods
224*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < program_->num_methods(); ++i) {
225*523fa7a6SAndroid Build Coastguard Worker auto name = program_->get_method_name(i).get();
226*523fa7a6SAndroid Build Coastguard Worker // It's safe to use the same memory manager for all modules because
227*523fa7a6SAndroid Build Coastguard Worker // we can guarantee that only one will be executing at a time.
228*523fa7a6SAndroid Build Coastguard Worker // Everything in this module runs on a single thread.
229*523fa7a6SAndroid Build Coastguard Worker Result<Method> method = program_->load_method(
230*523fa7a6SAndroid Build Coastguard Worker name, memory_->mem_manager(), event_tracer_.get());
231*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
232*523fa7a6SAndroid Build Coastguard Worker method.error(),
233*523fa7a6SAndroid Build Coastguard Worker "loading method %s failed with error 0x%" PRIx32,
234*523fa7a6SAndroid Build Coastguard Worker name,
235*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(method.error()));
236*523fa7a6SAndroid Build Coastguard Worker methods_.insert(
237*523fa7a6SAndroid Build Coastguard Worker {std::string(name),
238*523fa7a6SAndroid Build Coastguard Worker std::make_unique<Method>(std::move(method.get()))});
239*523fa7a6SAndroid Build Coastguard Worker }
240*523fa7a6SAndroid Build Coastguard Worker }
241*523fa7a6SAndroid Build Coastguard Worker
242*523fa7a6SAndroid Build Coastguard Worker Module(const Module&) = delete;
243*523fa7a6SAndroid Build Coastguard Worker Module& operator=(const Module&) = delete;
244*523fa7a6SAndroid Build Coastguard Worker Module(Module&&) = default;
245*523fa7a6SAndroid Build Coastguard Worker Module& operator=(Module&&) = default;
246*523fa7a6SAndroid Build Coastguard Worker
247*523fa7a6SAndroid Build Coastguard Worker /// Executes the specified method on the provided inputs and returns its
248*523fa7a6SAndroid Build Coastguard Worker /// outputs.
run_method(const std::string & method_name,const std::vector<EValue> & args,const std::optional<std::vector<Span<uint8_t>>> & output_storages=std::nullopt)249*523fa7a6SAndroid Build Coastguard Worker std::vector<EValue> run_method(
250*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name,
251*523fa7a6SAndroid Build Coastguard Worker const std::vector<EValue>& args,
252*523fa7a6SAndroid Build Coastguard Worker const std::optional<std::vector<Span<uint8_t>>>& output_storages =
253*523fa7a6SAndroid Build Coastguard Worker std::nullopt) {
254*523fa7a6SAndroid Build Coastguard Worker auto& method = get_method(method_name);
255*523fa7a6SAndroid Build Coastguard Worker exec_aten::ArrayRef<EValue> input_evalue_list(args.data(), args.size());
256*523fa7a6SAndroid Build Coastguard Worker
257*523fa7a6SAndroid Build Coastguard Worker Error set_inputs_status = method.set_inputs(input_evalue_list);
258*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
259*523fa7a6SAndroid Build Coastguard Worker set_inputs_status,
260*523fa7a6SAndroid Build Coastguard Worker "method->set_inputs() for method '%s' failed with error 0x%" PRIx32,
261*523fa7a6SAndroid Build Coastguard Worker method_name.c_str(),
262*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(set_inputs_status));
263*523fa7a6SAndroid Build Coastguard Worker
264*523fa7a6SAndroid Build Coastguard Worker #ifdef USE_ATEN_LIB
265*523fa7a6SAndroid Build Coastguard Worker // [TLS handling] This is to workaround an assertion failure
266*523fa7a6SAndroid Build Coastguard Worker // (https://fburl.com/code/302jyn8d) running `gelu` in ATen mode in fbcode
267*523fa7a6SAndroid Build Coastguard Worker // (such as bento). The problem is ExecuTorch ATen mode doesn't have
268*523fa7a6SAndroid Build Coastguard Worker // Thread Local State, but `torch-cpp` is assuming tls init is done. There
269*523fa7a6SAndroid Build Coastguard Worker // are two more checks: MKLDNN disabled and C10_MOBILE, if any of them is
270*523fa7a6SAndroid Build Coastguard Worker // true we won't be hitting this assertion error. However in `torch-cpp`
271*523fa7a6SAndroid Build Coastguard Worker // lib both checks are false. Production impact: this should not make any
272*523fa7a6SAndroid Build Coastguard Worker // impact in production environment, given that in xplat we are depending
273*523fa7a6SAndroid Build Coastguard Worker // on a library that enables C10_MOBILE (`torch_mobile_core`).
274*523fa7a6SAndroid Build Coastguard Worker c10::impl::ExcludeDispatchKeyGuard no_autograd(
275*523fa7a6SAndroid Build Coastguard Worker c10::autograd_dispatch_keyset);
276*523fa7a6SAndroid Build Coastguard Worker #endif
277*523fa7a6SAndroid Build Coastguard Worker if (output_storages) {
278*523fa7a6SAndroid Build Coastguard Worker setup_output_storage(method, *output_storages);
279*523fa7a6SAndroid Build Coastguard Worker }
280*523fa7a6SAndroid Build Coastguard Worker Error execute_status = method.execute();
281*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
282*523fa7a6SAndroid Build Coastguard Worker execute_status,
283*523fa7a6SAndroid Build Coastguard Worker "method->execute() failed with error 0x%" PRIx32,
284*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(execute_status));
285*523fa7a6SAndroid Build Coastguard Worker // process outputs
286*523fa7a6SAndroid Build Coastguard Worker return get_outputs(method_name);
287*523fa7a6SAndroid Build Coastguard Worker }
288*523fa7a6SAndroid Build Coastguard Worker
get_outputs(const std::string & method_name)289*523fa7a6SAndroid Build Coastguard Worker std::vector<EValue> get_outputs(const std::string& method_name) {
290*523fa7a6SAndroid Build Coastguard Worker auto& method = methods_[method_name];
291*523fa7a6SAndroid Build Coastguard Worker std::vector<EValue> result(method->outputs_size());
292*523fa7a6SAndroid Build Coastguard Worker
293*523fa7a6SAndroid Build Coastguard Worker Error get_outputs_status =
294*523fa7a6SAndroid Build Coastguard Worker method->get_outputs(result.data(), method->outputs_size());
295*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
296*523fa7a6SAndroid Build Coastguard Worker get_outputs_status,
297*523fa7a6SAndroid Build Coastguard Worker "method->get_outputs() for method '%s' failed with error 0x%" PRIx32,
298*523fa7a6SAndroid Build Coastguard Worker method_name.c_str(),
299*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(get_outputs_status));
300*523fa7a6SAndroid Build Coastguard Worker
301*523fa7a6SAndroid Build Coastguard Worker return result;
302*523fa7a6SAndroid Build Coastguard Worker }
303*523fa7a6SAndroid Build Coastguard Worker
get_method(const std::string & method_name)304*523fa7a6SAndroid Build Coastguard Worker Method& get_method(const std::string& method_name) {
305*523fa7a6SAndroid Build Coastguard Worker if (methods_.count(method_name) == 0) {
306*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
307*523fa7a6SAndroid Build Coastguard Worker Error::InvalidArgument,
308*523fa7a6SAndroid Build Coastguard Worker "no such method in program: %s",
309*523fa7a6SAndroid Build Coastguard Worker method_name.c_str());
310*523fa7a6SAndroid Build Coastguard Worker }
311*523fa7a6SAndroid Build Coastguard Worker return *methods_[method_name].get();
312*523fa7a6SAndroid Build Coastguard Worker }
313*523fa7a6SAndroid Build Coastguard Worker
314*523fa7a6SAndroid Build Coastguard Worker /// Returns the names of all methods in the program.
method_names() const315*523fa7a6SAndroid Build Coastguard Worker std::vector<std::string> method_names() const {
316*523fa7a6SAndroid Build Coastguard Worker std::vector<std::string> names;
317*523fa7a6SAndroid Build Coastguard Worker for (const auto& method : methods_) {
318*523fa7a6SAndroid Build Coastguard Worker names.push_back(method.first);
319*523fa7a6SAndroid Build Coastguard Worker }
320*523fa7a6SAndroid Build Coastguard Worker return names;
321*523fa7a6SAndroid Build Coastguard Worker }
322*523fa7a6SAndroid Build Coastguard Worker
has_etdump()323*523fa7a6SAndroid Build Coastguard Worker bool has_etdump() {
324*523fa7a6SAndroid Build Coastguard Worker return static_cast<bool>(event_tracer_);
325*523fa7a6SAndroid Build Coastguard Worker }
326*523fa7a6SAndroid Build Coastguard Worker
etdump()327*523fa7a6SAndroid Build Coastguard Worker ETDumpGen& etdump() {
328*523fa7a6SAndroid Build Coastguard Worker return *event_tracer_;
329*523fa7a6SAndroid Build Coastguard Worker }
330*523fa7a6SAndroid Build Coastguard Worker
has_etdump_debug_buffer() const331*523fa7a6SAndroid Build Coastguard Worker bool has_etdump_debug_buffer() const {
332*523fa7a6SAndroid Build Coastguard Worker return static_cast<bool>(debug_buffer_);
333*523fa7a6SAndroid Build Coastguard Worker }
334*523fa7a6SAndroid Build Coastguard Worker
get_etdump_debug_buffer()335*523fa7a6SAndroid Build Coastguard Worker Span<uint8_t> get_etdump_debug_buffer() {
336*523fa7a6SAndroid Build Coastguard Worker return Span<uint8_t>(debug_buffer_.get(), debug_buffer_size_);
337*523fa7a6SAndroid Build Coastguard Worker }
338*523fa7a6SAndroid Build Coastguard Worker
339*523fa7a6SAndroid Build Coastguard Worker private:
340*523fa7a6SAndroid Build Coastguard Worker /// A wrapper/util class for executorch memory allocations/manager.
341*523fa7a6SAndroid Build Coastguard Worker class Memory {
342*523fa7a6SAndroid Build Coastguard Worker public:
Memory(std::vector<std::vector<uint8_t>> && non_const_buffers)343*523fa7a6SAndroid Build Coastguard Worker explicit Memory(std::vector<std::vector<uint8_t>>&& non_const_buffers)
344*523fa7a6SAndroid Build Coastguard Worker : runtime_allocator_(),
345*523fa7a6SAndroid Build Coastguard Worker non_const_buffers_(std::move(non_const_buffers)),
346*523fa7a6SAndroid Build Coastguard Worker non_const_spans_(create_non_const_spans()),
347*523fa7a6SAndroid Build Coastguard Worker non_const_allocator_(
348*523fa7a6SAndroid Build Coastguard Worker {non_const_spans_.data(), non_const_spans_.size()}),
349*523fa7a6SAndroid Build Coastguard Worker mem_manager_(
350*523fa7a6SAndroid Build Coastguard Worker &const_allocator_,
351*523fa7a6SAndroid Build Coastguard Worker &non_const_allocator_,
352*523fa7a6SAndroid Build Coastguard Worker &runtime_allocator_,
353*523fa7a6SAndroid Build Coastguard Worker &temp_allocator_) {}
354*523fa7a6SAndroid Build Coastguard Worker
355*523fa7a6SAndroid Build Coastguard Worker /// Returns a pointer to the internal memory manager, the Memory instance
356*523fa7a6SAndroid Build Coastguard Worker /// must outlive this pointer.
mem_manager()357*523fa7a6SAndroid Build Coastguard Worker MemoryManager* mem_manager() {
358*523fa7a6SAndroid Build Coastguard Worker return &mem_manager_;
359*523fa7a6SAndroid Build Coastguard Worker }
360*523fa7a6SAndroid Build Coastguard Worker
361*523fa7a6SAndroid Build Coastguard Worker Memory(const Memory&) = delete;
362*523fa7a6SAndroid Build Coastguard Worker Memory& operator=(const Memory&) = delete;
363*523fa7a6SAndroid Build Coastguard Worker
364*523fa7a6SAndroid Build Coastguard Worker private:
365*523fa7a6SAndroid Build Coastguard Worker MemoryAllocator const_allocator_{MemoryAllocator(0, nullptr)};
366*523fa7a6SAndroid Build Coastguard Worker
367*523fa7a6SAndroid Build Coastguard Worker MallocMemoryAllocator runtime_allocator_;
368*523fa7a6SAndroid Build Coastguard Worker
369*523fa7a6SAndroid Build Coastguard Worker MemoryAllocator temp_allocator_{MemoryAllocator(0, nullptr)};
370*523fa7a6SAndroid Build Coastguard Worker
371*523fa7a6SAndroid Build Coastguard Worker std::vector<std::vector<uint8_t>> non_const_buffers_;
372*523fa7a6SAndroid Build Coastguard Worker
373*523fa7a6SAndroid Build Coastguard Worker std::vector<Span<uint8_t>> non_const_spans_;
374*523fa7a6SAndroid Build Coastguard Worker
375*523fa7a6SAndroid Build Coastguard Worker HierarchicalAllocator non_const_allocator_;
376*523fa7a6SAndroid Build Coastguard Worker
377*523fa7a6SAndroid Build Coastguard Worker MemoryManager mem_manager_;
378*523fa7a6SAndroid Build Coastguard Worker
create_non_const_spans()379*523fa7a6SAndroid Build Coastguard Worker std::vector<Span<uint8_t>> create_non_const_spans() {
380*523fa7a6SAndroid Build Coastguard Worker std::vector<Span<uint8_t>> result;
381*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < non_const_buffers_.size(); i++) {
382*523fa7a6SAndroid Build Coastguard Worker result.push_back(
383*523fa7a6SAndroid Build Coastguard Worker {non_const_buffers_[i].data(), non_const_buffers_[i].size()});
384*523fa7a6SAndroid Build Coastguard Worker }
385*523fa7a6SAndroid Build Coastguard Worker return result;
386*523fa7a6SAndroid Build Coastguard Worker }
387*523fa7a6SAndroid Build Coastguard Worker };
388*523fa7a6SAndroid Build Coastguard Worker
389*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<Memory> memory_;
390*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<DataLoader> loader_; // program_ points to this.
391*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<const Program> program_; // methods_ entries points to this.
392*523fa7a6SAndroid Build Coastguard Worker std::unordered_map<std::string, std::unique_ptr<Method>> methods_;
393*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<ETDumpGen> event_tracer_;
394*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<uint8_t[]> debug_buffer_;
395*523fa7a6SAndroid Build Coastguard Worker size_t debug_buffer_size_;
396*523fa7a6SAndroid Build Coastguard Worker };
397*523fa7a6SAndroid Build Coastguard Worker
load_module_from_buffer(const void * ptr,size_t ptr_len,bool enable_etdump,size_t debug_buffer_size,Program::Verification program_verification)398*523fa7a6SAndroid Build Coastguard Worker inline std::unique_ptr<Module> load_module_from_buffer(
399*523fa7a6SAndroid Build Coastguard Worker const void* ptr,
400*523fa7a6SAndroid Build Coastguard Worker size_t ptr_len,
401*523fa7a6SAndroid Build Coastguard Worker bool enable_etdump,
402*523fa7a6SAndroid Build Coastguard Worker size_t debug_buffer_size,
403*523fa7a6SAndroid Build Coastguard Worker Program::Verification program_verification) {
404*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
405*523fa7a6SAndroid Build Coastguard Worker auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);
406*523fa7a6SAndroid Build Coastguard Worker return std::make_unique<Module>(
407*523fa7a6SAndroid Build Coastguard Worker std::move(loader),
408*523fa7a6SAndroid Build Coastguard Worker enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
409*523fa7a6SAndroid Build Coastguard Worker debug_buffer_size,
410*523fa7a6SAndroid Build Coastguard Worker program_verification);
411*523fa7a6SAndroid Build Coastguard Worker }
412*523fa7a6SAndroid Build Coastguard Worker
load_module_from_file(const std::string & path,bool enable_etdump,size_t debug_buffer_size,Program::Verification program_verification)413*523fa7a6SAndroid Build Coastguard Worker inline std::unique_ptr<Module> load_module_from_file(
414*523fa7a6SAndroid Build Coastguard Worker const std::string& path,
415*523fa7a6SAndroid Build Coastguard Worker bool enable_etdump,
416*523fa7a6SAndroid Build Coastguard Worker size_t debug_buffer_size,
417*523fa7a6SAndroid Build Coastguard Worker Program::Verification program_verification) {
418*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_SCOPE_PROF("load_module_from_file");
419*523fa7a6SAndroid Build Coastguard Worker
420*523fa7a6SAndroid Build Coastguard Worker Result<MmapDataLoader> res = MmapDataLoader::from(
421*523fa7a6SAndroid Build Coastguard Worker path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
422*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
423*523fa7a6SAndroid Build Coastguard Worker res.error(),
424*523fa7a6SAndroid Build Coastguard Worker "Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
425*523fa7a6SAndroid Build Coastguard Worker path.c_str(),
426*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(res.error()));
427*523fa7a6SAndroid Build Coastguard Worker
428*523fa7a6SAndroid Build Coastguard Worker auto loader = std::make_unique<MmapDataLoader>(std::move(res.get()));
429*523fa7a6SAndroid Build Coastguard Worker return std::make_unique<Module>(
430*523fa7a6SAndroid Build Coastguard Worker std::move(loader),
431*523fa7a6SAndroid Build Coastguard Worker enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
432*523fa7a6SAndroid Build Coastguard Worker debug_buffer_size,
433*523fa7a6SAndroid Build Coastguard Worker program_verification);
434*523fa7a6SAndroid Build Coastguard Worker }
435*523fa7a6SAndroid Build Coastguard Worker
436*523fa7a6SAndroid Build Coastguard Worker static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
437*523fa7a6SAndroid Build Coastguard Worker
438*523fa7a6SAndroid Build Coastguard Worker struct PyBundledModule final {
PyBundledModuleexecutorch::extension::pybindings::__anon7828bc6a0111::PyBundledModule439*523fa7a6SAndroid Build Coastguard Worker explicit PyBundledModule(
440*523fa7a6SAndroid Build Coastguard Worker const py::bytes& buffer,
441*523fa7a6SAndroid Build Coastguard Worker uint32_t bundled_input_pool_size)
442*523fa7a6SAndroid Build Coastguard Worker : bundled_program_ptr_(buffer),
443*523fa7a6SAndroid Build Coastguard Worker program_ptr_(static_cast<const void*>(
444*523fa7a6SAndroid Build Coastguard Worker bundled_program_flatbuffer::GetBundledProgram(
445*523fa7a6SAndroid Build Coastguard Worker get_bundled_program_ptr())
446*523fa7a6SAndroid Build Coastguard Worker ->program()
447*523fa7a6SAndroid Build Coastguard Worker ->data())),
448*523fa7a6SAndroid Build Coastguard Worker program_len_(bundled_program_flatbuffer::GetBundledProgram(
449*523fa7a6SAndroid Build Coastguard Worker get_bundled_program_ptr())
450*523fa7a6SAndroid Build Coastguard Worker ->program()
451*523fa7a6SAndroid Build Coastguard Worker ->size()) {}
452*523fa7a6SAndroid Build Coastguard Worker
load_from_bufferexecutorch::extension::pybindings::__anon7828bc6a0111::PyBundledModule453*523fa7a6SAndroid Build Coastguard Worker static std::unique_ptr<PyBundledModule> load_from_buffer(
454*523fa7a6SAndroid Build Coastguard Worker const py::bytes& buffer,
455*523fa7a6SAndroid Build Coastguard Worker uint32_t bundled_input_pool_size) {
456*523fa7a6SAndroid Build Coastguard Worker return std::make_unique<PyBundledModule>(buffer, bundled_input_pool_size);
457*523fa7a6SAndroid Build Coastguard Worker }
458*523fa7a6SAndroid Build Coastguard Worker
get_bundled_program_ptrexecutorch::extension::pybindings::__anon7828bc6a0111::PyBundledModule459*523fa7a6SAndroid Build Coastguard Worker const void* get_bundled_program_ptr() {
460*523fa7a6SAndroid Build Coastguard Worker return bundled_program_ptr_.cast<std::string_view>().data();
461*523fa7a6SAndroid Build Coastguard Worker }
462*523fa7a6SAndroid Build Coastguard Worker
get_program_ptrexecutorch::extension::pybindings::__anon7828bc6a0111::PyBundledModule463*523fa7a6SAndroid Build Coastguard Worker const void* get_program_ptr() {
464*523fa7a6SAndroid Build Coastguard Worker return program_ptr_;
465*523fa7a6SAndroid Build Coastguard Worker }
466*523fa7a6SAndroid Build Coastguard Worker
get_program_lenexecutorch::extension::pybindings::__anon7828bc6a0111::PyBundledModule467*523fa7a6SAndroid Build Coastguard Worker size_t get_program_len() {
468*523fa7a6SAndroid Build Coastguard Worker return program_len_;
469*523fa7a6SAndroid Build Coastguard Worker }
470*523fa7a6SAndroid Build Coastguard Worker
471*523fa7a6SAndroid Build Coastguard Worker private:
472*523fa7a6SAndroid Build Coastguard Worker // Store the bytes object instead of a raw pointer so that this module will
473*523fa7a6SAndroid Build Coastguard Worker // keep the bytes alive.
474*523fa7a6SAndroid Build Coastguard Worker const py::bytes bundled_program_ptr_;
475*523fa7a6SAndroid Build Coastguard Worker const void* program_ptr_;
476*523fa7a6SAndroid Build Coastguard Worker size_t program_len_;
477*523fa7a6SAndroid Build Coastguard Worker };
478*523fa7a6SAndroid Build Coastguard Worker
479*523fa7a6SAndroid Build Coastguard Worker /// Expose a subset of TensorInfo information to python.
480*523fa7a6SAndroid Build Coastguard Worker struct PyTensorInfo final {
PyTensorInfoexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo481*523fa7a6SAndroid Build Coastguard Worker explicit PyTensorInfo(
482*523fa7a6SAndroid Build Coastguard Worker std::shared_ptr<Module> module,
483*523fa7a6SAndroid Build Coastguard Worker torch::executor::TensorInfo info)
484*523fa7a6SAndroid Build Coastguard Worker : module_(std::move(module)), info_(info) {}
485*523fa7a6SAndroid Build Coastguard Worker
sizesexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo486*523fa7a6SAndroid Build Coastguard Worker py::tuple sizes() const {
487*523fa7a6SAndroid Build Coastguard Worker const auto shape = info_.sizes();
488*523fa7a6SAndroid Build Coastguard Worker py::tuple tup(shape.size());
489*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < shape.size(); ++i) {
490*523fa7a6SAndroid Build Coastguard Worker tup[i] = py::cast(shape[i]);
491*523fa7a6SAndroid Build Coastguard Worker }
492*523fa7a6SAndroid Build Coastguard Worker return tup;
493*523fa7a6SAndroid Build Coastguard Worker }
494*523fa7a6SAndroid Build Coastguard Worker
dtypeexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo495*523fa7a6SAndroid Build Coastguard Worker int8_t dtype() const {
496*523fa7a6SAndroid Build Coastguard Worker return static_cast<std::underlying_type<exec_aten::ScalarType>::type>(
497*523fa7a6SAndroid Build Coastguard Worker info_.scalar_type());
498*523fa7a6SAndroid Build Coastguard Worker }
499*523fa7a6SAndroid Build Coastguard Worker
is_memory_plannedexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo500*523fa7a6SAndroid Build Coastguard Worker bool is_memory_planned() const {
501*523fa7a6SAndroid Build Coastguard Worker return info_.is_memory_planned();
502*523fa7a6SAndroid Build Coastguard Worker }
503*523fa7a6SAndroid Build Coastguard Worker
nbytesexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo504*523fa7a6SAndroid Build Coastguard Worker size_t nbytes() const {
505*523fa7a6SAndroid Build Coastguard Worker return info_.nbytes();
506*523fa7a6SAndroid Build Coastguard Worker }
507*523fa7a6SAndroid Build Coastguard Worker
reprexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo508*523fa7a6SAndroid Build Coastguard Worker std::string repr() const {
509*523fa7a6SAndroid Build Coastguard Worker std::string size_str = "[";
510*523fa7a6SAndroid Build Coastguard Worker for (const auto& d : info_.sizes()) {
511*523fa7a6SAndroid Build Coastguard Worker size_str.append(std::to_string(d));
512*523fa7a6SAndroid Build Coastguard Worker size_str.append(", ");
513*523fa7a6SAndroid Build Coastguard Worker }
514*523fa7a6SAndroid Build Coastguard Worker if (size_str.length() >= 2) {
515*523fa7a6SAndroid Build Coastguard Worker // Pop the last two characters (command and space) and add close bracket.
516*523fa7a6SAndroid Build Coastguard Worker size_str.pop_back();
517*523fa7a6SAndroid Build Coastguard Worker size_str.pop_back();
518*523fa7a6SAndroid Build Coastguard Worker }
519*523fa7a6SAndroid Build Coastguard Worker size_str.append("]");
520*523fa7a6SAndroid Build Coastguard Worker return "TensorInfo(sizes=" + size_str + ", dtype=" +
521*523fa7a6SAndroid Build Coastguard Worker std::string(executorch::runtime::toString(info_.scalar_type())) +
522*523fa7a6SAndroid Build Coastguard Worker ", is_memory_planned=" +
523*523fa7a6SAndroid Build Coastguard Worker (info_.is_memory_planned() ? "True" : "False") +
524*523fa7a6SAndroid Build Coastguard Worker ", nbytes=" + std::to_string(info_.nbytes()) + ")";
525*523fa7a6SAndroid Build Coastguard Worker }
526*523fa7a6SAndroid Build Coastguard Worker
527*523fa7a6SAndroid Build Coastguard Worker private:
528*523fa7a6SAndroid Build Coastguard Worker // TensorInfo relies on module to be alive.
529*523fa7a6SAndroid Build Coastguard Worker std::shared_ptr<Module> module_;
530*523fa7a6SAndroid Build Coastguard Worker torch::executor::TensorInfo info_;
531*523fa7a6SAndroid Build Coastguard Worker };
532*523fa7a6SAndroid Build Coastguard Worker
533*523fa7a6SAndroid Build Coastguard Worker /// Expose a subset of MethodMeta information to python.
534*523fa7a6SAndroid Build Coastguard Worker struct PyMethodMeta final {
PyMethodMetaexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta535*523fa7a6SAndroid Build Coastguard Worker explicit PyMethodMeta(
536*523fa7a6SAndroid Build Coastguard Worker std::shared_ptr<Module> module,
537*523fa7a6SAndroid Build Coastguard Worker torch::executor::MethodMeta meta)
538*523fa7a6SAndroid Build Coastguard Worker : module_(std::move(module)), meta_(meta) {}
539*523fa7a6SAndroid Build Coastguard Worker
nameexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta540*523fa7a6SAndroid Build Coastguard Worker const char* name() const {
541*523fa7a6SAndroid Build Coastguard Worker return meta_.name();
542*523fa7a6SAndroid Build Coastguard Worker }
543*523fa7a6SAndroid Build Coastguard Worker
num_inputsexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta544*523fa7a6SAndroid Build Coastguard Worker size_t num_inputs() const {
545*523fa7a6SAndroid Build Coastguard Worker return meta_.num_inputs();
546*523fa7a6SAndroid Build Coastguard Worker }
547*523fa7a6SAndroid Build Coastguard Worker
input_tensor_metaexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta548*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<PyTensorInfo> input_tensor_meta(size_t index) const {
549*523fa7a6SAndroid Build Coastguard Worker const auto result = meta_.input_tensor_meta(index);
550*523fa7a6SAndroid Build Coastguard Worker THROW_INDEX_IF_ERROR(
551*523fa7a6SAndroid Build Coastguard Worker result.error(), "Cannot get input tensor meta at %zu", index);
552*523fa7a6SAndroid Build Coastguard Worker return std::make_unique<PyTensorInfo>(module_, result.get());
553*523fa7a6SAndroid Build Coastguard Worker }
554*523fa7a6SAndroid Build Coastguard Worker
num_outputsexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta555*523fa7a6SAndroid Build Coastguard Worker size_t num_outputs() const {
556*523fa7a6SAndroid Build Coastguard Worker return meta_.num_outputs();
557*523fa7a6SAndroid Build Coastguard Worker }
558*523fa7a6SAndroid Build Coastguard Worker
output_tensor_metaexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta559*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<PyTensorInfo> output_tensor_meta(size_t index) const {
560*523fa7a6SAndroid Build Coastguard Worker const auto result = meta_.output_tensor_meta(index);
561*523fa7a6SAndroid Build Coastguard Worker THROW_INDEX_IF_ERROR(
562*523fa7a6SAndroid Build Coastguard Worker result.error(), "Cannot get output tensor meta at %zu", index);
563*523fa7a6SAndroid Build Coastguard Worker return std::make_unique<PyTensorInfo>(module_, result.get());
564*523fa7a6SAndroid Build Coastguard Worker }
565*523fa7a6SAndroid Build Coastguard Worker
reprexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta566*523fa7a6SAndroid Build Coastguard Worker py::str repr() const {
567*523fa7a6SAndroid Build Coastguard Worker py::list input_meta_strs;
568*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < meta_.num_inputs(); ++i) {
569*523fa7a6SAndroid Build Coastguard Worker input_meta_strs.append(py::str(input_tensor_meta(i)->repr()));
570*523fa7a6SAndroid Build Coastguard Worker }
571*523fa7a6SAndroid Build Coastguard Worker py::list output_meta_strs;
572*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < meta_.num_outputs(); ++i) {
573*523fa7a6SAndroid Build Coastguard Worker output_meta_strs.append(py::str(output_tensor_meta(i)->repr()));
574*523fa7a6SAndroid Build Coastguard Worker }
575*523fa7a6SAndroid Build Coastguard Worker // Add quotes to be more similar to Python's repr for strings.
576*523fa7a6SAndroid Build Coastguard Worker py::str format =
577*523fa7a6SAndroid Build Coastguard Worker "MethodMeta(name='{}', num_inputs={}, input_tensor_meta={}, num_outputs={}, output_tensor_meta={})";
578*523fa7a6SAndroid Build Coastguard Worker return format.format(
579*523fa7a6SAndroid Build Coastguard Worker std::string(meta_.name()),
580*523fa7a6SAndroid Build Coastguard Worker std::to_string(meta_.num_inputs()),
581*523fa7a6SAndroid Build Coastguard Worker input_meta_strs,
582*523fa7a6SAndroid Build Coastguard Worker std::to_string(meta_.num_outputs()),
583*523fa7a6SAndroid Build Coastguard Worker output_meta_strs);
584*523fa7a6SAndroid Build Coastguard Worker }
585*523fa7a6SAndroid Build Coastguard Worker
586*523fa7a6SAndroid Build Coastguard Worker private:
587*523fa7a6SAndroid Build Coastguard Worker // Must keep the Module object alive or else the meta object is invalidated.
588*523fa7a6SAndroid Build Coastguard Worker std::shared_ptr<Module> module_;
589*523fa7a6SAndroid Build Coastguard Worker torch::executor::MethodMeta meta_;
590*523fa7a6SAndroid Build Coastguard Worker };
591*523fa7a6SAndroid Build Coastguard Worker
592*523fa7a6SAndroid Build Coastguard Worker struct PyModule final {
PyModuleexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule593*523fa7a6SAndroid Build Coastguard Worker explicit PyModule(
594*523fa7a6SAndroid Build Coastguard Worker const py::bytes& buffer,
595*523fa7a6SAndroid Build Coastguard Worker bool enable_etdump,
596*523fa7a6SAndroid Build Coastguard Worker size_t debug_buffer_size = 0,
597*523fa7a6SAndroid Build Coastguard Worker Program::Verification program_verification =
598*523fa7a6SAndroid Build Coastguard Worker Program::Verification::InternalConsistency)
599*523fa7a6SAndroid Build Coastguard Worker : module_(load_module_from_buffer(
600*523fa7a6SAndroid Build Coastguard Worker buffer.cast<std::string_view>().data(),
601*523fa7a6SAndroid Build Coastguard Worker py::len(buffer),
602*523fa7a6SAndroid Build Coastguard Worker enable_etdump,
603*523fa7a6SAndroid Build Coastguard Worker debug_buffer_size,
604*523fa7a6SAndroid Build Coastguard Worker program_verification)) {}
605*523fa7a6SAndroid Build Coastguard Worker
PyModuleexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule606*523fa7a6SAndroid Build Coastguard Worker explicit PyModule(
607*523fa7a6SAndroid Build Coastguard Worker const void* ptr,
608*523fa7a6SAndroid Build Coastguard Worker size_t ptr_len,
609*523fa7a6SAndroid Build Coastguard Worker bool enable_etdump,
610*523fa7a6SAndroid Build Coastguard Worker size_t debug_buffer_size = 0,
611*523fa7a6SAndroid Build Coastguard Worker Program::Verification program_verification =
612*523fa7a6SAndroid Build Coastguard Worker Program::Verification::InternalConsistency)
613*523fa7a6SAndroid Build Coastguard Worker : module_(load_module_from_buffer(
614*523fa7a6SAndroid Build Coastguard Worker ptr,
615*523fa7a6SAndroid Build Coastguard Worker ptr_len,
616*523fa7a6SAndroid Build Coastguard Worker enable_etdump,
617*523fa7a6SAndroid Build Coastguard Worker debug_buffer_size,
618*523fa7a6SAndroid Build Coastguard Worker program_verification)) {}
619*523fa7a6SAndroid Build Coastguard Worker
PyModuleexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule620*523fa7a6SAndroid Build Coastguard Worker explicit PyModule(
621*523fa7a6SAndroid Build Coastguard Worker const std::string& path,
622*523fa7a6SAndroid Build Coastguard Worker bool enable_etdump,
623*523fa7a6SAndroid Build Coastguard Worker size_t debug_buffer_size = 0,
624*523fa7a6SAndroid Build Coastguard Worker Program::Verification program_verification =
625*523fa7a6SAndroid Build Coastguard Worker Program::Verification::InternalConsistency)
626*523fa7a6SAndroid Build Coastguard Worker : module_(load_module_from_file(
627*523fa7a6SAndroid Build Coastguard Worker path,
628*523fa7a6SAndroid Build Coastguard Worker enable_etdump,
629*523fa7a6SAndroid Build Coastguard Worker debug_buffer_size,
630*523fa7a6SAndroid Build Coastguard Worker program_verification)) {}
631*523fa7a6SAndroid Build Coastguard Worker
632*523fa7a6SAndroid Build Coastguard Worker PyModule(const PyModule&) = delete;
633*523fa7a6SAndroid Build Coastguard Worker PyModule& operator=(const PyModule&) = delete;
634*523fa7a6SAndroid Build Coastguard Worker PyModule(PyModule&&) = default;
635*523fa7a6SAndroid Build Coastguard Worker PyModule& operator=(PyModule&&) = default;
636*523fa7a6SAndroid Build Coastguard Worker
637*523fa7a6SAndroid Build Coastguard Worker // Module is only valid as long as the python buffer is alive.
load_from_bufferexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule638*523fa7a6SAndroid Build Coastguard Worker static std::unique_ptr<PyModule> load_from_buffer(
639*523fa7a6SAndroid Build Coastguard Worker const py::bytes& buffer,
640*523fa7a6SAndroid Build Coastguard Worker bool enable_etdump,
641*523fa7a6SAndroid Build Coastguard Worker size_t debug_buffer_size = 0,
642*523fa7a6SAndroid Build Coastguard Worker Program::Verification program_verification =
643*523fa7a6SAndroid Build Coastguard Worker Program::Verification::InternalConsistency) {
644*523fa7a6SAndroid Build Coastguard Worker return std::make_unique<PyModule>(
645*523fa7a6SAndroid Build Coastguard Worker buffer, enable_etdump, debug_buffer_size, program_verification);
646*523fa7a6SAndroid Build Coastguard Worker }
load_from_fileexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule647*523fa7a6SAndroid Build Coastguard Worker static std::unique_ptr<PyModule> load_from_file(
648*523fa7a6SAndroid Build Coastguard Worker const std::string& path,
649*523fa7a6SAndroid Build Coastguard Worker bool enable_etdump,
650*523fa7a6SAndroid Build Coastguard Worker size_t debug_buffer_size = 0,
651*523fa7a6SAndroid Build Coastguard Worker Program::Verification program_verification =
652*523fa7a6SAndroid Build Coastguard Worker Program::Verification::InternalConsistency) {
653*523fa7a6SAndroid Build Coastguard Worker return std::make_unique<PyModule>(
654*523fa7a6SAndroid Build Coastguard Worker path, enable_etdump, debug_buffer_size, program_verification);
655*523fa7a6SAndroid Build Coastguard Worker }
656*523fa7a6SAndroid Build Coastguard Worker
load_from_bundled_programexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule657*523fa7a6SAndroid Build Coastguard Worker static std::unique_ptr<PyModule> load_from_bundled_program(
658*523fa7a6SAndroid Build Coastguard Worker PyBundledModule& m,
659*523fa7a6SAndroid Build Coastguard Worker bool enable_etdump,
660*523fa7a6SAndroid Build Coastguard Worker size_t debug_buffer_size = 0) {
661*523fa7a6SAndroid Build Coastguard Worker return std::make_unique<PyModule>(
662*523fa7a6SAndroid Build Coastguard Worker m.get_program_ptr(),
663*523fa7a6SAndroid Build Coastguard Worker m.get_program_len(),
664*523fa7a6SAndroid Build Coastguard Worker enable_etdump,
665*523fa7a6SAndroid Build Coastguard Worker debug_buffer_size);
666*523fa7a6SAndroid Build Coastguard Worker }
667*523fa7a6SAndroid Build Coastguard Worker
run_methodexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule668*523fa7a6SAndroid Build Coastguard Worker py::list run_method(
669*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name,
670*523fa7a6SAndroid Build Coastguard Worker const py::sequence& inputs,
671*523fa7a6SAndroid Build Coastguard Worker bool clone_outputs = true) {
672*523fa7a6SAndroid Build Coastguard Worker const auto inputs_size = py::len(inputs);
673*523fa7a6SAndroid Build Coastguard Worker std::vector<EValue> cpp_inputs;
674*523fa7a6SAndroid Build Coastguard Worker cpp_inputs.reserve(inputs_size);
675*523fa7a6SAndroid Build Coastguard Worker
676*523fa7a6SAndroid Build Coastguard Worker #ifndef USE_ATEN_LIB // Portable mode
677*523fa7a6SAndroid Build Coastguard Worker // So the ETensors and their metadata stay in scope for
678*523fa7a6SAndroid Build Coastguard Worker // Module->run_method.
679*523fa7a6SAndroid Build Coastguard Worker std::vector<torch::executor::TensorImpl> input_tensors;
680*523fa7a6SAndroid Build Coastguard Worker std::vector<std::vector<torch::executor::Tensor::SizesType>> input_sizes;
681*523fa7a6SAndroid Build Coastguard Worker std::vector<std::vector<torch::executor::Tensor::StridesType>>
682*523fa7a6SAndroid Build Coastguard Worker input_strides;
683*523fa7a6SAndroid Build Coastguard Worker std::vector<std::vector<torch::executor::Tensor::DimOrderType>>
684*523fa7a6SAndroid Build Coastguard Worker input_dim_order;
685*523fa7a6SAndroid Build Coastguard Worker // We store pointers to these vector elements so important to reserve so
686*523fa7a6SAndroid Build Coastguard Worker // that we don't lose those on a vector resize. Don't need to do this for
687*523fa7a6SAndroid Build Coastguard Worker // the others since they are vectors of vectors, and we don't store a
688*523fa7a6SAndroid Build Coastguard Worker // pointer to the root level vector data.
689*523fa7a6SAndroid Build Coastguard Worker input_tensors.reserve(inputs_size);
690*523fa7a6SAndroid Build Coastguard Worker #endif
691*523fa7a6SAndroid Build Coastguard Worker
692*523fa7a6SAndroid Build Coastguard Worker // Convert python objects into EValues.
693*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < inputs_size; ++i) {
694*523fa7a6SAndroid Build Coastguard Worker auto python_input = inputs[i];
695*523fa7a6SAndroid Build Coastguard Worker const std::string& type_str = py::str(python_input.get_type());
696*523fa7a6SAndroid Build Coastguard Worker if (type_str == "<class 'torch.Tensor'>") {
697*523fa7a6SAndroid Build Coastguard Worker auto at_tensor = python_input.cast<at::Tensor>();
698*523fa7a6SAndroid Build Coastguard Worker // alias_etensor_to_attensor will assert on this later, so to better
699*523fa7a6SAndroid Build Coastguard Worker // propogate up to python we check early and throw an exception.
700*523fa7a6SAndroid Build Coastguard Worker if (!at_tensor.is_contiguous()) {
701*523fa7a6SAndroid Build Coastguard Worker auto error_msg = "Input " + std::to_string(i) + "for method " +
702*523fa7a6SAndroid Build Coastguard Worker method_name + " is not contiguous.";
703*523fa7a6SAndroid Build Coastguard Worker throw std::runtime_error(error_msg);
704*523fa7a6SAndroid Build Coastguard Worker }
705*523fa7a6SAndroid Build Coastguard Worker
706*523fa7a6SAndroid Build Coastguard Worker #ifdef USE_ATEN_LIB
707*523fa7a6SAndroid Build Coastguard Worker EValue evalue(at_tensor);
708*523fa7a6SAndroid Build Coastguard Worker #else
709*523fa7a6SAndroid Build Coastguard Worker // convert at::Tensor to torch::executor::Tensor
710*523fa7a6SAndroid Build Coastguard Worker auto type =
711*523fa7a6SAndroid Build Coastguard Worker torch_to_executorch_scalar_type(at_tensor.options().dtype());
712*523fa7a6SAndroid Build Coastguard Worker size_t dim = at_tensor.dim();
713*523fa7a6SAndroid Build Coastguard Worker // cant directly alias at::Tensor sizes and strides due to int64 vs
714*523fa7a6SAndroid Build Coastguard Worker // int32 typing conflict
715*523fa7a6SAndroid Build Coastguard Worker input_sizes.emplace_back(
716*523fa7a6SAndroid Build Coastguard Worker at_tensor.sizes().begin(), at_tensor.sizes().end());
717*523fa7a6SAndroid Build Coastguard Worker input_strides.emplace_back(
718*523fa7a6SAndroid Build Coastguard Worker at_tensor.strides().begin(), at_tensor.strides().end());
719*523fa7a6SAndroid Build Coastguard Worker
720*523fa7a6SAndroid Build Coastguard Worker // Only works for MemoryFormat::Contiguous inputs
721*523fa7a6SAndroid Build Coastguard Worker std::vector<torch::executor::Tensor::DimOrderType> dim_order;
722*523fa7a6SAndroid Build Coastguard Worker for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) {
723*523fa7a6SAndroid Build Coastguard Worker dim_order.push_back(cur_dim);
724*523fa7a6SAndroid Build Coastguard Worker }
725*523fa7a6SAndroid Build Coastguard Worker input_dim_order.push_back(std::move(dim_order));
726*523fa7a6SAndroid Build Coastguard Worker input_tensors.emplace_back(
727*523fa7a6SAndroid Build Coastguard Worker type,
728*523fa7a6SAndroid Build Coastguard Worker dim,
729*523fa7a6SAndroid Build Coastguard Worker input_sizes.back().data(),
730*523fa7a6SAndroid Build Coastguard Worker nullptr,
731*523fa7a6SAndroid Build Coastguard Worker input_dim_order.back().data(),
732*523fa7a6SAndroid Build Coastguard Worker input_strides.back().data());
733*523fa7a6SAndroid Build Coastguard Worker
734*523fa7a6SAndroid Build Coastguard Worker torch::executor::Tensor temp =
735*523fa7a6SAndroid Build Coastguard Worker torch::executor::Tensor(&input_tensors.back());
736*523fa7a6SAndroid Build Coastguard Worker alias_etensor_to_attensor(at_tensor, temp);
737*523fa7a6SAndroid Build Coastguard Worker EValue evalue(temp);
738*523fa7a6SAndroid Build Coastguard Worker #endif
739*523fa7a6SAndroid Build Coastguard Worker
740*523fa7a6SAndroid Build Coastguard Worker cpp_inputs.push_back(evalue);
741*523fa7a6SAndroid Build Coastguard Worker } else if (py::isinstance<py::none>(python_input)) {
742*523fa7a6SAndroid Build Coastguard Worker cpp_inputs.push_back(EValue());
743*523fa7a6SAndroid Build Coastguard Worker } else if (py::isinstance<py::bool_>(python_input)) {
744*523fa7a6SAndroid Build Coastguard Worker cpp_inputs.push_back(EValue(py::cast<bool>(python_input)));
745*523fa7a6SAndroid Build Coastguard Worker } else if (py::isinstance<py::int_>(python_input)) {
746*523fa7a6SAndroid Build Coastguard Worker cpp_inputs.push_back(EValue(py::cast<int64_t>(python_input)));
747*523fa7a6SAndroid Build Coastguard Worker } else {
748*523fa7a6SAndroid Build Coastguard Worker ET_ASSERT_UNREACHABLE_MSG("Unsupported pytype: %s", type_str.c_str());
749*523fa7a6SAndroid Build Coastguard Worker }
750*523fa7a6SAndroid Build Coastguard Worker }
751*523fa7a6SAndroid Build Coastguard Worker
752*523fa7a6SAndroid Build Coastguard Worker const auto& method = module_->get_method(method_name);
753*523fa7a6SAndroid Build Coastguard Worker const auto num_outputs = method.outputs_size();
754*523fa7a6SAndroid Build Coastguard Worker output_storages_ = make_output_storages(method);
755*523fa7a6SAndroid Build Coastguard Worker std::vector<Span<uint8_t>> output_storage_spans(num_outputs);
756*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < output_storages_.size(); ++i) {
757*523fa7a6SAndroid Build Coastguard Worker output_storage_spans[i] =
758*523fa7a6SAndroid Build Coastguard Worker Span<uint8_t>(output_storages_[i].data(), output_storages_[i].size());
759*523fa7a6SAndroid Build Coastguard Worker }
760*523fa7a6SAndroid Build Coastguard Worker auto outputs =
761*523fa7a6SAndroid Build Coastguard Worker module_->run_method(method_name, cpp_inputs, output_storage_spans);
762*523fa7a6SAndroid Build Coastguard Worker
763*523fa7a6SAndroid Build Coastguard Worker // Retrieve outputs
764*523fa7a6SAndroid Build Coastguard Worker return get_outputs_as_py_list(outputs, clone_outputs);
765*523fa7a6SAndroid Build Coastguard Worker }
766*523fa7a6SAndroid Build Coastguard Worker
forwardexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule767*523fa7a6SAndroid Build Coastguard Worker py::list forward(const py::sequence& inputs, bool clone_outputs = true) {
768*523fa7a6SAndroid Build Coastguard Worker return run_method("forward", inputs, clone_outputs);
769*523fa7a6SAndroid Build Coastguard Worker }
770*523fa7a6SAndroid Build Coastguard Worker
forward_single_inputexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule771*523fa7a6SAndroid Build Coastguard Worker py::list forward_single_input(
772*523fa7a6SAndroid Build Coastguard Worker const torch::Tensor& inputTensor,
773*523fa7a6SAndroid Build Coastguard Worker bool clone_outputs = true) {
774*523fa7a6SAndroid Build Coastguard Worker py::list py_list;
775*523fa7a6SAndroid Build Coastguard Worker py_list.append(py::cast(inputTensor));
776*523fa7a6SAndroid Build Coastguard Worker return run_method("forward", py_list, clone_outputs);
777*523fa7a6SAndroid Build Coastguard Worker }
778*523fa7a6SAndroid Build Coastguard Worker
has_etdumpexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule779*523fa7a6SAndroid Build Coastguard Worker bool has_etdump() {
780*523fa7a6SAndroid Build Coastguard Worker return module_->has_etdump();
781*523fa7a6SAndroid Build Coastguard Worker }
782*523fa7a6SAndroid Build Coastguard Worker
write_etdump_result_to_fileexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule783*523fa7a6SAndroid Build Coastguard Worker void write_etdump_result_to_file(
784*523fa7a6SAndroid Build Coastguard Worker const std::string& path,
785*523fa7a6SAndroid Build Coastguard Worker const py::object& debug_buffer_path) {
786*523fa7a6SAndroid Build Coastguard Worker if (!has_etdump()) {
787*523fa7a6SAndroid Build Coastguard Worker throw std::runtime_error("No etdump found");
788*523fa7a6SAndroid Build Coastguard Worker }
789*523fa7a6SAndroid Build Coastguard Worker auto& etdump = module_->etdump();
790*523fa7a6SAndroid Build Coastguard Worker etdump_result result = etdump.get_etdump_data();
791*523fa7a6SAndroid Build Coastguard Worker if (result.buf != nullptr && result.size > 0) {
792*523fa7a6SAndroid Build Coastguard Worker write_data_to_file(path, result.buf, result.size);
793*523fa7a6SAndroid Build Coastguard Worker free(result.buf);
794*523fa7a6SAndroid Build Coastguard Worker if (module_->has_etdump_debug_buffer() &&
795*523fa7a6SAndroid Build Coastguard Worker py::isinstance<py::str>(debug_buffer_path)) {
796*523fa7a6SAndroid Build Coastguard Worker // Also write out the debug buffer to a separate file if requested.
797*523fa7a6SAndroid Build Coastguard Worker std::string debug_buffer_path_str =
798*523fa7a6SAndroid Build Coastguard Worker py::cast<py::str>(debug_buffer_path);
799*523fa7a6SAndroid Build Coastguard Worker const auto debug_buffer = module_->get_etdump_debug_buffer();
800*523fa7a6SAndroid Build Coastguard Worker write_data_to_file(
801*523fa7a6SAndroid Build Coastguard Worker debug_buffer_path_str, debug_buffer.data(), debug_buffer.size());
802*523fa7a6SAndroid Build Coastguard Worker }
803*523fa7a6SAndroid Build Coastguard Worker } else {
804*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
805*523fa7a6SAndroid Build Coastguard Worker Info,
806*523fa7a6SAndroid Build Coastguard Worker "No etdump data found, try rebuilding with "
807*523fa7a6SAndroid Build Coastguard Worker "the CMake option EXECUTORCH_ENABLE_EVENT_TRACER or with "
808*523fa7a6SAndroid Build Coastguard Worker "buck run --config executorch.event_tracer_enabled=true");
809*523fa7a6SAndroid Build Coastguard Worker }
810*523fa7a6SAndroid Build Coastguard Worker }
811*523fa7a6SAndroid Build Coastguard Worker
load_bundled_inputexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule812*523fa7a6SAndroid Build Coastguard Worker void load_bundled_input(
813*523fa7a6SAndroid Build Coastguard Worker PyBundledModule& m,
814*523fa7a6SAndroid Build Coastguard Worker const std::string method_name,
815*523fa7a6SAndroid Build Coastguard Worker size_t testset_idx) {
816*523fa7a6SAndroid Build Coastguard Worker const void* bundled_program_ptr = m.get_bundled_program_ptr();
817*523fa7a6SAndroid Build Coastguard Worker Error status = executorch::bundled_program::load_bundled_input(
818*523fa7a6SAndroid Build Coastguard Worker module_->get_method(method_name), bundled_program_ptr, testset_idx);
819*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
820*523fa7a6SAndroid Build Coastguard Worker status,
821*523fa7a6SAndroid Build Coastguard Worker "load_bundled_input failed with status 0x%" PRIx32,
822*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(status));
823*523fa7a6SAndroid Build Coastguard Worker }
824*523fa7a6SAndroid Build Coastguard Worker
verify_result_with_bundled_expected_outputexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule825*523fa7a6SAndroid Build Coastguard Worker py::list verify_result_with_bundled_expected_output(
826*523fa7a6SAndroid Build Coastguard Worker PyBundledModule& m,
827*523fa7a6SAndroid Build Coastguard Worker const std::string method_name,
828*523fa7a6SAndroid Build Coastguard Worker size_t testset_idx,
829*523fa7a6SAndroid Build Coastguard Worker double rtol = 1e-5,
830*523fa7a6SAndroid Build Coastguard Worker double atol = 1e-8) {
831*523fa7a6SAndroid Build Coastguard Worker const void* bundled_program_ptr = m.get_bundled_program_ptr();
832*523fa7a6SAndroid Build Coastguard Worker auto& method = module_->get_method(method_name);
833*523fa7a6SAndroid Build Coastguard Worker Error status = executorch::bundled_program::load_bundled_input(
834*523fa7a6SAndroid Build Coastguard Worker method, bundled_program_ptr, testset_idx);
835*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
836*523fa7a6SAndroid Build Coastguard Worker status,
837*523fa7a6SAndroid Build Coastguard Worker "load_bundled_input failed with status 0x%" PRIx32,
838*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(status));
839*523fa7a6SAndroid Build Coastguard Worker py::list outputs = plan_execute(method_name);
840*523fa7a6SAndroid Build Coastguard Worker status = executorch::bundled_program::verify_method_outputs(
841*523fa7a6SAndroid Build Coastguard Worker method, bundled_program_ptr, testset_idx, rtol, atol);
842*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
843*523fa7a6SAndroid Build Coastguard Worker status,
844*523fa7a6SAndroid Build Coastguard Worker "Result verification failed with status %" PRIu32,
845*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(status));
846*523fa7a6SAndroid Build Coastguard Worker return outputs;
847*523fa7a6SAndroid Build Coastguard Worker }
848*523fa7a6SAndroid Build Coastguard Worker
plan_executeexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule849*523fa7a6SAndroid Build Coastguard Worker py::list plan_execute(
850*523fa7a6SAndroid Build Coastguard Worker const std::string method_name,
851*523fa7a6SAndroid Build Coastguard Worker bool clone_outputs = true) {
852*523fa7a6SAndroid Build Coastguard Worker auto& method = module_->get_method(method_name);
853*523fa7a6SAndroid Build Coastguard Worker // Need to pre-allocate space for outputs just like in run_method.
854*523fa7a6SAndroid Build Coastguard Worker const auto num_outputs = method.outputs_size();
855*523fa7a6SAndroid Build Coastguard Worker output_storages_ = make_output_storages(method);
856*523fa7a6SAndroid Build Coastguard Worker std::vector<Span<uint8_t>> output_storage_spans(num_outputs);
857*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < output_storages_.size(); ++i) {
858*523fa7a6SAndroid Build Coastguard Worker output_storage_spans[i] =
859*523fa7a6SAndroid Build Coastguard Worker Span<uint8_t>(output_storages_[i].data(), output_storages_[i].size());
860*523fa7a6SAndroid Build Coastguard Worker }
861*523fa7a6SAndroid Build Coastguard Worker setup_output_storage(method, output_storage_spans);
862*523fa7a6SAndroid Build Coastguard Worker auto status = method.execute();
863*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
864*523fa7a6SAndroid Build Coastguard Worker status,
865*523fa7a6SAndroid Build Coastguard Worker "executing execution plan for method 'forward' failed with error: 0x%" PRIx32,
866*523fa7a6SAndroid Build Coastguard Worker static_cast<uint32_t>(status));
867*523fa7a6SAndroid Build Coastguard Worker const auto outputs = module_->get_outputs(method_name);
868*523fa7a6SAndroid Build Coastguard Worker return get_outputs_as_py_list(outputs, clone_outputs);
869*523fa7a6SAndroid Build Coastguard Worker }
870*523fa7a6SAndroid Build Coastguard Worker
get_outputs_as_py_listexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule871*523fa7a6SAndroid Build Coastguard Worker py::list get_outputs_as_py_list(
872*523fa7a6SAndroid Build Coastguard Worker const std::vector<EValue>& outputs,
873*523fa7a6SAndroid Build Coastguard Worker bool clone_outputs = true) {
874*523fa7a6SAndroid Build Coastguard Worker const auto outputs_size = outputs.size();
875*523fa7a6SAndroid Build Coastguard Worker py::list list(outputs_size);
876*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < outputs_size; ++i) {
877*523fa7a6SAndroid Build Coastguard Worker auto& v = outputs[i];
878*523fa7a6SAndroid Build Coastguard Worker if (Tag::None == v.tag) {
879*523fa7a6SAndroid Build Coastguard Worker list[i] = py::none();
880*523fa7a6SAndroid Build Coastguard Worker } else if (Tag::Int == v.tag) {
881*523fa7a6SAndroid Build Coastguard Worker list[i] = py::cast(v.toInt());
882*523fa7a6SAndroid Build Coastguard Worker } else if (Tag::Double == v.tag) {
883*523fa7a6SAndroid Build Coastguard Worker list[i] = py::cast(v.toDouble());
884*523fa7a6SAndroid Build Coastguard Worker } else if (Tag::Bool == v.tag) {
885*523fa7a6SAndroid Build Coastguard Worker list[i] = py::cast(v.toBool());
886*523fa7a6SAndroid Build Coastguard Worker } else if (Tag::String == v.tag) {
887*523fa7a6SAndroid Build Coastguard Worker list[i] = py::cast(std::string(v.toString().data()));
888*523fa7a6SAndroid Build Coastguard Worker } else if (Tag::Tensor == v.tag) {
889*523fa7a6SAndroid Build Coastguard Worker #ifdef USE_ATEN_LIB
890*523fa7a6SAndroid Build Coastguard Worker // Clone so the outputs in python do not share a lifetime with the
891*523fa7a6SAndroid Build Coastguard Worker // module object
892*523fa7a6SAndroid Build Coastguard Worker if (clone_outputs) {
893*523fa7a6SAndroid Build Coastguard Worker list[i] = py::cast(v.toTensor().clone());
894*523fa7a6SAndroid Build Coastguard Worker } else {
895*523fa7a6SAndroid Build Coastguard Worker list[i] = py::cast(v.toTensor());
896*523fa7a6SAndroid Build Coastguard Worker }
897*523fa7a6SAndroid Build Coastguard Worker #else
898*523fa7a6SAndroid Build Coastguard Worker if (clone_outputs) {
899*523fa7a6SAndroid Build Coastguard Worker list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()).clone());
900*523fa7a6SAndroid Build Coastguard Worker } else {
901*523fa7a6SAndroid Build Coastguard Worker list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()));
902*523fa7a6SAndroid Build Coastguard Worker }
903*523fa7a6SAndroid Build Coastguard Worker #endif
904*523fa7a6SAndroid Build Coastguard Worker } else {
905*523fa7a6SAndroid Build Coastguard Worker ET_ASSERT_UNREACHABLE_MSG("Invalid model output type");
906*523fa7a6SAndroid Build Coastguard Worker }
907*523fa7a6SAndroid Build Coastguard Worker }
908*523fa7a6SAndroid Build Coastguard Worker return list;
909*523fa7a6SAndroid Build Coastguard Worker }
910*523fa7a6SAndroid Build Coastguard Worker
method_metaexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule911*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<PyMethodMeta> method_meta(const std::string method_name) {
912*523fa7a6SAndroid Build Coastguard Worker auto& method = module_->get_method(method_name);
913*523fa7a6SAndroid Build Coastguard Worker return std::make_unique<PyMethodMeta>(module_, method.method_meta());
914*523fa7a6SAndroid Build Coastguard Worker }
915*523fa7a6SAndroid Build Coastguard Worker
method_namesexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule916*523fa7a6SAndroid Build Coastguard Worker std::vector<std::string> method_names() {
917*523fa7a6SAndroid Build Coastguard Worker return module_->method_names();
918*523fa7a6SAndroid Build Coastguard Worker }
919*523fa7a6SAndroid Build Coastguard Worker
920*523fa7a6SAndroid Build Coastguard Worker private:
921*523fa7a6SAndroid Build Coastguard Worker std::shared_ptr<Module> module_;
922*523fa7a6SAndroid Build Coastguard Worker // Need to keep-alive output storages until they can be compared in case of
923*523fa7a6SAndroid Build Coastguard Worker // bundled programs.
924*523fa7a6SAndroid Build Coastguard Worker std::vector<std::vector<uint8_t>> output_storages_;
925*523fa7a6SAndroid Build Coastguard Worker
make_output_storagesexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule926*523fa7a6SAndroid Build Coastguard Worker std::vector<std::vector<uint8_t>> make_output_storages(const Method& method) {
927*523fa7a6SAndroid Build Coastguard Worker const auto num_outputs = method.outputs_size();
928*523fa7a6SAndroid Build Coastguard Worker // Create a buffer for each output tensor. Memory planned outputs and non
929*523fa7a6SAndroid Build Coastguard Worker // tensor outputs get an empty buffer in this list which is ignored later.
930*523fa7a6SAndroid Build Coastguard Worker std::vector<std::vector<uint8_t>> output_storages;
931*523fa7a6SAndroid Build Coastguard Worker output_storages_.reserve(num_outputs);
932*523fa7a6SAndroid Build Coastguard Worker auto meta = method.method_meta();
933*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < num_outputs; ++i) {
934*523fa7a6SAndroid Build Coastguard Worker auto output_type = meta.output_tag(i);
935*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
936*523fa7a6SAndroid Build Coastguard Worker output_type.error(), "Failed to get output type for output %zu", i);
937*523fa7a6SAndroid Build Coastguard Worker if (output_type.get() != Tag::Tensor) {
938*523fa7a6SAndroid Build Coastguard Worker // Skip allocating storage for non-tensor outputs.
939*523fa7a6SAndroid Build Coastguard Worker output_storages.emplace_back();
940*523fa7a6SAndroid Build Coastguard Worker continue;
941*523fa7a6SAndroid Build Coastguard Worker }
942*523fa7a6SAndroid Build Coastguard Worker const auto& output_tensor_meta =
943*523fa7a6SAndroid Build Coastguard Worker method.method_meta().output_tensor_meta(i);
944*523fa7a6SAndroid Build Coastguard Worker THROW_IF_ERROR(
945*523fa7a6SAndroid Build Coastguard Worker output_tensor_meta.error(),
946*523fa7a6SAndroid Build Coastguard Worker "Failed to get output tensor meta for output %zu",
947*523fa7a6SAndroid Build Coastguard Worker i);
948*523fa7a6SAndroid Build Coastguard Worker if (output_tensor_meta.get().is_memory_planned()) {
949*523fa7a6SAndroid Build Coastguard Worker // Skip allocating storage for planned memory outputs.
950*523fa7a6SAndroid Build Coastguard Worker output_storages.emplace_back();
951*523fa7a6SAndroid Build Coastguard Worker continue;
952*523fa7a6SAndroid Build Coastguard Worker }
953*523fa7a6SAndroid Build Coastguard Worker // Allocate storage for the output tensor.
954*523fa7a6SAndroid Build Coastguard Worker const size_t output_size = output_tensor_meta.get().nbytes();
955*523fa7a6SAndroid Build Coastguard Worker output_storages.emplace_back(output_size);
956*523fa7a6SAndroid Build Coastguard Worker }
957*523fa7a6SAndroid Build Coastguard Worker return output_storages;
958*523fa7a6SAndroid Build Coastguard Worker }
959*523fa7a6SAndroid Build Coastguard Worker };
960*523fa7a6SAndroid Build Coastguard Worker
create_profile_block(const std::string & name)961*523fa7a6SAndroid Build Coastguard Worker void create_profile_block(const std::string& name) {
962*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_PROFILE_CREATE_BLOCK(name.c_str());
963*523fa7a6SAndroid Build Coastguard Worker }
964*523fa7a6SAndroid Build Coastguard Worker
get_operator_names()965*523fa7a6SAndroid Build Coastguard Worker py::list get_operator_names() {
966*523fa7a6SAndroid Build Coastguard Worker Span<const Kernel> kernels = get_registered_kernels();
967*523fa7a6SAndroid Build Coastguard Worker py::list res;
968*523fa7a6SAndroid Build Coastguard Worker for (const Kernel& k : kernels) {
969*523fa7a6SAndroid Build Coastguard Worker if (k.name_ != nullptr) {
970*523fa7a6SAndroid Build Coastguard Worker res.append(py::cast(k.name_));
971*523fa7a6SAndroid Build Coastguard Worker }
972*523fa7a6SAndroid Build Coastguard Worker }
973*523fa7a6SAndroid Build Coastguard Worker return res;
974*523fa7a6SAndroid Build Coastguard Worker }
975*523fa7a6SAndroid Build Coastguard Worker
976*523fa7a6SAndroid Build Coastguard Worker } // namespace
977*523fa7a6SAndroid Build Coastguard Worker
PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME,m)978*523fa7a6SAndroid Build Coastguard Worker PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
979*523fa7a6SAndroid Build Coastguard Worker // Redirects cout and cerr for function calls this guards to the python env.
980*523fa7a6SAndroid Build Coastguard Worker auto call_guard = py::
981*523fa7a6SAndroid Build Coastguard Worker call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>();
982*523fa7a6SAndroid Build Coastguard Worker
983*523fa7a6SAndroid Build Coastguard Worker // Bind the verification enum to python.
984*523fa7a6SAndroid Build Coastguard Worker py::enum_<Program::Verification>(m, "Verification")
985*523fa7a6SAndroid Build Coastguard Worker .value("Minimal", Program::Verification::Minimal)
986*523fa7a6SAndroid Build Coastguard Worker .value("InternalConsistency", Program::Verification::InternalConsistency);
987*523fa7a6SAndroid Build Coastguard Worker
988*523fa7a6SAndroid Build Coastguard Worker m.def(
989*523fa7a6SAndroid Build Coastguard Worker "_load_for_executorch",
990*523fa7a6SAndroid Build Coastguard Worker PyModule::load_from_file,
991*523fa7a6SAndroid Build Coastguard Worker py::arg("path"),
992*523fa7a6SAndroid Build Coastguard Worker py::arg("enable_etdump") = false,
993*523fa7a6SAndroid Build Coastguard Worker py::arg("debug_buffer_size") = 0,
994*523fa7a6SAndroid Build Coastguard Worker py::arg("program_verification") =
995*523fa7a6SAndroid Build Coastguard Worker Program::Verification::InternalConsistency,
996*523fa7a6SAndroid Build Coastguard Worker call_guard);
997*523fa7a6SAndroid Build Coastguard Worker m.def(
998*523fa7a6SAndroid Build Coastguard Worker "_load_for_executorch_from_buffer",
999*523fa7a6SAndroid Build Coastguard Worker &PyModule::load_from_buffer,
1000*523fa7a6SAndroid Build Coastguard Worker py::arg("buffer"),
1001*523fa7a6SAndroid Build Coastguard Worker py::arg("enable_etdump") = false,
1002*523fa7a6SAndroid Build Coastguard Worker py::arg("debug_buffer_size") = 0,
1003*523fa7a6SAndroid Build Coastguard Worker py::arg("program_verification") =
1004*523fa7a6SAndroid Build Coastguard Worker Program::Verification::InternalConsistency,
1005*523fa7a6SAndroid Build Coastguard Worker call_guard);
1006*523fa7a6SAndroid Build Coastguard Worker m.def(
1007*523fa7a6SAndroid Build Coastguard Worker "_load_for_executorch_from_bundled_program",
1008*523fa7a6SAndroid Build Coastguard Worker &PyModule::load_from_bundled_program,
1009*523fa7a6SAndroid Build Coastguard Worker py::arg("ptr"),
1010*523fa7a6SAndroid Build Coastguard Worker py::arg("enable_etdump") = false,
1011*523fa7a6SAndroid Build Coastguard Worker py::arg("debug_buffer_size") = 0,
1012*523fa7a6SAndroid Build Coastguard Worker call_guard);
1013*523fa7a6SAndroid Build Coastguard Worker m.def(
1014*523fa7a6SAndroid Build Coastguard Worker "_load_bundled_program_from_buffer",
1015*523fa7a6SAndroid Build Coastguard Worker &PyBundledModule::load_from_buffer,
1016*523fa7a6SAndroid Build Coastguard Worker py::arg("buffer"),
1017*523fa7a6SAndroid Build Coastguard Worker py::arg("non_const_pool_size") = kDEFAULT_BUNDLED_INPUT_POOL_SIZE,
1018*523fa7a6SAndroid Build Coastguard Worker call_guard);
1019*523fa7a6SAndroid Build Coastguard Worker m.def(
1020*523fa7a6SAndroid Build Coastguard Worker "_dump_profile_results",
1021*523fa7a6SAndroid Build Coastguard Worker []() {
1022*523fa7a6SAndroid Build Coastguard Worker prof_result_t prof_result;
1023*523fa7a6SAndroid Build Coastguard Worker EXECUTORCH_DUMP_PROFILE_RESULTS(&prof_result);
1024*523fa7a6SAndroid Build Coastguard Worker return py::bytes(
1025*523fa7a6SAndroid Build Coastguard Worker reinterpret_cast<const char*>(prof_result.prof_data),
1026*523fa7a6SAndroid Build Coastguard Worker prof_result.num_bytes);
1027*523fa7a6SAndroid Build Coastguard Worker },
1028*523fa7a6SAndroid Build Coastguard Worker call_guard);
1029*523fa7a6SAndroid Build Coastguard Worker m.def("_get_operator_names", &get_operator_names);
1030*523fa7a6SAndroid Build Coastguard Worker m.def("_create_profile_block", &create_profile_block, call_guard);
1031*523fa7a6SAndroid Build Coastguard Worker m.def(
1032*523fa7a6SAndroid Build Coastguard Worker "_reset_profile_results",
1033*523fa7a6SAndroid Build Coastguard Worker []() { EXECUTORCH_RESET_PROFILE_RESULTS(); },
1034*523fa7a6SAndroid Build Coastguard Worker call_guard);
1035*523fa7a6SAndroid Build Coastguard Worker
1036*523fa7a6SAndroid Build Coastguard Worker py::class_<PyModule>(m, "ExecuTorchModule")
1037*523fa7a6SAndroid Build Coastguard Worker .def("load_bundled_input", &PyModule::load_bundled_input, call_guard)
1038*523fa7a6SAndroid Build Coastguard Worker .def(
1039*523fa7a6SAndroid Build Coastguard Worker "verify_result_with_bundled_expected_output",
1040*523fa7a6SAndroid Build Coastguard Worker &PyModule::verify_result_with_bundled_expected_output,
1041*523fa7a6SAndroid Build Coastguard Worker py::arg("bundle"),
1042*523fa7a6SAndroid Build Coastguard Worker py::arg("method_name"),
1043*523fa7a6SAndroid Build Coastguard Worker py::arg("testset_idx"),
1044*523fa7a6SAndroid Build Coastguard Worker py::arg("rtol") = 1e-5,
1045*523fa7a6SAndroid Build Coastguard Worker py::arg("atol") = 1e-8,
1046*523fa7a6SAndroid Build Coastguard Worker call_guard)
1047*523fa7a6SAndroid Build Coastguard Worker .def(
1048*523fa7a6SAndroid Build Coastguard Worker "plan_execute",
1049*523fa7a6SAndroid Build Coastguard Worker &PyModule::plan_execute,
1050*523fa7a6SAndroid Build Coastguard Worker py::arg("method_name"),
1051*523fa7a6SAndroid Build Coastguard Worker py::arg("clone_outputs") = true,
1052*523fa7a6SAndroid Build Coastguard Worker call_guard)
1053*523fa7a6SAndroid Build Coastguard Worker .def(
1054*523fa7a6SAndroid Build Coastguard Worker "method_meta",
1055*523fa7a6SAndroid Build Coastguard Worker &PyModule::method_meta,
1056*523fa7a6SAndroid Build Coastguard Worker py::arg("method_name"),
1057*523fa7a6SAndroid Build Coastguard Worker call_guard)
1058*523fa7a6SAndroid Build Coastguard Worker .def("method_names", &PyModule::method_names, call_guard)
1059*523fa7a6SAndroid Build Coastguard Worker .def(
1060*523fa7a6SAndroid Build Coastguard Worker "run_method",
1061*523fa7a6SAndroid Build Coastguard Worker &PyModule::run_method,
1062*523fa7a6SAndroid Build Coastguard Worker py::arg("method_name"),
1063*523fa7a6SAndroid Build Coastguard Worker py::arg("inputs") = py::list(),
1064*523fa7a6SAndroid Build Coastguard Worker py::arg("clone_outputs") = true,
1065*523fa7a6SAndroid Build Coastguard Worker call_guard)
1066*523fa7a6SAndroid Build Coastguard Worker .def(
1067*523fa7a6SAndroid Build Coastguard Worker "forward",
1068*523fa7a6SAndroid Build Coastguard Worker &PyModule::forward,
1069*523fa7a6SAndroid Build Coastguard Worker py::arg("inputs") = py::list(),
1070*523fa7a6SAndroid Build Coastguard Worker py::arg("clone_outputs") = true,
1071*523fa7a6SAndroid Build Coastguard Worker call_guard)
1072*523fa7a6SAndroid Build Coastguard Worker .def("has_etdump", &PyModule::has_etdump, call_guard)
1073*523fa7a6SAndroid Build Coastguard Worker .def(
1074*523fa7a6SAndroid Build Coastguard Worker "write_etdump_result_to_file",
1075*523fa7a6SAndroid Build Coastguard Worker &PyModule::write_etdump_result_to_file,
1076*523fa7a6SAndroid Build Coastguard Worker py::arg("path"),
1077*523fa7a6SAndroid Build Coastguard Worker py::arg("debug_buffer_path") = py::none(),
1078*523fa7a6SAndroid Build Coastguard Worker call_guard)
1079*523fa7a6SAndroid Build Coastguard Worker .def(
1080*523fa7a6SAndroid Build Coastguard Worker "__call__",
1081*523fa7a6SAndroid Build Coastguard Worker &PyModule::forward,
1082*523fa7a6SAndroid Build Coastguard Worker py::arg("inputs") = py::list(),
1083*523fa7a6SAndroid Build Coastguard Worker py::arg("clone_outputs") = true,
1084*523fa7a6SAndroid Build Coastguard Worker call_guard)
1085*523fa7a6SAndroid Build Coastguard Worker .def(
1086*523fa7a6SAndroid Build Coastguard Worker "__call__",
1087*523fa7a6SAndroid Build Coastguard Worker &PyModule::forward_single_input,
1088*523fa7a6SAndroid Build Coastguard Worker py::arg("inputs") = py::list(),
1089*523fa7a6SAndroid Build Coastguard Worker py::arg("clone_outputs") = true,
1090*523fa7a6SAndroid Build Coastguard Worker call_guard);
1091*523fa7a6SAndroid Build Coastguard Worker
1092*523fa7a6SAndroid Build Coastguard Worker py::class_<PyBundledModule>(m, "BundledModule");
1093*523fa7a6SAndroid Build Coastguard Worker py::class_<PyTensorInfo>(m, "TensorInfo")
1094*523fa7a6SAndroid Build Coastguard Worker .def("sizes", &PyTensorInfo::sizes, call_guard)
1095*523fa7a6SAndroid Build Coastguard Worker .def("dtype", &PyTensorInfo::dtype, call_guard)
1096*523fa7a6SAndroid Build Coastguard Worker .def("is_memory_planned", &PyTensorInfo::is_memory_planned, call_guard)
1097*523fa7a6SAndroid Build Coastguard Worker .def("nbytes", &PyTensorInfo::nbytes, call_guard)
1098*523fa7a6SAndroid Build Coastguard Worker .def("__repr__", &PyTensorInfo::repr, call_guard);
1099*523fa7a6SAndroid Build Coastguard Worker py::class_<PyMethodMeta>(m, "MethodMeta")
1100*523fa7a6SAndroid Build Coastguard Worker .def("name", &PyMethodMeta::name, call_guard)
1101*523fa7a6SAndroid Build Coastguard Worker .def("num_inputs", &PyMethodMeta::num_inputs, call_guard)
1102*523fa7a6SAndroid Build Coastguard Worker .def("num_outputs", &PyMethodMeta::num_outputs, call_guard)
1103*523fa7a6SAndroid Build Coastguard Worker .def(
1104*523fa7a6SAndroid Build Coastguard Worker "input_tensor_meta",
1105*523fa7a6SAndroid Build Coastguard Worker &PyMethodMeta::input_tensor_meta,
1106*523fa7a6SAndroid Build Coastguard Worker py::arg("index"),
1107*523fa7a6SAndroid Build Coastguard Worker call_guard)
1108*523fa7a6SAndroid Build Coastguard Worker .def(
1109*523fa7a6SAndroid Build Coastguard Worker "output_tensor_meta",
1110*523fa7a6SAndroid Build Coastguard Worker &PyMethodMeta::output_tensor_meta,
1111*523fa7a6SAndroid Build Coastguard Worker py::arg("index"),
1112*523fa7a6SAndroid Build Coastguard Worker call_guard)
1113*523fa7a6SAndroid Build Coastguard Worker .def("__repr__", &PyMethodMeta::repr, call_guard);
1114*523fa7a6SAndroid Build Coastguard Worker }
1115*523fa7a6SAndroid Build Coastguard Worker
1116*523fa7a6SAndroid Build Coastguard Worker } // namespace pybindings
1117*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
1118*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
1119