1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h"
16
17 #include <fcntl.h>
18 #include <sys/stat.h>
19 #include <unistd.h>
20
21 #include <memory>
22 #include <string>
23
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "absl/strings/str_format.h"
27 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
28 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/embedded_mobilenet_model.h"
29 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark_test_helper.h"
30 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h"
31 #include "tensorflow/lite/model_builder.h"
32 #include "tensorflow/lite/schema/schema_generated.h"
33
34 namespace tflite {
35 namespace acceleration {
36 namespace {
37
38 using ::testing::IsNull;
39 using ::testing::Not;
40 using ::testing::WhenDynamicCastTo;
41
42 class ModelLoaderTest : public ::testing::Test {
43 protected:
SetUp()44 void SetUp() override {
45 model_path_ = MiniBenchmarkTestHelper::DumpToTempFile(
46 "mobilenet_quant.tflite",
47 g_tflite_acceleration_embedded_mobilenet_model,
48 g_tflite_acceleration_embedded_mobilenet_model_len);
49 }
50 std::string model_path_;
51 };
52
TEST_F(ModelLoaderTest,CreateFromModelPath)53 TEST_F(ModelLoaderTest, CreateFromModelPath) {
54 auto model_loader = std::make_unique<PathModelLoader>(model_path_);
55
56 ASSERT_NE(model_loader, nullptr);
57 EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess);
58 }
59
TEST_F(ModelLoaderTest,CreateFromFdPath)60 TEST_F(ModelLoaderTest, CreateFromFdPath) {
61 int fd = open(model_path_.c_str(), O_RDONLY);
62 ASSERT_GE(fd, 0);
63 struct stat stat_buf = {0};
64 ASSERT_EQ(fstat(fd, &stat_buf), 0);
65 auto model_loader =
66 std::make_unique<MmapModelLoader>(fd, 0, stat_buf.st_size);
67 close(fd);
68
69 ASSERT_NE(model_loader, nullptr);
70 EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess);
71 }
72
TEST_F(ModelLoaderTest,CreateFromPipePath)73 TEST_F(ModelLoaderTest, CreateFromPipePath) {
74 // Setup.
75 // Read the model and serialize it.
76 auto model = FlatBufferModel::BuildFromFile(model_path_.c_str());
77 flatbuffers::FlatBufferBuilder fbb;
78 ModelT model_obj;
79 model->GetModel()->UnPackTo(&model_obj);
80 std::string model_description = model_obj.description;
81 fbb.Finish(CreateModel(fbb, &model_obj));
82 int pipe_fds[2];
83 ASSERT_EQ(pipe(pipe_fds), 0);
84 pid_t r = fork();
85 // Child thread to write to pipe.
86 if (r == 0) {
87 close(pipe_fds[0]);
88 int written_bytes = 0;
89 int remaining_bytes = fbb.GetSize();
90 uint8_t* buffer = fbb.GetBufferPointer();
91 while (remaining_bytes > 0 &&
92 (written_bytes = write(pipe_fds[1], buffer, remaining_bytes)) > 0) {
93 remaining_bytes -= written_bytes;
94 buffer += written_bytes;
95 }
96 close(pipe_fds[1]);
97 ASSERT_TRUE(written_bytes > 0 && remaining_bytes == 0);
98 _exit(0);
99 }
100
101 // Execute.
102 // Parent thread.
103 // Close the write pipe.
104 close(pipe_fds[1]);
105 auto model_loader =
106 std::make_unique<PipeModelLoader>(pipe_fds[0], fbb.GetSize());
107 ASSERT_NE(model_loader, nullptr);
108
109 // Verify.
110 EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess);
111 EXPECT_EQ(model_loader->GetModel()->GetModel()->description()->string_view(),
112 model_description);
113 }
114
TEST_F(ModelLoaderTest,InvalidModelPath)115 TEST_F(ModelLoaderTest, InvalidModelPath) {
116 auto model_loader = std::make_unique<PathModelLoader>("invalid/path");
117
118 ASSERT_NE(model_loader, nullptr);
119 EXPECT_THAT(model_loader->Init(), kMinibenchmarkModelBuildFailed);
120 }
121
TEST_F(ModelLoaderTest,InvalidFd)122 TEST_F(ModelLoaderTest, InvalidFd) {
123 auto model_loader = std::make_unique<MmapModelLoader>(0, 5, 10);
124
125 ASSERT_NE(model_loader, nullptr);
126 EXPECT_THAT(model_loader->Init(), kMinibenchmarkModelReadFailed);
127 }
128
TEST_F(ModelLoaderTest,InvalidPipe)129 TEST_F(ModelLoaderTest, InvalidPipe) {
130 auto model_loader = std::make_unique<PipeModelLoader>(-1, 10);
131
132 ASSERT_NE(model_loader, nullptr);
133 EXPECT_THAT(model_loader->Init(), kMinibenchmarkModelReadFailed);
134 }
135
TEST_F(ModelLoaderTest,CreateModelLoaderFromValidPath)136 TEST_F(ModelLoaderTest, CreateModelLoaderFromValidPath) {
137 EXPECT_THAT(CreateModelLoaderFromPath("a/b/c").get(),
138 WhenDynamicCastTo<PathModelLoader*>(Not(IsNull())));
139 EXPECT_THAT(CreateModelLoaderFromPath("fd:1:2:3").get(),
140 WhenDynamicCastTo<MmapModelLoader*>(Not(IsNull())));
141 EXPECT_THAT(CreateModelLoaderFromPath("pipe:1:2:3").get(),
142 WhenDynamicCastTo<PipeModelLoader*>(Not(IsNull())));
143 }
144
TEST_F(ModelLoaderTest,CreateModelLoaderFromInvalidPath)145 TEST_F(ModelLoaderTest, CreateModelLoaderFromInvalidPath) {
146 EXPECT_EQ(CreateModelLoaderFromPath("fd:1"), nullptr);
147 EXPECT_EQ(CreateModelLoaderFromPath("fd:1:2:3:4"), nullptr);
148
149 EXPECT_EQ(CreateModelLoaderFromPath("pipe:1"), nullptr);
150 EXPECT_EQ(CreateModelLoaderFromPath("pipe:1:2:3:4"), nullptr);
151 }
152
153 } // namespace
154 } // namespace acceleration
155 } // namespace tflite
156