xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/dataset_utils_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/dataset_utils.h"
17 
18 #include <functional>
19 #include <string>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/data/dataset_test_base.h"
23 #include "tensorflow/core/data/serialization_utils.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/node_def_builder.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/str_util.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/protobuf/error_codes.pb.h"
34 #include "tensorflow/core/util/determinism_test_util.h"
35 #include "tensorflow/core/util/work_sharder.h"
36 
37 namespace tensorflow {
38 namespace data {
39 namespace {
40 
TEST(DatasetUtilsTest,MatchesAnyVersion)41 TEST(DatasetUtilsTest, MatchesAnyVersion) {
42   EXPECT_TRUE(MatchesAnyVersion("BatchDataset", "BatchDataset"));
43   EXPECT_TRUE(MatchesAnyVersion("BatchDataset", "BatchDatasetV2"));
44   EXPECT_TRUE(MatchesAnyVersion("BatchDataset", "BatchDatasetV3"));
45   EXPECT_FALSE(MatchesAnyVersion("BatchDataset", "BatchDatasetXV3"));
46   EXPECT_FALSE(MatchesAnyVersion("BatchDataset", "BatchV2Dataset"));
47   EXPECT_FALSE(MatchesAnyVersion("BatchDataset", "PaddedBatchDataset"));
48 }
49 
TEST(DatasetUtilsTest,AddToFunctionLibrary)50 TEST(DatasetUtilsTest, AddToFunctionLibrary) {
51   auto make_fn_a = [](const string& fn_name) {
52     return FunctionDefHelper::Create(
53         /*function_name=*/fn_name,
54         /*in_def=*/{"arg: int64"},
55         /*out_def=*/{"ret: int64"},
56         /*attr_def=*/{},
57         /*node_def=*/{{{"node"}, "Identity", {"arg"}, {{"T", DT_INT64}}}},
58         /*ret_def=*/{{"ret", "node:output:0"}});
59   };
60 
61   auto make_fn_b = [](const string& fn_name) {
62     return FunctionDefHelper::Create(
63         /*function_name=*/fn_name,
64         /*in_def=*/{"arg: int64"},
65         /*out_def=*/{"ret: int64"},
66         /*attr_def=*/{},
67         /*node_def=*/
68         {{{"node"}, "Identity", {"arg"}, {{"T", DT_INT64}}},
69          {{"node2"}, "Identity", {"node:output:0"}, {{"T", DT_INT64}}}},
70         /*ret_def=*/{{"ret", "node2:output:0"}});
71   };
72 
73   FunctionDefLibrary fdef_base;
74   *fdef_base.add_function() = make_fn_a("0");
75   *fdef_base.add_function() = make_fn_a("1");
76   *fdef_base.add_function() = make_fn_a("2");
77 
78   FunctionDefLibrary fdef_to_add;
79   *fdef_to_add.add_function() = make_fn_b("0");  // Override
80   *fdef_to_add.add_function() = make_fn_a("1");  // Do nothing
81   *fdef_to_add.add_function() = make_fn_b("3");  // Add new function
82 
83   FunctionLibraryDefinition flib_0(OpRegistry::Global(), fdef_base);
84   TF_ASSERT_OK(AddToFunctionLibrary(&flib_0, fdef_to_add));
85 
86   FunctionLibraryDefinition flib_1(OpRegistry::Global(), fdef_base);
87   FunctionLibraryDefinition flib_to_add(OpRegistry::Global(), fdef_to_add);
88   TF_ASSERT_OK(AddToFunctionLibrary(&flib_1, flib_to_add));
89 
90   for (const auto& flib : {flib_0, flib_1}) {
91     EXPECT_TRUE(FunctionDefsEqual(*flib.Find("0"), make_fn_b("0")));
92     EXPECT_TRUE(FunctionDefsEqual(*flib.Find("1"), make_fn_a("1")));
93     EXPECT_TRUE(FunctionDefsEqual(*flib.Find("2"), make_fn_a("2")));
94     EXPECT_TRUE(FunctionDefsEqual(*flib.Find("3"), make_fn_b("3")));
95   }
96 }
97 
TEST(DatasetUtilsTest,AddToFunctionLibraryWithConflictingSignatures)98 TEST(DatasetUtilsTest, AddToFunctionLibraryWithConflictingSignatures) {
99   FunctionDefLibrary fdef_base;
100   *fdef_base.add_function() = FunctionDefHelper::Create(
101       /*function_name=*/"0",
102       /*in_def=*/{"arg: int64"},
103       /*out_def=*/{"ret: int64"},
104       /*attr_def=*/{},
105       /*node_def=*/{},
106       /*ret_def=*/{{"ret", "arg"}});
107 
108   FunctionDefLibrary fdef_to_add;
109   *fdef_to_add.add_function() = FunctionDefHelper::Create(
110       /*function_name=*/"0",
111       /*in_def=*/{"arg: int64"},
112       /*out_def=*/{"ret: int64", "ret2: int64"},
113       /*attr_def=*/{},
114       /*node_def=*/{},
115       /*ret_def=*/{{"ret", "arg"}, {"ret2", "arg"}});
116 
117   FunctionLibraryDefinition flib_0(OpRegistry::Global(), fdef_base);
118   Status s = AddToFunctionLibrary(&flib_0, fdef_to_add);
119   EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
120   EXPECT_EQ(
121       "Cannot add function '0' because a different function with the same "
122       "signature already exists.",
123       s.error_message());
124 
125   FunctionLibraryDefinition flib_1(OpRegistry::Global(), fdef_base);
126   FunctionLibraryDefinition flib_to_add(OpRegistry::Global(), fdef_to_add);
127   s = AddToFunctionLibrary(&flib_1, flib_to_add);
128   EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
129   EXPECT_EQ(
130       "Cannot add function '0' because a different function with the same "
131       "signature already exists.",
132       s.error_message());
133 }
134 
TEST(DatasetUtilsTest,StripDevicePlacement)135 TEST(DatasetUtilsTest, StripDevicePlacement) {
136   FunctionDefLibrary flib;
137   *flib.add_function() = FunctionDefHelper::Create(
138       /*function_name=*/"0",
139       /*in_def=*/{"arg: int64"},
140       /*out_def=*/{"ret: int64"},
141       /*attr_def=*/{},
142       /*node_def=*/
143       {{{"node"},
144         "Identity",
145         {"arg"},
146         {{"T", DT_INT64}},
147         /*dep=*/{},
148         /*device=*/"device:CPU:0"}},
149       /*ret_def=*/{{"ret", "arg"}});
150   EXPECT_EQ(flib.function(0).node_def(0).device(), "device:CPU:0");
151   StripDevicePlacement(&flib);
152   EXPECT_EQ(flib.function(0).node_def(0).device(), "");
153 }
154 
TEST(DatasetUtilsTest,RunnerWithMaxParallelism)155 TEST(DatasetUtilsTest, RunnerWithMaxParallelism) {
156   auto runner =
157       RunnerWithMaxParallelism([](const std::function<void()> fn) { fn(); }, 2);
158   auto fn = []() { ASSERT_EQ(GetPerThreadMaxParallelism(), 2); };
159   runner(fn);
160 }
161 
TEST(DatasetUtilsTest,ParseDeterminismPolicy)162 TEST(DatasetUtilsTest, ParseDeterminismPolicy) {
163   DeterminismPolicy determinism;
164   TF_ASSERT_OK(DeterminismPolicy::FromString("true", &determinism));
165   EXPECT_TRUE(determinism.IsDeterministic());
166   TF_ASSERT_OK(DeterminismPolicy::FromString("false", &determinism));
167   EXPECT_TRUE(determinism.IsNondeterministic());
168   TF_ASSERT_OK(DeterminismPolicy::FromString("default", &determinism));
169   EXPECT_TRUE(determinism.IsDefault());
170 }
171 
TEST(DatasetUtilsTest,DeterminismString)172 TEST(DatasetUtilsTest, DeterminismString) {
173   for (auto s : {"true", "false", "default"}) {
174     DeterminismPolicy determinism;
175     TF_ASSERT_OK(DeterminismPolicy::FromString(s, &determinism));
176     EXPECT_TRUE(s == determinism.String());
177   }
178 }
179 
TEST(DatasetUtilsTest,BoolConstructor)180 TEST(DatasetUtilsTest, BoolConstructor) {
181   EXPECT_TRUE(DeterminismPolicy(true).IsDeterministic());
182   EXPECT_FALSE(DeterminismPolicy(true).IsNondeterministic());
183   EXPECT_FALSE(DeterminismPolicy(true).IsDefault());
184 
185   EXPECT_TRUE(DeterminismPolicy(false).IsNondeterministic());
186   EXPECT_FALSE(DeterminismPolicy(false).IsDeterministic());
187   EXPECT_FALSE(DeterminismPolicy(false).IsDefault());
188 }
189 
190 REGISTER_DATASET_EXPERIMENT("test_only_experiment_0", 0);
191 REGISTER_DATASET_EXPERIMENT("test_only_experiment_1", 1);
192 REGISTER_DATASET_EXPERIMENT("test_only_experiment_5", 5);
193 REGISTER_DATASET_EXPERIMENT("test_only_experiment_10", 10);
194 REGISTER_DATASET_EXPERIMENT("test_only_experiment_50", 50);
195 REGISTER_DATASET_EXPERIMENT("test_only_experiment_99", 99);
196 REGISTER_DATASET_EXPERIMENT("test_only_experiment_100", 100);
197 
198 struct GetExperimentsHashTestCase {
199   uint64 hash;
200   std::vector<string> expected_in;
201   std::vector<string> expected_out;
202 };
203 
204 class GetExperimentsHashTest
205     : public ::testing::TestWithParam<GetExperimentsHashTestCase> {};
206 
TEST_P(GetExperimentsHashTest,DatasetUtils)207 TEST_P(GetExperimentsHashTest, DatasetUtils) {
208   const GetExperimentsHashTestCase test_case = GetParam();
209   uint64 hash_result = test_case.hash;
210   auto job_name = "job";
211   auto hash_func = [hash_result](const string& str) { return hash_result; };
212   auto experiments = GetExperiments(job_name, hash_func);
213 
214   absl::flat_hash_set<string> experiment_set(experiments.begin(),
215                                              experiments.end());
216   for (const auto& experiment : test_case.expected_in) {
217     EXPECT_TRUE(experiment_set.find(experiment) != experiment_set.end())
218         << "experiment=" << experiment << " hash=" << hash_result;
219   }
220   for (const auto& experiment : test_case.expected_out) {
221     EXPECT_TRUE(experiment_set.find(experiment) == experiment_set.end())
222         << "experiment=" << experiment << " hash=" << hash_result;
223   }
224 }
225 
226 INSTANTIATE_TEST_SUITE_P(
227     Test, GetExperimentsHashTest,
228     ::testing::Values<GetExperimentsHashTestCase>(
229         GetExperimentsHashTestCase{
230             /*hash=*/0,
231             /*expected_in=*/
232             {"test_only_experiment_1", "test_only_experiment_5",
233              "test_only_experiment_10", "test_only_experiment_50",
234              "test_only_experiment_99", "test_only_experiment_100"},
235             /*expected_out=*/{"test_only_experiment_0"},
236         },
237         GetExperimentsHashTestCase{
238             /*hash=*/5,
239             /*expected_in=*/
240             {"test_only_experiment_10", "test_only_experiment_50",
241              "test_only_experiment_99", "test_only_experiment_100"},
242             /*expected_out=*/
243             {
244                 "test_only_experiment_0",
245                 "test_only_experiment_1",
246                 "test_only_experiment_5",
247             },
248         },
249         GetExperimentsHashTestCase{
250             /*hash=*/95,
251             /*expected_in=*/
252             {"test_only_experiment_99", "test_only_experiment_100"},
253             /*expected_out=*/
254             {"test_only_experiment_0", "test_only_experiment_1",
255              "test_only_experiment_5", "test_only_experiment_10",
256              "test_only_experiment_50"},
257         },
258         GetExperimentsHashTestCase{
259             /*hash=*/99,
260             /*expected_in=*/{"test_only_experiment_100"},
261             /*expected_out=*/
262             {"test_only_experiment_0", "test_only_experiment_1",
263              "test_only_experiment_5", "test_only_experiment_10",
264              "test_only_experiment_50", "test_only_experiment_99"},
265         },
266         GetExperimentsHashTestCase{
267             /*hash=*/100,
268             /*expected_in=*/
269             {"test_only_experiment_1", "test_only_experiment_5",
270              "test_only_experiment_10", "test_only_experiment_50",
271              "test_only_experiment_99", "test_only_experiment_100"},
272             /*expected_out=*/{"test_only_experiment_0"},
273         },
274         GetExperimentsHashTestCase{
275             /*hash=*/105,
276             /*expected_in=*/
277             {"test_only_experiment_10", "test_only_experiment_50",
278              "test_only_experiment_99", "test_only_experiment_100"},
279             /*expected_out=*/
280             {
281                 "test_only_experiment_0",
282                 "test_only_experiment_1",
283                 "test_only_experiment_5",
284             },
285         },
286         GetExperimentsHashTestCase{
287             /*hash=*/195,
288             /*expected_in=*/
289             {"test_only_experiment_99", "test_only_experiment_100"},
290             /*expected_out=*/
291             {"test_only_experiment_0", "test_only_experiment_1",
292              "test_only_experiment_5", "test_only_experiment_10",
293              "test_only_experiment_50"},
294         }));
295 
296 struct GetExperimentsOptTestCase {
297   std::vector<string> opt_ins;
298   std::vector<string> opt_outs;
299   std::vector<string> expected_in;
300   std::vector<string> expected_out;
301 };
302 
303 class GetExperimentsOptTest
304     : public ::testing::TestWithParam<GetExperimentsOptTestCase> {};
305 
TEST_P(GetExperimentsOptTest,DatasetUtils)306 TEST_P(GetExperimentsOptTest, DatasetUtils) {
307   const GetExperimentsOptTestCase test_case = GetParam();
308   auto opt_ins = test_case.opt_ins;
309   auto opt_outs = test_case.opt_outs;
310   if (!opt_ins.empty()) {
311     setenv("TF_DATA_EXPERIMENT_OPT_IN", str_util::Join(opt_ins, ",").c_str(),
312            1);
313   }
314   if (!opt_outs.empty()) {
315     setenv("TF_DATA_EXPERIMENT_OPT_OUT", str_util::Join(opt_outs, ",").c_str(),
316            1);
317   }
318   auto job_name = "job";
319   auto hash_func = [](const string& str) { return 0; };
320   auto experiments = GetExperiments(job_name, hash_func);
321 
322   absl::flat_hash_set<string> experiment_set(experiments.begin(),
323                                              experiments.end());
324   for (const auto& experiment : test_case.expected_in) {
325     EXPECT_TRUE(experiment_set.find(experiment) != experiment_set.end())
326         << "experiment=" << experiment << " opt_ins={"
327         << str_util::Join(opt_ins, ",") << "} opt_outs={"
328         << str_util::Join(opt_outs, ",") << "}";
329   }
330   for (const auto& experiment : test_case.expected_out) {
331     EXPECT_TRUE(experiment_set.find(experiment) == experiment_set.end())
332         << "experiment=" << experiment << " opt_ins={"
333         << str_util::Join(opt_ins, ",") << "} opt_outs={"
334         << str_util::Join(opt_outs, ",") << "}";
335   }
336 
337   if (!opt_ins.empty()) {
338     unsetenv("TF_DATA_EXPERIMENT_OPT_IN");
339   }
340   if (!opt_outs.empty()) {
341     unsetenv("TF_DATA_EXPERIMENT_OPT_OUT");
342   }
343 }
344 
345 INSTANTIATE_TEST_SUITE_P(
346     Test, GetExperimentsOptTest,
347     ::testing::Values<GetExperimentsOptTestCase>(
348         GetExperimentsOptTestCase{
349             /*opt_ins=*/{"all"},
350             /*opt_outs=*/{"all"},
351             /*expected_in=*/{},
352             /*expected_out=*/
353             {"test_only_experiment_0", "test_only_experiment_1",
354              "test_only_experiment_5", "test_only_experiment_10",
355              "test_only_experiment_50", "test_only_experiment_99",
356              "test_only_experiment_100"}},
357         GetExperimentsOptTestCase{
358             /*opt_ins=*/{"all"},
359             /*opt_outs=*/{},
360             /*expected_in=*/
361             {"test_only_experiment_0", "test_only_experiment_1",
362              "test_only_experiment_5", "test_only_experiment_10",
363              "test_only_experiment_50", "test_only_experiment_99",
364              "test_only_experiment_100"},
365             /*expected_out=*/{}},
366         GetExperimentsOptTestCase{
367             /*opt_ins=*/{"all"},
368             /*opt_outs=*/{"test_only_experiment_1", "test_only_experiment_99"},
369             /*expected_in=*/
370             {"test_only_experiment_0", "test_only_experiment_5",
371              "test_only_experiment_10", "test_only_experiment_50",
372              "test_only_experiment_100"},
373             /*expected_out=*/
374             {"test_only_experiment_1", "test_only_experiment_99"}},
375         GetExperimentsOptTestCase{
376             /*opt_ins=*/{},
377             /*opt_outs=*/{"all"},
378             /*expected_in=*/{},
379             /*expected_out=*/
380             {"test_only_experiment_0", "test_only_experiment_1",
381              "test_only_experiment_5", "test_only_experiment_10",
382              "test_only_experiment_50", "test_only_experiment_99",
383              "test_only_experiment_100"}},
384         GetExperimentsOptTestCase{
385             /*opt_ins=*/{},
386             /*opt_outs=*/{},
387             /*expected_in=*/
388             {"test_only_experiment_1", "test_only_experiment_5",
389              "test_only_experiment_10", "test_only_experiment_50",
390              "test_only_experiment_99", "test_only_experiment_100"},
391             /*expected_out=*/{"test_only_experiment_0"}},
392         GetExperimentsOptTestCase{
393             /*opt_ins=*/{},
394             /*opt_outs=*/{"test_only_experiment_1", "test_only_experiment_99"},
395             /*expected_in=*/
396             {"test_only_experiment_5", "test_only_experiment_10",
397              "test_only_experiment_50", "test_only_experiment_100"},
398             /*expected_out=*/
399             {"test_only_experiment_0", "test_only_experiment_1",
400              "test_only_experiment_99"}},
401         GetExperimentsOptTestCase{
402             /*opt_ins=*/{"test_only_experiment_0", "test_only_experiment_100"},
403             /*opt_outs=*/{"all"},
404             /*expected_in=*/{},
405             /*expected_out=*/
406             {"test_only_experiment_0", "test_only_experiment_1",
407              "test_only_experiment_5", "test_only_experiment_10",
408              "test_only_experiment_50", "test_only_experiment_99",
409              "test_only_experiment_100"}},
410         GetExperimentsOptTestCase{
411             /*opt_ins=*/{"test_only_experiment_0", "test_only_experiment_100"},
412             /*opt_outs=*/{"all_except_opt_in"},
413             /*expected_in=*/
414             {"test_only_experiment_0", "test_only_experiment_100"},
415             /*expected_out=*/
416             {"test_only_experiment_1", "test_only_experiment_5",
417              "test_only_experiment_10", "test_only_experiment_50",
418              "test_only_experiment_99"}},
419         GetExperimentsOptTestCase{
420             /*opt_ins=*/{"test_only_experiment_0", "test_only_experiment_100"},
421             /*opt_outs=*/{},
422             /*expected_in=*/
423             {"test_only_experiment_0", "test_only_experiment_1",
424              "test_only_experiment_5", "test_only_experiment_10",
425              "test_only_experiment_50", "test_only_experiment_99",
426              "test_only_experiment_100"},
427             /*expected_out=*/{}},
428         GetExperimentsOptTestCase{
429             /*opt_ins=*/{"test_only_experiment_0", "test_only_experiment_100"},
430             /*opt_outs=*/{"test_only_experiment_1", "test_only_experiment_99"},
431             /*expected_in=*/
432             {"test_only_experiment_0", "test_only_experiment_5",
433              "test_only_experiment_10", "test_only_experiment_50",
434              "test_only_experiment_100"},
435             /*expected_out=*/
436             {"test_only_experiment_1", "test_only_experiment_99"}}));
437 
438 struct GetExperimentsJobNameTestCase {
439   string job_name;
440   std::vector<string> expected_in;
441   std::vector<string> expected_out;
442 };
443 
444 class GetExperimentsJobNameTest
445     : public ::testing::TestWithParam<GetExperimentsJobNameTestCase> {};
446 
TEST_P(GetExperimentsJobNameTest,DatasetUtils)447 TEST_P(GetExperimentsJobNameTest, DatasetUtils) {
448   const GetExperimentsJobNameTestCase test_case = GetParam();
449   auto job_name = test_case.job_name;
450   auto hash_func = [](const string& str) { return 0; };
451   auto experiments = GetExperiments(job_name, hash_func);
452 
453   absl::flat_hash_set<string> experiment_set(experiments.begin(),
454                                              experiments.end());
455   for (const auto& experiment : test_case.expected_in) {
456     EXPECT_TRUE(experiment_set.find(experiment) != experiment_set.end())
457         << "experiment=" << experiment << " job_name=" << job_name;
458   }
459   for (const auto& experiment : test_case.expected_out) {
460     EXPECT_TRUE(experiment_set.find(experiment) == experiment_set.end())
461         << "experiment=" << experiment << " job_name=" << job_name;
462   }
463 }
464 
465 INSTANTIATE_TEST_SUITE_P(
466     Test, GetExperimentsJobNameTest,
467     ::testing::Values(GetExperimentsJobNameTestCase{
468                           /*job_name=*/"",
469                           /*expected_in=*/{},
470                           /*expected_out=*/
471                           {"test_only_experiment_0", "test_only_experiment_1",
472                            "test_only_experiment_5", "test_only_experiment_10",
473                            "test_only_experiment_50", "test_only_experiment_99",
474                            "test_only_experiment_100"}},
475                       GetExperimentsJobNameTestCase{
476                           /*job_name=*/"job_name",
477                           /*expected_in=*/
478                           {"test_only_experiment_1", "test_only_experiment_5",
479                            "test_only_experiment_10", "test_only_experiment_50",
480                            "test_only_experiment_99",
481                            "test_only_experiment_100"},
482                           /*expected_out=*/{"test_only_experiment_0"}}));
483 
484 struct GetOptimizationsTestCase {
485   Options options;
486   std::vector<string> expected_enabled;
487   std::vector<string> expected_disabled;
488   std::vector<string> expected_default;
489 };
490 
491 // Tests the default.
GetOptimizationTestCase1()492 GetOptimizationsTestCase GetOptimizationTestCase1() {
493   return {
494       /*options=*/Options(),
495       /*expected_enabled=*/{},
496       /*expected_disabled=*/{},
497       /*expected_default=*/
498       {"noop_elimination", "map_and_batch_fusion", "shuffle_and_repeat_fusion",
499        "map_parallelization", "parallel_batch"}};
500 }
501 
502 // Tests disabling application of default optimizations.
GetOptimizationTestCase2()503 GetOptimizationsTestCase GetOptimizationTestCase2() {
504   Options options;
505   options.mutable_optimization_options()->set_apply_default_optimizations(
506       false);
507   return {options, /*expected_enabled=*/{}, /*expected_disabled=*/{},
508           /*expected_default=*/{}};
509 }
510 
511 // Tests explicitly enabling / disabling some default and non-default
512 // optimizations.
GetOptimizationTestCase3()513 GetOptimizationsTestCase GetOptimizationTestCase3() {
514   Options options;
515   options.set_deterministic(false);
516   options.mutable_optimization_options()->set_map_and_batch_fusion(true);
517   options.mutable_optimization_options()->set_map_parallelization(false);
518   options.mutable_optimization_options()->set_parallel_batch(false);
519   return {options,
520           /*expected_enabled=*/{"make_sloppy", "map_and_batch_fusion"},
521           /*expected_disabled=*/{"parallel_batch", "map_parallelization"},
522           /*expected_default=*/
523           {"noop_elimination", "shuffle_and_repeat_fusion"}};
524 }
525 
526 // Test enabling all / most available optimizations.
GetOptimizationTestCase4()527 GetOptimizationsTestCase GetOptimizationTestCase4() {
528   Options options;
529   options.set_deterministic(false);
530   options.mutable_optimization_options()->set_filter_fusion(true);
531   options.mutable_optimization_options()->set_filter_parallelization(true);
532   options.mutable_optimization_options()->set_map_and_batch_fusion(true);
533   options.mutable_optimization_options()->set_map_and_filter_fusion(true);
534   options.mutable_optimization_options()->set_map_fusion(true);
535   options.mutable_optimization_options()->set_map_parallelization(true);
536   options.mutable_optimization_options()->set_noop_elimination(true);
537   options.mutable_optimization_options()->set_parallel_batch(true);
538   options.mutable_optimization_options()->set_shuffle_and_repeat_fusion(true);
539   options.mutable_optimization_options()->set_inject_prefetch(true);
540   options.set_slack(true);
541   return {options,
542           /*expected_enabled=*/
543           {"filter_fusion", "filter_parallelization", "make_sloppy",
544            "map_and_batch_fusion", "map_and_filter_fusion", "map_fusion",
545            "map_parallelization", "noop_elimination", "parallel_batch",
546            "shuffle_and_repeat_fusion", "slack", "inject_prefetch"},
547           /*expected_disabled=*/{},
548           /*expected_default=*/{}};
549 }
550 
551 class GetOptimizationsTest
552     : public ::testing::TestWithParam<GetOptimizationsTestCase> {};
553 
TEST_P(GetOptimizationsTest,DatasetUtils)554 TEST_P(GetOptimizationsTest, DatasetUtils) {
555   const GetOptimizationsTestCase test_case = GetParam();
556   auto options = test_case.options;
557 
558   absl::flat_hash_set<tstring> actual_enabled, actual_disabled, actual_default;
559   GetOptimizations(options, &actual_enabled, &actual_disabled, &actual_default);
560 
561   EXPECT_THAT(std::vector<string>(actual_enabled.begin(), actual_enabled.end()),
562               ::testing::UnorderedElementsAreArray(test_case.expected_enabled));
563   EXPECT_THAT(
564       std::vector<string>(actual_disabled.begin(), actual_disabled.end()),
565       ::testing::UnorderedElementsAreArray(test_case.expected_disabled));
566   EXPECT_THAT(std::vector<string>(actual_default.begin(), actual_default.end()),
567               ::testing::UnorderedElementsAreArray(test_case.expected_default));
568 }
569 
570 INSTANTIATE_TEST_SUITE_P(Test, GetOptimizationsTest,
571                          ::testing::Values(GetOptimizationTestCase1(),
572                                            GetOptimizationTestCase2(),
573                                            GetOptimizationTestCase3(),
574                                            GetOptimizationTestCase4()));
575 
TEST(DeterministicOpsTest,GetOptimizations)576 TEST(DeterministicOpsTest, GetOptimizations) {
577   test::DeterministicOpsScope det_scope;
578   Options options;
579   // options.deterministic should be ignored when deterministic ops are enabled.
580   options.set_deterministic(false);
581   absl::flat_hash_set<tstring> actual_enabled, actual_disabled, actual_default;
582   GetOptimizations(options, &actual_enabled, &actual_disabled, &actual_default);
583   EXPECT_THAT(std::vector<string>(actual_enabled.begin(), actual_enabled.end()),
584               ::testing::UnorderedElementsAreArray({"make_deterministic"}));
585   EXPECT_EQ(actual_disabled.size(), 0);
586 }
587 
588 REGISTER_DATASET_EXPERIMENT("test_only_experiment", 42);
589 
TEST(DatasetUtilsTest,DatasetExperimentRegistry)590 TEST(DatasetUtilsTest, DatasetExperimentRegistry) {
591   auto experiments = DatasetExperimentRegistry::Experiments();
592   EXPECT_TRUE(experiments.find("test_only_experiment") != experiments.end());
593   EXPECT_TRUE(experiments.find("non_existing_experiment") == experiments.end());
594 }
595 
596 }  // namespace
597 }  // namespace data
598 }  // namespace tensorflow
599