1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/module/module.h>
10
11 #include <executorch/extension/data_loader/file_data_loader.h>
12 #include <executorch/extension/data_loader/mmap_data_loader.h>
13 #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
14 #include <executorch/runtime/platform/runtime.h>
15
16 /**
17 * Unwrap a Result to obtain its value (direct object, not a pointer).
18 * If the Result contains an error, propagate the error via trivial function
19 * return. The macro wraps the object into a unique_ptr.
20 *
21 * Note: A function using ET_UNWRAP_UNIQUE should itself return a Result or
22 * Error.
23 *
24 * @param[in] result__ Expression yielding the result to unwrap.
25 */
26 #define ET_UNWRAP_UNIQUE(result__) \
27 ({ \
28 auto et_result__ = (result__); \
29 if (!et_result__.ok()) { \
30 return et_result__.error(); \
31 } \
32 std::make_unique<std::remove_reference_t<decltype(*et_result__)>>( \
33 std::move(*et_result__)); \
34 })
35
36 namespace executorch {
37 namespace extension {
38
Module(const std::string & file_path,const LoadMode load_mode,std::unique_ptr<runtime::EventTracer> event_tracer)39 Module::Module(
40 const std::string& file_path,
41 const LoadMode load_mode,
42 std::unique_ptr<runtime::EventTracer> event_tracer)
43 : file_path_(file_path),
44 load_mode_(load_mode),
45 memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
46 temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
47 event_tracer_(std::move(event_tracer)) {
48 runtime::runtime_init();
49 }
50
Module(std::unique_ptr<runtime::DataLoader> data_loader,std::unique_ptr<runtime::MemoryAllocator> memory_allocator,std::unique_ptr<runtime::MemoryAllocator> temp_allocator,std::unique_ptr<runtime::EventTracer> event_tracer)51 Module::Module(
52 std::unique_ptr<runtime::DataLoader> data_loader,
53 std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
54 std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
55 std::unique_ptr<runtime::EventTracer> event_tracer)
56 : data_loader_(std::move(data_loader)),
57 memory_allocator_(
58 memory_allocator ? std::move(memory_allocator)
59 : std::make_unique<MallocMemoryAllocator>()),
60 temp_allocator_(
61 temp_allocator ? std::move(temp_allocator)
62 : std::make_unique<MallocMemoryAllocator>()),
63 event_tracer_(std::move(event_tracer)) {
64 runtime::runtime_init();
65 }
66
Module(std::shared_ptr<runtime::Program> program,std::unique_ptr<runtime::MemoryAllocator> memory_allocator,std::unique_ptr<runtime::MemoryAllocator> temp_allocator,std::unique_ptr<runtime::EventTracer> event_tracer)67 Module::Module(
68 std::shared_ptr<runtime::Program> program,
69 std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
70 std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
71 std::unique_ptr<runtime::EventTracer> event_tracer)
72 : program_(std::move(program)),
73 memory_allocator_(
74 memory_allocator ? std::move(memory_allocator)
75 : std::make_unique<MallocMemoryAllocator>()),
76 temp_allocator_(
77 temp_allocator ? std::move(temp_allocator)
78 : std::make_unique<MallocMemoryAllocator>()),
79 event_tracer_(std::move(event_tracer)) {
80 runtime::runtime_init();
81 }
82
load(const runtime::Program::Verification verification)83 runtime::Error Module::load(const runtime::Program::Verification verification) {
84 if (!is_loaded()) {
85 if (!data_loader_) {
86 switch (load_mode_) {
87 case LoadMode::File:
88 data_loader_ =
89 ET_UNWRAP_UNIQUE(FileDataLoader::from(file_path_.c_str()));
90 break;
91 case LoadMode::Mmap:
92 data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
93 file_path_.c_str(), MmapDataLoader::MlockConfig::NoMlock));
94 break;
95 case LoadMode::MmapUseMlock:
96 data_loader_ =
97 ET_UNWRAP_UNIQUE(MmapDataLoader::from(file_path_.c_str()));
98 break;
99 case LoadMode::MmapUseMlockIgnoreErrors:
100 data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
101 file_path_.c_str(),
102 MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
103 break;
104 }
105 };
106 auto program = ET_UNWRAP_UNIQUE(
107 runtime::Program::load(data_loader_.get(), verification));
108 program_ = std::shared_ptr<runtime::Program>(
109 program.release(), [](runtime::Program* pointer) { delete pointer; });
110 }
111 return runtime::Error::Ok;
112 }
113
method_names()114 runtime::Result<std::unordered_set<std::string>> Module::method_names() {
115 ET_CHECK_OK_OR_RETURN_ERROR(load());
116 const auto method_count = program_->num_methods();
117 std::unordered_set<std::string> result;
118 result.reserve(method_count);
119
120 for (auto index = 0; index < method_count; ++index) {
121 result.emplace(program_->get_method_name(index).get());
122 }
123 return result;
124 }
125
load_method(const std::string & method_name,torch::executor::EventTracer * event_tracer)126 runtime::Error Module::load_method(
127 const std::string& method_name,
128 torch::executor::EventTracer* event_tracer) {
129 if (!is_method_loaded(method_name)) {
130 ET_CHECK_OK_OR_RETURN_ERROR(load());
131
132 MethodHolder method_holder;
133 const auto method_metadata =
134 ET_UNWRAP(program_->method_meta(method_name.c_str()));
135 const auto planned_buffersCount =
136 method_metadata.num_memory_planned_buffers();
137 method_holder.planned_buffers.reserve(planned_buffersCount);
138 method_holder.planned_spans.reserve(planned_buffersCount);
139
140 for (auto index = 0; index < planned_buffersCount; ++index) {
141 const auto buffer_size =
142 method_metadata.memory_planned_buffer_size(index).get();
143 method_holder.planned_buffers.emplace_back(buffer_size);
144 method_holder.planned_spans.emplace_back(
145 method_holder.planned_buffers.back().data(), buffer_size);
146 }
147 method_holder.planned_memory =
148 std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
149 method_holder.planned_spans.data(),
150 method_holder.planned_spans.size()));
151 method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
152 memory_allocator_.get(),
153 method_holder.planned_memory.get(),
154 temp_allocator_.get());
155 method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
156 method_name.c_str(),
157 method_holder.memory_manager.get(),
158 event_tracer ? event_tracer : this->event_tracer()));
159 method_holder.inputs.resize(method_holder.method->inputs_size());
160 methods_.emplace(method_name, std::move(method_holder));
161 }
162 return runtime::Error::Ok;
163 }
164
method_meta(const std::string & method_name)165 runtime::Result<runtime::MethodMeta> Module::method_meta(
166 const std::string& method_name) {
167 ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
168 return methods_.at(method_name).method->method_meta();
169 }
170
execute(const std::string & method_name,const std::vector<runtime::EValue> & input_values)171 runtime::Result<std::vector<runtime::EValue>> Module::execute(
172 const std::string& method_name,
173 const std::vector<runtime::EValue>& input_values) {
174 ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
175 auto& method = methods_.at(method_name).method;
176 auto& inputs = methods_.at(method_name).inputs;
177
178 for (size_t i = 0; i < input_values.size(); ++i) {
179 if (!input_values[i].isNone()) {
180 inputs[i] = input_values[i];
181 }
182 }
183 for (size_t i = 0; i < inputs.size(); ++i) {
184 ET_CHECK_OR_RETURN_ERROR(
185 !inputs[i].isNone(), InvalidArgument, "input %zu is none", i);
186 }
187 ET_CHECK_OK_OR_RETURN_ERROR(method->set_inputs(
188 exec_aten::ArrayRef<runtime::EValue>(inputs.data(), inputs.size())));
189 ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
190
191 const auto outputs_size = method->outputs_size();
192 std::vector<runtime::EValue> outputs(outputs_size);
193 ET_CHECK_OK_OR_RETURN_ERROR(
194 method->get_outputs(outputs.data(), outputs_size));
195
196 return outputs;
197 }
198
set_input(const std::string & method_name,const runtime::EValue & input_value,size_t input_index)199 runtime::Error Module::set_input(
200 const std::string& method_name,
201 const runtime::EValue& input_value,
202 size_t input_index) {
203 ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
204 methods_.at(method_name).inputs.at(input_index) = input_value;
205 return runtime::Error::Ok;
206 }
207
set_inputs(const std::string & method_name,const std::vector<runtime::EValue> & input_values)208 runtime::Error Module::set_inputs(
209 const std::string& method_name,
210 const std::vector<runtime::EValue>& input_values) {
211 ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
212 auto& inputs = methods_.at(method_name).inputs;
213 ET_CHECK_OR_RETURN_ERROR(
214 inputs.size() == input_values.size(),
215 InvalidArgument,
216 "input size: %zu does not match method input size: %zu",
217 input_values.size(),
218 inputs.size());
219 inputs = input_values;
220 return runtime::Error::Ok;
221 }
222
set_output(const std::string & method_name,runtime::EValue output_value,size_t output_index)223 runtime::Error Module::set_output(
224 const std::string& method_name,
225 runtime::EValue output_value,
226 size_t output_index) {
227 ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
228 auto& method = methods_.at(method_name).method;
229 ET_CHECK_OR_RETURN_ERROR(
230 output_value.isTensor(),
231 InvalidArgument,
232 "output type: %zu is not tensor",
233 (size_t)output_value.tag);
234 const auto& output_tensor = output_value.toTensor();
235 return method->set_output_data_ptr(
236 output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
237 }
238
239 } // namespace extension
240 } // namespace executorch
241