1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 /**
18 * Tests for the ExternalDataset op.
19 *
20 * In make_external_dataset_test_graph.py, we generate a GraphDef that connects
21 * an ExternalDataset (producing serialized tf.Example protos), through
22 * tf.parse_example (parsing each to an int64_t scalar), to a Reduce (sum).
23 *
24 * Here, we load that graph and try to run it, with some ExternalDataset
25 * provided as we call Session::Run. We just need to Run once since Reduce
26 * should consume the entire dataset iterator.
27 */
28
29 #include <fcntl.h>
30 #include <stdint.h>
31
32 #include <string>
33 #include <utility>
34 #include <vector>
35
36 #include "gmock/gmock.h"
37 #include "gtest/gtest.h"
38 #include "fcp/tensorflow/external_dataset.h"
39 #include "fcp/tensorflow/test_selector.pb.h"
40 #include "google/protobuf/io/zero_copy_stream.h"
41 #include "google/protobuf/io/zero_copy_stream_impl.h"
42 #include "tensorflow/core/example/example.pb.h"
43 #include "tensorflow/core/example/feature_util.h"
44 #include "tensorflow/core/framework/graph.pb.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/framework/tensor_testutil.h"
47 #include "tensorflow/core/protobuf/error_codes.pb.h"
48 #include "tensorflow/core/public/session.h"
49
50 namespace fcp {
51
52 using ::testing::Eq;
53
54 //
55 // Constants related to the GraphDef we test with
56 // See make_external_dataset_test_graph.py
57 //
58
59 char const* const kExampleGraphPath =
60 "fcp/tensorflow/external_dataset_test.pbtxt";
61 char const* const kFeatureName = "val";
62 char const* const kTokenPlaceholderName = "token";
63 char const* const kSelectorPlaceholderName = "selector";
64 char const* const kOutputName = "total:0";
65
66 //
67 // TensorFlow boilerplate
68 //
69
LoadExampleGraph()70 tensorflow::GraphDef LoadExampleGraph() {
71 int fd = open(kExampleGraphPath, O_RDONLY);
72 FCP_CHECK(fd != -1) << "Failed to open the example graph, using path "
73 << kExampleGraphPath;
74
75 google::protobuf::io::FileInputStream fs(fd);
76 fs.SetCloseOnDelete(true);
77
78 tensorflow::GraphDef graph;
79 bool parsed = google::protobuf::TextFormat::Parse(&fs, &graph);
80 FCP_CHECK(parsed) << "Invalid text-format GraphDef";
81
82 return graph;
83 }
84
PrepareExampleGraphSession()85 std::unique_ptr<tensorflow::Session> PrepareExampleGraphSession() {
86 tensorflow::GraphDef graph = LoadExampleGraph();
87
88 std::unique_ptr<tensorflow::Session> session;
89 {
90 tensorflow::SessionOptions options;
91 tensorflow::Session* raw_session = nullptr;
92 tensorflow::Status session_new_status =
93 tensorflow::NewSession(options, &raw_session);
94 TF_CHECK_OK(session_new_status);
95 session = std::unique_ptr<tensorflow::Session>(raw_session);
96 }
97
98 tensorflow::Status graph_build_status = session->Create(graph);
99 TF_CHECK_OK(graph_build_status);
100 return session;
101 }
102
MakeExample(int64_t value)103 tensorflow::Example MakeExample(int64_t value) {
104 tensorflow::Example example;
105 tensorflow::AppendFeatureValues({value}, kFeatureName, &example);
106 return example;
107 }
108
SerializeExample(int64_t value)109 std::string SerializeExample(int64_t value) {
110 std::string serialized;
111 FCP_CHECK(MakeExample(value).SerializeToString(&serialized));
112 return serialized;
113 }
114
RunSession(tensorflow::Session * session,RandomToken dataset_token,tensorflow::Tensor selector,tensorflow::Tensor * output)115 tensorflow::Status RunSession(tensorflow::Session* session,
116 RandomToken dataset_token,
117 tensorflow::Tensor selector,
118 tensorflow::Tensor* output) {
119 auto token_tensor =
120 tensorflow::test::AsScalar<tensorflow::tstring>(dataset_token.ToString());
121
122 std::vector<tensorflow::Tensor> outputs;
123 tensorflow::Status run_status =
124 session->Run({{kTokenPlaceholderName, token_tensor},
125 {kSelectorPlaceholderName, selector}},
126 {kOutputName}, {}, &outputs);
127
128 if (run_status.ok() && output) {
129 FCP_CHECK(outputs.size() == 1);
130 *output = outputs[0];
131 }
132
133 return run_status;
134 }
135
RunSession(tensorflow::Session * session,RandomToken dataset_token,TestSelector const & selector,tensorflow::Tensor * output)136 tensorflow::Status RunSession(tensorflow::Session* session,
137 RandomToken dataset_token,
138 TestSelector const& selector,
139 tensorflow::Tensor* output) {
140 std::string selector_str;
141 FCP_CHECK(selector.SerializeToString(&selector_str));
142 auto selector_tensor =
143 tensorflow::test::AsScalar<tensorflow::tstring>(selector_str);
144 return RunSession(session, dataset_token, selector_tensor, output);
145 }
146
RunSessionAndGetOutput(tensorflow::Session * session,RandomToken dataset_token,TestSelector const & selector)147 tensorflow::Tensor RunSessionAndGetOutput(tensorflow::Session* session,
148 RandomToken dataset_token,
149 TestSelector const& selector) {
150 tensorflow::Tensor output;
151 tensorflow::Status run_status =
152 RunSession(session, dataset_token, selector, &output);
153 TF_CHECK_OK(run_status);
154 return output;
155 }
156
157 //
158 // ExternalDataset host object implementations for testing
159 //
160
161 class TestDatasetIterator : public ExternalDatasetIterator {
162 public:
TestDatasetIterator(std::shared_ptr<std::vector<int64_t> const> examples,int64_t lower_inclusive,int64_t upper_inclusive)163 explicit TestDatasetIterator(
164 std::shared_ptr<std::vector<int64_t> const> examples,
165 int64_t lower_inclusive, int64_t upper_inclusive)
166 : examples_(std::move(examples)),
167 lower_inclusive_(lower_inclusive),
168 upper_inclusive_(upper_inclusive) {}
169
GetNext()170 absl::StatusOr<std::string> GetNext() final {
171 while (index_ < examples_->size()) {
172 int64_t ex = examples_->at(index_);
173 index_++;
174
175 if (ex >= lower_inclusive_ && ex < upper_inclusive_) {
176 return SerializeExample(ex);
177 }
178 }
179
180 return absl::OutOfRangeError("");
181 }
182
183 private:
184 std::shared_ptr<std::vector<int64_t> const> examples_;
185 int index_ = 0;
186 int64_t lower_inclusive_;
187 int64_t upper_inclusive_;
188 };
189
190 class TestDatasetProvider
191 : public ExternalDatasetProvider::UsingProtoSelector<TestSelector> {
192 public:
TestDatasetProvider(std::vector<int64_t> examples)193 explicit TestDatasetProvider(std::vector<int64_t> examples) {
194 auto ex = std::make_shared<std::vector<int64_t>>(std::move(examples));
195 examples_ = std::move(ex);
196 }
197
MakeDataset(TestSelector selector)198 absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
199 TestSelector selector) final {
200 int64_t lower = selector.has_lower_inclusive()
201 ? selector.lower_inclusive().value()
202 : std::numeric_limits<int64_t>::min();
203 int64_t upper = selector.has_upper_inclusive()
204 ? selector.upper_inclusive().value()
205 : std::numeric_limits<int64_t>::max();
206 auto examples = examples_;
207 return ExternalDataset::FromFunction([examples, lower, upper]() {
208 return std::make_unique<TestDatasetIterator>(examples, lower, upper);
209 });
210 }
211
212 private:
213 std::shared_ptr<std::vector<int64_t> const> examples_;
214 };
215
216 class FailingIterator : public ExternalDatasetIterator {
217 public:
GetNext()218 absl::StatusOr<std::string> GetNext() final {
219 return absl::NotFoundError("");
220 }
221 };
222
223 class FailingIteratorDatasetProvider
224 : public ExternalDatasetProvider::UsingProtoSelector<TestSelector> {
225 public:
MakeDataset(TestSelector selector)226 absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
227 TestSelector selector) final {
228 return ExternalDataset::FromFunction(
229 []() { return std::make_unique<FailingIterator>(); });
230 }
231 };
232
233 //
234 // Actual tests
235 //
236
TEST(ExternalDatasetOpTest,RunExampleGraph)237 TEST(ExternalDatasetOpTest, RunExampleGraph) {
238 std::vector<int64_t> examples{123, 456, 789};
239
240 // Default selector (no filtering)
241 TestSelector selector;
242
243 tensorflow::Tensor expected = tensorflow::test::AsTensor<tensorflow::int64>(
244 {123 + 456 + 789}, tensorflow::TensorShape({1}));
245
246 auto stub = std::make_shared<TestDatasetProvider>(std::move(examples));
247 auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
248
249 auto session = PrepareExampleGraphSession();
250 tensorflow::Tensor output =
251 RunSessionAndGetOutput(session.get(), stub_reg.token(), selector);
252
253 tensorflow::test::ExpectTensorEqual<tensorflow::int64>(output, expected);
254 }
255
TEST(ExternalDatasetOpTest,RunExampleGraph_SelectorFilter)256 TEST(ExternalDatasetOpTest, RunExampleGraph_SelectorFilter) {
257 std::vector<int64_t> examples{123, 456, 789, 1024};
258
259 TestSelector selector;
260 selector.mutable_lower_inclusive()->set_value(124);
261 selector.mutable_upper_inclusive()->set_value(1023);
262
263 // Expecting some of the examples to be skipped, due to the filter.
264 tensorflow::Tensor expected = tensorflow::test::AsTensor<tensorflow::int64>(
265 {456 + 789}, tensorflow::TensorShape({1}));
266
267 auto stub = std::make_shared<TestDatasetProvider>(std::move(examples));
268 auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
269
270 auto session = PrepareExampleGraphSession();
271 tensorflow::Tensor output =
272 RunSessionAndGetOutput(session.get(), stub_reg.token(), selector);
273
274 tensorflow::test::ExpectTensorEqual<tensorflow::int64>(output, expected);
275 }
276
TEST(ExternalDatasetOpTest,TokenNotFound)277 TEST(ExternalDatasetOpTest, TokenNotFound) {
278 TestSelector selector;
279 auto session = PrepareExampleGraphSession();
280 tensorflow::Status status =
281 RunSession(session.get(), RandomToken::Generate(), selector, nullptr);
282 // Remove the cast after TF 2.12 is released and used in FCP.
283 EXPECT_THAT(
284 status.code(),
285 Eq(static_cast<tsl::errors::Code>(absl::StatusCode::kInvalidArgument)));
286 }
287
TEST(ExternalDatasetOpTest,FailingIterator)288 TEST(ExternalDatasetOpTest, FailingIterator) {
289 auto stub = std::make_shared<FailingIteratorDatasetProvider>();
290 auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
291
292 TestSelector selector;
293
294 auto session = PrepareExampleGraphSession();
295 tensorflow::Status status =
296 RunSession(session.get(), stub_reg.token(), selector, nullptr);
297 EXPECT_THAT(status.code(), Eq(tensorflow::error::NOT_FOUND));
298 }
299
TEST(ExternalDatasetOpTest,RunExampleGraph_InvalidSelector)300 TEST(ExternalDatasetOpTest, RunExampleGraph_InvalidSelector) {
301 std::vector<int64_t> examples{123};
302
303 // This is interpreted as a varint. The MSB is set, so it asks for another
304 // byte (but there aren't any).
305 std::string bad_selector = "\xFF";
306 tensorflow::Tensor bad_selector_tensor =
307 tensorflow::test::AsScalar<tensorflow::tstring>(bad_selector);
308 auto stub = std::make_shared<TestDatasetProvider>(std::move(examples));
309 auto stub_reg = ExternalDatasetProviderRegistry::Register(stub);
310
311 auto session = PrepareExampleGraphSession();
312 tensorflow::Status status =
313 RunSession(session.get(), stub_reg.token(), bad_selector_tensor, nullptr);
314 // Remove the cast after TF 2.12 is released and used in FCP.
315 EXPECT_THAT(
316 status.code(),
317 Eq(static_cast<tsl::errors::Code>(absl::StatusCode::kInvalidArgument)));
318 }
319
320 } // namespace fcp
321