1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/executor/tensor_parser.h>
10
11 #include <executorch/extension/data_loader/file_data_loader.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/executor/test/managed_memory_manager.h>
14 #include <executorch/schema/program_generated.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using executorch::runtime::Error;
22 using executorch::runtime::EValue;
23 using executorch::runtime::FreeableBuffer;
24 using executorch::runtime::Program;
25 using executorch::runtime::Result;
26 using executorch::runtime::deserialization::parseTensor;
27 using executorch::runtime::testing::ManagedMemoryManager;
28 using torch::executor::util::FileDataLoader;
29
30 constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
31 constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
32
33 class TensorParserTest : public ::testing::Test {
34 protected:
SetUp()35 void SetUp() override {
36 // Load the serialized ModuleAdd data.
37 const char* path = std::getenv("ET_MODULE_ADD_PATH");
38 Result<FileDataLoader> float_loader = FileDataLoader::from(path);
39 ASSERT_EQ(float_loader.error(), Error::Ok);
40 float_loader_ =
41 std::make_unique<FileDataLoader>(std::move(float_loader.get()));
42
43 // Load the serialized ModuleAddHalf data.
44 const char* half_path = std::getenv("ET_MODULE_ADD_HALF_PATH");
45 Result<FileDataLoader> half_loader = FileDataLoader::from(half_path);
46 ASSERT_EQ(half_loader.error(), Error::Ok);
47 half_loader_ =
48 std::make_unique<FileDataLoader>(std::move(half_loader.get()));
49 }
50
51 std::unique_ptr<FileDataLoader> float_loader_;
52 std::unique_ptr<FileDataLoader> half_loader_;
53 };
54
55 namespace executorch {
56 namespace runtime {
57 namespace testing {
58 // Provides access to private Program methods.
59 class ProgramTestFriend final {
60 public:
GetInternalProgram(const Program * program)61 const static executorch_flatbuffer::Program* GetInternalProgram(
62 const Program* program) {
63 return program->internal_program_;
64 }
65 };
66 } // namespace testing
67 } // namespace runtime
68 } // namespace executorch
69
70 using executorch::runtime::testing::ProgramTestFriend;
71
test_module_add(std::unique_ptr<FileDataLoader> & loader,ScalarType scalar_type,int type_size)72 void test_module_add(
73 std::unique_ptr<FileDataLoader>& loader,
74 ScalarType scalar_type,
75 int type_size) {
76 Result<Program> program =
77 Program::load(loader.get(), Program::Verification::Minimal);
78 EXPECT_EQ(program.error(), Error::Ok);
79
80 const Program* program_ = &program.get();
81 ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
82
83 const executorch_flatbuffer::Program* internal_program =
84 ProgramTestFriend::GetInternalProgram(program_);
85 executorch_flatbuffer::ExecutionPlan* execution_plan =
86 internal_program->execution_plan()->GetMutableObject(0);
87 auto flatbuffer_values = execution_plan->values();
88
89 int tensor_count = 0;
90 int double_count = 0;
91 for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
92 auto serialization_value = flatbuffer_values->Get(i);
93 if (serialization_value->val_type() ==
94 executorch_flatbuffer::KernelTypes::Tensor) {
95 tensor_count++;
96 Result<Tensor> tensor = parseTensor(
97 program_, &mmm.get(), serialization_value->val_as_Tensor());
98 Tensor t = tensor.get();
99 ASSERT_EQ(scalar_type, t.scalar_type());
100 ASSERT_EQ(2, t.dim()); // [2, 2]
101 ASSERT_EQ(4, t.numel());
102 ASSERT_EQ(type_size * t.numel(), t.nbytes());
103 } else if (
104 serialization_value->val_type() ==
105 executorch_flatbuffer::KernelTypes::Double) {
106 double_count++;
107 ASSERT_EQ(1.0, serialization_value->val_as_Double()->double_val());
108 }
109 }
110 ASSERT_EQ(3, tensor_count); // input x2, output
111 ASSERT_EQ(2, double_count); // alpha x2
112 }
113
TEST_F(TensorParserTest,TestModuleAddFloat)114 TEST_F(TensorParserTest, TestModuleAddFloat) {
115 test_module_add(float_loader_, ScalarType::Float, sizeof(float));
116 }
117
TEST_F(TensorParserTest,TestModuleAddHalf)118 TEST_F(TensorParserTest, TestModuleAddHalf) {
119 test_module_add(half_loader_, ScalarType::Half, sizeof(exec_aten::Half));
120 }
121
TEST_F(TensorParserTest,TestMutableState)122 TEST_F(TensorParserTest, TestMutableState) {
123 // Load the serialized ModuleSimpleTrain data.
124 const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
125 Result<FileDataLoader> train_loader = FileDataLoader::from(path);
126 ASSERT_EQ(train_loader.error(), Error::Ok);
127
128 Result<Program> program =
129 Program::load(&train_loader.get(), Program::Verification::Minimal);
130 EXPECT_EQ(program.error(), Error::Ok);
131
132 ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
133 ManagedMemoryManager mmm_copy(
134 kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
135
136 const executorch_flatbuffer::Program* internal_program =
137 ProgramTestFriend::GetInternalProgram(&program.get());
138 executorch_flatbuffer::ExecutionPlan* execution_plan =
139 internal_program->execution_plan()->GetMutableObject(0);
140 auto flatbuffer_values = execution_plan->values();
141
142 size_t num_mutable_tensors = 0;
143 for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
144 auto serialization_value = flatbuffer_values->Get(i);
145 if (serialization_value->val_type() ==
146 executorch_flatbuffer::KernelTypes::Tensor &&
147 serialization_value->val_as_Tensor()->allocation_info() != nullptr &&
148 serialization_value->val_as_Tensor()->data_buffer_idx() > 0) {
149 num_mutable_tensors++;
150 Result<torch::executor::Tensor> tensor = parseTensor(
151 &program.get(), &mmm.get(), serialization_value->val_as_Tensor());
152 torch::executor::Tensor t = tensor.get();
153 float loaded_value = t.const_data_ptr<float>()[0];
154 ASSERT_NE(nullptr, t.const_data_ptr());
155 ASSERT_NE(t.mutable_data_ptr<float>()[0], 0.5);
156 t.mutable_data_ptr<float>()[0] = 0.5;
157 ASSERT_EQ(
158 t.mutable_data_ptr<float>()[0],
159 0.5); // 0.5 can be represented perfectly by float so EQ and NE work
160 // fine here. Any power of 2 rational can be perfectly
161 // represented. See dyadic rationals for more info.
162
163 // Load the same tensor using the same mem manager and show the value is
164 // updated again.
165 Result<torch::executor::Tensor> tensor1_alias = parseTensor(
166 &program.get(), &mmm.get(), serialization_value->val_as_Tensor());
167 torch::executor::Tensor t2 = tensor.get();
168 ASSERT_NE(t2.mutable_data_ptr<float>()[0], 0.5);
169
170 // Show the tensors are equivalent
171 ASSERT_EQ(t.const_data_ptr(), t2.const_data_ptr());
172 // Set mutable tensor value back to 0.5 since it got overwritten by second
173 // parse.
174 t.mutable_data_ptr<float>()[0] = 0.5;
175
176 // Load the same tensor using a different mem manager and show the value
177 // is not the same as t.
178 Result<torch::executor::Tensor> tensor_new = parseTensor(
179 &program.get(),
180 &mmm_copy.get(),
181 serialization_value->val_as_Tensor());
182 torch::executor::Tensor t3 = tensor_new.get();
183 ASSERT_NE(t3.mutable_data_ptr<float>()[0], 0.5);
184 ASSERT_NE(t3.const_data_ptr(), t.const_data_ptr());
185 ASSERT_EQ(loaded_value, t3.const_data_ptr<float>()[0]);
186 }
187 }
188 ASSERT_EQ(num_mutable_tensors, 2);
189 }
190