1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
17
18 #include "tensorflow/core/framework/function_testlib.h"
19 #include "tensorflow/core/framework/tensor_shape.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23
24 namespace tensorflow {
25 namespace grappler {
26 namespace graph_tests_utils {
27
MakeBatchV2Node(StringPiece name,StringPiece input_node_name,StringPiece batch_size_node_name,StringPiece drop_remainder_node_name,bool parallel_copy)28 NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name,
29 StringPiece batch_size_node_name,
30 StringPiece drop_remainder_node_name,
31 bool parallel_copy) {
32 return test::function::NDef(
33 name, "BatchDatasetV2",
34 {string(input_node_name), string(batch_size_node_name),
35 string(drop_remainder_node_name)},
36 {{"parallel_copy", parallel_copy},
37 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
38 {"output_types", gtl::ArraySlice<DataType>{}}});
39 }
40
MakeParallelBatchNode(StringPiece name,StringPiece input_node_name,StringPiece batch_size_node_name,StringPiece num_parallel_calls_node_name,StringPiece drop_remainder_node_name,StringPiece deterministic)41 NodeDef MakeParallelBatchNode(StringPiece name, StringPiece input_node_name,
42 StringPiece batch_size_node_name,
43 StringPiece num_parallel_calls_node_name,
44 StringPiece drop_remainder_node_name,
45 StringPiece deterministic) {
46 return test::function::NDef(
47 name, "ParallelBatchDataset",
48 {string(input_node_name), string(batch_size_node_name),
49 string(num_parallel_calls_node_name), string(drop_remainder_node_name)},
50 {{"output_shapes", gtl::ArraySlice<TensorShape>{}},
51 {"output_types", gtl::ArraySlice<DataType>{}},
52 {"deterministic", string(deterministic)}});
53 }
54
MakeCacheV2Node(StringPiece name,StringPiece input_node_name,StringPiece filename_node_name,StringPiece cache_node_name)55 NodeDef MakeCacheV2Node(StringPiece name, StringPiece input_node_name,
56 StringPiece filename_node_name,
57 StringPiece cache_node_name) {
58 return test::function::NDef(
59 name, "CacheDatasetV2",
60 {
61 string(input_node_name),
62 string(filename_node_name),
63 string(cache_node_name),
64 },
65 {
66 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
67 {"output_types", gtl::ArraySlice<DataType>{}},
68 });
69 }
70
MakeFilterNode(StringPiece name,StringPiece input_node_name,StringPiece function_name)71 NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
72 StringPiece function_name) {
73 return test::function::NDef(
74 name, "FilterDataset", {string(input_node_name)},
75 {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
76 {"Targuments", {}},
77 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
78 {"output_types", gtl::ArraySlice<DataType>{}}});
79 }
80
MakeMapAndBatchNode(StringPiece name,StringPiece input_node_name,StringPiece batch_size_node_name,StringPiece num_parallel_calls_node_name,StringPiece drop_remainder_node_name,StringPiece function_name)81 NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name,
82 StringPiece batch_size_node_name,
83 StringPiece num_parallel_calls_node_name,
84 StringPiece drop_remainder_node_name,
85 StringPiece function_name) {
86 return test::function::NDef(
87 name, "MapAndBatchDataset",
88 {string(input_node_name), string(batch_size_node_name),
89 string(num_parallel_calls_node_name), string(drop_remainder_node_name)},
90 {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
91 {"Targuments", {}},
92 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
93 {"output_types", gtl::ArraySlice<DataType>{}}});
94 }
95
MakeMapNode(StringPiece name,StringPiece input_node_name,StringPiece function_name)96 NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
97 StringPiece function_name) {
98 return test::function::NDef(
99 name, "MapDataset", {string(input_node_name)},
100 {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
101 {"Targuments", {}},
102 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
103 {"output_types", gtl::ArraySlice<DataType>{}}});
104 }
105
MakeParallelInterleaveV2Node(StringPiece name,StringPiece input_node_name,StringPiece cycle_length_node_name,StringPiece block_length_node_name,StringPiece num_parallel_calls_node_name,StringPiece function_name,bool sloppy)106 NodeDef MakeParallelInterleaveV2Node(StringPiece name,
107 StringPiece input_node_name,
108 StringPiece cycle_length_node_name,
109 StringPiece block_length_node_name,
110 StringPiece num_parallel_calls_node_name,
111 StringPiece function_name, bool sloppy) {
112 return test::function::NDef(
113 name, "ParallelInterleaveDatasetV2",
114 {string(input_node_name), string(cycle_length_node_name),
115 string(block_length_node_name), string(num_parallel_calls_node_name)},
116 {
117 {"f", FunctionDefHelper::FunctionRef(string(function_name))},
118 {"Targuments", {}},
119 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
120 {"output_types", gtl::ArraySlice<DataType>{}},
121 {"sloppy", sloppy},
122 });
123 }
124
MakeParallelInterleaveV4Node(StringPiece name,StringPiece input_node_name,StringPiece cycle_length_node_name,StringPiece block_length_node_name,StringPiece num_parallel_calls_node_name,StringPiece function_name,StringPiece deterministic)125 NodeDef MakeParallelInterleaveV4Node(StringPiece name,
126 StringPiece input_node_name,
127 StringPiece cycle_length_node_name,
128 StringPiece block_length_node_name,
129 StringPiece num_parallel_calls_node_name,
130 StringPiece function_name,
131 StringPiece deterministic) {
132 return test::function::NDef(
133 name, "ParallelInterleaveDatasetV4",
134 {string(input_node_name), string(cycle_length_node_name),
135 string(block_length_node_name), string(num_parallel_calls_node_name)},
136 {
137 {"f", FunctionDefHelper::FunctionRef(string(function_name))},
138 {"Targuments", {}},
139 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
140 {"output_types", gtl::ArraySlice<DataType>{}},
141 {"deterministic", string(deterministic)},
142 });
143 }
144
MakeParallelMapNode(StringPiece name,StringPiece input_node_name,StringPiece num_parallel_calls_node_name,StringPiece function_name,bool sloppy)145 NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name,
146 StringPiece num_parallel_calls_node_name,
147 StringPiece function_name, bool sloppy) {
148 return test::function::NDef(
149 name, "ParallelMapDataset",
150 {string(input_node_name), string(num_parallel_calls_node_name)},
151 {
152 {"f", FunctionDefHelper::FunctionRef(string(function_name))},
153 {"Targuments", {}},
154 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
155 {"output_types", gtl::ArraySlice<DataType>{}},
156 {"sloppy", sloppy},
157 });
158 }
159
MakeParallelMapV2Node(StringPiece name,StringPiece input_node_name,StringPiece num_parallel_calls_node_name,StringPiece function_name,StringPiece deterministic)160 NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name,
161 StringPiece num_parallel_calls_node_name,
162 StringPiece function_name,
163 StringPiece deterministic) {
164 return test::function::NDef(
165 name, "ParallelMapDatasetV2",
166 {string(input_node_name), string(num_parallel_calls_node_name)},
167 {
168 {"f", FunctionDefHelper::FunctionRef(string(function_name))},
169 {"Targuments", {}},
170 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
171 {"output_types", gtl::ArraySlice<DataType>{}},
172 {"deterministic", string(deterministic)},
173 });
174 }
175
MakeParseExampleNode(StringPiece name,StringPiece input_node_name,StringPiece num_parallel_calls_node_name,bool sloppy)176 NodeDef MakeParseExampleNode(StringPiece name, StringPiece input_node_name,
177 StringPiece num_parallel_calls_node_name,
178 bool sloppy) {
179 return test::function::NDef(
180 name, "ParseExampleDataset",
181 {string(input_node_name), string(num_parallel_calls_node_name)},
182 {
183 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
184 {"output_types", gtl::ArraySlice<DataType>{}},
185 {"sloppy", sloppy},
186 });
187 }
188
MakeShuffleV2Node(StringPiece name,StringPiece input_node_name,StringPiece buffer_size_node_name,StringPiece seed_generator_node_name)189 NodeDef MakeShuffleV2Node(StringPiece name, StringPiece input_node_name,
190 StringPiece buffer_size_node_name,
191 StringPiece seed_generator_node_name) {
192 return test::function::NDef(
193 name, "ShuffleDatasetV2",
194 {
195 string(input_node_name),
196 string(buffer_size_node_name),
197 string(seed_generator_node_name),
198 },
199 {
200 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
201 {"output_types", gtl::ArraySlice<DataType>{}},
202 });
203 }
204
MakeTakeNode(StringPiece name,StringPiece input_node_name,StringPiece count_node_name)205 NodeDef MakeTakeNode(StringPiece name, StringPiece input_node_name,
206 StringPiece count_node_name) {
207 return test::function::NDef(
208 name, "TakeDataset",
209 {
210 string(input_node_name),
211 string(count_node_name),
212 },
213 {
214 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
215 {"output_types", gtl::ArraySlice<DataType>{}},
216 });
217 }
218
MakeTensorSliceNode(StringPiece name,StringPiece tensor_node_name,bool replicate_on_split)219 NodeDef MakeTensorSliceNode(StringPiece name, StringPiece tensor_node_name,
220 bool replicate_on_split) {
221 return test::function::NDef(
222 name, "TensorSliceDataset",
223 {
224 string(tensor_node_name),
225 },
226 {
227 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
228 {"output_types", gtl::ArraySlice<DataType>{}},
229 {"replicate_on_split", replicate_on_split},
230 });
231 }
232
MakeSkipNode(StringPiece name,StringPiece input_node_name,StringPiece count_node_name)233 NodeDef MakeSkipNode(StringPiece name, StringPiece input_node_name,
234 StringPiece count_node_name) {
235 return test::function::NDef(
236 name, "SkipDataset",
237 {
238 string(input_node_name),
239 string(count_node_name),
240 },
241 {
242 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
243 {"output_types", gtl::ArraySlice<DataType>{}},
244 });
245 }
246
MakeShardNode(StringPiece name,StringPiece input_node_name,StringPiece num_shards_node_name,StringPiece index_node_name)247 NodeDef MakeShardNode(StringPiece name, StringPiece input_node_name,
248 StringPiece num_shards_node_name,
249 StringPiece index_node_name) {
250 return test::function::NDef(
251 name, "ShardDataset",
252 {
253 string(input_node_name),
254 string(num_shards_node_name),
255 string(index_node_name),
256 },
257 {
258 {"output_shapes", gtl::ArraySlice<TensorShape>{}},
259 {"output_types", gtl::ArraySlice<DataType>{}},
260 });
261 }
262
MakePrefetchNode(StringPiece name,StringPiece input_node_name,StringPiece buffer_size)263 NodeDef MakePrefetchNode(StringPiece name, StringPiece input_node_name,
264 StringPiece buffer_size) {
265 return test::function::NDef(
266 name, "PrefetchDataset", {string(input_node_name), string(buffer_size)},
267 {{"output_shapes", gtl::ArraySlice<TensorShape>{}},
268 {"output_types", gtl::ArraySlice<DataType>{}},
269 {"slack_period", 0},
270 {"legacy_autotune", true},
271 {"buffer_size_min", 0}});
272 }
273
274 } // namespace graph_tests_utils
275 } // namespace grappler
276 } // namespace tensorflow
277