xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tflite_wrapper_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/engine/tflite_wrapper.h"
17 
18 #include <fstream>
19 #include <string>
20 #include <utility>
21 
22 #include "gtest/gtest.h"
23 #include "fcp/client/interruptible_runner.h"
24 #include "fcp/client/test_helpers.h"
25 #include "fcp/testing/testing.h"
26 
27 namespace fcp {
28 namespace client {
29 namespace engine {
30 namespace {
31 
32 const absl::string_view kAssetsPath = "fcp/client/engine/data/";
33 const absl::string_view kJoinModelFile = "join_model.flatbuffer";
34 
35 const int32_t kNumThreads = 4;
36 
37 class TfLiteWrapperTest : public testing::Test {
38  protected:
ReadFileAsString(const std::string & path)39   absl::StatusOr<std::string> ReadFileAsString(const std::string& path) {
40     std::ifstream input_istream(path);
41     if (!input_istream) {
42       return absl::InternalError("Failed to create input stream.");
43     }
44     std::stringstream output_stream;
45     output_stream << input_istream.rdbuf();
46     return output_stream.str();
47   }
48 
49   MockLogManager mock_log_manager_;
50   InterruptibleRunner::TimingConfig default_timing_config_ =
51       InterruptibleRunner::TimingConfig{
52           .polling_period = absl::Milliseconds(1000),
53           .graceful_shutdown_period = absl::Milliseconds(1000),
54           .extended_shutdown_period = absl::Milliseconds(2000),
55       };
56   std::vector<std::string> output_names_ = {"Identity"};
57   TfLiteInterpreterOptions options_ = {
58       .ensure_dynamic_tensors_are_released = true,
59       .large_tensor_threshold_for_dynamic_allocation = 1000};
60 };
61 
TEST_F(TfLiteWrapperTest,InvalidModel)62 TEST_F(TfLiteWrapperTest, InvalidModel) {
63   EXPECT_THAT(
64       TfLiteWrapper::Create(
65           "INVALID_FLATBUFFER", []() { return false; }, default_timing_config_,
66           &mock_log_manager_,
67           std::make_unique<absl::flat_hash_map<std::string, std::string>>(),
68           output_names_, options_, kNumThreads),
69       IsCode(INVALID_ARGUMENT));
70 }
71 
TEST_F(TfLiteWrapperTest,InputNotSet)72 TEST_F(TfLiteWrapperTest, InputNotSet) {
73   auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
74   ASSERT_OK(plan);
75   // The plan that we use here join two strings. It requires two string tensors
76   // as input.  We didn't pass the required tensor, therefore, we expect an
77   // internal error to be thrown.
78   EXPECT_THAT(
79       TfLiteWrapper::Create(
80           *plan, []() { return false; }, default_timing_config_,
81           &mock_log_manager_,
82           std::make_unique<absl::flat_hash_map<std::string, std::string>>(),
83           output_names_, options_, kNumThreads),
84       IsCode(INVALID_ARGUMENT));
85 }
86 
TEST_F(TfLiteWrapperTest,WrongNumberOfOutputs)87 TEST_F(TfLiteWrapperTest, WrongNumberOfOutputs) {
88   auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
89   ASSERT_OK(plan);
90   // The plan that we use here join two strings. It requires two string tensors
91   // as input.  We didn't pass the required tensor, therefore, we expect an
92   // internal error to be thrown.
93   EXPECT_THAT(
94       TfLiteWrapper::Create(
95           *plan, []() { return false; }, default_timing_config_,
96           &mock_log_manager_,
97           std::make_unique<absl::flat_hash_map<std::string, std::string>>(),
98           {"Identity", "EXTRA"}, options_, kNumThreads),
99       IsCode(INVALID_ARGUMENT));
100 }
101 
TEST_F(TfLiteWrapperTest,Aborted)102 TEST_F(TfLiteWrapperTest, Aborted) {
103   auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
104   ASSERT_OK(plan);
105   auto inputs =
106       std::make_unique<absl::flat_hash_map<std::string, std::string>>();
107   (*inputs)["x"] = "abc";
108   (*inputs)["y"] = "def";
109   // The should_abort function is set to always return true, therefore we expect
110   // to see a CANCELLED status when we run the plan.
111   auto wrapper = TfLiteWrapper::Create(
112       *plan, []() { return true; }, default_timing_config_, &mock_log_manager_,
113       std::move(inputs), output_names_, options_, kNumThreads);
114   ASSERT_OK(wrapper);
115   EXPECT_THAT((*wrapper)->Run(), IsCode(CANCELLED));
116 }
117 
TEST_F(TfLiteWrapperTest,Success)118 TEST_F(TfLiteWrapperTest, Success) {
119   auto plan = ReadFileAsString(absl::StrCat(kAssetsPath, kJoinModelFile));
120   ASSERT_OK(plan);
121   auto inputs =
122       std::make_unique<absl::flat_hash_map<std::string, std::string>>();
123   (*inputs)["x"] = "abc";
124   (*inputs)["y"] = "def";
125   auto wrapper = TfLiteWrapper::Create(
126       *plan, []() { return false; }, default_timing_config_, &mock_log_manager_,
127       std::move(inputs), output_names_, options_, kNumThreads);
128   EXPECT_THAT(wrapper, IsCode(OK));
129   auto outputs = (*wrapper)->Run();
130   ASSERT_OK(outputs);
131   EXPECT_EQ(outputs->output_tensor_names.size(), 1);
132   EXPECT_EQ(
133       *static_cast<tensorflow::tstring*>(outputs->output_tensors.at(0).data()),
134       "abcdef");
135 }
136 
137 }  // anonymous namespace
138 }  // namespace engine
139 }  // namespace client
140 }  // namespace fcp
141