1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/test_util.h"
17
18 #include <functional>
19 #include <string>
20 #include <vector>
21
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/core/data/dataset_test_base.h"
25 #include "tensorflow/core/data/service/common.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/function.pb.h"
28 #include "tensorflow/core/framework/function_testlib.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/tensor_testutil.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/path.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/statusor.h"
40 #include "tensorflow/core/platform/tstring.h"
41 #include "tensorflow/core/platform/types.h"
42
43 namespace tensorflow {
44 namespace data {
45 namespace testing {
46 namespace {
47
48 using ::tensorflow::test::AsScalar;
49 using ::tensorflow::test::function::GDef;
50 using ::tensorflow::test::function::NDef;
51
52 constexpr int64_t kShardHint = -1;
53 constexpr const char kTestdataDir[] =
54 "tensorflow/core/data/service/testdata";
55 constexpr const char kInterleaveTextlineDatasetFile[] =
56 "interleave_textline_dataset.pbtxt";
57
GetMapNode(absl::string_view name,absl::string_view input_node_name,absl::string_view function_name)58 NodeDef GetMapNode(absl::string_view name, absl::string_view input_node_name,
59 absl::string_view function_name) {
60 return NDef(
61 name, /*op=*/"MapDataset", {std::string(input_node_name)},
62 {{"f", FunctionDefHelper::FunctionRef(std::string(function_name))},
63 {"Targuments", {}},
64 {"output_shapes", gtl::ArraySlice<TensorShape>{TensorShape()}},
65 {"output_types", gtl::ArraySlice<DataType>{DT_INT64}}});
66 }
67
XTimesX()68 FunctionDef XTimesX() {
69 return FunctionDefHelper::Create(
70 /*function_name=*/"XTimesX",
71 /*in_def=*/{"x: int64"},
72 /*out_def=*/{"y: int64"},
73 /*attr_def=*/{},
74 /*node_def=*/{{{"y"}, "Mul", {"x", "x"}, {{"T", DT_INT64}}}},
75 /*ret_def=*/{{"y", "y:z:0"}});
76 }
77
CreateTestFiles(const std::vector<tstring> & filenames,const std::vector<tstring> & contents)78 Status CreateTestFiles(const std::vector<tstring>& filenames,
79 const std::vector<tstring>& contents) {
80 if (filenames.size() != contents.size()) {
81 return errors::InvalidArgument(
82 "The number of files does not match with the contents.");
83 }
84 for (int i = 0; i < filenames.size(); ++i) {
85 TF_RETURN_IF_ERROR(WriteDataToFile(filenames[i], contents[i].data()));
86 }
87 return OkStatus();
88 }
89 } // namespace
90
RangeDataset(int64_t range)91 DatasetDef RangeDataset(int64_t range) {
92 DatasetDef dataset_def;
93 *dataset_def.mutable_graph() = GDef(
94 {NDef("start", "Const", /*inputs=*/{},
95 {{"value", AsScalar<int64_t>(0)}, {"dtype", DT_INT64}}),
96 NDef("stop", "Const", /*inputs=*/{},
97 {{"value", AsScalar<int64_t>(range)}, {"dtype", DT_INT64}}),
98 NDef("step", "Const", /*inputs=*/{},
99 {{"value", AsScalar<int64_t>(1)}, {"dtype", DT_INT64}}),
100 NDef("range", "RangeDataset", /*inputs=*/{"start", "stop", "step"},
101 {{"output_shapes", gtl::ArraySlice<TensorShape>{TensorShape()}},
102 {"output_types", gtl::ArraySlice<DataType>{DT_INT64}}}),
103 NDef("dataset", "_Retval", /*inputs=*/{"range"},
104 {{"T", DT_VARIANT}, {"index", 0}})},
105 {});
106 return dataset_def;
107 }
108
RangeSquareDataset(const int64_t range)109 DatasetDef RangeSquareDataset(const int64_t range) {
110 DatasetDef dataset_def;
111 *dataset_def.mutable_graph() = GDef(
112 {NDef("start", "Const", /*inputs=*/{},
113 {{"value", AsScalar<int64_t>(0)}, {"dtype", DT_INT64}}),
114 NDef("stop", "Const", /*inputs=*/{},
115 {{"value", AsScalar<int64_t>(range)}, {"dtype", DT_INT64}}),
116 NDef("step", "Const", /*inputs=*/{},
117 {{"value", AsScalar<int64_t>(1)}, {"dtype", DT_INT64}}),
118 NDef("range", "RangeDataset", /*inputs=*/{"start", "stop", "step"},
119 {{"output_shapes", gtl::ArraySlice<TensorShape>{TensorShape()}},
120 {"output_types", gtl::ArraySlice<DataType>{DT_INT64}}}),
121 GetMapNode("map", "range", "XTimesX"),
122 NDef("dataset", "_Retval", /*inputs=*/{"map"},
123 {{"T", DT_VARIANT}, {"index", 0}})},
124 {XTimesX()});
125 return dataset_def;
126 }
127
RangeDatasetWithShardHint(const int64_t range)128 DatasetDef RangeDatasetWithShardHint(const int64_t range) {
129 DatasetDef dataset_def;
130 *dataset_def.mutable_graph() = GDef(
131 {NDef("start", "Const", /*inputs=*/{},
132 {{"value", AsScalar<int64_t>(0)}, {"dtype", DT_INT64}}),
133 NDef("stop", "Const", /*inputs=*/{},
134 {{"value", AsScalar<int64_t>(range)}, {"dtype", DT_INT64}}),
135 NDef("step", "Const", /*inputs=*/{},
136 {{"value", AsScalar<int64_t>(1)}, {"dtype", DT_INT64}}),
137 NDef("range", "RangeDataset", /*inputs=*/{"start", "stop", "step"},
138 {{"output_shapes", gtl::ArraySlice<TensorShape>{TensorShape()}},
139 {"output_types", gtl::ArraySlice<DataType>{DT_INT64}}}),
140 NDef("num_shards", "Const", /*inputs=*/{},
141 {{"value", AsScalar<int64_t>(kShardHint)}, {"dtype", DT_INT64}}),
142 NDef("index", "Const", /*inputs=*/{},
143 {{"value", AsScalar<int64_t>(kShardHint)}, {"dtype", DT_INT64}}),
144 NDef("ShardDataset", "ShardDataset",
145 /*inputs=*/{"range", "num_shards", "index"},
146 {{"output_shapes", gtl::ArraySlice<TensorShape>{TensorShape()}},
147 {"output_types", gtl::ArraySlice<DataType>{DT_INT64}}}),
148 NDef("dataset", "_Retval", /*inputs=*/{"ShardDataset"},
149 {{"T", DT_VARIANT}, {"index", 0}})},
150 /*funcs=*/{});
151 return dataset_def;
152 }
153
InfiniteDataset()154 DatasetDef InfiniteDataset() {
155 DatasetDef dataset_def;
156 *dataset_def.mutable_graph() = GDef(
157 {NDef("start", "Const", /*inputs=*/{},
158 {{"value", AsScalar<int64_t>(0)}, {"dtype", DT_INT64}}),
159 NDef("stop", "Const", /*inputs=*/{},
160 {{"value", AsScalar<int64_t>(100000000)}, {"dtype", DT_INT64}}),
161 NDef("step", "Const", /*inputs=*/{},
162 {{"value", AsScalar<int64_t>(1)}, {"dtype", DT_INT64}}),
163 NDef("range", "RangeDataset", /*inputs=*/{"start", "stop", "step"},
164 {{"output_shapes", gtl::ArraySlice<TensorShape>{TensorShape()}},
165 {"output_types", gtl::ArraySlice<DataType>{DT_INT64}}}),
166 NDef("count", "Const", /*inputs=*/{},
167 {{"value", AsScalar<int64_t>(-1)}, {"dtype", DT_INT64}}),
168 NDef("repeat", "RepeatDataset", /*inputs=*/{"range", "count"},
169 {{"output_shapes", gtl::ArraySlice<TensorShape>{TensorShape()}},
170 {"output_types", gtl::ArraySlice<DataType>{DT_INT64}}}),
171 NDef("dataset", "_Retval", /*inputs=*/{"repeat"},
172 {{"T", DT_VARIANT}, {"index", 0}})},
173 {});
174 return dataset_def;
175 }
176
InterleaveTextlineDataset(const std::vector<tstring> & filenames,const std::vector<tstring> & contents)177 StatusOr<DatasetDef> InterleaveTextlineDataset(
178 const std::vector<tstring>& filenames,
179 const std::vector<tstring>& contents) {
180 TF_RETURN_IF_ERROR(CreateTestFiles(filenames, contents));
181 DatasetDef dataset;
182 std::string graph_file =
183 io::JoinPath(kTestdataDir, kInterleaveTextlineDatasetFile);
184 TF_RETURN_IF_ERROR(
185 ReadTextProto(Env::Default(), graph_file, dataset.mutable_graph()));
186
187 Tensor filenames_tensor = test::AsTensor<tstring>(
188 filenames, TensorShape({static_cast<int64_t>(filenames.size())}));
189 filenames_tensor.AsProtoTensorContent(
190 (*dataset.mutable_graph()->mutable_node(0)->mutable_attr())["value"]
191 .mutable_tensor());
192 return dataset;
193 }
194
WaitWhile(std::function<StatusOr<bool> ()> f)195 Status WaitWhile(std::function<StatusOr<bool>()> f) {
196 while (true) {
197 TF_ASSIGN_OR_RETURN(bool result, f());
198 if (!result) {
199 return OkStatus();
200 }
201 Env::Default()->SleepForMicroseconds(10 * 1000); // 10ms.
202 }
203 }
204
205 } // namespace testing
206 } // namespace data
207 } // namespace tensorflow
208