1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/executor/program.h>
10
11 #include <cctype>
12 #include <filesystem>
13
14 #include <cstring>
15 #include <memory>
16
17 #include <executorch/extension/data_loader/buffer_data_loader.h>
18 #include <executorch/extension/data_loader/file_data_loader.h>
19 #include <executorch/runtime/core/error.h>
20 #include <executorch/runtime/core/result.h>
21 #include <executorch/runtime/platform/runtime.h>
22 #include <executorch/schema/program_generated.h>
23 #include <executorch/test/utils/DeathTest.h>
24
25 #include <gtest/gtest.h>
26
27 using namespace ::testing;
28 using executorch::runtime::DataLoader;
29 using executorch::runtime::Error;
30 using executorch::runtime::FreeableBuffer;
31 using executorch::runtime::Program;
32 using executorch::runtime::Result;
33 using torch::executor::util::BufferDataLoader;
34 using torch::executor::util::FileDataLoader;
35
36 // Verification level to use for tests not specifically focused on verification.
37 // Use the highest level to exercise it more.
38 constexpr Program::Verification kDefaultVerification =
39 Program::Verification::InternalConsistency;
40
41 class ProgramTest : public ::testing::Test {
42 protected:
SetUp()43 void SetUp() override {
44 // Since these tests cause ET_LOG to be called, the PAL must be initialized
45 // first.
46 executorch::runtime::runtime_init();
47
48 // Load the serialized ModuleAdd data.
49 const char* path = std::getenv("ET_MODULE_ADD_PATH");
50 Result<FileDataLoader> loader = FileDataLoader::from(path);
51 ASSERT_EQ(loader.error(), Error::Ok);
52
53 // This file should always be compatible.
54 Result<FreeableBuffer> header = loader->load(
55 /*offset=*/0,
56 Program::kMinHeadBytes,
57 /*segment_info=*/
58 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
59 ASSERT_EQ(header.error(), Error::Ok);
60 EXPECT_EQ(
61 Program::check_header(header->data(), header->size()),
62 Program::HeaderStatus::CompatibleVersion);
63
64 add_loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
65
66 // Load the serialized ModuleMultiEntry data.
67 path = std::getenv("ET_MODULE_MULTI_ENTRY_PATH");
68 Result<FileDataLoader> multi_loader = FileDataLoader::from(path);
69 ASSERT_EQ(multi_loader.error(), Error::Ok);
70
71 // This file should always be compatible.
72 Result<FreeableBuffer> multi_header = multi_loader->load(
73 /*offset=*/0,
74 Program::kMinHeadBytes,
75 /*segment_info=*/
76 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
77 ASSERT_EQ(multi_header.error(), Error::Ok);
78 EXPECT_EQ(
79 Program::check_header(multi_header->data(), multi_header->size()),
80 Program::HeaderStatus::CompatibleVersion);
81
82 multi_loader_ =
83 std::make_unique<FileDataLoader>(std::move(multi_loader.get()));
84 }
85
86 std::unique_ptr<FileDataLoader> add_loader_;
87 std::unique_ptr<FileDataLoader> multi_loader_;
88 };
89
90 namespace executorch {
91 namespace runtime {
92 namespace testing {
93 // Provides access to private Program methods.
94 class ProgramTestFriend final {
95 public:
LoadSegment(const Program * program,const DataLoader::SegmentInfo & segment_info)96 ET_NODISCARD static Result<FreeableBuffer> LoadSegment(
97 const Program* program,
98 const DataLoader::SegmentInfo& segment_info) {
99 return program->LoadSegment(segment_info);
100 }
101
load_mutable_subsegment_into(const Program * program,size_t mutable_data_segments_index,size_t offset_index,size_t size,void * buffer)102 ET_NODISCARD static Error load_mutable_subsegment_into(
103 const Program* program,
104 size_t mutable_data_segments_index,
105 size_t offset_index,
106 size_t size,
107 void* buffer) {
108 return program->load_mutable_subsegment_into(
109 mutable_data_segments_index, offset_index, size, buffer);
110 }
111
GetInternalProgram(const Program * program)112 const static executorch_flatbuffer::Program* GetInternalProgram(
113 const Program* program) {
114 return program->internal_program_;
115 }
116 };
117 } // namespace testing
118 } // namespace runtime
119 } // namespace executorch
120
121 using executorch::runtime::testing::ProgramTestFriend;
122
TEST_F(ProgramTest,DataParsesWithMinimalVerification)123 TEST_F(ProgramTest, DataParsesWithMinimalVerification) {
124 // Parse the Program from the data.
125 Result<Program> program =
126 Program::load(add_loader_.get(), Program::Verification::Minimal);
127
128 // Should have succeeded.
129 EXPECT_EQ(program.error(), Error::Ok);
130 }
131
TEST_F(ProgramTest,DataParsesWithInternalConsistencyVerification)132 TEST_F(ProgramTest, DataParsesWithInternalConsistencyVerification) {
133 // Parse the Program from the data.
134 Result<Program> program = Program::load(
135 add_loader_.get(), Program::Verification::InternalConsistency);
136
137 // Should have succeeded.
138 EXPECT_EQ(program.error(), Error::Ok);
139 }
140
TEST_F(ProgramTest,BadMagicFailsToLoad)141 TEST_F(ProgramTest, BadMagicFailsToLoad) {
142 // Make a local copy of the data.
143 size_t data_len = add_loader_->size().get();
144 auto data = std::make_unique<char[]>(data_len);
145 {
146 Result<FreeableBuffer> src = add_loader_->load(
147 /*offset=*/0,
148 data_len,
149 /*segment_info=*/
150 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
151 ASSERT_EQ(src.error(), Error::Ok);
152 ASSERT_EQ(src->size(), data_len);
153 memcpy(data.get(), src->data(), data_len);
154 // FreeableBuffer goes out of scope and frees its data.
155 }
156
157 // Corrupt the magic value.
158 EXPECT_EQ(data[4], 'E');
159 data[4] = 'X';
160 EXPECT_EQ(data[5], 'T');
161 data[5] = 'Y';
162
163 // Wrap the modified data in a loader.
164 BufferDataLoader data_loader(data.get(), data_len);
165
166 {
167 // Parse the Program from the data. Use minimal verification to show that
168 // even this catches the header problem.
169 Result<Program> program =
170 Program::load(&data_loader, Program::Verification::Minimal);
171
172 // Should fail.
173 ASSERT_EQ(program.error(), Error::InvalidProgram);
174 }
175
176 // Fix the data.
177 data[4] = 'E';
178 data[5] = 'T';
179
180 {
181 // Parse the Program from the data again.
182 Result<Program> program =
183 Program::load(&data_loader, Program::Verification::Minimal);
184
185 // Should now succeed.
186 ASSERT_EQ(program.error(), Error::Ok);
187 }
188 }
189
TEST_F(ProgramTest,VerificationCatchesTruncation)190 TEST_F(ProgramTest, VerificationCatchesTruncation) {
191 // Get the program data.
192 size_t full_data_len = add_loader_->size().get();
193 Result<FreeableBuffer> full_data = add_loader_->load(
194 /*offset=*/0,
195 full_data_len,
196 /*segment_info=*/
197 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
198 ASSERT_EQ(full_data.error(), Error::Ok);
199
200 // Make a loader that only exposes half of the data.
201 BufferDataLoader half_data_loader(full_data->data(), full_data_len / 2);
202
203 // Loading with full verification should fail.
204 Result<Program> program = Program::load(
205 &half_data_loader, Program::Verification::InternalConsistency);
206 ASSERT_EQ(program.error(), Error::InvalidProgram);
207 }
208
TEST_F(ProgramTest,VerificationCatchesCorruption)209 TEST_F(ProgramTest, VerificationCatchesCorruption) {
210 // Make a local copy of the data.
211 size_t data_len = add_loader_->size().get();
212 auto data = std::make_unique<char[]>(data_len);
213 {
214 Result<FreeableBuffer> src = add_loader_->load(
215 /*offset=*/0,
216 data_len,
217 /*segment_info=*/
218 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
219 ASSERT_EQ(src.error(), Error::Ok);
220 ASSERT_EQ(src->size(), data_len);
221 memcpy(data.get(), src->data(), data_len);
222 // FreeableBuffer goes out of scope and frees its data.
223 }
224
225 // Corrupt the second half of the data.
226 std::memset(&data[data_len / 2], 0x55, data_len - (data_len / 2));
227
228 // Wrap the corrupted data in a loader.
229 BufferDataLoader data_loader(data.get(), data_len);
230
231 // Should fail to parse corrupted data when using full verification.
232 Result<Program> program =
233 Program::load(&data_loader, Program::Verification::InternalConsistency);
234 ASSERT_EQ(program.error(), Error::InvalidProgram);
235 }
236
TEST_F(ProgramTest,UnalignedProgramDataFails)237 TEST_F(ProgramTest, UnalignedProgramDataFails) {
238 // Make a local copy of the data, on an odd alignment.
239 size_t data_len = add_loader_->size().get();
240 auto data = std::make_unique<char[]>(data_len + 1);
241 {
242 Result<FreeableBuffer> src = add_loader_->load(
243 /*offset=*/0,
244 data_len,
245 /*segment_info=*/
246 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
247 ASSERT_EQ(src.error(), Error::Ok);
248 ASSERT_EQ(src->size(), data_len);
249 memcpy(data.get() + 1, src->data(), data_len);
250 // FreeableBuffer goes out of scope and frees its data.
251 }
252
253 // Wrap the offset data in a loader.
254 BufferDataLoader data_loader(data.get() + 1, data_len);
255
256 // Should refuse to accept unaligned data.
257 Result<Program> program =
258 Program::load(&data_loader, Program::Verification::Minimal);
259 ASSERT_NE(program.error(), Error::Ok);
260 }
261
TEST_F(ProgramTest,LoadSegmentWithNoSegments)262 TEST_F(ProgramTest, LoadSegmentWithNoSegments) {
263 // Load a program with no appended segments.
264 Result<Program> program =
265 Program::load(add_loader_.get(), kDefaultVerification);
266 EXPECT_EQ(program.error(), Error::Ok);
267
268 // Loading a non-program segment should fail.
269 const auto segment_info = DataLoader::SegmentInfo(
270 DataLoader::SegmentInfo::Type::Backend,
271 /*segment_index=*/0,
272 "some-backend");
273 Result<FreeableBuffer> segment =
274 ProgramTestFriend::LoadSegment(&program.get(), segment_info);
275 EXPECT_NE(segment.error(), Error::Ok);
276 }
277
TEST_F(ProgramTest,ShortDataHeader)278 TEST_F(ProgramTest, ShortDataHeader) {
279 Result<FreeableBuffer> header = add_loader_->load(
280 /*offset=*/0,
281 Program::kMinHeadBytes,
282 /*segment_info=*/
283 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
284 ASSERT_EQ(header.error(), Error::Ok);
285
286 // Provide less than the required amount of data.
287 EXPECT_EQ(
288 Program::check_header(header->data(), Program::kMinHeadBytes - 1),
289 Program::HeaderStatus::ShortData);
290 }
291
TEST_F(ProgramTest,IncompatibleHeader)292 TEST_F(ProgramTest, IncompatibleHeader) {
293 // Make a local copy of the header.
294 size_t data_len = Program::kMinHeadBytes;
295 auto data = std::make_unique<char[]>(data_len);
296 {
297 Result<FreeableBuffer> src = add_loader_->load(
298 /*offset=*/0,
299 data_len,
300 /*segment_info=*/
301 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
302 ASSERT_EQ(src.error(), Error::Ok);
303 ASSERT_EQ(src->size(), data_len);
304 memcpy(data.get(), src->data(), data_len);
305 // FreeableBuffer goes out of scope and frees its data.
306 }
307
308 // Change the number part of the magic value to a different value.
309 EXPECT_EQ(data[4], 'E');
310 EXPECT_EQ(data[5], 'T');
311 EXPECT_TRUE(std::isdigit(data[6])) << "Not a digit: " << data[6];
312 EXPECT_TRUE(std::isdigit(data[7])) << "Not a digit: " << data[7];
313
314 // Modify the tens digit.
315 if (data[6] == '9') {
316 data[6] = '0';
317 } else {
318 data[6] += 1;
319 }
320 EXPECT_TRUE(std::isdigit(data[6])) << "Not a digit: " << data[6];
321
322 // Should count as present but incompatible.
323 EXPECT_EQ(
324 Program::check_header(data.get(), data_len),
325 Program::HeaderStatus::IncompatibleVersion);
326 }
327
TEST_F(ProgramTest,HeaderNotPresent)328 TEST_F(ProgramTest, HeaderNotPresent) {
329 // Make a local copy of the header.
330 size_t data_len = Program::kMinHeadBytes;
331 auto data = std::make_unique<char[]>(data_len);
332 {
333 Result<FreeableBuffer> src = add_loader_->load(
334 /*offset=*/0,
335 data_len,
336 /*segment_info=*/
337 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
338 ASSERT_EQ(src.error(), Error::Ok);
339 ASSERT_EQ(src->size(), data_len);
340 memcpy(data.get(), src->data(), data_len);
341 // FreeableBuffer goes out of scope and frees its data.
342 }
343
344 // Corrupt the magic value.
345 EXPECT_EQ(data[4], 'E');
346 data[4] = 'X';
347 EXPECT_EQ(data[5], 'T');
348 data[5] = 'Y';
349
350 // The header is not present.
351 EXPECT_EQ(
352 Program::check_header(data.get(), data_len),
353 Program::HeaderStatus::NotPresent);
354 }
355
TEST_F(ProgramTest,getMethods)356 TEST_F(ProgramTest, getMethods) {
357 // Parse the Program from the data.
358 Result<Program> program_res =
359 Program::load(multi_loader_.get(), kDefaultVerification);
360 EXPECT_EQ(program_res.error(), Error::Ok);
361
362 Program program(std::move(program_res.get()));
363
364 // Method calls should succeed without hitting ET_CHECK.
365 EXPECT_EQ(program.num_methods(), 2);
366 auto res = program.get_method_name(0);
367 EXPECT_TRUE(res.ok());
368 EXPECT_EQ(strcmp(res.get(), "forward"), 0);
369 auto res2 = program.get_method_name(1);
370 EXPECT_TRUE(res2.ok());
371 EXPECT_EQ(strcmp(res2.get(), "forward2"), 0);
372 }
373
374 // Test that the deprecated Load method (capital 'L') still works.
TEST_F(ProgramTest,DEPRECATEDLoad)375 TEST_F(ProgramTest, DEPRECATEDLoad) {
376 // Parse the Program from the data.
377 // NOLINTNEXTLINE(facebook-hte-Deprecated)
378 Result<Program> program_res = Program::Load(multi_loader_.get());
379 EXPECT_EQ(program_res.error(), Error::Ok);
380 }
381
TEST_F(ProgramTest,LoadConstantSegmentWithNoConstantSegment)382 TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) {
383 Result<Program> program =
384 Program::load(add_loader_.get(), kDefaultVerification);
385 ASSERT_EQ(program.error(), Error::Ok);
386
387 // Load constant segment data should fail.
388 const auto segment_info = DataLoader::SegmentInfo(
389 DataLoader::SegmentInfo::Type::Constant,
390 /*segment_index=*/0);
391 Result<FreeableBuffer> segment =
392 ProgramTestFriend::LoadSegment(&program.get(), segment_info);
393 EXPECT_NE(segment.error(), Error::Ok);
394
395 const executorch_flatbuffer::Program* flatbuffer_program =
396 ProgramTestFriend::GetInternalProgram(&program.get());
397
398 // The constant buffer should be empty.
399 EXPECT_EQ(flatbuffer_program->constant_buffer()->size(), 0);
400
401 // Expect 1 constant segment, placeholder for non-const tensors.
402 EXPECT_EQ(flatbuffer_program->segments()->size(), 1);
403 }
404
TEST_F(ProgramTest,LoadConstantSegment)405 TEST_F(ProgramTest, LoadConstantSegment) {
406 // Load the serialized ModuleLinear data, with constants in the segment.
407 const char* linear_path = std::getenv("ET_MODULE_LINEAR_PATH");
408 Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path);
409 ASSERT_EQ(linear_loader.error(), Error::Ok);
410
411 // This file should always be compatible.
412 Result<FreeableBuffer> linear_header = linear_loader->load(
413 /*offset=*/0,
414 Program::kMinHeadBytes,
415 /*segment_info=*/
416 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
417 ASSERT_EQ(linear_header.error(), Error::Ok);
418 EXPECT_EQ(
419 Program::check_header(linear_header->data(), linear_header->size()),
420 Program::HeaderStatus::CompatibleVersion);
421
422 Result<Program> program = Program::load(&linear_loader.get());
423 ASSERT_EQ(program.error(), Error::Ok);
424
425 // Load constant segment data, which is currently always in segment index
426 // zero.
427 const auto segment_info = DataLoader::SegmentInfo(
428 DataLoader::SegmentInfo::Type::Constant,
429 /*segment_index=*/0);
430 Result<FreeableBuffer> segment =
431 ProgramTestFriend::LoadSegment(&program.get(), segment_info);
432 EXPECT_EQ(segment.error(), Error::Ok);
433
434 const executorch_flatbuffer::Program* flatbuffer_program =
435 ProgramTestFriend::GetInternalProgram(&program.get());
436
437 // Expect one segment containing the constants.
438 EXPECT_EQ(flatbuffer_program->segments()->size(), 1);
439
440 // The constant buffer should be empty.
441 EXPECT_EQ(flatbuffer_program->constant_buffer()->size(), 0);
442
443 // Check constant segment offsets.
444 EXPECT_EQ(flatbuffer_program->constant_segment()->segment_index(), 0);
445 EXPECT_GE(flatbuffer_program->constant_segment()->offsets()->size(), 1);
446 }
447
TEST_F(ProgramTest,LoadConstantSegmentWhenConstantBufferExists)448 TEST_F(ProgramTest, LoadConstantSegmentWhenConstantBufferExists) {
449 // Load the serialized ModuleLinear data, with constants in the flatbuffer and
450 // no constants in the segment.
451 const char* linear_path =
452 std::getenv("DEPRECATED_ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH");
453 Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path);
454 ASSERT_EQ(linear_loader.error(), Error::Ok);
455
456 // This file should always be compatible.
457 Result<FreeableBuffer> linear_header = linear_loader->load(
458 /*offset=*/0,
459 Program::kMinHeadBytes,
460 /*segment_info=*/
461 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
462 ASSERT_EQ(linear_header.error(), Error::Ok);
463 EXPECT_EQ(
464 Program::check_header(linear_header->data(), linear_header->size()),
465 Program::HeaderStatus::CompatibleVersion);
466
467 Result<Program> program = Program::load(&linear_loader.get());
468 ASSERT_EQ(program.error(), Error::Ok);
469
470 const executorch_flatbuffer::Program* flatbuffer_program =
471 ProgramTestFriend::GetInternalProgram(&program.get());
472
473 // Expect no segments.
474 EXPECT_EQ(flatbuffer_program->segments()->size(), 0);
475
476 // The constant buffer should exist.
477 EXPECT_GE(flatbuffer_program->constant_buffer()->size(), 1);
478 }
479
TEST_F(ProgramTest,LoadFromMutableSegment)480 TEST_F(ProgramTest, LoadFromMutableSegment) {
481 // Load the serialized ModuleSimpleTrain data.
482 auto path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
483 Result<FileDataLoader> training_loader = FileDataLoader::from(path);
484 ASSERT_EQ(training_loader.error(), Error::Ok);
485
486 // This file should always be compatible.
487 Result<FreeableBuffer> training_header = training_loader->load(
488 /*offset=*/0,
489 Program::kMinHeadBytes,
490 DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
491 ASSERT_EQ(training_header.error(), Error::Ok);
492 EXPECT_EQ(
493 Program::check_header(training_header->data(), training_header->size()),
494 Program::HeaderStatus::CompatibleVersion);
495
496 Result<Program> program = Program::load(&training_loader.get());
497 ASSERT_EQ(program.error(), Error::Ok);
498
499 // dummy buffers to load into
500 uint8_t buffer[1] = {0};
501 uint8_t buffer2[1] = {0};
502
503 // Load some mutable segment data
504 Error err = ProgramTestFriend::load_mutable_subsegment_into(
505 &program.get(), 0, 1, 1, buffer);
506 EXPECT_EQ(err, Error::Ok);
507
508 // Check that the data loaded correctly, and then mutate it
509 EXPECT_EQ(buffer[0], 232); // 232 comes from inspecting the file itself. The
510 // file is seeded so this value should be stable.
511 buffer[0] = 0;
512
513 // Load the same mutable segment data from file into a different buffer.
514 err = ProgramTestFriend::load_mutable_subsegment_into(
515 &program.get(),
516 0, // mutable_data_segments_index
517 1, // offset_index
518 1, // size
519 buffer2);
520 EXPECT_EQ(err, Error::Ok);
521
522 // Check that new data loaded from the file does not reflect the change to
523 // buffer.
524 EXPECT_EQ(buffer2[0], 232);
525
526 const executorch_flatbuffer::Program* flatbuffer_program =
527 ProgramTestFriend::GetInternalProgram(&program.get());
528
529 // Expect 2 segments. 1 mutable segment and 1 constant segment.
530 EXPECT_EQ(flatbuffer_program->segments()->size(), 2);
531
532 // Expect a mutable data segment.
533 EXPECT_EQ(flatbuffer_program->mutable_data_segments()->size(), 1);
534
535 // Expect the 0 index to be reserved and the offsets for weight and bias of
536 // linear to be indices 1 and 2.
537 EXPECT_EQ(
538 flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->size(),
539 3);
540 EXPECT_EQ(
541 flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(0),
542 0);
543 EXPECT_EQ(
544 flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(1),
545 0);
546 EXPECT_EQ(
547 flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(2),
548 36);
549
550 // Loading beyond file should fail
551 err = ProgramTestFriend::load_mutable_subsegment_into(
552 &program.get(), 0, 1, 500, buffer);
553 EXPECT_NE(err, Error::Ok);
554
555 // Loading beyond offsets should fail
556 err = ProgramTestFriend::load_mutable_subsegment_into(
557 &program.get(), 0, 500, 1, buffer);
558 EXPECT_NE(err, Error::Ok);
559
560 // Loading beyond segments should fail
561 err = ProgramTestFriend::load_mutable_subsegment_into(
562 &program.get(), 500, 1, 1, buffer);
563 EXPECT_NE(err, Error::Ok);
564 }
565