xref: /aosp_15_r20/external/executorch/extension/runner_util/test/inputs_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/runner_util/inputs.h>
10 
11 #include <executorch/extension/data_loader/file_data_loader.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/span.h>
14 #include <executorch/runtime/executor/method.h>
15 #include <executorch/runtime/executor/program.h>
16 #include <executorch/runtime/executor/test/managed_memory_manager.h>
17 #include <executorch/runtime/platform/runtime.h>
18 #include <gtest/gtest.h>
19 
20 using namespace ::testing;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using executorch::extension::BufferCleanup;
24 using executorch::extension::FileDataLoader;
25 using executorch::extension::prepare_input_tensors;
26 using executorch::runtime::Error;
27 using executorch::runtime::EValue;
28 using executorch::runtime::MemoryAllocator;
29 using executorch::runtime::MemoryManager;
30 using executorch::runtime::Method;
31 using executorch::runtime::Program;
32 using executorch::runtime::Result;
33 using executorch::runtime::Span;
34 using executorch::runtime::Tag;
35 using executorch::runtime::testing::ManagedMemoryManager;
36 
37 class InputsTest : public ::testing::Test {
38  protected:
SetUp()39   void SetUp() override {
40     torch::executor::runtime_init();
41 
42     // Create a loader for the serialized ModuleAdd program.
43     const char* path = std::getenv("ET_MODULE_ADD_PATH");
44     Result<FileDataLoader> loader = FileDataLoader::from(path);
45     ASSERT_EQ(loader.error(), Error::Ok);
46     loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
47 
48     // Use it to load the program.
49     Result<Program> program = Program::load(
50         loader_.get(), Program::Verification::InternalConsistency);
51     ASSERT_EQ(program.error(), Error::Ok);
52     program_ = std::make_unique<Program>(std::move(program.get()));
53 
54     mmm_ = std::make_unique<ManagedMemoryManager>(
55         /*planned_memory_bytes=*/32 * 1024U,
56         /*method_allocator_bytes=*/32 * 1024U);
57 
58     // Load the forward method.
59     Result<Method> method = program_->load_method("forward", &mmm_->get());
60     ASSERT_EQ(method.error(), Error::Ok);
61     method_ = std::make_unique<Method>(std::move(method.get()));
62   }
63 
64  private:
65   // Must outlive method_, but tests shouldn't need to touch them.
66   std::unique_ptr<FileDataLoader> loader_;
67   std::unique_ptr<ManagedMemoryManager> mmm_;
68   std::unique_ptr<Program> program_;
69 
70  protected:
71   std::unique_ptr<Method> method_;
72 };
73 
TEST_F(InputsTest,Smoke)74 TEST_F(InputsTest, Smoke) {
75   Result<BufferCleanup> input_buffers = prepare_input_tensors(*method_);
76   ASSERT_EQ(input_buffers.error(), Error::Ok);
77 
78   // We can't look at the input tensors, but we can check that the outputs make
79   // sense after executing the method.
80   Error status = method_->execute();
81   ASSERT_EQ(status, Error::Ok);
82 
83   // Get the single output, which should be a floating-point Tensor.
84   ASSERT_EQ(method_->outputs_size(), 1);
85   const EValue& output_value = method_->get_output(0);
86   ASSERT_EQ(output_value.tag, Tag::Tensor);
87   Tensor output = output_value.toTensor();
88   ASSERT_EQ(output.scalar_type(), ScalarType::Float);
89 
90   // ModuleAdd adds its two inputs together, so if the input elements were set
91   // to 1, the output elemements should all be 2.
92   Span<float> elements(output.mutable_data_ptr<float>(), output.numel());
93   EXPECT_GT(elements.size(), 0); // Make sure we're actually testing something.
94   for (float e : elements) {
95     EXPECT_EQ(e, 2.0);
96   }
97 
98   // Although it's tough to test directly, ASAN should let us know if
99   // BufferCleanup doesn't behave properly: either freeing too soon or leaking
100   // the pointers.
101 }
102 
TEST(BufferCleanupTest,Smoke)103 TEST(BufferCleanupTest, Smoke) {
104   // Returns the size of the buffer at index `i`.
105   auto test_buffer_size = [](size_t i) {
106     // Use multiples of OS page sizes. As this gets bigger, we're more
107     // likely to allocate outside the main heap in a separate page, making
108     // it easier to catch uses-after-free.
109     return 4096 << i;
110   };
111 
112   // Create some buffers.
113   constexpr size_t kNumBuffers = 8;
114   void** buffers = (void**)malloc(kNumBuffers * sizeof(void*));
115   for (int i = 0; i < kNumBuffers; i++) {
116     size_t nbytes = test_buffer_size(i);
117     buffers[i] = malloc(nbytes);
118     memset(reinterpret_cast<char*>(buffers[i]), 0x00, nbytes);
119   }
120 
121   std::unique_ptr<BufferCleanup> bc2;
122   {
123     // bc1 should own `buffers` and the buffers that its entries point to.
124     BufferCleanup bc1({buffers, kNumBuffers});
125 
126     // They're still alive; no segfaults or ASAN complaints if we write to them.
127     for (int i = 0; i < kNumBuffers; i++) {
128       size_t nbytes = test_buffer_size(i);
129       memset(reinterpret_cast<char*>(buffers[i]), 0xff, nbytes);
130     }
131 
132     // Move ownership to a new object.
133     bc2 = std::make_unique<BufferCleanup>(std::move(bc1));
134 
135     // Still alive.
136     for (int i = 0; i < kNumBuffers; i++) {
137       size_t nbytes = test_buffer_size(i);
138       memset(reinterpret_cast<char*>(buffers[i]), 0x00, nbytes);
139     }
140 
141     // bc1 goes out of scope here. If it thinks it owns the buffers, it will
142     // try to free them.
143   }
144 
145   // bc2 should own the buffers now, and they should still be alive.
146   for (int i = 0; i < kNumBuffers; i++) {
147     size_t nbytes = test_buffer_size(i);
148     memset(reinterpret_cast<char*>(buffers[i]), 0xff, nbytes);
149   }
150 
151   // Destroy bc2, which should destroy the buffers. There's no way for us to
152   // check that it happened, but the sanitizer should complain if there's a
153   // memory leak. And if bc1 freed them before, we should get a double-free
154   // complaint.
155   bc2.reset();
156 }
157