xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/external_dataset_op_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /**
18  * Tests for the ExternalDataset op.
19  *
20  * In make_external_dataset_test_graph.py, we generate a GraphDef that connects
21  * an ExternalDataset (producing serialized tf.Example protos), through
22  * tf.parse_example (parsing each to an int64_t scalar), to a Reduce (sum).
23  *
24  * Here, we load that graph and try to run it, with some ExternalDataset
25  * provided as we call Session::Run. We just need to Run once since Reduce
26  * should consume the entire dataset iterator.
27  */
28 
29 #include <fcntl.h>
30 #include <stdint.h>
31 
32 #include <string>
33 #include <utility>
34 #include <vector>
35 
36 #include "gmock/gmock.h"
37 #include "gtest/gtest.h"
38 #include "fcp/tensorflow/external_dataset.h"
39 #include "fcp/tensorflow/test_selector.pb.h"
40 #include "google/protobuf/io/zero_copy_stream.h"
41 #include "google/protobuf/io/zero_copy_stream_impl.h"
42 #include "tensorflow/core/example/example.pb.h"
43 #include "tensorflow/core/example/feature_util.h"
44 #include "tensorflow/core/framework/graph.pb.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/framework/tensor_testutil.h"
47 #include "tensorflow/core/protobuf/error_codes.pb.h"
48 #include "tensorflow/core/public/session.h"
49 
50 namespace fcp {
51 
52 using ::testing::Eq;
53 
54 //
55 // Constants related to the GraphDef we test with
56 // See make_external_dataset_test_graph.py
57 //
58 
59 char const* const kExampleGraphPath =
60     "fcp/tensorflow/external_dataset_test.pbtxt";
61 char const* const kFeatureName = "val";
62 char const* const kTokenPlaceholderName = "token";
63 char const* const kSelectorPlaceholderName = "selector";
64 char const* const kOutputName = "total:0";
65 
66 //
67 // TensorFlow boilerplate
68 //
69 
LoadExampleGraph()70 tensorflow::GraphDef LoadExampleGraph() {
71   int fd = open(kExampleGraphPath, O_RDONLY);
72   FCP_CHECK(fd != -1) << "Failed to open the example graph, using path "
73                       << kExampleGraphPath;
74 
75   google::protobuf::io::FileInputStream fs(fd);
76   fs.SetCloseOnDelete(true);
77 
78   tensorflow::GraphDef graph;
79   bool parsed = google::protobuf::TextFormat::Parse(&fs, &graph);
80   FCP_CHECK(parsed) << "Invalid text-format GraphDef";
81 
82   return graph;
83 }
84 
PrepareExampleGraphSession()85 std::unique_ptr<tensorflow::Session> PrepareExampleGraphSession() {
86   tensorflow::GraphDef graph = LoadExampleGraph();
87 
88   std::unique_ptr<tensorflow::Session> session;
89   {
90     tensorflow::SessionOptions options;
91     tensorflow::Session* raw_session = nullptr;
92     tensorflow::Status session_new_status =
93         tensorflow::NewSession(options, &raw_session);
94     TF_CHECK_OK(session_new_status);
95     session = std::unique_ptr<tensorflow::Session>(raw_session);
96   }
97 
98   tensorflow::Status graph_build_status = session->Create(graph);
99   TF_CHECK_OK(graph_build_status);
100   return session;
101 }
102 
MakeExample(int64_t value)103 tensorflow::Example MakeExample(int64_t value) {
104   tensorflow::Example example;
105   tensorflow::AppendFeatureValues({value}, kFeatureName, &example);
106   return example;
107 }
108 
SerializeExample(int64_t value)109 std::string SerializeExample(int64_t value) {
110   std::string serialized;
111   FCP_CHECK(MakeExample(value).SerializeToString(&serialized));
112   return serialized;
113 }
114 
RunSession(tensorflow::Session * session,RandomToken dataset_token,tensorflow::Tensor selector,tensorflow::Tensor * output)115 tensorflow::Status RunSession(tensorflow::Session* session,
116                               RandomToken dataset_token,
117                               tensorflow::Tensor selector,
118                               tensorflow::Tensor* output) {
119   auto token_tensor =
120       tensorflow::test::AsScalar<tensorflow::tstring>(dataset_token.ToString());
121 
122   std::vector<tensorflow::Tensor> outputs;
123   tensorflow::Status run_status =
124       session->Run({{kTokenPlaceholderName, token_tensor},
125                     {kSelectorPlaceholderName, selector}},
126                    {kOutputName}, {}, &outputs);
127 
128   if (run_status.ok() && output) {
129     FCP_CHECK(outputs.size() == 1);
130     *output = outputs[0];
131   }
132 
133   return run_status;
134 }
135 
RunSession(tensorflow::Session * session,RandomToken dataset_token,TestSelector const & selector,tensorflow::Tensor * output)136 tensorflow::Status RunSession(tensorflow::Session* session,
137                               RandomToken dataset_token,
138                               TestSelector const& selector,
139                               tensorflow::Tensor* output) {
140   std::string selector_str;
141   FCP_CHECK(selector.SerializeToString(&selector_str));
142   auto selector_tensor =
143       tensorflow::test::AsScalar<tensorflow::tstring>(selector_str);
144   return RunSession(session, dataset_token, selector_tensor, output);
145 }
146 
RunSessionAndGetOutput(tensorflow::Session * session,RandomToken dataset_token,TestSelector const & selector)147 tensorflow::Tensor RunSessionAndGetOutput(tensorflow::Session* session,
148                                           RandomToken dataset_token,
149                                           TestSelector const& selector) {
150   tensorflow::Tensor output;
151   tensorflow::Status run_status =
152       RunSession(session, dataset_token, selector, &output);
153   TF_CHECK_OK(run_status);
154   return output;
155 }
156 
157 //
158 // ExternalDataset host object implementations for testing
159 //
160 
161 class TestDatasetIterator : public ExternalDatasetIterator {
162  public:
TestDatasetIterator(std::shared_ptr<std::vector<int64_t> const> examples,int64_t lower_inclusive,int64_t upper_inclusive)163   explicit TestDatasetIterator(
164       std::shared_ptr<std::vector<int64_t> const> examples,
165       int64_t lower_inclusive, int64_t upper_inclusive)
166       : examples_(std::move(examples)),
167         lower_inclusive_(lower_inclusive),
168         upper_inclusive_(upper_inclusive) {}
169 
GetNext()170   absl::StatusOr<std::string> GetNext() final {
171     while (index_ < examples_->size()) {
172       int64_t ex = examples_->at(index_);
173       index_++;
174 
175       if (ex >= lower_inclusive_ && ex < upper_inclusive_) {
176         return SerializeExample(ex);
177       }
178     }
179 
180     return absl::OutOfRangeError("");
181   }
182 
183  private:
184   std::shared_ptr<std::vector<int64_t> const> examples_;
185   int index_ = 0;
186   int64_t lower_inclusive_;
187   int64_t upper_inclusive_;
188 };
189 
190 class TestDatasetProvider
191     : public ExternalDatasetProvider::UsingProtoSelector<TestSelector> {
192  public:
TestDatasetProvider(std::vector<int64_t> examples)193   explicit TestDatasetProvider(std::vector<int64_t> examples) {
194     auto ex = std::make_shared<std::vector<int64_t>>(std::move(examples));
195     examples_ = std::move(ex);
196   }
197 
MakeDataset(TestSelector selector)198   absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
199       TestSelector selector) final {
200     int64_t lower = selector.has_lower_inclusive()
201                         ? selector.lower_inclusive().value()
202                         : std::numeric_limits<int64_t>::min();
203     int64_t upper = selector.has_upper_inclusive()
204                         ? selector.upper_inclusive().value()
205                         : std::numeric_limits<int64_t>::max();
206     auto examples = examples_;
207     return ExternalDataset::FromFunction([examples, lower, upper]() {
208       return std::make_unique<TestDatasetIterator>(examples, lower, upper);
209     });
210   }
211 
212  private:
213   std::shared_ptr<std::vector<int64_t> const> examples_;
214 };
215 
216 class FailingIterator : public ExternalDatasetIterator {
217  public:
GetNext()218   absl::StatusOr<std::string> GetNext() final {
219     return absl::NotFoundError("");
220   }
221 };
222 
223 class FailingIteratorDatasetProvider
224     : public ExternalDatasetProvider::UsingProtoSelector<TestSelector> {
225  public:
MakeDataset(TestSelector selector)226   absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
227       TestSelector selector) final {
228     return ExternalDataset::FromFunction(
229         []() { return std::make_unique<FailingIterator>(); });
230   }
231 };
232 
233 //
234 // Actual tests
235 //
236 
TEST(ExternalDatasetOpTest,RunExampleGraph)237 TEST(ExternalDatasetOpTest, RunExampleGraph) {
238   std::vector<int64_t> examples{123, 456, 789};
239 
240   // Default selector (no filtering)
241   TestSelector selector;
242 
243   tensorflow::Tensor expected = tensorflow::test::AsTensor<tensorflow::int64>(
244       {123 + 456 + 789}, tensorflow::TensorShape({1}));
245 
246   auto stub = std::make_shared<TestDatasetProvider>(std::move(examples));
247   auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
248 
249   auto session = PrepareExampleGraphSession();
250   tensorflow::Tensor output =
251       RunSessionAndGetOutput(session.get(), stub_reg.token(), selector);
252 
253   tensorflow::test::ExpectTensorEqual<tensorflow::int64>(output, expected);
254 }
255 
TEST(ExternalDatasetOpTest,RunExampleGraph_SelectorFilter)256 TEST(ExternalDatasetOpTest, RunExampleGraph_SelectorFilter) {
257   std::vector<int64_t> examples{123, 456, 789, 1024};
258 
259   TestSelector selector;
260   selector.mutable_lower_inclusive()->set_value(124);
261   selector.mutable_upper_inclusive()->set_value(1023);
262 
263   // Expecting some of the examples to be skipped, due to the filter.
264   tensorflow::Tensor expected = tensorflow::test::AsTensor<tensorflow::int64>(
265       {456 + 789}, tensorflow::TensorShape({1}));
266 
267   auto stub = std::make_shared<TestDatasetProvider>(std::move(examples));
268   auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
269 
270   auto session = PrepareExampleGraphSession();
271   tensorflow::Tensor output =
272       RunSessionAndGetOutput(session.get(), stub_reg.token(), selector);
273 
274   tensorflow::test::ExpectTensorEqual<tensorflow::int64>(output, expected);
275 }
276 
TEST(ExternalDatasetOpTest,TokenNotFound)277 TEST(ExternalDatasetOpTest, TokenNotFound) {
278   TestSelector selector;
279   auto session = PrepareExampleGraphSession();
280   tensorflow::Status status =
281       RunSession(session.get(), RandomToken::Generate(), selector, nullptr);
282   // Remove the cast after TF 2.12 is released and used in FCP.
283   EXPECT_THAT(
284       status.code(),
285       Eq(static_cast<tsl::errors::Code>(absl::StatusCode::kInvalidArgument)));
286 }
287 
TEST(ExternalDatasetOpTest,FailingIterator)288 TEST(ExternalDatasetOpTest, FailingIterator) {
289   auto stub = std::make_shared<FailingIteratorDatasetProvider>();
290   auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
291 
292   TestSelector selector;
293 
294   auto session = PrepareExampleGraphSession();
295   tensorflow::Status status =
296       RunSession(session.get(), stub_reg.token(), selector, nullptr);
297   EXPECT_THAT(status.code(), Eq(tensorflow::error::NOT_FOUND));
298 }
299 
TEST(ExternalDatasetOpTest,RunExampleGraph_InvalidSelector)300 TEST(ExternalDatasetOpTest, RunExampleGraph_InvalidSelector) {
301   std::vector<int64_t> examples{123};
302 
303   // This is interpreted as a varint. The MSB is set, so it asks for another
304   // byte (but there aren't any).
305   std::string bad_selector = "\xFF";
306   tensorflow::Tensor bad_selector_tensor =
307       tensorflow::test::AsScalar<tensorflow::tstring>(bad_selector);
308   auto stub = std::make_shared<TestDatasetProvider>(std::move(examples));
309   auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
310 
311   auto session = PrepareExampleGraphSession();
312   tensorflow::Status status =
313       RunSession(session.get(), stub_reg.token(), bad_selector_tensor, nullptr);
314   // Remove the cast after TF 2.12 is released and used in FCP.
315   EXPECT_THAT(
316       status.code(),
317       Eq(static_cast<tsl::errors::Code>(absl::StatusCode::kInvalidArgument)));
318 }
319 
320 }  // namespace fcp
321