1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/module/module.h>
10
11 #include <array>
12 #include <thread>
13
14 #include <gtest/gtest.h>
15
16 #include <executorch/extension/data_loader/file_data_loader.h>
17 #include <executorch/extension/tensor/tensor.h>
18
19 using namespace ::executorch::extension;
20 using namespace ::executorch::runtime;
21
22 class ModuleTest : public ::testing::Test {
23 protected:
SetUpTestSuite()24 static void SetUpTestSuite() {
25 model_path_ = std::getenv("RESOURCES_PATH") + std::string("/add.pte");
26 }
27
28 static std::string model_path_;
29 };
30
31 std::string ModuleTest::model_path_;
32
TEST_F(ModuleTest,TestLoad)33 TEST_F(ModuleTest, TestLoad) {
34 Module module(model_path_);
35
36 EXPECT_FALSE(module.is_loaded());
37 const auto error = module.load();
38 EXPECT_EQ(error, Error::Ok);
39 EXPECT_TRUE(module.is_loaded());
40 }
41
TEST_F(ModuleTest,TestLoadNonExistent)42 TEST_F(ModuleTest, TestLoadNonExistent) {
43 Module module("/path/to/nonexistent/file.pte");
44 const auto error = module.load();
45
46 EXPECT_NE(error, Error::Ok);
47 EXPECT_FALSE(module.is_loaded());
48 }
49
TEST_F(ModuleTest,TestLoadCorruptedFile)50 TEST_F(ModuleTest, TestLoadCorruptedFile) {
51 Module module("/dev/null");
52 const auto error = module.load();
53
54 EXPECT_NE(error, Error::Ok);
55 EXPECT_FALSE(module.is_loaded());
56 }
57
TEST_F(ModuleTest,TestMethodNames)58 TEST_F(ModuleTest, TestMethodNames) {
59 Module module(model_path_);
60
61 const auto method_names = module.method_names();
62 EXPECT_EQ(method_names.error(), Error::Ok);
63 EXPECT_EQ(method_names.get(), std::unordered_set<std::string>{"forward"});
64 }
65
TEST_F(ModuleTest,TestNonExistentMethodNames)66 TEST_F(ModuleTest, TestNonExistentMethodNames) {
67 Module module("/path/to/nonexistent/file.pte");
68
69 const auto method_names = module.method_names();
70 EXPECT_NE(method_names.error(), Error::Ok);
71 }
72
TEST_F(ModuleTest,TestLoadMethod)73 TEST_F(ModuleTest, TestLoadMethod) {
74 Module module(model_path_);
75
76 EXPECT_FALSE(module.is_method_loaded("forward"));
77 const auto error = module.load_method("forward");
78 EXPECT_EQ(error, Error::Ok);
79 EXPECT_TRUE(module.is_method_loaded("forward"));
80 EXPECT_TRUE(module.is_loaded());
81 }
82
TEST_F(ModuleTest,TestLoadNonExistentMethod)83 TEST_F(ModuleTest, TestLoadNonExistentMethod) {
84 Module module(model_path_);
85
86 const auto error = module.load_method("backward");
87 EXPECT_NE(error, Error::Ok);
88 EXPECT_FALSE(module.is_method_loaded("backward"));
89 EXPECT_TRUE(module.is_loaded());
90 }
91
TEST_F(ModuleTest,TestMethodMeta)92 TEST_F(ModuleTest, TestMethodMeta) {
93 Module module(model_path_);
94
95 const auto meta = module.method_meta("forward");
96 EXPECT_EQ(meta.error(), Error::Ok);
97 EXPECT_STREQ(meta->name(), "forward");
98 EXPECT_EQ(meta->num_inputs(), 2);
99 EXPECT_EQ(*(meta->input_tag(0)), Tag::Tensor);
100 EXPECT_EQ(meta->num_outputs(), 1);
101 EXPECT_EQ(*(meta->output_tag(0)), Tag::Tensor);
102
103 const auto input_meta = meta->input_tensor_meta(0);
104 EXPECT_EQ(input_meta.error(), Error::Ok);
105 EXPECT_EQ(input_meta->scalar_type(), exec_aten::ScalarType::Float);
106 EXPECT_EQ(input_meta->sizes().size(), 1);
107 EXPECT_EQ(input_meta->sizes()[0], 1);
108
109 const auto output_meta = meta->output_tensor_meta(0);
110 EXPECT_EQ(output_meta.error(), Error::Ok);
111 EXPECT_EQ(output_meta->scalar_type(), exec_aten::ScalarType::Float);
112 EXPECT_EQ(output_meta->sizes().size(), 1);
113 EXPECT_EQ(output_meta->sizes()[0], 1);
114 }
115
TEST_F(ModuleTest,TestNonExistentMethodMeta)116 TEST_F(ModuleTest, TestNonExistentMethodMeta) {
117 Module module("/path/to/nonexistent/file.pte");
118
119 const auto meta = module.method_meta("forward");
120 EXPECT_NE(meta.error(), Error::Ok);
121 }
122
TEST_F(ModuleTest,TestExecute)123 TEST_F(ModuleTest, TestExecute) {
124 Module module(model_path_);
125 auto tensor = make_tensor_ptr({1.f});
126
127 const auto result = module.execute("forward", {tensor, tensor});
128 EXPECT_EQ(result.error(), Error::Ok);
129
130 EXPECT_TRUE(module.is_loaded());
131 EXPECT_TRUE(module.is_method_loaded("forward"));
132
133 const auto data = result->at(0).toTensor().const_data_ptr<float>();
134
135 EXPECT_NEAR(data[0], 2, 1e-5);
136 }
137
TEST_F(ModuleTest,TestExecutePreload)138 TEST_F(ModuleTest, TestExecutePreload) {
139 Module module(model_path_);
140
141 const auto error = module.load();
142 EXPECT_EQ(error, Error::Ok);
143
144 auto tensor = make_tensor_ptr({1.f});
145
146 const auto result = module.execute("forward", {tensor, tensor});
147 EXPECT_EQ(result.error(), Error::Ok);
148
149 const auto data = result->at(0).toTensor().const_data_ptr<float>();
150
151 EXPECT_NEAR(data[0], 2, 1e-5);
152 }
153
TEST_F(ModuleTest,TestExecutePreload_method)154 TEST_F(ModuleTest, TestExecutePreload_method) {
155 Module module(model_path_);
156
157 const auto error = module.load_method("forward");
158 EXPECT_EQ(error, Error::Ok);
159
160 auto tensor = make_tensor_ptr({1.f});
161
162 const auto result = module.execute("forward", {tensor, tensor});
163 EXPECT_EQ(result.error(), Error::Ok);
164
165 const auto data = result->at(0).toTensor().const_data_ptr<float>();
166
167 EXPECT_NEAR(data[0], 2, 1e-5);
168 }
169
TEST_F(ModuleTest,TestExecutePreloadProgramAndMethod)170 TEST_F(ModuleTest, TestExecutePreloadProgramAndMethod) {
171 Module module(model_path_);
172
173 const auto load_error = module.load();
174 EXPECT_EQ(load_error, Error::Ok);
175
176 const auto load_method_error = module.load_method("forward");
177 EXPECT_EQ(load_method_error, Error::Ok);
178
179 auto tensor = make_tensor_ptr({1.f});
180
181 const auto result = module.execute("forward", {tensor, tensor});
182 EXPECT_EQ(result.error(), Error::Ok);
183
184 const auto data = result->at(0).toTensor().const_data_ptr<float>();
185
186 EXPECT_NEAR(data[0], 2, 1e-5);
187 }
188
TEST_F(ModuleTest,TestExecuteOnNonExistent)189 TEST_F(ModuleTest, TestExecuteOnNonExistent) {
190 Module module("/path/to/nonexistent/file.pte");
191
192 const auto result = module.execute("forward");
193
194 EXPECT_NE(result.error(), Error::Ok);
195 }
196
TEST_F(ModuleTest,TestExecuteOnCurrupted)197 TEST_F(ModuleTest, TestExecuteOnCurrupted) {
198 Module module("/dev/null");
199
200 const auto result = module.execute("forward");
201
202 EXPECT_NE(result.error(), Error::Ok);
203 }
204
TEST_F(ModuleTest,TestGet)205 TEST_F(ModuleTest, TestGet) {
206 Module module(model_path_);
207
208 auto tensor = make_tensor_ptr({1.f});
209
210 const auto result = module.get("forward", {tensor, tensor});
211 EXPECT_EQ(result.error(), Error::Ok);
212 const auto data = result->toTensor().const_data_ptr<float>();
213 EXPECT_NEAR(data[0], 2, 1e-5);
214 }
215
TEST_F(ModuleTest,TestForward)216 TEST_F(ModuleTest, TestForward) {
217 auto module = std::make_unique<Module>(model_path_);
218 auto tensor = make_tensor_ptr({21.f});
219
220 const auto result = module->forward({tensor, tensor});
221 EXPECT_EQ(result.error(), Error::Ok);
222
223 const auto data = result->at(0).toTensor().const_data_ptr<float>();
224
225 EXPECT_NEAR(data[0], 42, 1e-5);
226
227 auto tensor2 = make_tensor_ptr({2.f});
228 const auto result2 = module->forward({tensor2, tensor2});
229 EXPECT_EQ(result2.error(), Error::Ok);
230
231 const auto data2 = result->at(0).toTensor().const_data_ptr<float>();
232
233 EXPECT_NEAR(data2[0], 4, 1e-5);
234 }
235
TEST_F(ModuleTest,TestForwardWithInvalidInputs)236 TEST_F(ModuleTest, TestForwardWithInvalidInputs) {
237 Module module(model_path_);
238
239 const auto result = module.forward(EValue());
240
241 EXPECT_NE(result.error(), Error::Ok);
242 }
243
TEST_F(ModuleTest,TestProgramSharingBetweenModules)244 TEST_F(ModuleTest, TestProgramSharingBetweenModules) {
245 Module module1(model_path_);
246 EXPECT_FALSE(module1.is_loaded());
247
248 auto load_error = module1.load();
249 EXPECT_EQ(load_error, Error::Ok);
250 EXPECT_TRUE(module1.is_loaded());
251
252 Module module2(module1.program());
253 EXPECT_TRUE(module2.is_loaded());
254
255 auto method_names1 = module1.method_names();
256 EXPECT_EQ(method_names1.error(), Error::Ok);
257
258 auto method_names2 = module2.method_names();
259 EXPECT_EQ(method_names2.error(), Error::Ok);
260 EXPECT_EQ(method_names1.get(), method_names2.get());
261
262 auto load_method_error = module1.load_method("forward");
263 EXPECT_EQ(load_method_error, Error::Ok);
264 EXPECT_TRUE(module1.is_method_loaded("forward"));
265 EXPECT_FALSE(module2.is_method_loaded("forward"));
266
267 auto load_method_error2 = module2.load_method("forward");
268 EXPECT_EQ(load_method_error2, Error::Ok);
269 EXPECT_TRUE(module2.is_method_loaded("forward"));
270 }
271
TEST_F(ModuleTest,TestProgramSharingAndDataLoaderManagement)272 TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) {
273 auto loader = FileDataLoader::from(model_path_.c_str());
274 EXPECT_EQ(loader.error(), Error::Ok);
275 auto data_loader = std::make_unique<FileDataLoader>(std::move(loader.get()));
276
277 auto module1 = std::make_unique<Module>(std::move(data_loader));
278
279 auto load_error = module1->load();
280 EXPECT_EQ(load_error, Error::Ok);
281 EXPECT_TRUE(module1->is_loaded());
282
283 auto tensor = make_tensor_ptr({1.f});
284
285 const auto result1 = module1->execute("forward", {tensor, tensor});
286 EXPECT_EQ(result1.error(), Error::Ok);
287
288 auto module2 = std::make_unique<Module>(module1->program());
289
290 const auto result2 = module2->execute("forward", {tensor, tensor});
291 EXPECT_EQ(result2.error(), Error::Ok);
292
293 module1 = std::make_unique<Module>("/path/to/nonexistent/file.pte");
294 EXPECT_FALSE(module1->is_loaded());
295
296 const auto result3 = module2->execute("forward", {tensor, tensor});
297 EXPECT_EQ(result3.error(), Error::Ok);
298 }
299
TEST_F(ModuleTest,TestProgramPersistenceAndReuseAfterModuleDestruction)300 TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) {
301 std::shared_ptr<Program> shared_program;
302
303 {
304 auto loader = FileDataLoader::from(model_path_.c_str());
305 EXPECT_EQ(loader.error(), Error::Ok);
306 auto data_loader =
307 std::make_unique<FileDataLoader>(std::move(loader.get()));
308 auto* data_loader_ptr = data_loader.get();
309
310 Module module(std::move(data_loader));
311
312 auto load_error = module.load();
313 EXPECT_EQ(load_error, Error::Ok);
314 EXPECT_TRUE(module.is_loaded());
315
316 shared_program = module.program();
317 EXPECT_NE(shared_program, nullptr);
318
319 EXPECT_NE(data_loader_ptr, nullptr);
320 }
321
322 EXPECT_NE(shared_program, nullptr);
323
324 Module module(shared_program);
325
326 EXPECT_EQ(module.program(), shared_program);
327
328 auto tensor = make_tensor_ptr({1.f});
329
330 const auto result = module.execute("forward", {tensor, tensor});
331 EXPECT_EQ(result.error(), Error::Ok);
332
333 auto data = result->at(0).toTensor().const_data_ptr<float>();
334
335 EXPECT_NEAR(data[0], 2, 1e-5);
336 }
337
TEST_F(ModuleTest,TestConcurrentExecutionWithSharedProgram)338 TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) {
339 std::shared_ptr<Program> program;
340 {
341 Module module(model_path_);
342 EXPECT_FALSE(module.is_loaded());
343
344 auto load_error = module.load();
345 EXPECT_EQ(load_error, Error::Ok);
346 EXPECT_TRUE(module.is_loaded());
347
348 program = module.program();
349 }
350 EXPECT_TRUE(program != nullptr);
351
352 auto thread = [](std::shared_ptr<Program> program,
353 const std::array<float, 1>& input) {
354 Module module(program);
355 auto tensor = from_blob((void*)input.data(), {1});
356
357 const auto result = module.forward({tensor, tensor});
358 EXPECT_EQ(result.error(), Error::Ok);
359
360 const auto data = result->at(0).toTensor().const_data_ptr<float>();
361 EXPECT_NEAR(data[0], (input[0] * 2), 1e-5);
362 };
363
364 std::thread t1(thread, program, std::array<float, 1>{1});
365 std::thread t2(thread, program, std::array<float, 1>{2});
366 std::thread t3(thread, program, std::array<float, 1>{3});
367 std::thread t4(thread, program, std::array<float, 1>{4});
368 std::thread t5(thread, program, std::array<float, 1>{5});
369
370 t1.join();
371 t2.join();
372 t3.join();
373 t4.join();
374 t5.join();
375 }
376
TEST_F(ModuleTest,TestSetInputsBeforeExecute)377 TEST_F(ModuleTest, TestSetInputsBeforeExecute) {
378 Module module(model_path_);
379
380 auto tensor1 = make_tensor_ptr({4.f});
381 auto tensor2 = make_tensor_ptr({5.f});
382
383 EXPECT_EQ(module.set_inputs({tensor1, tensor2}), Error::Ok);
384
385 const auto result = module.forward();
386 EXPECT_EQ(result.error(), Error::Ok);
387
388 const auto data = result->at(0).toTensor().const_data_ptr<float>();
389 EXPECT_NEAR(data[0], 9, 1e-5);
390 }
391
TEST_F(ModuleTest,TestSetInputCombinedWithExecute)392 TEST_F(ModuleTest, TestSetInputCombinedWithExecute) {
393 Module module(model_path_);
394
395 auto tensor1 = make_tensor_ptr({2.f});
396 auto tensor2 = make_tensor_ptr({3.f});
397
398 EXPECT_EQ(module.set_input(tensor2, 1), Error::Ok);
399
400 const auto result = module.forward(tensor1);
401 EXPECT_EQ(result.error(), Error::Ok);
402
403 const auto data = result->at(0).toTensor().const_data_ptr<float>();
404 EXPECT_NEAR(data[0], 5, 1e-5);
405 }
406
TEST_F(ModuleTest,TestPartiallySetInputs)407 TEST_F(ModuleTest, TestPartiallySetInputs) {
408 Module module(model_path_);
409
410 auto tensor = make_tensor_ptr({1.f});
411
412 EXPECT_EQ(module.set_input(tensor, 0), Error::Ok);
413
414 const auto result = module.forward();
415 EXPECT_NE(result.error(), Error::Ok);
416 }
417
TEST_F(ModuleTest,TestUnsetInputs)418 TEST_F(ModuleTest, TestUnsetInputs) {
419 Module module(model_path_);
420
421 const auto result = module.forward();
422 EXPECT_NE(result.error(), Error::Ok);
423 }
424
TEST_F(ModuleTest,TestSetOutputInvalidIndex)425 TEST_F(ModuleTest, TestSetOutputInvalidIndex) {
426 Module module(model_path_);
427
428 auto output_tensor = empty({1});
429
430 EXPECT_NE(module.set_output(output_tensor, 1), Error::Ok);
431 }
432
TEST_F(ModuleTest,TestSetOutputInvalidType)433 TEST_F(ModuleTest, TestSetOutputInvalidType) {
434 Module module(model_path_);
435
436 EXPECT_NE(module.set_output(EValue()), Error::Ok);
437 }
438