xref: /aosp_15_r20/external/executorch/runtime/executor/test/tensor_parser_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/executor/tensor_parser.h>
10 
11 #include <executorch/extension/data_loader/file_data_loader.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/executor/test/managed_memory_manager.h>
14 #include <executorch/schema/program_generated.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace ::testing;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using executorch::runtime::Error;
22 using executorch::runtime::EValue;
23 using executorch::runtime::FreeableBuffer;
24 using executorch::runtime::Program;
25 using executorch::runtime::Result;
26 using executorch::runtime::deserialization::parseTensor;
27 using executorch::runtime::testing::ManagedMemoryManager;
28 using torch::executor::util::FileDataLoader;
29 
30 constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
31 constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
32 
33 class TensorParserTest : public ::testing::Test {
34  protected:
SetUp()35   void SetUp() override {
36     // Load the serialized ModuleAdd data.
37     const char* path = std::getenv("ET_MODULE_ADD_PATH");
38     Result<FileDataLoader> float_loader = FileDataLoader::from(path);
39     ASSERT_EQ(float_loader.error(), Error::Ok);
40     float_loader_ =
41         std::make_unique<FileDataLoader>(std::move(float_loader.get()));
42 
43     // Load the serialized ModuleAddHalf data.
44     const char* half_path = std::getenv("ET_MODULE_ADD_HALF_PATH");
45     Result<FileDataLoader> half_loader = FileDataLoader::from(half_path);
46     ASSERT_EQ(half_loader.error(), Error::Ok);
47     half_loader_ =
48         std::make_unique<FileDataLoader>(std::move(half_loader.get()));
49   }
50 
51   std::unique_ptr<FileDataLoader> float_loader_;
52   std::unique_ptr<FileDataLoader> half_loader_;
53 };
54 
55 namespace executorch {
56 namespace runtime {
57 namespace testing {
58 // Provides access to private Program methods.
59 class ProgramTestFriend final {
60  public:
GetInternalProgram(const Program * program)61   const static executorch_flatbuffer::Program* GetInternalProgram(
62       const Program* program) {
63     return program->internal_program_;
64   }
65 };
66 } // namespace testing
67 } // namespace runtime
68 } // namespace executorch
69 
70 using executorch::runtime::testing::ProgramTestFriend;
71 
test_module_add(std::unique_ptr<FileDataLoader> & loader,ScalarType scalar_type,int type_size)72 void test_module_add(
73     std::unique_ptr<FileDataLoader>& loader,
74     ScalarType scalar_type,
75     int type_size) {
76   Result<Program> program =
77       Program::load(loader.get(), Program::Verification::Minimal);
78   EXPECT_EQ(program.error(), Error::Ok);
79 
80   const Program* program_ = &program.get();
81   ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
82 
83   const executorch_flatbuffer::Program* internal_program =
84       ProgramTestFriend::GetInternalProgram(program_);
85   executorch_flatbuffer::ExecutionPlan* execution_plan =
86       internal_program->execution_plan()->GetMutableObject(0);
87   auto flatbuffer_values = execution_plan->values();
88 
89   int tensor_count = 0;
90   int double_count = 0;
91   for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
92     auto serialization_value = flatbuffer_values->Get(i);
93     if (serialization_value->val_type() ==
94         executorch_flatbuffer::KernelTypes::Tensor) {
95       tensor_count++;
96       Result<Tensor> tensor = parseTensor(
97           program_, &mmm.get(), serialization_value->val_as_Tensor());
98       Tensor t = tensor.get();
99       ASSERT_EQ(scalar_type, t.scalar_type());
100       ASSERT_EQ(2, t.dim()); // [2, 2]
101       ASSERT_EQ(4, t.numel());
102       ASSERT_EQ(type_size * t.numel(), t.nbytes());
103     } else if (
104         serialization_value->val_type() ==
105         executorch_flatbuffer::KernelTypes::Double) {
106       double_count++;
107       ASSERT_EQ(1.0, serialization_value->val_as_Double()->double_val());
108     }
109   }
110   ASSERT_EQ(3, tensor_count); // input x2, output
111   ASSERT_EQ(2, double_count); // alpha x2
112 }
113 
TEST_F(TensorParserTest,TestModuleAddFloat)114 TEST_F(TensorParserTest, TestModuleAddFloat) {
115   test_module_add(float_loader_, ScalarType::Float, sizeof(float));
116 }
117 
TEST_F(TensorParserTest,TestModuleAddHalf)118 TEST_F(TensorParserTest, TestModuleAddHalf) {
119   test_module_add(half_loader_, ScalarType::Half, sizeof(exec_aten::Half));
120 }
121 
TEST_F(TensorParserTest,TestMutableState)122 TEST_F(TensorParserTest, TestMutableState) {
123   // Load the serialized ModuleSimpleTrain data.
124   const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
125   Result<FileDataLoader> train_loader = FileDataLoader::from(path);
126   ASSERT_EQ(train_loader.error(), Error::Ok);
127 
128   Result<Program> program =
129       Program::load(&train_loader.get(), Program::Verification::Minimal);
130   EXPECT_EQ(program.error(), Error::Ok);
131 
132   ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
133   ManagedMemoryManager mmm_copy(
134       kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
135 
136   const executorch_flatbuffer::Program* internal_program =
137       ProgramTestFriend::GetInternalProgram(&program.get());
138   executorch_flatbuffer::ExecutionPlan* execution_plan =
139       internal_program->execution_plan()->GetMutableObject(0);
140   auto flatbuffer_values = execution_plan->values();
141 
142   size_t num_mutable_tensors = 0;
143   for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
144     auto serialization_value = flatbuffer_values->Get(i);
145     if (serialization_value->val_type() ==
146             executorch_flatbuffer::KernelTypes::Tensor &&
147         serialization_value->val_as_Tensor()->allocation_info() != nullptr &&
148         serialization_value->val_as_Tensor()->data_buffer_idx() > 0) {
149       num_mutable_tensors++;
150       Result<torch::executor::Tensor> tensor = parseTensor(
151           &program.get(), &mmm.get(), serialization_value->val_as_Tensor());
152       torch::executor::Tensor t = tensor.get();
153       float loaded_value = t.const_data_ptr<float>()[0];
154       ASSERT_NE(nullptr, t.const_data_ptr());
155       ASSERT_NE(t.mutable_data_ptr<float>()[0], 0.5);
156       t.mutable_data_ptr<float>()[0] = 0.5;
157       ASSERT_EQ(
158           t.mutable_data_ptr<float>()[0],
159           0.5); // 0.5 can be represented perfectly by float so EQ and NE work
160                 // fine here. Any power of 2 rational can be perfectly
161                 // represented. See dyadic rationals for more info.
162 
163       // Load the same tensor using the same mem manager and show the value is
164       // updated again.
165       Result<torch::executor::Tensor> tensor1_alias = parseTensor(
166           &program.get(), &mmm.get(), serialization_value->val_as_Tensor());
167       torch::executor::Tensor t2 = tensor.get();
168       ASSERT_NE(t2.mutable_data_ptr<float>()[0], 0.5);
169 
170       // Show the tensors are equivalent
171       ASSERT_EQ(t.const_data_ptr(), t2.const_data_ptr());
172       // Set mutable tensor value back to 0.5 since it got overwritten by second
173       // parse.
174       t.mutable_data_ptr<float>()[0] = 0.5;
175 
176       // Load the same tensor using a different mem manager and show the value
177       // is not the same as t.
178       Result<torch::executor::Tensor> tensor_new = parseTensor(
179           &program.get(),
180           &mmm_copy.get(),
181           serialization_value->val_as_Tensor());
182       torch::executor::Tensor t3 = tensor_new.get();
183       ASSERT_NE(t3.mutable_data_ptr<float>()[0], 0.5);
184       ASSERT_NE(t3.const_data_ptr(), t.const_data_ptr());
185       ASSERT_EQ(loaded_value, t3.const_data_ptr<float>()[0]);
186     }
187   }
188   ASSERT_EQ(num_mutable_tensors, 2);
189 }
190