1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/engine/tflite_wrapper.h"
17
18 #include <fstream>
19 #include <string>
20 #include <utility>
21
22 #include "gtest/gtest.h"
23 #include "fcp/client/interruptible_runner.h"
24 #include "fcp/client/test_helpers.h"
25 #include "fcp/testing/testing.h"
26
27 namespace fcp {
28 namespace client {
29 namespace engine {
30 namespace {
31
32 const absl::string_view kAssetsPath = "fcp/client/engine/data/";
33 const absl::string_view kJoinModelFile = "join_model.flatbuffer";
34
35 const int32_t kNumThreads = 4;
36
37 class TfLiteWrapperTest : public testing::Test {
38 protected:
ReadFileAsString(const std::string & path)39 absl::StatusOr<std::string> ReadFileAsString(const std::string& path) {
40 std::ifstream input_istream(path);
41 if (!input_istream) {
42 return absl::InternalError("Failed to create input stream.");
43 }
44 std::stringstream output_stream;
45 output_stream << input_istream.rdbuf();
46 return output_stream.str();
47 }
48
49 MockLogManager mock_log_manager_;
50 InterruptibleRunner::TimingConfig default_timing_config_ =
51 InterruptibleRunner::TimingConfig{
52 .polling_period = absl::Milliseconds(1000),
53 .graceful_shutdown_period = absl::Milliseconds(1000),
54 .extended_shutdown_period = absl::Milliseconds(2000),
55 };
56 std::vector<std::string> output_names_ = {"Identity"};
57 TfLiteInterpreterOptions options_ = {
58 .ensure_dynamic_tensors_are_released = true,
59 .large_tensor_threshold_for_dynamic_allocation = 1000};
60 };
61
TEST_F(TfLiteWrapperTest,InvalidModel)62 TEST_F(TfLiteWrapperTest, InvalidModel) {
63 EXPECT_THAT(
64 TfLiteWrapper::Create(
65 "INVALID_FLATBUFFER", []() { return false; }, default_timing_config_,
66 &mock_log_manager_,
67 std::make_unique<absl::flat_hash_map<std::string, std::string>>(),
68 output_names_, options_, kNumThreads),
69 IsCode(INVALID_ARGUMENT));
70 }
71
TEST_F(TfLiteWrapperTest,InputNotSet)72 TEST_F(TfLiteWrapperTest, InputNotSet) {
73 auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
74 ASSERT_OK(plan);
75 // The plan that we use here join two strings. It requires two string tensors
76 // as input. We didn't pass the required tensor, therefore, we expect an
77 // internal error to be thrown.
78 EXPECT_THAT(
79 TfLiteWrapper::Create(
80 *plan, []() { return false; }, default_timing_config_,
81 &mock_log_manager_,
82 std::make_unique<absl::flat_hash_map<std::string, std::string>>(),
83 output_names_, options_, kNumThreads),
84 IsCode(INVALID_ARGUMENT));
85 }
86
TEST_F(TfLiteWrapperTest,WrongNumberOfOutputs)87 TEST_F(TfLiteWrapperTest, WrongNumberOfOutputs) {
88 auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
89 ASSERT_OK(plan);
90 // The plan that we use here join two strings. It requires two string tensors
91 // as input. We didn't pass the required tensor, therefore, we expect an
92 // internal error to be thrown.
93 EXPECT_THAT(
94 TfLiteWrapper::Create(
95 *plan, []() { return false; }, default_timing_config_,
96 &mock_log_manager_,
97 std::make_unique<absl::flat_hash_map<std::string, std::string>>(),
98 {"Identity", "EXTRA"}, options_, kNumThreads),
99 IsCode(INVALID_ARGUMENT));
100 }
101
TEST_F(TfLiteWrapperTest,Aborted)102 TEST_F(TfLiteWrapperTest, Aborted) {
103 auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
104 ASSERT_OK(plan);
105 auto inputs =
106 std::make_unique<absl::flat_hash_map<std::string, std::string>>();
107 (*inputs)["x"] = "abc";
108 (*inputs)["y"] = "def";
109 // The should_abort function is set to always return true, therefore we expect
110 // to see a CANCELLED status when we run the plan.
111 auto wrapper = TfLiteWrapper::Create(
112 *plan, []() { return true; }, default_timing_config_, &mock_log_manager_,
113 std::move(inputs), output_names_, options_, kNumThreads);
114 ASSERT_OK(wrapper);
115 EXPECT_THAT((*wrapper)->Run(), IsCode(CANCELLED));
116 }
117
TEST_F(TfLiteWrapperTest,Success)118 TEST_F(TfLiteWrapperTest, Success) {
119 auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
120 ASSERT_OK(plan);
121 auto inputs =
122 std::make_unique<absl::flat_hash_map<std::string, std::string>>();
123 (*inputs)["x"] = "abc";
124 (*inputs)["y"] = "def";
125 auto wrapper = TfLiteWrapper::Create(
126 *plan, []() { return false; }, default_timing_config_, &mock_log_manager_,
127 std::move(inputs), output_names_, options_, kNumThreads);
128 EXPECT_THAT(wrapper, IsCode(OK));
129 auto outputs = (*wrapper)->Run();
130 ASSERT_OK(outputs);
131 EXPECT_EQ(outputs->output_tensor_names.size(), 1);
132 EXPECT_EQ(
133 *static_cast<tensorflow::tstring*>(outputs->output_tensors.at(0).data()),
134 "abcdef");
135 }
136
137 } // anonymous namespace
138 } // namespace engine
139 } // namespace client
140 } // namespace fcp
141