xref: /aosp_15_r20/external/executorch/runtime/executor/test/program_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/executor/program.h>
10 
11 #include <cctype>
12 #include <filesystem>
13 
14 #include <cstring>
15 #include <memory>
16 
17 #include <executorch/extension/data_loader/buffer_data_loader.h>
18 #include <executorch/extension/data_loader/file_data_loader.h>
19 #include <executorch/runtime/core/error.h>
20 #include <executorch/runtime/core/result.h>
21 #include <executorch/runtime/platform/runtime.h>
22 #include <executorch/schema/program_generated.h>
23 #include <executorch/test/utils/DeathTest.h>
24 
25 #include <gtest/gtest.h>
26 
27 using namespace ::testing;
28 using executorch::runtime::DataLoader;
29 using executorch::runtime::Error;
30 using executorch::runtime::FreeableBuffer;
31 using executorch::runtime::Program;
32 using executorch::runtime::Result;
33 using torch::executor::util::BufferDataLoader;
34 using torch::executor::util::FileDataLoader;
35 
36 // Verification level to use for tests not specifically focused on verification.
37 // Use the highest level to exercise it more.
38 constexpr Program::Verification kDefaultVerification =
39     Program::Verification::InternalConsistency;
40 
41 class ProgramTest : public ::testing::Test {
42  protected:
SetUp()43   void SetUp() override {
44     // Since these tests cause ET_LOG to be called, the PAL must be initialized
45     // first.
46     executorch::runtime::runtime_init();
47 
48     // Load the serialized ModuleAdd data.
49     const char* path = std::getenv("ET_MODULE_ADD_PATH");
50     Result<FileDataLoader> loader = FileDataLoader::from(path);
51     ASSERT_EQ(loader.error(), Error::Ok);
52 
53     // This file should always be compatible.
54     Result<FreeableBuffer> header = loader->load(
55         /*offset=*/0,
56         Program::kMinHeadBytes,
57         /*segment_info=*/
58         DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
59     ASSERT_EQ(header.error(), Error::Ok);
60     EXPECT_EQ(
61         Program::check_header(header->data(), header->size()),
62         Program::HeaderStatus::CompatibleVersion);
63 
64     add_loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
65 
66     // Load the serialized ModuleMultiEntry data.
67     path = std::getenv("ET_MODULE_MULTI_ENTRY_PATH");
68     Result<FileDataLoader> multi_loader = FileDataLoader::from(path);
69     ASSERT_EQ(multi_loader.error(), Error::Ok);
70 
71     // This file should always be compatible.
72     Result<FreeableBuffer> multi_header = multi_loader->load(
73         /*offset=*/0,
74         Program::kMinHeadBytes,
75         /*segment_info=*/
76         DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
77     ASSERT_EQ(multi_header.error(), Error::Ok);
78     EXPECT_EQ(
79         Program::check_header(multi_header->data(), multi_header->size()),
80         Program::HeaderStatus::CompatibleVersion);
81 
82     multi_loader_ =
83         std::make_unique<FileDataLoader>(std::move(multi_loader.get()));
84   }
85 
86   std::unique_ptr<FileDataLoader> add_loader_;
87   std::unique_ptr<FileDataLoader> multi_loader_;
88 };
89 
90 namespace executorch {
91 namespace runtime {
92 namespace testing {
93 // Provides access to private Program methods.
94 class ProgramTestFriend final {
95  public:
LoadSegment(const Program * program,const DataLoader::SegmentInfo & segment_info)96   ET_NODISCARD static Result<FreeableBuffer> LoadSegment(
97       const Program* program,
98       const DataLoader::SegmentInfo& segment_info) {
99     return program->LoadSegment(segment_info);
100   }
101 
load_mutable_subsegment_into(const Program * program,size_t mutable_data_segments_index,size_t offset_index,size_t size,void * buffer)102   ET_NODISCARD static Error load_mutable_subsegment_into(
103       const Program* program,
104       size_t mutable_data_segments_index,
105       size_t offset_index,
106       size_t size,
107       void* buffer) {
108     return program->load_mutable_subsegment_into(
109         mutable_data_segments_index, offset_index, size, buffer);
110   }
111 
GetInternalProgram(const Program * program)112   const static executorch_flatbuffer::Program* GetInternalProgram(
113       const Program* program) {
114     return program->internal_program_;
115   }
116 };
117 } // namespace testing
118 } // namespace runtime
119 } // namespace executorch
120 
121 using executorch::runtime::testing::ProgramTestFriend;
122 
TEST_F(ProgramTest,DataParsesWithMinimalVerification)123 TEST_F(ProgramTest, DataParsesWithMinimalVerification) {
124   // Parse the Program from the data.
125   Result<Program> program =
126       Program::load(add_loader_.get(), Program::Verification::Minimal);
127 
128   // Should have succeeded.
129   EXPECT_EQ(program.error(), Error::Ok);
130 }
131 
TEST_F(ProgramTest,DataParsesWithInternalConsistencyVerification)132 TEST_F(ProgramTest, DataParsesWithInternalConsistencyVerification) {
133   // Parse the Program from the data.
134   Result<Program> program = Program::load(
135       add_loader_.get(), Program::Verification::InternalConsistency);
136 
137   // Should have succeeded.
138   EXPECT_EQ(program.error(), Error::Ok);
139 }
140 
TEST_F(ProgramTest,BadMagicFailsToLoad)141 TEST_F(ProgramTest, BadMagicFailsToLoad) {
142   // Make a local copy of the data.
143   size_t data_len = add_loader_->size().get();
144   auto data = std::make_unique<char[]>(data_len);
145   {
146     Result<FreeableBuffer> src = add_loader_->load(
147         /*offset=*/0,
148         data_len,
149         /*segment_info=*/
150         DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
151     ASSERT_EQ(src.error(), Error::Ok);
152     ASSERT_EQ(src->size(), data_len);
153     memcpy(data.get(), src->data(), data_len);
154     // FreeableBuffer goes out of scope and frees its data.
155   }
156 
157   // Corrupt the magic value.
158   EXPECT_EQ(data[4], 'E');
159   data[4] = 'X';
160   EXPECT_EQ(data[5], 'T');
161   data[5] = 'Y';
162 
163   // Wrap the modified data in a loader.
164   BufferDataLoader data_loader(data.get(), data_len);
165 
166   {
167     // Parse the Program from the data. Use minimal verification to show that
168     // even this catches the header problem.
169     Result<Program> program =
170         Program::load(&data_loader, Program::Verification::Minimal);
171 
172     // Should fail.
173     ASSERT_EQ(program.error(), Error::InvalidProgram);
174   }
175 
176   // Fix the data.
177   data[4] = 'E';
178   data[5] = 'T';
179 
180   {
181     // Parse the Program from the data again.
182     Result<Program> program =
183         Program::load(&data_loader, Program::Verification::Minimal);
184 
185     // Should now succeed.
186     ASSERT_EQ(program.error(), Error::Ok);
187   }
188 }
189 
TEST_F(ProgramTest,VerificationCatchesTruncation)190 TEST_F(ProgramTest, VerificationCatchesTruncation) {
191   // Get the program data.
192   size_t full_data_len = add_loader_->size().get();
193   Result<FreeableBuffer> full_data = add_loader_->load(
194       /*offset=*/0,
195       full_data_len,
196       /*segment_info=*/
197       DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
198   ASSERT_EQ(full_data.error(), Error::Ok);
199 
200   // Make a loader that only exposes half of the data.
201   BufferDataLoader half_data_loader(full_data->data(), full_data_len / 2);
202 
203   // Loading with full verification should fail.
204   Result<Program> program = Program::load(
205       &half_data_loader, Program::Verification::InternalConsistency);
206   ASSERT_EQ(program.error(), Error::InvalidProgram);
207 }
208 
TEST_F(ProgramTest,VerificationCatchesCorruption)209 TEST_F(ProgramTest, VerificationCatchesCorruption) {
210   // Make a local copy of the data.
211   size_t data_len = add_loader_->size().get();
212   auto data = std::make_unique<char[]>(data_len);
213   {
214     Result<FreeableBuffer> src = add_loader_->load(
215         /*offset=*/0,
216         data_len,
217         /*segment_info=*/
218         DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
219     ASSERT_EQ(src.error(), Error::Ok);
220     ASSERT_EQ(src->size(), data_len);
221     memcpy(data.get(), src->data(), data_len);
222     // FreeableBuffer goes out of scope and frees its data.
223   }
224 
225   // Corrupt the second half of the data.
226   std::memset(&data[data_len / 2], 0x55, data_len - (data_len / 2));
227 
228   // Wrap the corrupted data in a loader.
229   BufferDataLoader data_loader(data.get(), data_len);
230 
231   // Should fail to parse corrupted data when using full verification.
232   Result<Program> program =
233       Program::load(&data_loader, Program::Verification::InternalConsistency);
234   ASSERT_EQ(program.error(), Error::InvalidProgram);
235 }
236 
TEST_F(ProgramTest,UnalignedProgramDataFails)237 TEST_F(ProgramTest, UnalignedProgramDataFails) {
238   // Make a local copy of the data, on an odd alignment.
239   size_t data_len = add_loader_->size().get();
240   auto data = std::make_unique<char[]>(data_len + 1);
241   {
242     Result<FreeableBuffer> src = add_loader_->load(
243         /*offset=*/0,
244         data_len,
245         /*segment_info=*/
246         DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
247     ASSERT_EQ(src.error(), Error::Ok);
248     ASSERT_EQ(src->size(), data_len);
249     memcpy(data.get() + 1, src->data(), data_len);
250     // FreeableBuffer goes out of scope and frees its data.
251   }
252 
253   // Wrap the offset data in a loader.
254   BufferDataLoader data_loader(data.get() + 1, data_len);
255 
256   // Should refuse to accept unaligned data.
257   Result<Program> program =
258       Program::load(&data_loader, Program::Verification::Minimal);
259   ASSERT_NE(program.error(), Error::Ok);
260 }
261 
TEST_F(ProgramTest,LoadSegmentWithNoSegments)262 TEST_F(ProgramTest, LoadSegmentWithNoSegments) {
263   // Load a program with no appended segments.
264   Result<Program> program =
265       Program::load(add_loader_.get(), kDefaultVerification);
266   EXPECT_EQ(program.error(), Error::Ok);
267 
268   // Loading a non-program segment should fail.
269   const auto segment_info = DataLoader::SegmentInfo(
270       DataLoader::SegmentInfo::Type::Backend,
271       /*segment_index=*/0,
272       "some-backend");
273   Result<FreeableBuffer> segment =
274       ProgramTestFriend::LoadSegment(&program.get(), segment_info);
275   EXPECT_NE(segment.error(), Error::Ok);
276 }
277 
TEST_F(ProgramTest,ShortDataHeader)278 TEST_F(ProgramTest, ShortDataHeader) {
279   Result<FreeableBuffer> header = add_loader_->load(
280       /*offset=*/0,
281       Program::kMinHeadBytes,
282       /*segment_info=*/
283       DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
284   ASSERT_EQ(header.error(), Error::Ok);
285 
286   // Provide less than the required amount of data.
287   EXPECT_EQ(
288       Program::check_header(header->data(), Program::kMinHeadBytes - 1),
289       Program::HeaderStatus::ShortData);
290 }
291 
TEST_F(ProgramTest,IncompatibleHeader)292 TEST_F(ProgramTest, IncompatibleHeader) {
293   // Make a local copy of the header.
294   size_t data_len = Program::kMinHeadBytes;
295   auto data = std::make_unique<char[]>(data_len);
296   {
297     Result<FreeableBuffer> src = add_loader_->load(
298         /*offset=*/0,
299         data_len,
300         /*segment_info=*/
301         DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
302     ASSERT_EQ(src.error(), Error::Ok);
303     ASSERT_EQ(src->size(), data_len);
304     memcpy(data.get(), src->data(), data_len);
305     // FreeableBuffer goes out of scope and frees its data.
306   }
307 
308   // Change the number part of the magic value to a different value.
309   EXPECT_EQ(data[4], 'E');
310   EXPECT_EQ(data[5], 'T');
311   EXPECT_TRUE(std::isdigit(data[6])) << "Not a digit: " << data[6];
312   EXPECT_TRUE(std::isdigit(data[7])) << "Not a digit: " << data[7];
313 
314   // Modify the tens digit.
315   if (data[6] == '9') {
316     data[6] = '0';
317   } else {
318     data[6] += 1;
319   }
320   EXPECT_TRUE(std::isdigit(data[6])) << "Not a digit: " << data[6];
321 
322   // Should count as present but incompatible.
323   EXPECT_EQ(
324       Program::check_header(data.get(), data_len),
325       Program::HeaderStatus::IncompatibleVersion);
326 }
327 
TEST_F(ProgramTest,HeaderNotPresent)328 TEST_F(ProgramTest, HeaderNotPresent) {
329   // Make a local copy of the header.
330   size_t data_len = Program::kMinHeadBytes;
331   auto data = std::make_unique<char[]>(data_len);
332   {
333     Result<FreeableBuffer> src = add_loader_->load(
334         /*offset=*/0,
335         data_len,
336         /*segment_info=*/
337         DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
338     ASSERT_EQ(src.error(), Error::Ok);
339     ASSERT_EQ(src->size(), data_len);
340     memcpy(data.get(), src->data(), data_len);
341     // FreeableBuffer goes out of scope and frees its data.
342   }
343 
344   // Corrupt the magic value.
345   EXPECT_EQ(data[4], 'E');
346   data[4] = 'X';
347   EXPECT_EQ(data[5], 'T');
348   data[5] = 'Y';
349 
350   // The header is not present.
351   EXPECT_EQ(
352       Program::check_header(data.get(), data_len),
353       Program::HeaderStatus::NotPresent);
354 }
355 
TEST_F(ProgramTest,getMethods)356 TEST_F(ProgramTest, getMethods) {
357   // Parse the Program from the data.
358   Result<Program> program_res =
359       Program::load(multi_loader_.get(), kDefaultVerification);
360   EXPECT_EQ(program_res.error(), Error::Ok);
361 
362   Program program(std::move(program_res.get()));
363 
364   // Method calls should succeed without hitting ET_CHECK.
365   EXPECT_EQ(program.num_methods(), 2);
366   auto res = program.get_method_name(0);
367   EXPECT_TRUE(res.ok());
368   EXPECT_EQ(strcmp(res.get(), "forward"), 0);
369   auto res2 = program.get_method_name(1);
370   EXPECT_TRUE(res2.ok());
371   EXPECT_EQ(strcmp(res2.get(), "forward2"), 0);
372 }
373 
374 // Test that the deprecated Load method (capital 'L') still works.
TEST_F(ProgramTest,DEPRECATEDLoad)375 TEST_F(ProgramTest, DEPRECATEDLoad) {
376   // Parse the Program from the data.
377   // NOLINTNEXTLINE(facebook-hte-Deprecated)
378   Result<Program> program_res = Program::Load(multi_loader_.get());
379   EXPECT_EQ(program_res.error(), Error::Ok);
380 }
381 
TEST_F(ProgramTest,LoadConstantSegmentWithNoConstantSegment)382 TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) {
383   Result<Program> program =
384       Program::load(add_loader_.get(), kDefaultVerification);
385   ASSERT_EQ(program.error(), Error::Ok);
386 
387   // Load constant segment data should fail.
388   const auto segment_info = DataLoader::SegmentInfo(
389       DataLoader::SegmentInfo::Type::Constant,
390       /*segment_index=*/0);
391   Result<FreeableBuffer> segment =
392       ProgramTestFriend::LoadSegment(&program.get(), segment_info);
393   EXPECT_NE(segment.error(), Error::Ok);
394 
395   const executorch_flatbuffer::Program* flatbuffer_program =
396       ProgramTestFriend::GetInternalProgram(&program.get());
397 
398   // The constant buffer should be empty.
399   EXPECT_EQ(flatbuffer_program->constant_buffer()->size(), 0);
400 
401   // Expect 1 constant segment, placeholder for non-const tensors.
402   EXPECT_EQ(flatbuffer_program->segments()->size(), 1);
403 }
404 
TEST_F(ProgramTest,LoadConstantSegment)405 TEST_F(ProgramTest, LoadConstantSegment) {
406   // Load the serialized ModuleLinear data, with constants in the segment.
407   const char* linear_path = std::getenv("ET_MODULE_LINEAR_PATH");
408   Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path);
409   ASSERT_EQ(linear_loader.error(), Error::Ok);
410 
411   // This file should always be compatible.
412   Result<FreeableBuffer> linear_header = linear_loader->load(
413       /*offset=*/0,
414       Program::kMinHeadBytes,
415       /*segment_info=*/
416       DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
417   ASSERT_EQ(linear_header.error(), Error::Ok);
418   EXPECT_EQ(
419       Program::check_header(linear_header->data(), linear_header->size()),
420       Program::HeaderStatus::CompatibleVersion);
421 
422   Result<Program> program = Program::load(&linear_loader.get());
423   ASSERT_EQ(program.error(), Error::Ok);
424 
425   // Load constant segment data, which is currently always in segment index
426   // zero.
427   const auto segment_info = DataLoader::SegmentInfo(
428       DataLoader::SegmentInfo::Type::Constant,
429       /*segment_index=*/0);
430   Result<FreeableBuffer> segment =
431       ProgramTestFriend::LoadSegment(&program.get(), segment_info);
432   EXPECT_EQ(segment.error(), Error::Ok);
433 
434   const executorch_flatbuffer::Program* flatbuffer_program =
435       ProgramTestFriend::GetInternalProgram(&program.get());
436 
437   // Expect one segment containing the constants.
438   EXPECT_EQ(flatbuffer_program->segments()->size(), 1);
439 
440   // The constant buffer should be empty.
441   EXPECT_EQ(flatbuffer_program->constant_buffer()->size(), 0);
442 
443   // Check constant segment offsets.
444   EXPECT_EQ(flatbuffer_program->constant_segment()->segment_index(), 0);
445   EXPECT_GE(flatbuffer_program->constant_segment()->offsets()->size(), 1);
446 }
447 
TEST_F(ProgramTest,LoadConstantSegmentWhenConstantBufferExists)448 TEST_F(ProgramTest, LoadConstantSegmentWhenConstantBufferExists) {
449   // Load the serialized ModuleLinear data, with constants in the flatbuffer and
450   // no constants in the segment.
451   const char* linear_path =
452       std::getenv("DEPRECATED_ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH");
453   Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path);
454   ASSERT_EQ(linear_loader.error(), Error::Ok);
455 
456   // This file should always be compatible.
457   Result<FreeableBuffer> linear_header = linear_loader->load(
458       /*offset=*/0,
459       Program::kMinHeadBytes,
460       /*segment_info=*/
461       DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
462   ASSERT_EQ(linear_header.error(), Error::Ok);
463   EXPECT_EQ(
464       Program::check_header(linear_header->data(), linear_header->size()),
465       Program::HeaderStatus::CompatibleVersion);
466 
467   Result<Program> program = Program::load(&linear_loader.get());
468   ASSERT_EQ(program.error(), Error::Ok);
469 
470   const executorch_flatbuffer::Program* flatbuffer_program =
471       ProgramTestFriend::GetInternalProgram(&program.get());
472 
473   // Expect no segments.
474   EXPECT_EQ(flatbuffer_program->segments()->size(), 0);
475 
476   // The constant buffer should exist.
477   EXPECT_GE(flatbuffer_program->constant_buffer()->size(), 1);
478 }
479 
TEST_F(ProgramTest,LoadFromMutableSegment)480 TEST_F(ProgramTest, LoadFromMutableSegment) {
481   // Load the serialized ModuleSimpleTrain data.
482   auto path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
483   Result<FileDataLoader> training_loader = FileDataLoader::from(path);
484   ASSERT_EQ(training_loader.error(), Error::Ok);
485 
486   // This file should always be compatible.
487   Result<FreeableBuffer> training_header = training_loader->load(
488       /*offset=*/0,
489       Program::kMinHeadBytes,
490       DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
491   ASSERT_EQ(training_header.error(), Error::Ok);
492   EXPECT_EQ(
493       Program::check_header(training_header->data(), training_header->size()),
494       Program::HeaderStatus::CompatibleVersion);
495 
496   Result<Program> program = Program::load(&training_loader.get());
497   ASSERT_EQ(program.error(), Error::Ok);
498 
499   // dummy buffers to load into
500   uint8_t buffer[1] = {0};
501   uint8_t buffer2[1] = {0};
502 
503   // Load some mutable segment data
504   Error err = ProgramTestFriend::load_mutable_subsegment_into(
505       &program.get(), 0, 1, 1, buffer);
506   EXPECT_EQ(err, Error::Ok);
507 
508   // Check that the data loaded correctly, and then mutate it
509   EXPECT_EQ(buffer[0], 232); // 232 comes from inspecting the file itself. The
510                              // file is seeded so this value should be stable.
511   buffer[0] = 0;
512 
513   // Load the same mutable segment data from file into a different buffer.
514   err = ProgramTestFriend::load_mutable_subsegment_into(
515       &program.get(),
516       0, // mutable_data_segments_index
517       1, // offset_index
518       1, // size
519       buffer2);
520   EXPECT_EQ(err, Error::Ok);
521 
522   // Check that new data loaded from the file does not reflect the change to
523   // buffer.
524   EXPECT_EQ(buffer2[0], 232);
525 
526   const executorch_flatbuffer::Program* flatbuffer_program =
527       ProgramTestFriend::GetInternalProgram(&program.get());
528 
529   // Expect 2 segments. 1 mutable segment and 1 constant segment.
530   EXPECT_EQ(flatbuffer_program->segments()->size(), 2);
531 
532   // Expect a mutable data segment.
533   EXPECT_EQ(flatbuffer_program->mutable_data_segments()->size(), 1);
534 
535   // Expect the 0 index to be reserved and the offsets for weight and bias of
536   // linear to be indices 1 and 2.
537   EXPECT_EQ(
538       flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->size(),
539       3);
540   EXPECT_EQ(
541       flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(0),
542       0);
543   EXPECT_EQ(
544       flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(1),
545       0);
546   EXPECT_EQ(
547       flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(2),
548       36);
549 
550   // Loading beyond file should fail
551   err = ProgramTestFriend::load_mutable_subsegment_into(
552       &program.get(), 0, 1, 500, buffer);
553   EXPECT_NE(err, Error::Ok);
554 
555   // Loading beyond offsets should fail
556   err = ProgramTestFriend::load_mutable_subsegment_into(
557       &program.get(), 0, 500, 1, buffer);
558   EXPECT_NE(err, Error::Ok);
559 
560   // Loading beyond segments should fail
561   err = ProgramTestFriend::load_mutable_subsegment_into(
562       &program.get(), 500, 1, 1, buffer);
563   EXPECT_NE(err, Error::Ok);
564 }
565