xref: /aosp_15_r20/external/executorch/extension/module/module.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/module/module.h>
10 
11 #include <executorch/extension/data_loader/file_data_loader.h>
12 #include <executorch/extension/data_loader/mmap_data_loader.h>
13 #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
14 #include <executorch/runtime/platform/runtime.h>
15 
16 /**
17  * Unwrap a Result to obtain its value (direct object, not a pointer).
18  * If the Result contains an error, propagate the error via trivial function
19  * return. The macro wraps the object into a unique_ptr.
20  *
21  * Note: A function using ET_UNWRAP_UNIQUE should itself return a Result or
22  * Error.
23  *
24  * @param[in] result__ Expression yielding the result to unwrap.
25  */
26 #define ET_UNWRAP_UNIQUE(result__)                                     \
27   ({                                                                   \
28     auto et_result__ = (result__);                                     \
29     if (!et_result__.ok()) {                                           \
30       return et_result__.error();                                      \
31     }                                                                  \
32     std::make_unique<std::remove_reference_t<decltype(*et_result__)>>( \
33         std::move(*et_result__));                                      \
34   })
35 
36 namespace executorch {
37 namespace extension {
38 
Module(const std::string & file_path,const LoadMode load_mode,std::unique_ptr<runtime::EventTracer> event_tracer)39 Module::Module(
40     const std::string& file_path,
41     const LoadMode load_mode,
42     std::unique_ptr<runtime::EventTracer> event_tracer)
43     : file_path_(file_path),
44       load_mode_(load_mode),
45       memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
46       temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
47       event_tracer_(std::move(event_tracer)) {
48   runtime::runtime_init();
49 }
50 
Module(std::unique_ptr<runtime::DataLoader> data_loader,std::unique_ptr<runtime::MemoryAllocator> memory_allocator,std::unique_ptr<runtime::MemoryAllocator> temp_allocator,std::unique_ptr<runtime::EventTracer> event_tracer)51 Module::Module(
52     std::unique_ptr<runtime::DataLoader> data_loader,
53     std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
54     std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
55     std::unique_ptr<runtime::EventTracer> event_tracer)
56     : data_loader_(std::move(data_loader)),
57       memory_allocator_(
58           memory_allocator ? std::move(memory_allocator)
59                            : std::make_unique<MallocMemoryAllocator>()),
60       temp_allocator_(
61           temp_allocator ? std::move(temp_allocator)
62                          : std::make_unique<MallocMemoryAllocator>()),
63       event_tracer_(std::move(event_tracer)) {
64   runtime::runtime_init();
65 }
66 
Module(std::shared_ptr<runtime::Program> program,std::unique_ptr<runtime::MemoryAllocator> memory_allocator,std::unique_ptr<runtime::MemoryAllocator> temp_allocator,std::unique_ptr<runtime::EventTracer> event_tracer)67 Module::Module(
68     std::shared_ptr<runtime::Program> program,
69     std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
70     std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
71     std::unique_ptr<runtime::EventTracer> event_tracer)
72     : program_(std::move(program)),
73       memory_allocator_(
74           memory_allocator ? std::move(memory_allocator)
75                            : std::make_unique<MallocMemoryAllocator>()),
76       temp_allocator_(
77           temp_allocator ? std::move(temp_allocator)
78                          : std::make_unique<MallocMemoryAllocator>()),
79       event_tracer_(std::move(event_tracer)) {
80   runtime::runtime_init();
81 }
82 
load(const runtime::Program::Verification verification)83 runtime::Error Module::load(const runtime::Program::Verification verification) {
84   if (!is_loaded()) {
85     if (!data_loader_) {
86       switch (load_mode_) {
87         case LoadMode::File:
88           data_loader_ =
89               ET_UNWRAP_UNIQUE(FileDataLoader::from(file_path_.c_str()));
90           break;
91         case LoadMode::Mmap:
92           data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
93               file_path_.c_str(), MmapDataLoader::MlockConfig::NoMlock));
94           break;
95         case LoadMode::MmapUseMlock:
96           data_loader_ =
97               ET_UNWRAP_UNIQUE(MmapDataLoader::from(file_path_.c_str()));
98           break;
99         case LoadMode::MmapUseMlockIgnoreErrors:
100           data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
101               file_path_.c_str(),
102               MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
103           break;
104       }
105     };
106     auto program = ET_UNWRAP_UNIQUE(
107         runtime::Program::load(data_loader_.get(), verification));
108     program_ = std::shared_ptr<runtime::Program>(
109         program.release(), [](runtime::Program* pointer) { delete pointer; });
110   }
111   return runtime::Error::Ok;
112 }
113 
method_names()114 runtime::Result<std::unordered_set<std::string>> Module::method_names() {
115   ET_CHECK_OK_OR_RETURN_ERROR(load());
116   const auto method_count = program_->num_methods();
117   std::unordered_set<std::string> result;
118   result.reserve(method_count);
119 
120   for (auto index = 0; index < method_count; ++index) {
121     result.emplace(program_->get_method_name(index).get());
122   }
123   return result;
124 }
125 
load_method(const std::string & method_name,torch::executor::EventTracer * event_tracer)126 runtime::Error Module::load_method(
127     const std::string& method_name,
128     torch::executor::EventTracer* event_tracer) {
129   if (!is_method_loaded(method_name)) {
130     ET_CHECK_OK_OR_RETURN_ERROR(load());
131 
132     MethodHolder method_holder;
133     const auto method_metadata =
134         ET_UNWRAP(program_->method_meta(method_name.c_str()));
135     const auto planned_buffersCount =
136         method_metadata.num_memory_planned_buffers();
137     method_holder.planned_buffers.reserve(planned_buffersCount);
138     method_holder.planned_spans.reserve(planned_buffersCount);
139 
140     for (auto index = 0; index < planned_buffersCount; ++index) {
141       const auto buffer_size =
142           method_metadata.memory_planned_buffer_size(index).get();
143       method_holder.planned_buffers.emplace_back(buffer_size);
144       method_holder.planned_spans.emplace_back(
145           method_holder.planned_buffers.back().data(), buffer_size);
146     }
147     method_holder.planned_memory =
148         std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
149             method_holder.planned_spans.data(),
150             method_holder.planned_spans.size()));
151     method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
152         memory_allocator_.get(),
153         method_holder.planned_memory.get(),
154         temp_allocator_.get());
155     method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
156         method_name.c_str(),
157         method_holder.memory_manager.get(),
158         event_tracer ? event_tracer : this->event_tracer()));
159     method_holder.inputs.resize(method_holder.method->inputs_size());
160     methods_.emplace(method_name, std::move(method_holder));
161   }
162   return runtime::Error::Ok;
163 }
164 
method_meta(const std::string & method_name)165 runtime::Result<runtime::MethodMeta> Module::method_meta(
166     const std::string& method_name) {
167   ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
168   return methods_.at(method_name).method->method_meta();
169 }
170 
execute(const std::string & method_name,const std::vector<runtime::EValue> & input_values)171 runtime::Result<std::vector<runtime::EValue>> Module::execute(
172     const std::string& method_name,
173     const std::vector<runtime::EValue>& input_values) {
174   ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
175   auto& method = methods_.at(method_name).method;
176   auto& inputs = methods_.at(method_name).inputs;
177 
178   for (size_t i = 0; i < input_values.size(); ++i) {
179     if (!input_values[i].isNone()) {
180       inputs[i] = input_values[i];
181     }
182   }
183   for (size_t i = 0; i < inputs.size(); ++i) {
184     ET_CHECK_OR_RETURN_ERROR(
185         !inputs[i].isNone(), InvalidArgument, "input %zu is none", i);
186   }
187   ET_CHECK_OK_OR_RETURN_ERROR(method->set_inputs(
188       exec_aten::ArrayRef<runtime::EValue>(inputs.data(), inputs.size())));
189   ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
190 
191   const auto outputs_size = method->outputs_size();
192   std::vector<runtime::EValue> outputs(outputs_size);
193   ET_CHECK_OK_OR_RETURN_ERROR(
194       method->get_outputs(outputs.data(), outputs_size));
195 
196   return outputs;
197 }
198 
set_input(const std::string & method_name,const runtime::EValue & input_value,size_t input_index)199 runtime::Error Module::set_input(
200     const std::string& method_name,
201     const runtime::EValue& input_value,
202     size_t input_index) {
203   ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
204   methods_.at(method_name).inputs.at(input_index) = input_value;
205   return runtime::Error::Ok;
206 }
207 
set_inputs(const std::string & method_name,const std::vector<runtime::EValue> & input_values)208 runtime::Error Module::set_inputs(
209     const std::string& method_name,
210     const std::vector<runtime::EValue>& input_values) {
211   ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
212   auto& inputs = methods_.at(method_name).inputs;
213   ET_CHECK_OR_RETURN_ERROR(
214       inputs.size() == input_values.size(),
215       InvalidArgument,
216       "input size: %zu does not match method input size: %zu",
217       input_values.size(),
218       inputs.size());
219   inputs = input_values;
220   return runtime::Error::Ok;
221 }
222 
set_output(const std::string & method_name,runtime::EValue output_value,size_t output_index)223 runtime::Error Module::set_output(
224     const std::string& method_name,
225     runtime::EValue output_value,
226     size_t output_index) {
227   ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
228   auto& method = methods_.at(method_name).method;
229   ET_CHECK_OR_RETURN_ERROR(
230       output_value.isTensor(),
231       InvalidArgument,
232       "output type: %zu is not tensor",
233       (size_t)output_value.tag);
234   const auto& output_tensor = output_value.toTensor();
235   return method->set_output_data_ptr(
236       output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
237 }
238 
239 } // namespace extension
240 } // namespace executorch
241