1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/runner_util/inputs.h>
10
11 #include <executorch/extension/data_loader/file_data_loader.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/span.h>
14 #include <executorch/runtime/executor/method.h>
15 #include <executorch/runtime/executor/program.h>
16 #include <executorch/runtime/executor/test/managed_memory_manager.h>
17 #include <executorch/runtime/platform/runtime.h>
18 #include <gtest/gtest.h>
19
20 using namespace ::testing;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using executorch::extension::BufferCleanup;
24 using executorch::extension::FileDataLoader;
25 using executorch::extension::prepare_input_tensors;
26 using executorch::runtime::Error;
27 using executorch::runtime::EValue;
28 using executorch::runtime::MemoryAllocator;
29 using executorch::runtime::MemoryManager;
30 using executorch::runtime::Method;
31 using executorch::runtime::Program;
32 using executorch::runtime::Result;
33 using executorch::runtime::Span;
34 using executorch::runtime::Tag;
35 using executorch::runtime::testing::ManagedMemoryManager;
36
37 class InputsTest : public ::testing::Test {
38 protected:
SetUp()39 void SetUp() override {
40 torch::executor::runtime_init();
41
42 // Create a loader for the serialized ModuleAdd program.
43 const char* path = std::getenv("ET_MODULE_ADD_PATH");
44 Result<FileDataLoader> loader = FileDataLoader::from(path);
45 ASSERT_EQ(loader.error(), Error::Ok);
46 loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
47
48 // Use it to load the program.
49 Result<Program> program = Program::load(
50 loader_.get(), Program::Verification::InternalConsistency);
51 ASSERT_EQ(program.error(), Error::Ok);
52 program_ = std::make_unique<Program>(std::move(program.get()));
53
54 mmm_ = std::make_unique<ManagedMemoryManager>(
55 /*planned_memory_bytes=*/32 * 1024U,
56 /*method_allocator_bytes=*/32 * 1024U);
57
58 // Load the forward method.
59 Result<Method> method = program_->load_method("forward", &mmm_->get());
60 ASSERT_EQ(method.error(), Error::Ok);
61 method_ = std::make_unique<Method>(std::move(method.get()));
62 }
63
64 private:
65 // Must outlive method_, but tests shouldn't need to touch them.
66 std::unique_ptr<FileDataLoader> loader_;
67 std::unique_ptr<ManagedMemoryManager> mmm_;
68 std::unique_ptr<Program> program_;
69
70 protected:
71 std::unique_ptr<Method> method_;
72 };
73
TEST_F(InputsTest,Smoke)74 TEST_F(InputsTest, Smoke) {
75 Result<BufferCleanup> input_buffers = prepare_input_tensors(*method_);
76 ASSERT_EQ(input_buffers.error(), Error::Ok);
77
78 // We can't look at the input tensors, but we can check that the outputs make
79 // sense after executing the method.
80 Error status = method_->execute();
81 ASSERT_EQ(status, Error::Ok);
82
83 // Get the single output, which should be a floating-point Tensor.
84 ASSERT_EQ(method_->outputs_size(), 1);
85 const EValue& output_value = method_->get_output(0);
86 ASSERT_EQ(output_value.tag, Tag::Tensor);
87 Tensor output = output_value.toTensor();
88 ASSERT_EQ(output.scalar_type(), ScalarType::Float);
89
90 // ModuleAdd adds its two inputs together, so if the input elements were set
91 // to 1, the output elemements should all be 2.
92 Span<float> elements(output.mutable_data_ptr<float>(), output.numel());
93 EXPECT_GT(elements.size(), 0); // Make sure we're actually testing something.
94 for (float e : elements) {
95 EXPECT_EQ(e, 2.0);
96 }
97
98 // Although it's tough to test directly, ASAN should let us know if
99 // BufferCleanup doesn't behave properly: either freeing too soon or leaking
100 // the pointers.
101 }
102
TEST(BufferCleanupTest,Smoke)103 TEST(BufferCleanupTest, Smoke) {
104 // Returns the size of the buffer at index `i`.
105 auto test_buffer_size = [](size_t i) {
106 // Use multiples of OS page sizes. As this gets bigger, we're more
107 // likely to allocate outside the main heap in a separate page, making
108 // it easier to catch uses-after-free.
109 return 4096 << i;
110 };
111
112 // Create some buffers.
113 constexpr size_t kNumBuffers = 8;
114 void** buffers = (void**)malloc(kNumBuffers * sizeof(void*));
115 for (int i = 0; i < kNumBuffers; i++) {
116 size_t nbytes = test_buffer_size(i);
117 buffers[i] = malloc(nbytes);
118 memset(reinterpret_cast<char*>(buffers[i]), 0x00, nbytes);
119 }
120
121 std::unique_ptr<BufferCleanup> bc2;
122 {
123 // bc1 should own `buffers` and the buffers that its entries point to.
124 BufferCleanup bc1({buffers, kNumBuffers});
125
126 // They're still alive; no segfaults or ASAN complaints if we write to them.
127 for (int i = 0; i < kNumBuffers; i++) {
128 size_t nbytes = test_buffer_size(i);
129 memset(reinterpret_cast<char*>(buffers[i]), 0xff, nbytes);
130 }
131
132 // Move ownership to a new object.
133 bc2 = std::make_unique<BufferCleanup>(std::move(bc1));
134
135 // Still alive.
136 for (int i = 0; i < kNumBuffers; i++) {
137 size_t nbytes = test_buffer_size(i);
138 memset(reinterpret_cast<char*>(buffers[i]), 0x00, nbytes);
139 }
140
141 // bc1 goes out of scope here. If it thinks it owns the buffers, it will
142 // try to free them.
143 }
144
145 // bc2 should own the buffers now, and they should still be alive.
146 for (int i = 0; i < kNumBuffers; i++) {
147 size_t nbytes = test_buffer_size(i);
148 memset(reinterpret_cast<char*>(buffers[i]), 0xff, nbytes);
149 }
150
151 // Destroy bc2, which should destroy the buffers. There's no way for us to
152 // check that it happened, but the sanitizer should complain if there's a
153 // memory leak. And if bc1 freed them before, we should get a double-free
154 // complaint.
155 bc2.reset();
156 }
157