xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
17 
18 #include "tensorflow/core/framework/function_testlib.h"
19 #include "tensorflow/core/framework/tensor_shape.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 
24 namespace tensorflow {
25 namespace grappler {
26 namespace graph_tests_utils {
27 
MakeBatchV2Node(StringPiece name,StringPiece input_node_name,StringPiece batch_size_node_name,StringPiece drop_remainder_node_name,bool parallel_copy)28 NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name,
29                         StringPiece batch_size_node_name,
30                         StringPiece drop_remainder_node_name,
31                         bool parallel_copy) {
32   return test::function::NDef(
33       name, "BatchDatasetV2",
34       {string(input_node_name), string(batch_size_node_name),
35        string(drop_remainder_node_name)},
36       {{"parallel_copy", parallel_copy},
37        {"output_shapes", gtl::ArraySlice<TensorShape>{}},
38        {"output_types", gtl::ArraySlice<DataType>{}}});
39 }
40 
MakeParallelBatchNode(StringPiece name,StringPiece input_node_name,StringPiece batch_size_node_name,StringPiece num_parallel_calls_node_name,StringPiece drop_remainder_node_name,StringPiece deterministic)41 NodeDef MakeParallelBatchNode(StringPiece name, StringPiece input_node_name,
42                               StringPiece batch_size_node_name,
43                               StringPiece num_parallel_calls_node_name,
44                               StringPiece drop_remainder_node_name,
45                               StringPiece deterministic) {
46   return test::function::NDef(
47       name, "ParallelBatchDataset",
48       {string(input_node_name), string(batch_size_node_name),
49        string(num_parallel_calls_node_name), string(drop_remainder_node_name)},
50       {{"output_shapes", gtl::ArraySlice<TensorShape>{}},
51        {"output_types", gtl::ArraySlice<DataType>{}},
52        {"deterministic", string(deterministic)}});
53 }
54 
MakeCacheV2Node(StringPiece name,StringPiece input_node_name,StringPiece filename_node_name,StringPiece cache_node_name)55 NodeDef MakeCacheV2Node(StringPiece name, StringPiece input_node_name,
56                         StringPiece filename_node_name,
57                         StringPiece cache_node_name) {
58   return test::function::NDef(
59       name, "CacheDatasetV2",
60       {
61           string(input_node_name),
62           string(filename_node_name),
63           string(cache_node_name),
64       },
65       {
66           {"output_shapes", gtl::ArraySlice<TensorShape>{}},
67           {"output_types", gtl::ArraySlice<DataType>{}},
68       });
69 }
70 
MakeFilterNode(StringPiece name,StringPiece input_node_name,StringPiece function_name)71 NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
72                        StringPiece function_name) {
73   return test::function::NDef(
74       name, "FilterDataset", {string(input_node_name)},
75       {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
76        {"Targuments", {}},
77        {"output_shapes", gtl::ArraySlice<TensorShape>{}},
78        {"output_types", gtl::ArraySlice<DataType>{}}});
79 }
80 
MakeMapAndBatchNode(StringPiece name,StringPiece input_node_name,StringPiece batch_size_node_name,StringPiece num_parallel_calls_node_name,StringPiece drop_remainder_node_name,StringPiece function_name)81 NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name,
82                             StringPiece batch_size_node_name,
83                             StringPiece num_parallel_calls_node_name,
84                             StringPiece drop_remainder_node_name,
85                             StringPiece function_name) {
86   return test::function::NDef(
87       name, "MapAndBatchDataset",
88       {string(input_node_name), string(batch_size_node_name),
89        string(num_parallel_calls_node_name), string(drop_remainder_node_name)},
90       {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
91        {"Targuments", {}},
92        {"output_shapes", gtl::ArraySlice<TensorShape>{}},
93        {"output_types", gtl::ArraySlice<DataType>{}}});
94 }
95 
MakeMapNode(StringPiece name,StringPiece input_node_name,StringPiece function_name)96 NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
97                     StringPiece function_name) {
98   return test::function::NDef(
99       name, "MapDataset", {string(input_node_name)},
100       {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
101        {"Targuments", {}},
102        {"output_shapes", gtl::ArraySlice<TensorShape>{}},
103        {"output_types", gtl::ArraySlice<DataType>{}}});
104 }
105 
MakeParallelInterleaveV2Node(StringPiece name,StringPiece input_node_name,StringPiece cycle_length_node_name,StringPiece block_length_node_name,StringPiece num_parallel_calls_node_name,StringPiece function_name,bool sloppy)106 NodeDef MakeParallelInterleaveV2Node(StringPiece name,
107                                      StringPiece input_node_name,
108                                      StringPiece cycle_length_node_name,
109                                      StringPiece block_length_node_name,
110                                      StringPiece num_parallel_calls_node_name,
111                                      StringPiece function_name, bool sloppy) {
112   return test::function::NDef(
113       name, "ParallelInterleaveDatasetV2",
114       {string(input_node_name), string(cycle_length_node_name),
115        string(block_length_node_name), string(num_parallel_calls_node_name)},
116       {
117           {"f", FunctionDefHelper::FunctionRef(string(function_name))},
118           {"Targuments", {}},
119           {"output_shapes", gtl::ArraySlice<TensorShape>{}},
120           {"output_types", gtl::ArraySlice<DataType>{}},
121           {"sloppy", sloppy},
122       });
123 }
124 
MakeParallelInterleaveV4Node(StringPiece name,StringPiece input_node_name,StringPiece cycle_length_node_name,StringPiece block_length_node_name,StringPiece num_parallel_calls_node_name,StringPiece function_name,StringPiece deterministic)125 NodeDef MakeParallelInterleaveV4Node(StringPiece name,
126                                      StringPiece input_node_name,
127                                      StringPiece cycle_length_node_name,
128                                      StringPiece block_length_node_name,
129                                      StringPiece num_parallel_calls_node_name,
130                                      StringPiece function_name,
131                                      StringPiece deterministic) {
132   return test::function::NDef(
133       name, "ParallelInterleaveDatasetV4",
134       {string(input_node_name), string(cycle_length_node_name),
135        string(block_length_node_name), string(num_parallel_calls_node_name)},
136       {
137           {"f", FunctionDefHelper::FunctionRef(string(function_name))},
138           {"Targuments", {}},
139           {"output_shapes", gtl::ArraySlice<TensorShape>{}},
140           {"output_types", gtl::ArraySlice<DataType>{}},
141           {"deterministic", string(deterministic)},
142       });
143 }
144 
MakeParallelMapNode(StringPiece name,StringPiece input_node_name,StringPiece num_parallel_calls_node_name,StringPiece function_name,bool sloppy)145 NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name,
146                             StringPiece num_parallel_calls_node_name,
147                             StringPiece function_name, bool sloppy) {
148   return test::function::NDef(
149       name, "ParallelMapDataset",
150       {string(input_node_name), string(num_parallel_calls_node_name)},
151       {
152           {"f", FunctionDefHelper::FunctionRef(string(function_name))},
153           {"Targuments", {}},
154           {"output_shapes", gtl::ArraySlice<TensorShape>{}},
155           {"output_types", gtl::ArraySlice<DataType>{}},
156           {"sloppy", sloppy},
157       });
158 }
159 
MakeParallelMapV2Node(StringPiece name,StringPiece input_node_name,StringPiece num_parallel_calls_node_name,StringPiece function_name,StringPiece deterministic)160 NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name,
161                               StringPiece num_parallel_calls_node_name,
162                               StringPiece function_name,
163                               StringPiece deterministic) {
164   return test::function::NDef(
165       name, "ParallelMapDatasetV2",
166       {string(input_node_name), string(num_parallel_calls_node_name)},
167       {
168           {"f", FunctionDefHelper::FunctionRef(string(function_name))},
169           {"Targuments", {}},
170           {"output_shapes", gtl::ArraySlice<TensorShape>{}},
171           {"output_types", gtl::ArraySlice<DataType>{}},
172           {"deterministic", string(deterministic)},
173       });
174 }
175 
MakeParseExampleNode(StringPiece name,StringPiece input_node_name,StringPiece num_parallel_calls_node_name,bool sloppy)176 NodeDef MakeParseExampleNode(StringPiece name, StringPiece input_node_name,
177                              StringPiece num_parallel_calls_node_name,
178                              bool sloppy) {
179   return test::function::NDef(
180       name, "ParseExampleDataset",
181       {string(input_node_name), string(num_parallel_calls_node_name)},
182       {
183           {"output_shapes", gtl::ArraySlice<TensorShape>{}},
184           {"output_types", gtl::ArraySlice<DataType>{}},
185           {"sloppy", sloppy},
186       });
187 }
188 
MakeShuffleV2Node(StringPiece name,StringPiece input_node_name,StringPiece buffer_size_node_name,StringPiece seed_generator_node_name)189 NodeDef MakeShuffleV2Node(StringPiece name, StringPiece input_node_name,
190                           StringPiece buffer_size_node_name,
191                           StringPiece seed_generator_node_name) {
192   return test::function::NDef(
193       name, "ShuffleDatasetV2",
194       {
195           string(input_node_name),
196           string(buffer_size_node_name),
197           string(seed_generator_node_name),
198       },
199       {
200           {"output_shapes", gtl::ArraySlice<TensorShape>{}},
201           {"output_types", gtl::ArraySlice<DataType>{}},
202       });
203 }
204 
MakeTakeNode(StringPiece name,StringPiece input_node_name,StringPiece count_node_name)205 NodeDef MakeTakeNode(StringPiece name, StringPiece input_node_name,
206                      StringPiece count_node_name) {
207   return test::function::NDef(
208       name, "TakeDataset",
209       {
210           string(input_node_name),
211           string(count_node_name),
212       },
213       {
214           {"output_shapes", gtl::ArraySlice<TensorShape>{}},
215           {"output_types", gtl::ArraySlice<DataType>{}},
216       });
217 }
218 
MakeTensorSliceNode(StringPiece name,StringPiece tensor_node_name,bool replicate_on_split)219 NodeDef MakeTensorSliceNode(StringPiece name, StringPiece tensor_node_name,
220                             bool replicate_on_split) {
221   return test::function::NDef(
222       name, "TensorSliceDataset",
223       {
224           string(tensor_node_name),
225       },
226       {
227           {"output_shapes", gtl::ArraySlice<TensorShape>{}},
228           {"output_types", gtl::ArraySlice<DataType>{}},
229           {"replicate_on_split", replicate_on_split},
230       });
231 }
232 
MakeSkipNode(StringPiece name,StringPiece input_node_name,StringPiece count_node_name)233 NodeDef MakeSkipNode(StringPiece name, StringPiece input_node_name,
234                      StringPiece count_node_name) {
235   return test::function::NDef(
236       name, "SkipDataset",
237       {
238           string(input_node_name),
239           string(count_node_name),
240       },
241       {
242           {"output_shapes", gtl::ArraySlice<TensorShape>{}},
243           {"output_types", gtl::ArraySlice<DataType>{}},
244       });
245 }
246 
MakeShardNode(StringPiece name,StringPiece input_node_name,StringPiece num_shards_node_name,StringPiece index_node_name)247 NodeDef MakeShardNode(StringPiece name, StringPiece input_node_name,
248                       StringPiece num_shards_node_name,
249                       StringPiece index_node_name) {
250   return test::function::NDef(
251       name, "ShardDataset",
252       {
253           string(input_node_name),
254           string(num_shards_node_name),
255           string(index_node_name),
256       },
257       {
258           {"output_shapes", gtl::ArraySlice<TensorShape>{}},
259           {"output_types", gtl::ArraySlice<DataType>{}},
260       });
261 }
262 
MakePrefetchNode(StringPiece name,StringPiece input_node_name,StringPiece buffer_size)263 NodeDef MakePrefetchNode(StringPiece name, StringPiece input_node_name,
264                          StringPiece buffer_size) {
265   return test::function::NDef(
266       name, "PrefetchDataset", {string(input_node_name), string(buffer_size)},
267       {{"output_shapes", gtl::ArraySlice<TensorShape>{}},
268        {"output_types", gtl::ArraySlice<DataType>{}},
269        {"slack_period", 0},
270        {"legacy_autotune", true},
271        {"buffer_size_min", 0}});
272 }
273 
274 }  // namespace graph_tests_utils
275 }  // namespace grappler
276 }  // namespace tensorflow
277