xref: /aosp_15_r20/external/executorch/extension/pybindings/pybindings.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <algorithm>
10*523fa7a6SAndroid Build Coastguard Worker #include <cstdio>
11*523fa7a6SAndroid Build Coastguard Worker #include <iostream>
12*523fa7a6SAndroid Build Coastguard Worker #include <memory>
13*523fa7a6SAndroid Build Coastguard Worker #include <stdexcept>
14*523fa7a6SAndroid Build Coastguard Worker #include <unordered_map>
15*523fa7a6SAndroid Build Coastguard Worker 
16*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/iostream.h>
17*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/pybind11.h>
18*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/stl.h>
19*523fa7a6SAndroid Build Coastguard Worker 
20*523fa7a6SAndroid Build Coastguard Worker #include <executorch/devtools/bundled_program/bundled_program.h>
21*523fa7a6SAndroid Build Coastguard Worker #include <executorch/devtools/bundled_program/schema/bundled_program_schema_generated.h>
22*523fa7a6SAndroid Build Coastguard Worker #include <executorch/devtools/etdump/etdump_flatcc.h>
23*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/data_loader/buffer_data_loader.h>
24*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/data_loader/mmap_data_loader.h>
25*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
26*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/data_loader.h>
27*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
28*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/executor/method.h>
29*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/executor/program.h>
30*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/kernel/operator_registry.h>
31*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/assert.h>
32*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/platform.h>
33*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/profiler.h>
34*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/runtime.h>
35*523fa7a6SAndroid Build Coastguard Worker 
36*523fa7a6SAndroid Build Coastguard Worker #include <ATen/Functions.h>
37*523fa7a6SAndroid Build Coastguard Worker #include <ATen/Tensor.h>
38*523fa7a6SAndroid Build Coastguard Worker #include <ATen/core/functional.h>
39*523fa7a6SAndroid Build Coastguard Worker #include <c10/core/ScalarTypeToTypeMeta.h>
40*523fa7a6SAndroid Build Coastguard Worker #include <torch/csrc/utils/pybind.h>
41*523fa7a6SAndroid Build Coastguard Worker #include <torch/python.h>
42*523fa7a6SAndroid Build Coastguard Worker 
43*523fa7a6SAndroid Build Coastguard Worker #ifndef USE_ATEN_LIB
44*523fa7a6SAndroid Build Coastguard Worker #include <c10/core/impl/LocalDispatchKeySet.h>
45*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/aten_util/aten_bridge.h>
46*523fa7a6SAndroid Build Coastguard Worker #endif
47*523fa7a6SAndroid Build Coastguard Worker 
48*523fa7a6SAndroid Build Coastguard Worker /// Throws a runtime_error with the provided message if `error` is not `Ok`.
49*523fa7a6SAndroid Build Coastguard Worker #define THROW_IF_ERROR(error, message, ...)                       \
50*523fa7a6SAndroid Build Coastguard Worker   ({                                                              \
51*523fa7a6SAndroid Build Coastguard Worker     if ((error) != Error::Ok) {                                   \
52*523fa7a6SAndroid Build Coastguard Worker       char msg_buf[128];                                          \
53*523fa7a6SAndroid Build Coastguard Worker       snprintf(msg_buf, sizeof(msg_buf), message, ##__VA_ARGS__); \
54*523fa7a6SAndroid Build Coastguard Worker       /* pybind will convert this to a python exception. */       \
55*523fa7a6SAndroid Build Coastguard Worker       throw std::runtime_error(msg_buf);                          \
56*523fa7a6SAndroid Build Coastguard Worker     }                                                             \
57*523fa7a6SAndroid Build Coastguard Worker   })
58*523fa7a6SAndroid Build Coastguard Worker 
59*523fa7a6SAndroid Build Coastguard Worker #define THROW_INDEX_IF_ERROR(error, message, ...)                 \
60*523fa7a6SAndroid Build Coastguard Worker   ({                                                              \
61*523fa7a6SAndroid Build Coastguard Worker     if ((error) != Error::Ok) {                                   \
62*523fa7a6SAndroid Build Coastguard Worker       char msg_buf[128];                                          \
63*523fa7a6SAndroid Build Coastguard Worker       snprintf(msg_buf, sizeof(msg_buf), message, ##__VA_ARGS__); \
64*523fa7a6SAndroid Build Coastguard Worker       /* pybind will convert this to a python exception. */       \
65*523fa7a6SAndroid Build Coastguard Worker       throw std::out_of_range(msg_buf);                           \
66*523fa7a6SAndroid Build Coastguard Worker     }                                                             \
67*523fa7a6SAndroid Build Coastguard Worker   })
68*523fa7a6SAndroid Build Coastguard Worker 
69*523fa7a6SAndroid Build Coastguard Worker // Our logs work by writing to stderr. By default this is done through fprintf
70*523fa7a6SAndroid Build Coastguard Worker // (as defined in posix.cpp) which then does not show up in python environments.
71*523fa7a6SAndroid Build Coastguard Worker // Here we override the pal to use std::cerr which can be properly redirected by
72*523fa7a6SAndroid Build Coastguard Worker // scoped_estream_redirect.
et_pal_emit_log_message(et_timestamp_t timestamp,et_pal_log_level_t level,const char * filename,ET_UNUSED const char * function,size_t line,const char * message,ET_UNUSED size_t length)73*523fa7a6SAndroid Build Coastguard Worker void et_pal_emit_log_message(
74*523fa7a6SAndroid Build Coastguard Worker     et_timestamp_t timestamp,
75*523fa7a6SAndroid Build Coastguard Worker     et_pal_log_level_t level,
76*523fa7a6SAndroid Build Coastguard Worker     const char* filename,
77*523fa7a6SAndroid Build Coastguard Worker     ET_UNUSED const char* function,
78*523fa7a6SAndroid Build Coastguard Worker     size_t line,
79*523fa7a6SAndroid Build Coastguard Worker     const char* message,
80*523fa7a6SAndroid Build Coastguard Worker     ET_UNUSED size_t length) {
81*523fa7a6SAndroid Build Coastguard Worker   std::cerr << "[" << filename << ":" << line << "] " << message << std::endl;
82*523fa7a6SAndroid Build Coastguard Worker }
83*523fa7a6SAndroid Build Coastguard Worker 
84*523fa7a6SAndroid Build Coastguard Worker namespace py = pybind11;
85*523fa7a6SAndroid Build Coastguard Worker using executorch::bundled_program::verify_method_outputs;
86*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::BufferDataLoader;
87*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::MallocMemoryAllocator;
88*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::MmapDataLoader;
89*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::ArrayRef;
90*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::DataLoader;
91*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Error;
92*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::EValue;
93*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::EventTracerDebugLogLevel;
94*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::get_registered_kernels;
95*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::HierarchicalAllocator;
96*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Kernel;
97*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::MemoryAllocator;
98*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::MemoryManager;
99*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Method;
100*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::prof_result_t;
101*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Program;
102*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Result;
103*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Span;
104*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Tag;
105*523fa7a6SAndroid Build Coastguard Worker using torch::executor::etdump_result;
106*523fa7a6SAndroid Build Coastguard Worker using torch::executor::ETDumpGen;
107*523fa7a6SAndroid Build Coastguard Worker 
108*523fa7a6SAndroid Build Coastguard Worker #ifndef USE_ATEN_LIB
109*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::alias_attensor_to_etensor;
110*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::alias_etensor_to_attensor;
111*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::torch_to_executorch_scalar_type;
112*523fa7a6SAndroid Build Coastguard Worker #endif // !USE_ATEN_LIB
113*523fa7a6SAndroid Build Coastguard Worker 
114*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
115*523fa7a6SAndroid Build Coastguard Worker namespace extension {
116*523fa7a6SAndroid Build Coastguard Worker namespace pybindings {
117*523fa7a6SAndroid Build Coastguard Worker 
118*523fa7a6SAndroid Build Coastguard Worker namespace {
119*523fa7a6SAndroid Build Coastguard Worker 
write_data_to_file(const std::string & path,void * buf,size_t size)120*523fa7a6SAndroid Build Coastguard Worker void write_data_to_file(const std::string& path, void* buf, size_t size) {
121*523fa7a6SAndroid Build Coastguard Worker   FILE* f = fopen(path.c_str(), "w+");
122*523fa7a6SAndroid Build Coastguard Worker   if (!f) {
123*523fa7a6SAndroid Build Coastguard Worker     throw std::runtime_error(
124*523fa7a6SAndroid Build Coastguard Worker         "Failed to open file " + path + ": " + strerror(errno));
125*523fa7a6SAndroid Build Coastguard Worker   }
126*523fa7a6SAndroid Build Coastguard Worker   size_t num_written = fwrite(buf, 1, size, f);
127*523fa7a6SAndroid Build Coastguard Worker   if (num_written != size) {
128*523fa7a6SAndroid Build Coastguard Worker     fclose(f);
129*523fa7a6SAndroid Build Coastguard Worker     throw std::runtime_error("Failed to write etdump to file " + path);
130*523fa7a6SAndroid Build Coastguard Worker   }
131*523fa7a6SAndroid Build Coastguard Worker   int err = fclose(f);
132*523fa7a6SAndroid Build Coastguard Worker   if (err) {
133*523fa7a6SAndroid Build Coastguard Worker     throw std::runtime_error(
134*523fa7a6SAndroid Build Coastguard Worker         "Failed to close etdump file " + path + ": " + strerror(err));
135*523fa7a6SAndroid Build Coastguard Worker   }
136*523fa7a6SAndroid Build Coastguard Worker }
137*523fa7a6SAndroid Build Coastguard Worker 
setup_output_storage(Method & method,const std::vector<Span<uint8_t>> & output_storages)138*523fa7a6SAndroid Build Coastguard Worker void setup_output_storage(
139*523fa7a6SAndroid Build Coastguard Worker     Method& method,
140*523fa7a6SAndroid Build Coastguard Worker     const std::vector<Span<uint8_t>>& output_storages) {
141*523fa7a6SAndroid Build Coastguard Worker   if (output_storages.size() != method.outputs_size()) {
142*523fa7a6SAndroid Build Coastguard Worker     THROW_IF_ERROR(
143*523fa7a6SAndroid Build Coastguard Worker         Error::InvalidArgument,
144*523fa7a6SAndroid Build Coastguard Worker         "number of output storages %zu does not match number of outputs %zu",
145*523fa7a6SAndroid Build Coastguard Worker         output_storages.size(),
146*523fa7a6SAndroid Build Coastguard Worker         method.outputs_size());
147*523fa7a6SAndroid Build Coastguard Worker   }
148*523fa7a6SAndroid Build Coastguard Worker   for (size_t i = 0; i < output_storages.size(); ++i) {
149*523fa7a6SAndroid Build Coastguard Worker     if (output_storages[i].size() == 0) {
150*523fa7a6SAndroid Build Coastguard Worker       // Skip empty output storages, this would happen for non-tensor outputs
151*523fa7a6SAndroid Build Coastguard Worker       // and memory planned outputs.
152*523fa7a6SAndroid Build Coastguard Worker       continue;
153*523fa7a6SAndroid Build Coastguard Worker     }
154*523fa7a6SAndroid Build Coastguard Worker     Error output_status = method.set_output_data_ptr(
155*523fa7a6SAndroid Build Coastguard Worker         output_storages[i].data(), output_storages[i].size(), i);
156*523fa7a6SAndroid Build Coastguard Worker     // We already should be skipping non-tensor outputs, and memory planned
157*523fa7a6SAndroid Build Coastguard Worker     // outputs so any error is real.
158*523fa7a6SAndroid Build Coastguard Worker     THROW_IF_ERROR(
159*523fa7a6SAndroid Build Coastguard Worker         output_status,
160*523fa7a6SAndroid Build Coastguard Worker         "set_output_data_ptr failed for output %zu with error 0x%" PRIx32,
161*523fa7a6SAndroid Build Coastguard Worker         i,
162*523fa7a6SAndroid Build Coastguard Worker         static_cast<uint32_t>(output_status));
163*523fa7a6SAndroid Build Coastguard Worker   }
164*523fa7a6SAndroid Build Coastguard Worker }
165*523fa7a6SAndroid Build Coastguard Worker 
166*523fa7a6SAndroid Build Coastguard Worker class Module final {
167*523fa7a6SAndroid Build Coastguard Worker  public:
Module(std::unique_ptr<DataLoader> loader,std::unique_ptr<ETDumpGen> tracer=nullptr,size_t debug_buffer_size=0,Program::Verification program_verification=Program::Verification::InternalConsistency)168*523fa7a6SAndroid Build Coastguard Worker   explicit Module(
169*523fa7a6SAndroid Build Coastguard Worker       std::unique_ptr<DataLoader> loader,
170*523fa7a6SAndroid Build Coastguard Worker       std::unique_ptr<ETDumpGen> tracer = nullptr,
171*523fa7a6SAndroid Build Coastguard Worker       size_t debug_buffer_size = 0,
172*523fa7a6SAndroid Build Coastguard Worker       Program::Verification program_verification =
173*523fa7a6SAndroid Build Coastguard Worker           Program::Verification::InternalConsistency)
174*523fa7a6SAndroid Build Coastguard Worker       : loader_(std::move(loader)),
175*523fa7a6SAndroid Build Coastguard Worker         event_tracer_(std::move(tracer)),
176*523fa7a6SAndroid Build Coastguard Worker         debug_buffer_size_(debug_buffer_size) {
177*523fa7a6SAndroid Build Coastguard Worker     ::executorch::runtime::runtime_init();
178*523fa7a6SAndroid Build Coastguard Worker     Result<Program> program =
179*523fa7a6SAndroid Build Coastguard Worker         Program::load(loader_.get(), program_verification);
180*523fa7a6SAndroid Build Coastguard Worker     THROW_IF_ERROR(
181*523fa7a6SAndroid Build Coastguard Worker         program.error(),
182*523fa7a6SAndroid Build Coastguard Worker         "loading program failed with error: 0x%" PRIx32,
183*523fa7a6SAndroid Build Coastguard Worker         static_cast<uint32_t>(program.error()));
184*523fa7a6SAndroid Build Coastguard Worker     program_ = std::make_unique<Program>(std::move(program.get()));
185*523fa7a6SAndroid Build Coastguard Worker 
186*523fa7a6SAndroid Build Coastguard Worker     // Figure out the size of each non_const layer we need to support every
187*523fa7a6SAndroid Build Coastguard Worker     // method in the program. Map will be easier to use than a list because we
188*523fa7a6SAndroid Build Coastguard Worker     // dont know how many non_const arenas there will be
189*523fa7a6SAndroid Build Coastguard Worker     std::map<size_t, int64_t> non_const_buffer_sizes;
190*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < program_->num_methods(); ++i) {
191*523fa7a6SAndroid Build Coastguard Worker       auto name = program_->get_method_name(i).get();
192*523fa7a6SAndroid Build Coastguard Worker       auto method_meta = program_->method_meta(name).get();
193*523fa7a6SAndroid Build Coastguard Worker       for (size_t j = 0; j < method_meta.num_non_const_buffers(); j++) {
194*523fa7a6SAndroid Build Coastguard Worker         int64_t buffer_size = method_meta.non_const_buffer_size(j).get();
195*523fa7a6SAndroid Build Coastguard Worker         if (non_const_buffer_sizes.find(j) == non_const_buffer_sizes.end()) {
196*523fa7a6SAndroid Build Coastguard Worker           non_const_buffer_sizes.insert({j, buffer_size});
197*523fa7a6SAndroid Build Coastguard Worker         } else {
198*523fa7a6SAndroid Build Coastguard Worker           non_const_buffer_sizes[j] =
199*523fa7a6SAndroid Build Coastguard Worker               std::max(non_const_buffer_sizes[j], buffer_size);
200*523fa7a6SAndroid Build Coastguard Worker         }
201*523fa7a6SAndroid Build Coastguard Worker       }
202*523fa7a6SAndroid Build Coastguard Worker     }
203*523fa7a6SAndroid Build Coastguard Worker 
204*523fa7a6SAndroid Build Coastguard Worker     // Allocate the arenas. Using vector because we need to remember the size as
205*523fa7a6SAndroid Build Coastguard Worker     // well, so vector is easier then unique_ptr.
206*523fa7a6SAndroid Build Coastguard Worker     std::vector<std::vector<uint8_t>> non_const_buffers_;
207*523fa7a6SAndroid Build Coastguard Worker     for (std::map<size_t, int64_t>::iterator i = non_const_buffer_sizes.begin();
208*523fa7a6SAndroid Build Coastguard Worker          i != non_const_buffer_sizes.end();
209*523fa7a6SAndroid Build Coastguard Worker          i++) {
210*523fa7a6SAndroid Build Coastguard Worker       non_const_buffers_.push_back(std::vector<uint8_t>(i->second));
211*523fa7a6SAndroid Build Coastguard Worker     }
212*523fa7a6SAndroid Build Coastguard Worker 
213*523fa7a6SAndroid Build Coastguard Worker     memory_ = std::make_unique<Memory>(std::move(non_const_buffers_));
214*523fa7a6SAndroid Build Coastguard Worker     if (event_tracer_ && debug_buffer_size > 0) {
215*523fa7a6SAndroid Build Coastguard Worker       // If a debug buffer was requested for the ETDump, allocate it and make
216*523fa7a6SAndroid Build Coastguard Worker       // sure its lifetime is as long as the event_tracer.
217*523fa7a6SAndroid Build Coastguard Worker       debug_buffer_ = std::make_unique<uint8_t[]>(debug_buffer_size);
218*523fa7a6SAndroid Build Coastguard Worker       event_tracer_->set_debug_buffer(get_etdump_debug_buffer());
219*523fa7a6SAndroid Build Coastguard Worker       event_tracer_->set_event_tracer_debug_level(
220*523fa7a6SAndroid Build Coastguard Worker           EventTracerDebugLogLevel::kIntermediateOutputs);
221*523fa7a6SAndroid Build Coastguard Worker     }
222*523fa7a6SAndroid Build Coastguard Worker 
223*523fa7a6SAndroid Build Coastguard Worker     // Load methods
224*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < program_->num_methods(); ++i) {
225*523fa7a6SAndroid Build Coastguard Worker       auto name = program_->get_method_name(i).get();
226*523fa7a6SAndroid Build Coastguard Worker       // It's safe to use the same memory manager for all modules because
227*523fa7a6SAndroid Build Coastguard Worker       // we can guarantee that only one will be executing at a time.
228*523fa7a6SAndroid Build Coastguard Worker       // Everything in this module runs on a single thread.
229*523fa7a6SAndroid Build Coastguard Worker       Result<Method> method = program_->load_method(
230*523fa7a6SAndroid Build Coastguard Worker           name, memory_->mem_manager(), event_tracer_.get());
231*523fa7a6SAndroid Build Coastguard Worker       THROW_IF_ERROR(
232*523fa7a6SAndroid Build Coastguard Worker           method.error(),
233*523fa7a6SAndroid Build Coastguard Worker           "loading method %s failed with error 0x%" PRIx32,
234*523fa7a6SAndroid Build Coastguard Worker           name,
235*523fa7a6SAndroid Build Coastguard Worker           static_cast<uint32_t>(method.error()));
236*523fa7a6SAndroid Build Coastguard Worker       methods_.insert(
237*523fa7a6SAndroid Build Coastguard Worker           {std::string(name),
238*523fa7a6SAndroid Build Coastguard Worker            std::make_unique<Method>(std::move(method.get()))});
239*523fa7a6SAndroid Build Coastguard Worker     }
240*523fa7a6SAndroid Build Coastguard Worker   }
241*523fa7a6SAndroid Build Coastguard Worker 
242*523fa7a6SAndroid Build Coastguard Worker   Module(const Module&) = delete;
243*523fa7a6SAndroid Build Coastguard Worker   Module& operator=(const Module&) = delete;
244*523fa7a6SAndroid Build Coastguard Worker   Module(Module&&) = default;
245*523fa7a6SAndroid Build Coastguard Worker   Module& operator=(Module&&) = default;
246*523fa7a6SAndroid Build Coastguard Worker 
247*523fa7a6SAndroid Build Coastguard Worker   /// Executes the specified method on the provided inputs and returns its
248*523fa7a6SAndroid Build Coastguard Worker   /// outputs.
run_method(const std::string & method_name,const std::vector<EValue> & args,const std::optional<std::vector<Span<uint8_t>>> & output_storages=std::nullopt)249*523fa7a6SAndroid Build Coastguard Worker   std::vector<EValue> run_method(
250*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name,
251*523fa7a6SAndroid Build Coastguard Worker       const std::vector<EValue>& args,
252*523fa7a6SAndroid Build Coastguard Worker       const std::optional<std::vector<Span<uint8_t>>>& output_storages =
253*523fa7a6SAndroid Build Coastguard Worker           std::nullopt) {
254*523fa7a6SAndroid Build Coastguard Worker     auto& method = get_method(method_name);
255*523fa7a6SAndroid Build Coastguard Worker     exec_aten::ArrayRef<EValue> input_evalue_list(args.data(), args.size());
256*523fa7a6SAndroid Build Coastguard Worker 
257*523fa7a6SAndroid Build Coastguard Worker     Error set_inputs_status = method.set_inputs(input_evalue_list);
258*523fa7a6SAndroid Build Coastguard Worker     THROW_IF_ERROR(
259*523fa7a6SAndroid Build Coastguard Worker         set_inputs_status,
260*523fa7a6SAndroid Build Coastguard Worker         "method->set_inputs() for method '%s' failed with error 0x%" PRIx32,
261*523fa7a6SAndroid Build Coastguard Worker         method_name.c_str(),
262*523fa7a6SAndroid Build Coastguard Worker         static_cast<uint32_t>(set_inputs_status));
263*523fa7a6SAndroid Build Coastguard Worker 
264*523fa7a6SAndroid Build Coastguard Worker #ifdef USE_ATEN_LIB
265*523fa7a6SAndroid Build Coastguard Worker     // [TLS handling] This is to workaround an assertion failure
266*523fa7a6SAndroid Build Coastguard Worker     // (https://fburl.com/code/302jyn8d) running `gelu` in ATen mode in fbcode
267*523fa7a6SAndroid Build Coastguard Worker     // (such as bento). The problem is ExecuTorch ATen mode doesn't have
268*523fa7a6SAndroid Build Coastguard Worker     // Thread Local State, but `torch-cpp` is assuming tls init is done. There
269*523fa7a6SAndroid Build Coastguard Worker     // are two more checks: MKLDNN disabled and C10_MOBILE, if any of them is
270*523fa7a6SAndroid Build Coastguard Worker     // true we won't be hitting this assertion error. However in `torch-cpp`
271*523fa7a6SAndroid Build Coastguard Worker     // lib both checks are false. Production impact: this should not make any
272*523fa7a6SAndroid Build Coastguard Worker     // impact in production environment, given that in xplat we are depending
273*523fa7a6SAndroid Build Coastguard Worker     // on a library that enables C10_MOBILE (`torch_mobile_core`).
274*523fa7a6SAndroid Build Coastguard Worker     c10::impl::ExcludeDispatchKeyGuard no_autograd(
275*523fa7a6SAndroid Build Coastguard Worker         c10::autograd_dispatch_keyset);
276*523fa7a6SAndroid Build Coastguard Worker #endif
277*523fa7a6SAndroid Build Coastguard Worker     if (output_storages) {
278*523fa7a6SAndroid Build Coastguard Worker       setup_output_storage(method, *output_storages);
279*523fa7a6SAndroid Build Coastguard Worker     }
280*523fa7a6SAndroid Build Coastguard Worker     Error execute_status = method.execute();
281*523fa7a6SAndroid Build Coastguard Worker     THROW_IF_ERROR(
282*523fa7a6SAndroid Build Coastguard Worker         execute_status,
283*523fa7a6SAndroid Build Coastguard Worker         "method->execute() failed with error 0x%" PRIx32,
284*523fa7a6SAndroid Build Coastguard Worker         static_cast<uint32_t>(execute_status));
285*523fa7a6SAndroid Build Coastguard Worker     // process outputs
286*523fa7a6SAndroid Build Coastguard Worker     return get_outputs(method_name);
287*523fa7a6SAndroid Build Coastguard Worker   }
288*523fa7a6SAndroid Build Coastguard Worker 
get_outputs(const std::string & method_name)289*523fa7a6SAndroid Build Coastguard Worker   std::vector<EValue> get_outputs(const std::string& method_name) {
290*523fa7a6SAndroid Build Coastguard Worker     auto& method = methods_[method_name];
291*523fa7a6SAndroid Build Coastguard Worker     std::vector<EValue> result(method->outputs_size());
292*523fa7a6SAndroid Build Coastguard Worker 
293*523fa7a6SAndroid Build Coastguard Worker     Error get_outputs_status =
294*523fa7a6SAndroid Build Coastguard Worker         method->get_outputs(result.data(), method->outputs_size());
295*523fa7a6SAndroid Build Coastguard Worker     THROW_IF_ERROR(
296*523fa7a6SAndroid Build Coastguard Worker         get_outputs_status,
297*523fa7a6SAndroid Build Coastguard Worker         "method->get_outputs() for method '%s' failed with error 0x%" PRIx32,
298*523fa7a6SAndroid Build Coastguard Worker         method_name.c_str(),
299*523fa7a6SAndroid Build Coastguard Worker         static_cast<uint32_t>(get_outputs_status));
300*523fa7a6SAndroid Build Coastguard Worker 
301*523fa7a6SAndroid Build Coastguard Worker     return result;
302*523fa7a6SAndroid Build Coastguard Worker   }
303*523fa7a6SAndroid Build Coastguard Worker 
get_method(const std::string & method_name)304*523fa7a6SAndroid Build Coastguard Worker   Method& get_method(const std::string& method_name) {
305*523fa7a6SAndroid Build Coastguard Worker     if (methods_.count(method_name) == 0) {
306*523fa7a6SAndroid Build Coastguard Worker       THROW_IF_ERROR(
307*523fa7a6SAndroid Build Coastguard Worker           Error::InvalidArgument,
308*523fa7a6SAndroid Build Coastguard Worker           "no such method in program: %s",
309*523fa7a6SAndroid Build Coastguard Worker           method_name.c_str());
310*523fa7a6SAndroid Build Coastguard Worker     }
311*523fa7a6SAndroid Build Coastguard Worker     return *methods_[method_name].get();
312*523fa7a6SAndroid Build Coastguard Worker   }
313*523fa7a6SAndroid Build Coastguard Worker 
314*523fa7a6SAndroid Build Coastguard Worker   /// Returns the names of all methods in the program.
method_names() const315*523fa7a6SAndroid Build Coastguard Worker   std::vector<std::string> method_names() const {
316*523fa7a6SAndroid Build Coastguard Worker     std::vector<std::string> names;
317*523fa7a6SAndroid Build Coastguard Worker     for (const auto& method : methods_) {
318*523fa7a6SAndroid Build Coastguard Worker       names.push_back(method.first);
319*523fa7a6SAndroid Build Coastguard Worker     }
320*523fa7a6SAndroid Build Coastguard Worker     return names;
321*523fa7a6SAndroid Build Coastguard Worker   }
322*523fa7a6SAndroid Build Coastguard Worker 
has_etdump()323*523fa7a6SAndroid Build Coastguard Worker   bool has_etdump() {
324*523fa7a6SAndroid Build Coastguard Worker     return static_cast<bool>(event_tracer_);
325*523fa7a6SAndroid Build Coastguard Worker   }
326*523fa7a6SAndroid Build Coastguard Worker 
etdump()327*523fa7a6SAndroid Build Coastguard Worker   ETDumpGen& etdump() {
328*523fa7a6SAndroid Build Coastguard Worker     return *event_tracer_;
329*523fa7a6SAndroid Build Coastguard Worker   }
330*523fa7a6SAndroid Build Coastguard Worker 
has_etdump_debug_buffer() const331*523fa7a6SAndroid Build Coastguard Worker   bool has_etdump_debug_buffer() const {
332*523fa7a6SAndroid Build Coastguard Worker     return static_cast<bool>(debug_buffer_);
333*523fa7a6SAndroid Build Coastguard Worker   }
334*523fa7a6SAndroid Build Coastguard Worker 
get_etdump_debug_buffer()335*523fa7a6SAndroid Build Coastguard Worker   Span<uint8_t> get_etdump_debug_buffer() {
336*523fa7a6SAndroid Build Coastguard Worker     return Span<uint8_t>(debug_buffer_.get(), debug_buffer_size_);
337*523fa7a6SAndroid Build Coastguard Worker   }
338*523fa7a6SAndroid Build Coastguard Worker 
339*523fa7a6SAndroid Build Coastguard Worker  private:
340*523fa7a6SAndroid Build Coastguard Worker   /// A wrapper/util class for executorch memory allocations/manager.
341*523fa7a6SAndroid Build Coastguard Worker   class Memory {
342*523fa7a6SAndroid Build Coastguard Worker    public:
Memory(std::vector<std::vector<uint8_t>> && non_const_buffers)343*523fa7a6SAndroid Build Coastguard Worker     explicit Memory(std::vector<std::vector<uint8_t>>&& non_const_buffers)
344*523fa7a6SAndroid Build Coastguard Worker         : runtime_allocator_(),
345*523fa7a6SAndroid Build Coastguard Worker           non_const_buffers_(std::move(non_const_buffers)),
346*523fa7a6SAndroid Build Coastguard Worker           non_const_spans_(create_non_const_spans()),
347*523fa7a6SAndroid Build Coastguard Worker           non_const_allocator_(
348*523fa7a6SAndroid Build Coastguard Worker               {non_const_spans_.data(), non_const_spans_.size()}),
349*523fa7a6SAndroid Build Coastguard Worker           mem_manager_(
350*523fa7a6SAndroid Build Coastguard Worker               &const_allocator_,
351*523fa7a6SAndroid Build Coastguard Worker               &non_const_allocator_,
352*523fa7a6SAndroid Build Coastguard Worker               &runtime_allocator_,
353*523fa7a6SAndroid Build Coastguard Worker               &temp_allocator_) {}
354*523fa7a6SAndroid Build Coastguard Worker 
355*523fa7a6SAndroid Build Coastguard Worker     /// Returns a pointer to the internal memory manager, the Memory instance
356*523fa7a6SAndroid Build Coastguard Worker     /// must outlive this pointer.
mem_manager()357*523fa7a6SAndroid Build Coastguard Worker     MemoryManager* mem_manager() {
358*523fa7a6SAndroid Build Coastguard Worker       return &mem_manager_;
359*523fa7a6SAndroid Build Coastguard Worker     }
360*523fa7a6SAndroid Build Coastguard Worker 
361*523fa7a6SAndroid Build Coastguard Worker     Memory(const Memory&) = delete;
362*523fa7a6SAndroid Build Coastguard Worker     Memory& operator=(const Memory&) = delete;
363*523fa7a6SAndroid Build Coastguard Worker 
364*523fa7a6SAndroid Build Coastguard Worker    private:
365*523fa7a6SAndroid Build Coastguard Worker     MemoryAllocator const_allocator_{MemoryAllocator(0, nullptr)};
366*523fa7a6SAndroid Build Coastguard Worker 
367*523fa7a6SAndroid Build Coastguard Worker     MallocMemoryAllocator runtime_allocator_;
368*523fa7a6SAndroid Build Coastguard Worker 
369*523fa7a6SAndroid Build Coastguard Worker     MemoryAllocator temp_allocator_{MemoryAllocator(0, nullptr)};
370*523fa7a6SAndroid Build Coastguard Worker 
371*523fa7a6SAndroid Build Coastguard Worker     std::vector<std::vector<uint8_t>> non_const_buffers_;
372*523fa7a6SAndroid Build Coastguard Worker 
373*523fa7a6SAndroid Build Coastguard Worker     std::vector<Span<uint8_t>> non_const_spans_;
374*523fa7a6SAndroid Build Coastguard Worker 
375*523fa7a6SAndroid Build Coastguard Worker     HierarchicalAllocator non_const_allocator_;
376*523fa7a6SAndroid Build Coastguard Worker 
377*523fa7a6SAndroid Build Coastguard Worker     MemoryManager mem_manager_;
378*523fa7a6SAndroid Build Coastguard Worker 
create_non_const_spans()379*523fa7a6SAndroid Build Coastguard Worker     std::vector<Span<uint8_t>> create_non_const_spans() {
380*523fa7a6SAndroid Build Coastguard Worker       std::vector<Span<uint8_t>> result;
381*523fa7a6SAndroid Build Coastguard Worker       for (size_t i = 0; i < non_const_buffers_.size(); i++) {
382*523fa7a6SAndroid Build Coastguard Worker         result.push_back(
383*523fa7a6SAndroid Build Coastguard Worker             {non_const_buffers_[i].data(), non_const_buffers_[i].size()});
384*523fa7a6SAndroid Build Coastguard Worker       }
385*523fa7a6SAndroid Build Coastguard Worker       return result;
386*523fa7a6SAndroid Build Coastguard Worker     }
387*523fa7a6SAndroid Build Coastguard Worker   };
388*523fa7a6SAndroid Build Coastguard Worker 
389*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<Memory> memory_;
390*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<DataLoader> loader_; // program_ points to this.
391*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<const Program> program_; // methods_ entries points to this.
392*523fa7a6SAndroid Build Coastguard Worker   std::unordered_map<std::string, std::unique_ptr<Method>> methods_;
393*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<ETDumpGen> event_tracer_;
394*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<uint8_t[]> debug_buffer_;
395*523fa7a6SAndroid Build Coastguard Worker   size_t debug_buffer_size_;
396*523fa7a6SAndroid Build Coastguard Worker };
397*523fa7a6SAndroid Build Coastguard Worker 
load_module_from_buffer(const void * ptr,size_t ptr_len,bool enable_etdump,size_t debug_buffer_size,Program::Verification program_verification)398*523fa7a6SAndroid Build Coastguard Worker inline std::unique_ptr<Module> load_module_from_buffer(
399*523fa7a6SAndroid Build Coastguard Worker     const void* ptr,
400*523fa7a6SAndroid Build Coastguard Worker     size_t ptr_len,
401*523fa7a6SAndroid Build Coastguard Worker     bool enable_etdump,
402*523fa7a6SAndroid Build Coastguard Worker     size_t debug_buffer_size,
403*523fa7a6SAndroid Build Coastguard Worker     Program::Verification program_verification) {
404*523fa7a6SAndroid Build Coastguard Worker   EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
405*523fa7a6SAndroid Build Coastguard Worker   auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);
406*523fa7a6SAndroid Build Coastguard Worker   return std::make_unique<Module>(
407*523fa7a6SAndroid Build Coastguard Worker       std::move(loader),
408*523fa7a6SAndroid Build Coastguard Worker       enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
409*523fa7a6SAndroid Build Coastguard Worker       debug_buffer_size,
410*523fa7a6SAndroid Build Coastguard Worker       program_verification);
411*523fa7a6SAndroid Build Coastguard Worker }
412*523fa7a6SAndroid Build Coastguard Worker 
load_module_from_file(const std::string & path,bool enable_etdump,size_t debug_buffer_size,Program::Verification program_verification)413*523fa7a6SAndroid Build Coastguard Worker inline std::unique_ptr<Module> load_module_from_file(
414*523fa7a6SAndroid Build Coastguard Worker     const std::string& path,
415*523fa7a6SAndroid Build Coastguard Worker     bool enable_etdump,
416*523fa7a6SAndroid Build Coastguard Worker     size_t debug_buffer_size,
417*523fa7a6SAndroid Build Coastguard Worker     Program::Verification program_verification) {
418*523fa7a6SAndroid Build Coastguard Worker   EXECUTORCH_SCOPE_PROF("load_module_from_file");
419*523fa7a6SAndroid Build Coastguard Worker 
420*523fa7a6SAndroid Build Coastguard Worker   Result<MmapDataLoader> res = MmapDataLoader::from(
421*523fa7a6SAndroid Build Coastguard Worker       path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
422*523fa7a6SAndroid Build Coastguard Worker   THROW_IF_ERROR(
423*523fa7a6SAndroid Build Coastguard Worker       res.error(),
424*523fa7a6SAndroid Build Coastguard Worker       "Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
425*523fa7a6SAndroid Build Coastguard Worker       path.c_str(),
426*523fa7a6SAndroid Build Coastguard Worker       static_cast<uint32_t>(res.error()));
427*523fa7a6SAndroid Build Coastguard Worker 
428*523fa7a6SAndroid Build Coastguard Worker   auto loader = std::make_unique<MmapDataLoader>(std::move(res.get()));
429*523fa7a6SAndroid Build Coastguard Worker   return std::make_unique<Module>(
430*523fa7a6SAndroid Build Coastguard Worker       std::move(loader),
431*523fa7a6SAndroid Build Coastguard Worker       enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
432*523fa7a6SAndroid Build Coastguard Worker       debug_buffer_size,
433*523fa7a6SAndroid Build Coastguard Worker       program_verification);
434*523fa7a6SAndroid Build Coastguard Worker }
435*523fa7a6SAndroid Build Coastguard Worker 
436*523fa7a6SAndroid Build Coastguard Worker static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
437*523fa7a6SAndroid Build Coastguard Worker 
438*523fa7a6SAndroid Build Coastguard Worker struct PyBundledModule final {
PyBundledModuleexecutorch::extension::pybindings::__anon7828bc6a0111::PyBundledModule439*523fa7a6SAndroid Build Coastguard Worker   explicit PyBundledModule(
440*523fa7a6SAndroid Build Coastguard Worker       const py::bytes& buffer,
441*523fa7a6SAndroid Build Coastguard Worker       uint32_t bundled_input_pool_size)
442*523fa7a6SAndroid Build Coastguard Worker       : bundled_program_ptr_(buffer),
443*523fa7a6SAndroid Build Coastguard Worker         program_ptr_(static_cast<const void*>(
444*523fa7a6SAndroid Build Coastguard Worker             bundled_program_flatbuffer::GetBundledProgram(
445*523fa7a6SAndroid Build Coastguard Worker                 get_bundled_program_ptr())
446*523fa7a6SAndroid Build Coastguard Worker                 ->program()
447*523fa7a6SAndroid Build Coastguard Worker                 ->data())),
448*523fa7a6SAndroid Build Coastguard Worker         program_len_(bundled_program_flatbuffer::GetBundledProgram(
449*523fa7a6SAndroid Build Coastguard Worker                          get_bundled_program_ptr())
450*523fa7a6SAndroid Build Coastguard Worker                          ->program()
451*523fa7a6SAndroid Build Coastguard Worker                          ->size()) {}
452*523fa7a6SAndroid Build Coastguard Worker 
load_from_bufferexecutorch::extension::pybindings::__anon7828bc6a0111::PyBundledModule453*523fa7a6SAndroid Build Coastguard Worker   static std::unique_ptr<PyBundledModule> load_from_buffer(
454*523fa7a6SAndroid Build Coastguard Worker       const py::bytes& buffer,
455*523fa7a6SAndroid Build Coastguard Worker       uint32_t bundled_input_pool_size) {
456*523fa7a6SAndroid Build Coastguard Worker     return std::make_unique<PyBundledModule>(buffer, bundled_input_pool_size);
457*523fa7a6SAndroid Build Coastguard Worker   }
458*523fa7a6SAndroid Build Coastguard Worker 
get_bundled_program_ptrexecutorch::extension::pybindings::__anon7828bc6a0111::PyBundledModule459*523fa7a6SAndroid Build Coastguard Worker   const void* get_bundled_program_ptr() {
460*523fa7a6SAndroid Build Coastguard Worker     return bundled_program_ptr_.cast<std::string_view>().data();
461*523fa7a6SAndroid Build Coastguard Worker   }
462*523fa7a6SAndroid Build Coastguard Worker 
get_program_ptrexecutorch::extension::pybindings::__anon7828bc6a0111::PyBundledModule463*523fa7a6SAndroid Build Coastguard Worker   const void* get_program_ptr() {
464*523fa7a6SAndroid Build Coastguard Worker     return program_ptr_;
465*523fa7a6SAndroid Build Coastguard Worker   }
466*523fa7a6SAndroid Build Coastguard Worker 
get_program_lenexecutorch::extension::pybindings::__anon7828bc6a0111::PyBundledModule467*523fa7a6SAndroid Build Coastguard Worker   size_t get_program_len() {
468*523fa7a6SAndroid Build Coastguard Worker     return program_len_;
469*523fa7a6SAndroid Build Coastguard Worker   }
470*523fa7a6SAndroid Build Coastguard Worker 
471*523fa7a6SAndroid Build Coastguard Worker  private:
472*523fa7a6SAndroid Build Coastguard Worker   // Store the bytes object instead of a raw pointer so that this module will
473*523fa7a6SAndroid Build Coastguard Worker   // keep the bytes alive.
474*523fa7a6SAndroid Build Coastguard Worker   const py::bytes bundled_program_ptr_;
475*523fa7a6SAndroid Build Coastguard Worker   const void* program_ptr_;
476*523fa7a6SAndroid Build Coastguard Worker   size_t program_len_;
477*523fa7a6SAndroid Build Coastguard Worker };
478*523fa7a6SAndroid Build Coastguard Worker 
479*523fa7a6SAndroid Build Coastguard Worker /// Expose a subset of TensorInfo information to python.
480*523fa7a6SAndroid Build Coastguard Worker struct PyTensorInfo final {
PyTensorInfoexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo481*523fa7a6SAndroid Build Coastguard Worker   explicit PyTensorInfo(
482*523fa7a6SAndroid Build Coastguard Worker       std::shared_ptr<Module> module,
483*523fa7a6SAndroid Build Coastguard Worker       torch::executor::TensorInfo info)
484*523fa7a6SAndroid Build Coastguard Worker       : module_(std::move(module)), info_(info) {}
485*523fa7a6SAndroid Build Coastguard Worker 
sizesexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo486*523fa7a6SAndroid Build Coastguard Worker   py::tuple sizes() const {
487*523fa7a6SAndroid Build Coastguard Worker     const auto shape = info_.sizes();
488*523fa7a6SAndroid Build Coastguard Worker     py::tuple tup(shape.size());
489*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < shape.size(); ++i) {
490*523fa7a6SAndroid Build Coastguard Worker       tup[i] = py::cast(shape[i]);
491*523fa7a6SAndroid Build Coastguard Worker     }
492*523fa7a6SAndroid Build Coastguard Worker     return tup;
493*523fa7a6SAndroid Build Coastguard Worker   }
494*523fa7a6SAndroid Build Coastguard Worker 
dtypeexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo495*523fa7a6SAndroid Build Coastguard Worker   int8_t dtype() const {
496*523fa7a6SAndroid Build Coastguard Worker     return static_cast<std::underlying_type<exec_aten::ScalarType>::type>(
497*523fa7a6SAndroid Build Coastguard Worker         info_.scalar_type());
498*523fa7a6SAndroid Build Coastguard Worker   }
499*523fa7a6SAndroid Build Coastguard Worker 
is_memory_plannedexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo500*523fa7a6SAndroid Build Coastguard Worker   bool is_memory_planned() const {
501*523fa7a6SAndroid Build Coastguard Worker     return info_.is_memory_planned();
502*523fa7a6SAndroid Build Coastguard Worker   }
503*523fa7a6SAndroid Build Coastguard Worker 
nbytesexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo504*523fa7a6SAndroid Build Coastguard Worker   size_t nbytes() const {
505*523fa7a6SAndroid Build Coastguard Worker     return info_.nbytes();
506*523fa7a6SAndroid Build Coastguard Worker   }
507*523fa7a6SAndroid Build Coastguard Worker 
reprexecutorch::extension::pybindings::__anon7828bc6a0111::PyTensorInfo508*523fa7a6SAndroid Build Coastguard Worker   std::string repr() const {
509*523fa7a6SAndroid Build Coastguard Worker     std::string size_str = "[";
510*523fa7a6SAndroid Build Coastguard Worker     for (const auto& d : info_.sizes()) {
511*523fa7a6SAndroid Build Coastguard Worker       size_str.append(std::to_string(d));
512*523fa7a6SAndroid Build Coastguard Worker       size_str.append(", ");
513*523fa7a6SAndroid Build Coastguard Worker     }
514*523fa7a6SAndroid Build Coastguard Worker     if (size_str.length() >= 2) {
515*523fa7a6SAndroid Build Coastguard Worker       // Pop the last two characters (command and space) and add close bracket.
516*523fa7a6SAndroid Build Coastguard Worker       size_str.pop_back();
517*523fa7a6SAndroid Build Coastguard Worker       size_str.pop_back();
518*523fa7a6SAndroid Build Coastguard Worker     }
519*523fa7a6SAndroid Build Coastguard Worker     size_str.append("]");
520*523fa7a6SAndroid Build Coastguard Worker     return "TensorInfo(sizes=" + size_str + ", dtype=" +
521*523fa7a6SAndroid Build Coastguard Worker         std::string(executorch::runtime::toString(info_.scalar_type())) +
522*523fa7a6SAndroid Build Coastguard Worker         ", is_memory_planned=" +
523*523fa7a6SAndroid Build Coastguard Worker         (info_.is_memory_planned() ? "True" : "False") +
524*523fa7a6SAndroid Build Coastguard Worker         ", nbytes=" + std::to_string(info_.nbytes()) + ")";
525*523fa7a6SAndroid Build Coastguard Worker   }
526*523fa7a6SAndroid Build Coastguard Worker 
527*523fa7a6SAndroid Build Coastguard Worker  private:
528*523fa7a6SAndroid Build Coastguard Worker   // TensorInfo relies on module to be alive.
529*523fa7a6SAndroid Build Coastguard Worker   std::shared_ptr<Module> module_;
530*523fa7a6SAndroid Build Coastguard Worker   torch::executor::TensorInfo info_;
531*523fa7a6SAndroid Build Coastguard Worker };
532*523fa7a6SAndroid Build Coastguard Worker 
533*523fa7a6SAndroid Build Coastguard Worker /// Expose a subset of MethodMeta information to python.
534*523fa7a6SAndroid Build Coastguard Worker struct PyMethodMeta final {
PyMethodMetaexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta535*523fa7a6SAndroid Build Coastguard Worker   explicit PyMethodMeta(
536*523fa7a6SAndroid Build Coastguard Worker       std::shared_ptr<Module> module,
537*523fa7a6SAndroid Build Coastguard Worker       torch::executor::MethodMeta meta)
538*523fa7a6SAndroid Build Coastguard Worker       : module_(std::move(module)), meta_(meta) {}
539*523fa7a6SAndroid Build Coastguard Worker 
nameexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta540*523fa7a6SAndroid Build Coastguard Worker   const char* name() const {
541*523fa7a6SAndroid Build Coastguard Worker     return meta_.name();
542*523fa7a6SAndroid Build Coastguard Worker   }
543*523fa7a6SAndroid Build Coastguard Worker 
num_inputsexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta544*523fa7a6SAndroid Build Coastguard Worker   size_t num_inputs() const {
545*523fa7a6SAndroid Build Coastguard Worker     return meta_.num_inputs();
546*523fa7a6SAndroid Build Coastguard Worker   }
547*523fa7a6SAndroid Build Coastguard Worker 
input_tensor_metaexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta548*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<PyTensorInfo> input_tensor_meta(size_t index) const {
549*523fa7a6SAndroid Build Coastguard Worker     const auto result = meta_.input_tensor_meta(index);
550*523fa7a6SAndroid Build Coastguard Worker     THROW_INDEX_IF_ERROR(
551*523fa7a6SAndroid Build Coastguard Worker         result.error(), "Cannot get input tensor meta at %zu", index);
552*523fa7a6SAndroid Build Coastguard Worker     return std::make_unique<PyTensorInfo>(module_, result.get());
553*523fa7a6SAndroid Build Coastguard Worker   }
554*523fa7a6SAndroid Build Coastguard Worker 
num_outputsexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta555*523fa7a6SAndroid Build Coastguard Worker   size_t num_outputs() const {
556*523fa7a6SAndroid Build Coastguard Worker     return meta_.num_outputs();
557*523fa7a6SAndroid Build Coastguard Worker   }
558*523fa7a6SAndroid Build Coastguard Worker 
output_tensor_metaexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta559*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<PyTensorInfo> output_tensor_meta(size_t index) const {
560*523fa7a6SAndroid Build Coastguard Worker     const auto result = meta_.output_tensor_meta(index);
561*523fa7a6SAndroid Build Coastguard Worker     THROW_INDEX_IF_ERROR(
562*523fa7a6SAndroid Build Coastguard Worker         result.error(), "Cannot get output tensor meta at %zu", index);
563*523fa7a6SAndroid Build Coastguard Worker     return std::make_unique<PyTensorInfo>(module_, result.get());
564*523fa7a6SAndroid Build Coastguard Worker   }
565*523fa7a6SAndroid Build Coastguard Worker 
reprexecutorch::extension::pybindings::__anon7828bc6a0111::PyMethodMeta566*523fa7a6SAndroid Build Coastguard Worker   py::str repr() const {
567*523fa7a6SAndroid Build Coastguard Worker     py::list input_meta_strs;
568*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < meta_.num_inputs(); ++i) {
569*523fa7a6SAndroid Build Coastguard Worker       input_meta_strs.append(py::str(input_tensor_meta(i)->repr()));
570*523fa7a6SAndroid Build Coastguard Worker     }
571*523fa7a6SAndroid Build Coastguard Worker     py::list output_meta_strs;
572*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < meta_.num_outputs(); ++i) {
573*523fa7a6SAndroid Build Coastguard Worker       output_meta_strs.append(py::str(output_tensor_meta(i)->repr()));
574*523fa7a6SAndroid Build Coastguard Worker     }
575*523fa7a6SAndroid Build Coastguard Worker     // Add quotes to be more similar to Python's repr for strings.
576*523fa7a6SAndroid Build Coastguard Worker     py::str format =
577*523fa7a6SAndroid Build Coastguard Worker         "MethodMeta(name='{}', num_inputs={}, input_tensor_meta={}, num_outputs={}, output_tensor_meta={})";
578*523fa7a6SAndroid Build Coastguard Worker     return format.format(
579*523fa7a6SAndroid Build Coastguard Worker         std::string(meta_.name()),
580*523fa7a6SAndroid Build Coastguard Worker         std::to_string(meta_.num_inputs()),
581*523fa7a6SAndroid Build Coastguard Worker         input_meta_strs,
582*523fa7a6SAndroid Build Coastguard Worker         std::to_string(meta_.num_outputs()),
583*523fa7a6SAndroid Build Coastguard Worker         output_meta_strs);
584*523fa7a6SAndroid Build Coastguard Worker   }
585*523fa7a6SAndroid Build Coastguard Worker 
586*523fa7a6SAndroid Build Coastguard Worker  private:
587*523fa7a6SAndroid Build Coastguard Worker   // Must keep the Module object alive or else the meta object is invalidated.
588*523fa7a6SAndroid Build Coastguard Worker   std::shared_ptr<Module> module_;
589*523fa7a6SAndroid Build Coastguard Worker   torch::executor::MethodMeta meta_;
590*523fa7a6SAndroid Build Coastguard Worker };
591*523fa7a6SAndroid Build Coastguard Worker 
592*523fa7a6SAndroid Build Coastguard Worker struct PyModule final {
PyModuleexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule593*523fa7a6SAndroid Build Coastguard Worker   explicit PyModule(
594*523fa7a6SAndroid Build Coastguard Worker       const py::bytes& buffer,
595*523fa7a6SAndroid Build Coastguard Worker       bool enable_etdump,
596*523fa7a6SAndroid Build Coastguard Worker       size_t debug_buffer_size = 0,
597*523fa7a6SAndroid Build Coastguard Worker       Program::Verification program_verification =
598*523fa7a6SAndroid Build Coastguard Worker           Program::Verification::InternalConsistency)
599*523fa7a6SAndroid Build Coastguard Worker       : module_(load_module_from_buffer(
600*523fa7a6SAndroid Build Coastguard Worker             buffer.cast<std::string_view>().data(),
601*523fa7a6SAndroid Build Coastguard Worker             py::len(buffer),
602*523fa7a6SAndroid Build Coastguard Worker             enable_etdump,
603*523fa7a6SAndroid Build Coastguard Worker             debug_buffer_size,
604*523fa7a6SAndroid Build Coastguard Worker             program_verification)) {}
605*523fa7a6SAndroid Build Coastguard Worker 
PyModuleexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule606*523fa7a6SAndroid Build Coastguard Worker   explicit PyModule(
607*523fa7a6SAndroid Build Coastguard Worker       const void* ptr,
608*523fa7a6SAndroid Build Coastguard Worker       size_t ptr_len,
609*523fa7a6SAndroid Build Coastguard Worker       bool enable_etdump,
610*523fa7a6SAndroid Build Coastguard Worker       size_t debug_buffer_size = 0,
611*523fa7a6SAndroid Build Coastguard Worker       Program::Verification program_verification =
612*523fa7a6SAndroid Build Coastguard Worker           Program::Verification::InternalConsistency)
613*523fa7a6SAndroid Build Coastguard Worker       : module_(load_module_from_buffer(
614*523fa7a6SAndroid Build Coastguard Worker             ptr,
615*523fa7a6SAndroid Build Coastguard Worker             ptr_len,
616*523fa7a6SAndroid Build Coastguard Worker             enable_etdump,
617*523fa7a6SAndroid Build Coastguard Worker             debug_buffer_size,
618*523fa7a6SAndroid Build Coastguard Worker             program_verification)) {}
619*523fa7a6SAndroid Build Coastguard Worker 
PyModuleexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule620*523fa7a6SAndroid Build Coastguard Worker   explicit PyModule(
621*523fa7a6SAndroid Build Coastguard Worker       const std::string& path,
622*523fa7a6SAndroid Build Coastguard Worker       bool enable_etdump,
623*523fa7a6SAndroid Build Coastguard Worker       size_t debug_buffer_size = 0,
624*523fa7a6SAndroid Build Coastguard Worker       Program::Verification program_verification =
625*523fa7a6SAndroid Build Coastguard Worker           Program::Verification::InternalConsistency)
626*523fa7a6SAndroid Build Coastguard Worker       : module_(load_module_from_file(
627*523fa7a6SAndroid Build Coastguard Worker             path,
628*523fa7a6SAndroid Build Coastguard Worker             enable_etdump,
629*523fa7a6SAndroid Build Coastguard Worker             debug_buffer_size,
630*523fa7a6SAndroid Build Coastguard Worker             program_verification)) {}
631*523fa7a6SAndroid Build Coastguard Worker 
632*523fa7a6SAndroid Build Coastguard Worker   PyModule(const PyModule&) = delete;
633*523fa7a6SAndroid Build Coastguard Worker   PyModule& operator=(const PyModule&) = delete;
634*523fa7a6SAndroid Build Coastguard Worker   PyModule(PyModule&&) = default;
635*523fa7a6SAndroid Build Coastguard Worker   PyModule& operator=(PyModule&&) = default;
636*523fa7a6SAndroid Build Coastguard Worker 
637*523fa7a6SAndroid Build Coastguard Worker   // Module is only valid as long as the python buffer is alive.
load_from_bufferexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule638*523fa7a6SAndroid Build Coastguard Worker   static std::unique_ptr<PyModule> load_from_buffer(
639*523fa7a6SAndroid Build Coastguard Worker       const py::bytes& buffer,
640*523fa7a6SAndroid Build Coastguard Worker       bool enable_etdump,
641*523fa7a6SAndroid Build Coastguard Worker       size_t debug_buffer_size = 0,
642*523fa7a6SAndroid Build Coastguard Worker       Program::Verification program_verification =
643*523fa7a6SAndroid Build Coastguard Worker           Program::Verification::InternalConsistency) {
644*523fa7a6SAndroid Build Coastguard Worker     return std::make_unique<PyModule>(
645*523fa7a6SAndroid Build Coastguard Worker         buffer, enable_etdump, debug_buffer_size, program_verification);
646*523fa7a6SAndroid Build Coastguard Worker   }
load_from_fileexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule647*523fa7a6SAndroid Build Coastguard Worker   static std::unique_ptr<PyModule> load_from_file(
648*523fa7a6SAndroid Build Coastguard Worker       const std::string& path,
649*523fa7a6SAndroid Build Coastguard Worker       bool enable_etdump,
650*523fa7a6SAndroid Build Coastguard Worker       size_t debug_buffer_size = 0,
651*523fa7a6SAndroid Build Coastguard Worker       Program::Verification program_verification =
652*523fa7a6SAndroid Build Coastguard Worker           Program::Verification::InternalConsistency) {
653*523fa7a6SAndroid Build Coastguard Worker     return std::make_unique<PyModule>(
654*523fa7a6SAndroid Build Coastguard Worker         path, enable_etdump, debug_buffer_size, program_verification);
655*523fa7a6SAndroid Build Coastguard Worker   }
656*523fa7a6SAndroid Build Coastguard Worker 
load_from_bundled_programexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule657*523fa7a6SAndroid Build Coastguard Worker   static std::unique_ptr<PyModule> load_from_bundled_program(
658*523fa7a6SAndroid Build Coastguard Worker       PyBundledModule& m,
659*523fa7a6SAndroid Build Coastguard Worker       bool enable_etdump,
660*523fa7a6SAndroid Build Coastguard Worker       size_t debug_buffer_size = 0) {
661*523fa7a6SAndroid Build Coastguard Worker     return std::make_unique<PyModule>(
662*523fa7a6SAndroid Build Coastguard Worker         m.get_program_ptr(),
663*523fa7a6SAndroid Build Coastguard Worker         m.get_program_len(),
664*523fa7a6SAndroid Build Coastguard Worker         enable_etdump,
665*523fa7a6SAndroid Build Coastguard Worker         debug_buffer_size);
666*523fa7a6SAndroid Build Coastguard Worker   }
667*523fa7a6SAndroid Build Coastguard Worker 
run_methodexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule668*523fa7a6SAndroid Build Coastguard Worker   py::list run_method(
669*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name,
670*523fa7a6SAndroid Build Coastguard Worker       const py::sequence& inputs,
671*523fa7a6SAndroid Build Coastguard Worker       bool clone_outputs = true) {
672*523fa7a6SAndroid Build Coastguard Worker     const auto inputs_size = py::len(inputs);
673*523fa7a6SAndroid Build Coastguard Worker     std::vector<EValue> cpp_inputs;
674*523fa7a6SAndroid Build Coastguard Worker     cpp_inputs.reserve(inputs_size);
675*523fa7a6SAndroid Build Coastguard Worker 
676*523fa7a6SAndroid Build Coastguard Worker #ifndef USE_ATEN_LIB // Portable mode
677*523fa7a6SAndroid Build Coastguard Worker     // So the ETensors and their metadata stay in scope for
678*523fa7a6SAndroid Build Coastguard Worker     // Module->run_method.
679*523fa7a6SAndroid Build Coastguard Worker     std::vector<torch::executor::TensorImpl> input_tensors;
680*523fa7a6SAndroid Build Coastguard Worker     std::vector<std::vector<torch::executor::Tensor::SizesType>> input_sizes;
681*523fa7a6SAndroid Build Coastguard Worker     std::vector<std::vector<torch::executor::Tensor::StridesType>>
682*523fa7a6SAndroid Build Coastguard Worker         input_strides;
683*523fa7a6SAndroid Build Coastguard Worker     std::vector<std::vector<torch::executor::Tensor::DimOrderType>>
684*523fa7a6SAndroid Build Coastguard Worker         input_dim_order;
685*523fa7a6SAndroid Build Coastguard Worker     // We store pointers to these vector elements so important to reserve so
686*523fa7a6SAndroid Build Coastguard Worker     // that we don't lose those on a vector resize. Don't need to do this for
687*523fa7a6SAndroid Build Coastguard Worker     // the others since they are vectors of vectors, and we don't store a
688*523fa7a6SAndroid Build Coastguard Worker     // pointer to the root level vector data.
689*523fa7a6SAndroid Build Coastguard Worker     input_tensors.reserve(inputs_size);
690*523fa7a6SAndroid Build Coastguard Worker #endif
691*523fa7a6SAndroid Build Coastguard Worker 
692*523fa7a6SAndroid Build Coastguard Worker     // Convert python objects into EValues.
693*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < inputs_size; ++i) {
694*523fa7a6SAndroid Build Coastguard Worker       auto python_input = inputs[i];
695*523fa7a6SAndroid Build Coastguard Worker       const std::string& type_str = py::str(python_input.get_type());
696*523fa7a6SAndroid Build Coastguard Worker       if (type_str == "<class 'torch.Tensor'>") {
697*523fa7a6SAndroid Build Coastguard Worker         auto at_tensor = python_input.cast<at::Tensor>();
698*523fa7a6SAndroid Build Coastguard Worker         // alias_etensor_to_attensor will assert on this later, so to better
699*523fa7a6SAndroid Build Coastguard Worker         // propogate up to python we check early and throw an exception.
700*523fa7a6SAndroid Build Coastguard Worker         if (!at_tensor.is_contiguous()) {
701*523fa7a6SAndroid Build Coastguard Worker           auto error_msg = "Input " + std::to_string(i) + "for method " +
702*523fa7a6SAndroid Build Coastguard Worker               method_name + " is not contiguous.";
703*523fa7a6SAndroid Build Coastguard Worker           throw std::runtime_error(error_msg);
704*523fa7a6SAndroid Build Coastguard Worker         }
705*523fa7a6SAndroid Build Coastguard Worker 
706*523fa7a6SAndroid Build Coastguard Worker #ifdef USE_ATEN_LIB
707*523fa7a6SAndroid Build Coastguard Worker         EValue evalue(at_tensor);
708*523fa7a6SAndroid Build Coastguard Worker #else
709*523fa7a6SAndroid Build Coastguard Worker         // convert at::Tensor to torch::executor::Tensor
710*523fa7a6SAndroid Build Coastguard Worker         auto type =
711*523fa7a6SAndroid Build Coastguard Worker             torch_to_executorch_scalar_type(at_tensor.options().dtype());
712*523fa7a6SAndroid Build Coastguard Worker         size_t dim = at_tensor.dim();
713*523fa7a6SAndroid Build Coastguard Worker         // cant directly alias at::Tensor sizes and strides due to int64 vs
714*523fa7a6SAndroid Build Coastguard Worker         // int32 typing conflict
715*523fa7a6SAndroid Build Coastguard Worker         input_sizes.emplace_back(
716*523fa7a6SAndroid Build Coastguard Worker             at_tensor.sizes().begin(), at_tensor.sizes().end());
717*523fa7a6SAndroid Build Coastguard Worker         input_strides.emplace_back(
718*523fa7a6SAndroid Build Coastguard Worker             at_tensor.strides().begin(), at_tensor.strides().end());
719*523fa7a6SAndroid Build Coastguard Worker 
720*523fa7a6SAndroid Build Coastguard Worker         // Only works for MemoryFormat::Contiguous inputs
721*523fa7a6SAndroid Build Coastguard Worker         std::vector<torch::executor::Tensor::DimOrderType> dim_order;
722*523fa7a6SAndroid Build Coastguard Worker         for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) {
723*523fa7a6SAndroid Build Coastguard Worker           dim_order.push_back(cur_dim);
724*523fa7a6SAndroid Build Coastguard Worker         }
725*523fa7a6SAndroid Build Coastguard Worker         input_dim_order.push_back(std::move(dim_order));
726*523fa7a6SAndroid Build Coastguard Worker         input_tensors.emplace_back(
727*523fa7a6SAndroid Build Coastguard Worker             type,
728*523fa7a6SAndroid Build Coastguard Worker             dim,
729*523fa7a6SAndroid Build Coastguard Worker             input_sizes.back().data(),
730*523fa7a6SAndroid Build Coastguard Worker             nullptr,
731*523fa7a6SAndroid Build Coastguard Worker             input_dim_order.back().data(),
732*523fa7a6SAndroid Build Coastguard Worker             input_strides.back().data());
733*523fa7a6SAndroid Build Coastguard Worker 
734*523fa7a6SAndroid Build Coastguard Worker         torch::executor::Tensor temp =
735*523fa7a6SAndroid Build Coastguard Worker             torch::executor::Tensor(&input_tensors.back());
736*523fa7a6SAndroid Build Coastguard Worker         alias_etensor_to_attensor(at_tensor, temp);
737*523fa7a6SAndroid Build Coastguard Worker         EValue evalue(temp);
738*523fa7a6SAndroid Build Coastguard Worker #endif
739*523fa7a6SAndroid Build Coastguard Worker 
740*523fa7a6SAndroid Build Coastguard Worker         cpp_inputs.push_back(evalue);
741*523fa7a6SAndroid Build Coastguard Worker       } else if (py::isinstance<py::none>(python_input)) {
742*523fa7a6SAndroid Build Coastguard Worker         cpp_inputs.push_back(EValue());
743*523fa7a6SAndroid Build Coastguard Worker       } else if (py::isinstance<py::bool_>(python_input)) {
744*523fa7a6SAndroid Build Coastguard Worker         cpp_inputs.push_back(EValue(py::cast<bool>(python_input)));
745*523fa7a6SAndroid Build Coastguard Worker       } else if (py::isinstance<py::int_>(python_input)) {
746*523fa7a6SAndroid Build Coastguard Worker         cpp_inputs.push_back(EValue(py::cast<int64_t>(python_input)));
747*523fa7a6SAndroid Build Coastguard Worker       } else {
748*523fa7a6SAndroid Build Coastguard Worker         ET_ASSERT_UNREACHABLE_MSG("Unsupported pytype: %s", type_str.c_str());
749*523fa7a6SAndroid Build Coastguard Worker       }
750*523fa7a6SAndroid Build Coastguard Worker     }
751*523fa7a6SAndroid Build Coastguard Worker 
752*523fa7a6SAndroid Build Coastguard Worker     const auto& method = module_->get_method(method_name);
753*523fa7a6SAndroid Build Coastguard Worker     const auto num_outputs = method.outputs_size();
754*523fa7a6SAndroid Build Coastguard Worker     output_storages_ = make_output_storages(method);
755*523fa7a6SAndroid Build Coastguard Worker     std::vector<Span<uint8_t>> output_storage_spans(num_outputs);
756*523fa7a6SAndroid Build Coastguard Worker     for (int i = 0; i < output_storages_.size(); ++i) {
757*523fa7a6SAndroid Build Coastguard Worker       output_storage_spans[i] =
758*523fa7a6SAndroid Build Coastguard Worker           Span<uint8_t>(output_storages_[i].data(), output_storages_[i].size());
759*523fa7a6SAndroid Build Coastguard Worker     }
760*523fa7a6SAndroid Build Coastguard Worker     auto outputs =
761*523fa7a6SAndroid Build Coastguard Worker         module_->run_method(method_name, cpp_inputs, output_storage_spans);
762*523fa7a6SAndroid Build Coastguard Worker 
763*523fa7a6SAndroid Build Coastguard Worker     // Retrieve outputs
764*523fa7a6SAndroid Build Coastguard Worker     return get_outputs_as_py_list(outputs, clone_outputs);
765*523fa7a6SAndroid Build Coastguard Worker   }
766*523fa7a6SAndroid Build Coastguard Worker 
forwardexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule767*523fa7a6SAndroid Build Coastguard Worker   py::list forward(const py::sequence& inputs, bool clone_outputs = true) {
768*523fa7a6SAndroid Build Coastguard Worker     return run_method("forward", inputs, clone_outputs);
769*523fa7a6SAndroid Build Coastguard Worker   }
770*523fa7a6SAndroid Build Coastguard Worker 
forward_single_inputexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule771*523fa7a6SAndroid Build Coastguard Worker   py::list forward_single_input(
772*523fa7a6SAndroid Build Coastguard Worker       const torch::Tensor& inputTensor,
773*523fa7a6SAndroid Build Coastguard Worker       bool clone_outputs = true) {
774*523fa7a6SAndroid Build Coastguard Worker     py::list py_list;
775*523fa7a6SAndroid Build Coastguard Worker     py_list.append(py::cast(inputTensor));
776*523fa7a6SAndroid Build Coastguard Worker     return run_method("forward", py_list, clone_outputs);
777*523fa7a6SAndroid Build Coastguard Worker   }
778*523fa7a6SAndroid Build Coastguard Worker 
has_etdumpexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule779*523fa7a6SAndroid Build Coastguard Worker   bool has_etdump() {
780*523fa7a6SAndroid Build Coastguard Worker     return module_->has_etdump();
781*523fa7a6SAndroid Build Coastguard Worker   }
782*523fa7a6SAndroid Build Coastguard Worker 
write_etdump_result_to_fileexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule783*523fa7a6SAndroid Build Coastguard Worker   void write_etdump_result_to_file(
784*523fa7a6SAndroid Build Coastguard Worker       const std::string& path,
785*523fa7a6SAndroid Build Coastguard Worker       const py::object& debug_buffer_path) {
786*523fa7a6SAndroid Build Coastguard Worker     if (!has_etdump()) {
787*523fa7a6SAndroid Build Coastguard Worker       throw std::runtime_error("No etdump found");
788*523fa7a6SAndroid Build Coastguard Worker     }
789*523fa7a6SAndroid Build Coastguard Worker     auto& etdump = module_->etdump();
790*523fa7a6SAndroid Build Coastguard Worker     etdump_result result = etdump.get_etdump_data();
791*523fa7a6SAndroid Build Coastguard Worker     if (result.buf != nullptr && result.size > 0) {
792*523fa7a6SAndroid Build Coastguard Worker       write_data_to_file(path, result.buf, result.size);
793*523fa7a6SAndroid Build Coastguard Worker       free(result.buf);
794*523fa7a6SAndroid Build Coastguard Worker       if (module_->has_etdump_debug_buffer() &&
795*523fa7a6SAndroid Build Coastguard Worker           py::isinstance<py::str>(debug_buffer_path)) {
796*523fa7a6SAndroid Build Coastguard Worker         // Also write out the debug buffer to a separate file if requested.
797*523fa7a6SAndroid Build Coastguard Worker         std::string debug_buffer_path_str =
798*523fa7a6SAndroid Build Coastguard Worker             py::cast<py::str>(debug_buffer_path);
799*523fa7a6SAndroid Build Coastguard Worker         const auto debug_buffer = module_->get_etdump_debug_buffer();
800*523fa7a6SAndroid Build Coastguard Worker         write_data_to_file(
801*523fa7a6SAndroid Build Coastguard Worker             debug_buffer_path_str, debug_buffer.data(), debug_buffer.size());
802*523fa7a6SAndroid Build Coastguard Worker       }
803*523fa7a6SAndroid Build Coastguard Worker     } else {
804*523fa7a6SAndroid Build Coastguard Worker       ET_LOG(
805*523fa7a6SAndroid Build Coastguard Worker           Info,
806*523fa7a6SAndroid Build Coastguard Worker           "No etdump data found, try rebuilding with "
807*523fa7a6SAndroid Build Coastguard Worker           "the CMake option EXECUTORCH_ENABLE_EVENT_TRACER or with "
808*523fa7a6SAndroid Build Coastguard Worker           "buck run --config executorch.event_tracer_enabled=true");
809*523fa7a6SAndroid Build Coastguard Worker     }
810*523fa7a6SAndroid Build Coastguard Worker   }
811*523fa7a6SAndroid Build Coastguard Worker 
load_bundled_inputexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule812*523fa7a6SAndroid Build Coastguard Worker   void load_bundled_input(
813*523fa7a6SAndroid Build Coastguard Worker       PyBundledModule& m,
814*523fa7a6SAndroid Build Coastguard Worker       const std::string method_name,
815*523fa7a6SAndroid Build Coastguard Worker       size_t testset_idx) {
816*523fa7a6SAndroid Build Coastguard Worker     const void* bundled_program_ptr = m.get_bundled_program_ptr();
817*523fa7a6SAndroid Build Coastguard Worker     Error status = executorch::bundled_program::load_bundled_input(
818*523fa7a6SAndroid Build Coastguard Worker         module_->get_method(method_name), bundled_program_ptr, testset_idx);
819*523fa7a6SAndroid Build Coastguard Worker     THROW_IF_ERROR(
820*523fa7a6SAndroid Build Coastguard Worker         status,
821*523fa7a6SAndroid Build Coastguard Worker         "load_bundled_input failed with status 0x%" PRIx32,
822*523fa7a6SAndroid Build Coastguard Worker         static_cast<uint32_t>(status));
823*523fa7a6SAndroid Build Coastguard Worker   }
824*523fa7a6SAndroid Build Coastguard Worker 
verify_result_with_bundled_expected_outputexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule825*523fa7a6SAndroid Build Coastguard Worker   py::list verify_result_with_bundled_expected_output(
826*523fa7a6SAndroid Build Coastguard Worker       PyBundledModule& m,
827*523fa7a6SAndroid Build Coastguard Worker       const std::string method_name,
828*523fa7a6SAndroid Build Coastguard Worker       size_t testset_idx,
829*523fa7a6SAndroid Build Coastguard Worker       double rtol = 1e-5,
830*523fa7a6SAndroid Build Coastguard Worker       double atol = 1e-8) {
831*523fa7a6SAndroid Build Coastguard Worker     const void* bundled_program_ptr = m.get_bundled_program_ptr();
832*523fa7a6SAndroid Build Coastguard Worker     auto& method = module_->get_method(method_name);
833*523fa7a6SAndroid Build Coastguard Worker     Error status = executorch::bundled_program::load_bundled_input(
834*523fa7a6SAndroid Build Coastguard Worker         method, bundled_program_ptr, testset_idx);
835*523fa7a6SAndroid Build Coastguard Worker     THROW_IF_ERROR(
836*523fa7a6SAndroid Build Coastguard Worker         status,
837*523fa7a6SAndroid Build Coastguard Worker         "load_bundled_input failed with status 0x%" PRIx32,
838*523fa7a6SAndroid Build Coastguard Worker         static_cast<uint32_t>(status));
839*523fa7a6SAndroid Build Coastguard Worker     py::list outputs = plan_execute(method_name);
840*523fa7a6SAndroid Build Coastguard Worker     status = executorch::bundled_program::verify_method_outputs(
841*523fa7a6SAndroid Build Coastguard Worker         method, bundled_program_ptr, testset_idx, rtol, atol);
842*523fa7a6SAndroid Build Coastguard Worker     THROW_IF_ERROR(
843*523fa7a6SAndroid Build Coastguard Worker         status,
844*523fa7a6SAndroid Build Coastguard Worker         "Result verification failed with status %" PRIu32,
845*523fa7a6SAndroid Build Coastguard Worker         static_cast<uint32_t>(status));
846*523fa7a6SAndroid Build Coastguard Worker     return outputs;
847*523fa7a6SAndroid Build Coastguard Worker   }
848*523fa7a6SAndroid Build Coastguard Worker 
plan_executeexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule849*523fa7a6SAndroid Build Coastguard Worker   py::list plan_execute(
850*523fa7a6SAndroid Build Coastguard Worker       const std::string method_name,
851*523fa7a6SAndroid Build Coastguard Worker       bool clone_outputs = true) {
852*523fa7a6SAndroid Build Coastguard Worker     auto& method = module_->get_method(method_name);
853*523fa7a6SAndroid Build Coastguard Worker     // Need to pre-allocate space for outputs just like in run_method.
854*523fa7a6SAndroid Build Coastguard Worker     const auto num_outputs = method.outputs_size();
855*523fa7a6SAndroid Build Coastguard Worker     output_storages_ = make_output_storages(method);
856*523fa7a6SAndroid Build Coastguard Worker     std::vector<Span<uint8_t>> output_storage_spans(num_outputs);
857*523fa7a6SAndroid Build Coastguard Worker     for (int i = 0; i < output_storages_.size(); ++i) {
858*523fa7a6SAndroid Build Coastguard Worker       output_storage_spans[i] =
859*523fa7a6SAndroid Build Coastguard Worker           Span<uint8_t>(output_storages_[i].data(), output_storages_[i].size());
860*523fa7a6SAndroid Build Coastguard Worker     }
861*523fa7a6SAndroid Build Coastguard Worker     setup_output_storage(method, output_storage_spans);
862*523fa7a6SAndroid Build Coastguard Worker     auto status = method.execute();
863*523fa7a6SAndroid Build Coastguard Worker     THROW_IF_ERROR(
864*523fa7a6SAndroid Build Coastguard Worker         status,
865*523fa7a6SAndroid Build Coastguard Worker         "executing execution plan for method 'forward' failed with error: 0x%" PRIx32,
866*523fa7a6SAndroid Build Coastguard Worker         static_cast<uint32_t>(status));
867*523fa7a6SAndroid Build Coastguard Worker     const auto outputs = module_->get_outputs(method_name);
868*523fa7a6SAndroid Build Coastguard Worker     return get_outputs_as_py_list(outputs, clone_outputs);
869*523fa7a6SAndroid Build Coastguard Worker   }
870*523fa7a6SAndroid Build Coastguard Worker 
get_outputs_as_py_listexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule871*523fa7a6SAndroid Build Coastguard Worker   py::list get_outputs_as_py_list(
872*523fa7a6SAndroid Build Coastguard Worker       const std::vector<EValue>& outputs,
873*523fa7a6SAndroid Build Coastguard Worker       bool clone_outputs = true) {
874*523fa7a6SAndroid Build Coastguard Worker     const auto outputs_size = outputs.size();
875*523fa7a6SAndroid Build Coastguard Worker     py::list list(outputs_size);
876*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < outputs_size; ++i) {
877*523fa7a6SAndroid Build Coastguard Worker       auto& v = outputs[i];
878*523fa7a6SAndroid Build Coastguard Worker       if (Tag::None == v.tag) {
879*523fa7a6SAndroid Build Coastguard Worker         list[i] = py::none();
880*523fa7a6SAndroid Build Coastguard Worker       } else if (Tag::Int == v.tag) {
881*523fa7a6SAndroid Build Coastguard Worker         list[i] = py::cast(v.toInt());
882*523fa7a6SAndroid Build Coastguard Worker       } else if (Tag::Double == v.tag) {
883*523fa7a6SAndroid Build Coastguard Worker         list[i] = py::cast(v.toDouble());
884*523fa7a6SAndroid Build Coastguard Worker       } else if (Tag::Bool == v.tag) {
885*523fa7a6SAndroid Build Coastguard Worker         list[i] = py::cast(v.toBool());
886*523fa7a6SAndroid Build Coastguard Worker       } else if (Tag::String == v.tag) {
887*523fa7a6SAndroid Build Coastguard Worker         list[i] = py::cast(std::string(v.toString().data()));
888*523fa7a6SAndroid Build Coastguard Worker       } else if (Tag::Tensor == v.tag) {
889*523fa7a6SAndroid Build Coastguard Worker #ifdef USE_ATEN_LIB
890*523fa7a6SAndroid Build Coastguard Worker         // Clone so the outputs in python do not share a lifetime with the
891*523fa7a6SAndroid Build Coastguard Worker         // module object
892*523fa7a6SAndroid Build Coastguard Worker         if (clone_outputs) {
893*523fa7a6SAndroid Build Coastguard Worker           list[i] = py::cast(v.toTensor().clone());
894*523fa7a6SAndroid Build Coastguard Worker         } else {
895*523fa7a6SAndroid Build Coastguard Worker           list[i] = py::cast(v.toTensor());
896*523fa7a6SAndroid Build Coastguard Worker         }
897*523fa7a6SAndroid Build Coastguard Worker #else
898*523fa7a6SAndroid Build Coastguard Worker         if (clone_outputs) {
899*523fa7a6SAndroid Build Coastguard Worker           list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()).clone());
900*523fa7a6SAndroid Build Coastguard Worker         } else {
901*523fa7a6SAndroid Build Coastguard Worker           list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()));
902*523fa7a6SAndroid Build Coastguard Worker         }
903*523fa7a6SAndroid Build Coastguard Worker #endif
904*523fa7a6SAndroid Build Coastguard Worker       } else {
905*523fa7a6SAndroid Build Coastguard Worker         ET_ASSERT_UNREACHABLE_MSG("Invalid model output type");
906*523fa7a6SAndroid Build Coastguard Worker       }
907*523fa7a6SAndroid Build Coastguard Worker     }
908*523fa7a6SAndroid Build Coastguard Worker     return list;
909*523fa7a6SAndroid Build Coastguard Worker   }
910*523fa7a6SAndroid Build Coastguard Worker 
method_metaexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule911*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<PyMethodMeta> method_meta(const std::string method_name) {
912*523fa7a6SAndroid Build Coastguard Worker     auto& method = module_->get_method(method_name);
913*523fa7a6SAndroid Build Coastguard Worker     return std::make_unique<PyMethodMeta>(module_, method.method_meta());
914*523fa7a6SAndroid Build Coastguard Worker   }
915*523fa7a6SAndroid Build Coastguard Worker 
method_namesexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule916*523fa7a6SAndroid Build Coastguard Worker   std::vector<std::string> method_names() {
917*523fa7a6SAndroid Build Coastguard Worker     return module_->method_names();
918*523fa7a6SAndroid Build Coastguard Worker   }
919*523fa7a6SAndroid Build Coastguard Worker 
920*523fa7a6SAndroid Build Coastguard Worker  private:
921*523fa7a6SAndroid Build Coastguard Worker   std::shared_ptr<Module> module_;
922*523fa7a6SAndroid Build Coastguard Worker   // Need to keep-alive output storages until they can be compared in case of
923*523fa7a6SAndroid Build Coastguard Worker   // bundled programs.
924*523fa7a6SAndroid Build Coastguard Worker   std::vector<std::vector<uint8_t>> output_storages_;
925*523fa7a6SAndroid Build Coastguard Worker 
make_output_storagesexecutorch::extension::pybindings::__anon7828bc6a0111::PyModule926*523fa7a6SAndroid Build Coastguard Worker   std::vector<std::vector<uint8_t>> make_output_storages(const Method& method) {
927*523fa7a6SAndroid Build Coastguard Worker     const auto num_outputs = method.outputs_size();
928*523fa7a6SAndroid Build Coastguard Worker     // Create a buffer for each output tensor. Memory planned outputs and non
929*523fa7a6SAndroid Build Coastguard Worker     // tensor outputs get an empty buffer in this list which is ignored later.
930*523fa7a6SAndroid Build Coastguard Worker     std::vector<std::vector<uint8_t>> output_storages;
931*523fa7a6SAndroid Build Coastguard Worker     output_storages_.reserve(num_outputs);
932*523fa7a6SAndroid Build Coastguard Worker     auto meta = method.method_meta();
933*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < num_outputs; ++i) {
934*523fa7a6SAndroid Build Coastguard Worker       auto output_type = meta.output_tag(i);
935*523fa7a6SAndroid Build Coastguard Worker       THROW_IF_ERROR(
936*523fa7a6SAndroid Build Coastguard Worker           output_type.error(), "Failed to get output type for output %zu", i);
937*523fa7a6SAndroid Build Coastguard Worker       if (output_type.get() != Tag::Tensor) {
938*523fa7a6SAndroid Build Coastguard Worker         // Skip allocating storage for non-tensor outputs.
939*523fa7a6SAndroid Build Coastguard Worker         output_storages.emplace_back();
940*523fa7a6SAndroid Build Coastguard Worker         continue;
941*523fa7a6SAndroid Build Coastguard Worker       }
942*523fa7a6SAndroid Build Coastguard Worker       const auto& output_tensor_meta =
943*523fa7a6SAndroid Build Coastguard Worker           method.method_meta().output_tensor_meta(i);
944*523fa7a6SAndroid Build Coastguard Worker       THROW_IF_ERROR(
945*523fa7a6SAndroid Build Coastguard Worker           output_tensor_meta.error(),
946*523fa7a6SAndroid Build Coastguard Worker           "Failed to get output tensor meta for output %zu",
947*523fa7a6SAndroid Build Coastguard Worker           i);
948*523fa7a6SAndroid Build Coastguard Worker       if (output_tensor_meta.get().is_memory_planned()) {
949*523fa7a6SAndroid Build Coastguard Worker         // Skip allocating storage for planned memory outputs.
950*523fa7a6SAndroid Build Coastguard Worker         output_storages.emplace_back();
951*523fa7a6SAndroid Build Coastguard Worker         continue;
952*523fa7a6SAndroid Build Coastguard Worker       }
953*523fa7a6SAndroid Build Coastguard Worker       // Allocate storage for the output tensor.
954*523fa7a6SAndroid Build Coastguard Worker       const size_t output_size = output_tensor_meta.get().nbytes();
955*523fa7a6SAndroid Build Coastguard Worker       output_storages.emplace_back(output_size);
956*523fa7a6SAndroid Build Coastguard Worker     }
957*523fa7a6SAndroid Build Coastguard Worker     return output_storages;
958*523fa7a6SAndroid Build Coastguard Worker   }
959*523fa7a6SAndroid Build Coastguard Worker };
960*523fa7a6SAndroid Build Coastguard Worker 
create_profile_block(const std::string & name)961*523fa7a6SAndroid Build Coastguard Worker void create_profile_block(const std::string& name) {
962*523fa7a6SAndroid Build Coastguard Worker   EXECUTORCH_PROFILE_CREATE_BLOCK(name.c_str());
963*523fa7a6SAndroid Build Coastguard Worker }
964*523fa7a6SAndroid Build Coastguard Worker 
get_operator_names()965*523fa7a6SAndroid Build Coastguard Worker py::list get_operator_names() {
966*523fa7a6SAndroid Build Coastguard Worker   Span<const Kernel> kernels = get_registered_kernels();
967*523fa7a6SAndroid Build Coastguard Worker   py::list res;
968*523fa7a6SAndroid Build Coastguard Worker   for (const Kernel& k : kernels) {
969*523fa7a6SAndroid Build Coastguard Worker     if (k.name_ != nullptr) {
970*523fa7a6SAndroid Build Coastguard Worker       res.append(py::cast(k.name_));
971*523fa7a6SAndroid Build Coastguard Worker     }
972*523fa7a6SAndroid Build Coastguard Worker   }
973*523fa7a6SAndroid Build Coastguard Worker   return res;
974*523fa7a6SAndroid Build Coastguard Worker }
975*523fa7a6SAndroid Build Coastguard Worker 
976*523fa7a6SAndroid Build Coastguard Worker } // namespace
977*523fa7a6SAndroid Build Coastguard Worker 
PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME,m)978*523fa7a6SAndroid Build Coastguard Worker PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
979*523fa7a6SAndroid Build Coastguard Worker   // Redirects cout and cerr for function calls this guards to the python env.
980*523fa7a6SAndroid Build Coastguard Worker   auto call_guard = py::
981*523fa7a6SAndroid Build Coastguard Worker       call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>();
982*523fa7a6SAndroid Build Coastguard Worker 
983*523fa7a6SAndroid Build Coastguard Worker   // Bind the verification enum to python.
984*523fa7a6SAndroid Build Coastguard Worker   py::enum_<Program::Verification>(m, "Verification")
985*523fa7a6SAndroid Build Coastguard Worker       .value("Minimal", Program::Verification::Minimal)
986*523fa7a6SAndroid Build Coastguard Worker       .value("InternalConsistency", Program::Verification::InternalConsistency);
987*523fa7a6SAndroid Build Coastguard Worker 
988*523fa7a6SAndroid Build Coastguard Worker   m.def(
989*523fa7a6SAndroid Build Coastguard Worker       "_load_for_executorch",
990*523fa7a6SAndroid Build Coastguard Worker       PyModule::load_from_file,
991*523fa7a6SAndroid Build Coastguard Worker       py::arg("path"),
992*523fa7a6SAndroid Build Coastguard Worker       py::arg("enable_etdump") = false,
993*523fa7a6SAndroid Build Coastguard Worker       py::arg("debug_buffer_size") = 0,
994*523fa7a6SAndroid Build Coastguard Worker       py::arg("program_verification") =
995*523fa7a6SAndroid Build Coastguard Worker           Program::Verification::InternalConsistency,
996*523fa7a6SAndroid Build Coastguard Worker       call_guard);
997*523fa7a6SAndroid Build Coastguard Worker   m.def(
998*523fa7a6SAndroid Build Coastguard Worker       "_load_for_executorch_from_buffer",
999*523fa7a6SAndroid Build Coastguard Worker       &PyModule::load_from_buffer,
1000*523fa7a6SAndroid Build Coastguard Worker       py::arg("buffer"),
1001*523fa7a6SAndroid Build Coastguard Worker       py::arg("enable_etdump") = false,
1002*523fa7a6SAndroid Build Coastguard Worker       py::arg("debug_buffer_size") = 0,
1003*523fa7a6SAndroid Build Coastguard Worker       py::arg("program_verification") =
1004*523fa7a6SAndroid Build Coastguard Worker           Program::Verification::InternalConsistency,
1005*523fa7a6SAndroid Build Coastguard Worker       call_guard);
1006*523fa7a6SAndroid Build Coastguard Worker   m.def(
1007*523fa7a6SAndroid Build Coastguard Worker       "_load_for_executorch_from_bundled_program",
1008*523fa7a6SAndroid Build Coastguard Worker       &PyModule::load_from_bundled_program,
1009*523fa7a6SAndroid Build Coastguard Worker       py::arg("ptr"),
1010*523fa7a6SAndroid Build Coastguard Worker       py::arg("enable_etdump") = false,
1011*523fa7a6SAndroid Build Coastguard Worker       py::arg("debug_buffer_size") = 0,
1012*523fa7a6SAndroid Build Coastguard Worker       call_guard);
1013*523fa7a6SAndroid Build Coastguard Worker   m.def(
1014*523fa7a6SAndroid Build Coastguard Worker       "_load_bundled_program_from_buffer",
1015*523fa7a6SAndroid Build Coastguard Worker       &PyBundledModule::load_from_buffer,
1016*523fa7a6SAndroid Build Coastguard Worker       py::arg("buffer"),
1017*523fa7a6SAndroid Build Coastguard Worker       py::arg("non_const_pool_size") = kDEFAULT_BUNDLED_INPUT_POOL_SIZE,
1018*523fa7a6SAndroid Build Coastguard Worker       call_guard);
1019*523fa7a6SAndroid Build Coastguard Worker   m.def(
1020*523fa7a6SAndroid Build Coastguard Worker       "_dump_profile_results",
1021*523fa7a6SAndroid Build Coastguard Worker       []() {
1022*523fa7a6SAndroid Build Coastguard Worker         prof_result_t prof_result;
1023*523fa7a6SAndroid Build Coastguard Worker         EXECUTORCH_DUMP_PROFILE_RESULTS(&prof_result);
1024*523fa7a6SAndroid Build Coastguard Worker         return py::bytes(
1025*523fa7a6SAndroid Build Coastguard Worker             reinterpret_cast<const char*>(prof_result.prof_data),
1026*523fa7a6SAndroid Build Coastguard Worker             prof_result.num_bytes);
1027*523fa7a6SAndroid Build Coastguard Worker       },
1028*523fa7a6SAndroid Build Coastguard Worker       call_guard);
1029*523fa7a6SAndroid Build Coastguard Worker   m.def("_get_operator_names", &get_operator_names);
1030*523fa7a6SAndroid Build Coastguard Worker   m.def("_create_profile_block", &create_profile_block, call_guard);
1031*523fa7a6SAndroid Build Coastguard Worker   m.def(
1032*523fa7a6SAndroid Build Coastguard Worker       "_reset_profile_results",
1033*523fa7a6SAndroid Build Coastguard Worker       []() { EXECUTORCH_RESET_PROFILE_RESULTS(); },
1034*523fa7a6SAndroid Build Coastguard Worker       call_guard);
1035*523fa7a6SAndroid Build Coastguard Worker 
1036*523fa7a6SAndroid Build Coastguard Worker   py::class_<PyModule>(m, "ExecuTorchModule")
1037*523fa7a6SAndroid Build Coastguard Worker       .def("load_bundled_input", &PyModule::load_bundled_input, call_guard)
1038*523fa7a6SAndroid Build Coastguard Worker       .def(
1039*523fa7a6SAndroid Build Coastguard Worker           "verify_result_with_bundled_expected_output",
1040*523fa7a6SAndroid Build Coastguard Worker           &PyModule::verify_result_with_bundled_expected_output,
1041*523fa7a6SAndroid Build Coastguard Worker           py::arg("bundle"),
1042*523fa7a6SAndroid Build Coastguard Worker           py::arg("method_name"),
1043*523fa7a6SAndroid Build Coastguard Worker           py::arg("testset_idx"),
1044*523fa7a6SAndroid Build Coastguard Worker           py::arg("rtol") = 1e-5,
1045*523fa7a6SAndroid Build Coastguard Worker           py::arg("atol") = 1e-8,
1046*523fa7a6SAndroid Build Coastguard Worker           call_guard)
1047*523fa7a6SAndroid Build Coastguard Worker       .def(
1048*523fa7a6SAndroid Build Coastguard Worker           "plan_execute",
1049*523fa7a6SAndroid Build Coastguard Worker           &PyModule::plan_execute,
1050*523fa7a6SAndroid Build Coastguard Worker           py::arg("method_name"),
1051*523fa7a6SAndroid Build Coastguard Worker           py::arg("clone_outputs") = true,
1052*523fa7a6SAndroid Build Coastguard Worker           call_guard)
1053*523fa7a6SAndroid Build Coastguard Worker       .def(
1054*523fa7a6SAndroid Build Coastguard Worker           "method_meta",
1055*523fa7a6SAndroid Build Coastguard Worker           &PyModule::method_meta,
1056*523fa7a6SAndroid Build Coastguard Worker           py::arg("method_name"),
1057*523fa7a6SAndroid Build Coastguard Worker           call_guard)
1058*523fa7a6SAndroid Build Coastguard Worker       .def("method_names", &PyModule::method_names, call_guard)
1059*523fa7a6SAndroid Build Coastguard Worker       .def(
1060*523fa7a6SAndroid Build Coastguard Worker           "run_method",
1061*523fa7a6SAndroid Build Coastguard Worker           &PyModule::run_method,
1062*523fa7a6SAndroid Build Coastguard Worker           py::arg("method_name"),
1063*523fa7a6SAndroid Build Coastguard Worker           py::arg("inputs") = py::list(),
1064*523fa7a6SAndroid Build Coastguard Worker           py::arg("clone_outputs") = true,
1065*523fa7a6SAndroid Build Coastguard Worker           call_guard)
1066*523fa7a6SAndroid Build Coastguard Worker       .def(
1067*523fa7a6SAndroid Build Coastguard Worker           "forward",
1068*523fa7a6SAndroid Build Coastguard Worker           &PyModule::forward,
1069*523fa7a6SAndroid Build Coastguard Worker           py::arg("inputs") = py::list(),
1070*523fa7a6SAndroid Build Coastguard Worker           py::arg("clone_outputs") = true,
1071*523fa7a6SAndroid Build Coastguard Worker           call_guard)
1072*523fa7a6SAndroid Build Coastguard Worker       .def("has_etdump", &PyModule::has_etdump, call_guard)
1073*523fa7a6SAndroid Build Coastguard Worker       .def(
1074*523fa7a6SAndroid Build Coastguard Worker           "write_etdump_result_to_file",
1075*523fa7a6SAndroid Build Coastguard Worker           &PyModule::write_etdump_result_to_file,
1076*523fa7a6SAndroid Build Coastguard Worker           py::arg("path"),
1077*523fa7a6SAndroid Build Coastguard Worker           py::arg("debug_buffer_path") = py::none(),
1078*523fa7a6SAndroid Build Coastguard Worker           call_guard)
1079*523fa7a6SAndroid Build Coastguard Worker       .def(
1080*523fa7a6SAndroid Build Coastguard Worker           "__call__",
1081*523fa7a6SAndroid Build Coastguard Worker           &PyModule::forward,
1082*523fa7a6SAndroid Build Coastguard Worker           py::arg("inputs") = py::list(),
1083*523fa7a6SAndroid Build Coastguard Worker           py::arg("clone_outputs") = true,
1084*523fa7a6SAndroid Build Coastguard Worker           call_guard)
1085*523fa7a6SAndroid Build Coastguard Worker       .def(
1086*523fa7a6SAndroid Build Coastguard Worker           "__call__",
1087*523fa7a6SAndroid Build Coastguard Worker           &PyModule::forward_single_input,
1088*523fa7a6SAndroid Build Coastguard Worker           py::arg("inputs") = py::list(),
1089*523fa7a6SAndroid Build Coastguard Worker           py::arg("clone_outputs") = true,
1090*523fa7a6SAndroid Build Coastguard Worker           call_guard);
1091*523fa7a6SAndroid Build Coastguard Worker 
1092*523fa7a6SAndroid Build Coastguard Worker   py::class_<PyBundledModule>(m, "BundledModule");
1093*523fa7a6SAndroid Build Coastguard Worker   py::class_<PyTensorInfo>(m, "TensorInfo")
1094*523fa7a6SAndroid Build Coastguard Worker       .def("sizes", &PyTensorInfo::sizes, call_guard)
1095*523fa7a6SAndroid Build Coastguard Worker       .def("dtype", &PyTensorInfo::dtype, call_guard)
1096*523fa7a6SAndroid Build Coastguard Worker       .def("is_memory_planned", &PyTensorInfo::is_memory_planned, call_guard)
1097*523fa7a6SAndroid Build Coastguard Worker       .def("nbytes", &PyTensorInfo::nbytes, call_guard)
1098*523fa7a6SAndroid Build Coastguard Worker       .def("__repr__", &PyTensorInfo::repr, call_guard);
1099*523fa7a6SAndroid Build Coastguard Worker   py::class_<PyMethodMeta>(m, "MethodMeta")
1100*523fa7a6SAndroid Build Coastguard Worker       .def("name", &PyMethodMeta::name, call_guard)
1101*523fa7a6SAndroid Build Coastguard Worker       .def("num_inputs", &PyMethodMeta::num_inputs, call_guard)
1102*523fa7a6SAndroid Build Coastguard Worker       .def("num_outputs", &PyMethodMeta::num_outputs, call_guard)
1103*523fa7a6SAndroid Build Coastguard Worker       .def(
1104*523fa7a6SAndroid Build Coastguard Worker           "input_tensor_meta",
1105*523fa7a6SAndroid Build Coastguard Worker           &PyMethodMeta::input_tensor_meta,
1106*523fa7a6SAndroid Build Coastguard Worker           py::arg("index"),
1107*523fa7a6SAndroid Build Coastguard Worker           call_guard)
1108*523fa7a6SAndroid Build Coastguard Worker       .def(
1109*523fa7a6SAndroid Build Coastguard Worker           "output_tensor_meta",
1110*523fa7a6SAndroid Build Coastguard Worker           &PyMethodMeta::output_tensor_meta,
1111*523fa7a6SAndroid Build Coastguard Worker           py::arg("index"),
1112*523fa7a6SAndroid Build Coastguard Worker           call_guard)
1113*523fa7a6SAndroid Build Coastguard Worker       .def("__repr__", &PyMethodMeta::repr, call_guard);
1114*523fa7a6SAndroid Build Coastguard Worker }
1115*523fa7a6SAndroid Build Coastguard Worker 
1116*523fa7a6SAndroid Build Coastguard Worker } // namespace pybindings
1117*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
1118*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
1119