1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h"
16 
17 #include <fcntl.h>
18 #include <sys/stat.h>
19 #include <unistd.h>
20 
21 #include <memory>
22 #include <string>
23 
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "absl/strings/str_format.h"
27 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
28 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/embedded_mobilenet_model.h"
29 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark_test_helper.h"
30 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h"
31 #include "tensorflow/lite/model_builder.h"
32 #include "tensorflow/lite/schema/schema_generated.h"
33 
34 namespace tflite {
35 namespace acceleration {
36 namespace {
37 
38 using ::testing::IsNull;
39 using ::testing::Not;
40 using ::testing::WhenDynamicCastTo;
41 
42 class ModelLoaderTest : public ::testing::Test {
43  protected:
SetUp()44   void SetUp() override {
45     model_path_ = MiniBenchmarkTestHelper::DumpToTempFile(
46         "mobilenet_quant.tflite",
47         g_tflite_acceleration_embedded_mobilenet_model,
48         g_tflite_acceleration_embedded_mobilenet_model_len);
49   }
50   std::string model_path_;
51 };
52 
TEST_F(ModelLoaderTest,CreateFromModelPath)53 TEST_F(ModelLoaderTest, CreateFromModelPath) {
54   auto model_loader = std::make_unique<PathModelLoader>(model_path_);
55 
56   ASSERT_NE(model_loader, nullptr);
57   EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess);
58 }
59 
TEST_F(ModelLoaderTest,CreateFromFdPath)60 TEST_F(ModelLoaderTest, CreateFromFdPath) {
61   int fd = open(model_path_.c_str(), O_RDONLY);
62   ASSERT_GE(fd, 0);
63   struct stat stat_buf = {0};
64   ASSERT_EQ(fstat(fd, &stat_buf), 0);
65   auto model_loader =
66       std::make_unique<MmapModelLoader>(fd, 0, stat_buf.st_size);
67   close(fd);
68 
69   ASSERT_NE(model_loader, nullptr);
70   EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess);
71 }
72 
TEST_F(ModelLoaderTest,CreateFromPipePath)73 TEST_F(ModelLoaderTest, CreateFromPipePath) {
74   // Setup.
75   // Read the model and serialize it.
76   auto model = FlatBufferModel::BuildFromFile(model_path_.c_str());
77   flatbuffers::FlatBufferBuilder fbb;
78   ModelT model_obj;
79   model->GetModel()->UnPackTo(&model_obj);
80   std::string model_description = model_obj.description;
81   fbb.Finish(CreateModel(fbb, &model_obj));
82   int pipe_fds[2];
83   ASSERT_EQ(pipe(pipe_fds), 0);
84   pid_t r = fork();
85   // Child thread to write to pipe.
86   if (r == 0) {
87     close(pipe_fds[0]);
88     int written_bytes = 0;
89     int remaining_bytes = fbb.GetSize();
90     uint8_t* buffer = fbb.GetBufferPointer();
91     while (remaining_bytes > 0 &&
92            (written_bytes = write(pipe_fds[1], buffer, remaining_bytes)) > 0) {
93       remaining_bytes -= written_bytes;
94       buffer += written_bytes;
95     }
96     close(pipe_fds[1]);
97     ASSERT_TRUE(written_bytes > 0 && remaining_bytes == 0);
98     _exit(0);
99   }
100 
101   // Execute.
102   // Parent thread.
103   // Close the write pipe.
104   close(pipe_fds[1]);
105   auto model_loader =
106       std::make_unique<PipeModelLoader>(pipe_fds[0], fbb.GetSize());
107   ASSERT_NE(model_loader, nullptr);
108 
109   // Verify.
110   EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess);
111   EXPECT_EQ(model_loader->GetModel()->GetModel()->description()->string_view(),
112             model_description);
113 }
114 
TEST_F(ModelLoaderTest,InvalidModelPath)115 TEST_F(ModelLoaderTest, InvalidModelPath) {
116   auto model_loader = std::make_unique<PathModelLoader>("invalid/path");
117 
118   ASSERT_NE(model_loader, nullptr);
119   EXPECT_THAT(model_loader->Init(), kMinibenchmarkModelBuildFailed);
120 }
121 
TEST_F(ModelLoaderTest,InvalidFd)122 TEST_F(ModelLoaderTest, InvalidFd) {
123   auto model_loader = std::make_unique<MmapModelLoader>(0, 5, 10);
124 
125   ASSERT_NE(model_loader, nullptr);
126   EXPECT_THAT(model_loader->Init(), kMinibenchmarkModelReadFailed);
127 }
128 
TEST_F(ModelLoaderTest,InvalidPipe)129 TEST_F(ModelLoaderTest, InvalidPipe) {
130   auto model_loader = std::make_unique<PipeModelLoader>(-1, 10);
131 
132   ASSERT_NE(model_loader, nullptr);
133   EXPECT_THAT(model_loader->Init(), kMinibenchmarkModelReadFailed);
134 }
135 
TEST_F(ModelLoaderTest,CreateModelLoaderFromValidPath)136 TEST_F(ModelLoaderTest, CreateModelLoaderFromValidPath) {
137   EXPECT_THAT(CreateModelLoaderFromPath("a/b/c").get(),
138               WhenDynamicCastTo<PathModelLoader*>(Not(IsNull())));
139   EXPECT_THAT(CreateModelLoaderFromPath("fd:1:2:3").get(),
140               WhenDynamicCastTo<MmapModelLoader*>(Not(IsNull())));
141   EXPECT_THAT(CreateModelLoaderFromPath("pipe:1:2:3").get(),
142               WhenDynamicCastTo<PipeModelLoader*>(Not(IsNull())));
143 }
144 
TEST_F(ModelLoaderTest,CreateModelLoaderFromInvalidPath)145 TEST_F(ModelLoaderTest, CreateModelLoaderFromInvalidPath) {
146   EXPECT_EQ(CreateModelLoaderFromPath("fd:1"), nullptr);
147   EXPECT_EQ(CreateModelLoaderFromPath("fd:1:2:3:4"), nullptr);
148 
149   EXPECT_EQ(CreateModelLoaderFromPath("pipe:1"), nullptr);
150   EXPECT_EQ(CreateModelLoaderFromPath("pipe:1:2:3:4"), nullptr);
151 }
152 
153 }  // namespace
154 }  // namespace acceleration
155 }  // namespace tflite
156