xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/data/graph_test_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
18 
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/lib/core/stringpiece.h"
21 
22 namespace tensorflow {
23 namespace grappler {
24 namespace graph_tests_utils {
25 
26 // Creates a test NodeDef for BatchDatasetV2.
27 NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name,
28                         StringPiece batch_size_node_name,
29                         StringPiece drop_remainder_node_name,
30                         bool parallel_copy);
31 
32 // Creates a test NodeDef for ParallelBatchDataset.
33 NodeDef MakeParallelBatchNode(StringPiece name, StringPiece input_node_name,
34                               StringPiece batch_size_node_name,
35                               StringPiece num_parallel_calls_node_name,
36                               StringPiece drop_remainder_node_name,
37                               StringPiece deterministic);
38 
39 // Creates a test NodeDef for ShuffleDatasetV2.
40 NodeDef MakeCacheV2Node(StringPiece name, StringPiece input_node_name,
41                         StringPiece filename_node_name,
42                         StringPiece cache_node_name);
43 
44 // Creates a test NodeDef for FilterDataset.
45 NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
46                        StringPiece function_name = "IsZero");
47 
48 // Creates a test NodeDef for MapDataset.
49 NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
50                     StringPiece function_name = "XTimesTwo");
51 
52 // Creates a test NodeDef for MapAndBatchDataset.
53 NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name,
54                             StringPiece batch_size_node_name,
55                             StringPiece num_parallel_calls_node_name,
56                             StringPiece drop_remainder_node_name,
57                             StringPiece function_name = "XTimesTwo");
58 
59 // Creates a test NodeDef for ParallelInterleaveDatasetV2.
60 NodeDef MakeParallelInterleaveV2Node(StringPiece name,
61                                      StringPiece input_node_name,
62                                      StringPiece cycle_length_node_name,
63                                      StringPiece block_length_node_name,
64                                      StringPiece num_parallel_calls_node_name,
65                                      StringPiece function_name, bool sloppy);
66 
67 // Creates a test NodeDef for ParallelInterleaveDatasetV4.
68 NodeDef MakeParallelInterleaveV4Node(StringPiece name,
69                                      StringPiece input_node_name,
70                                      StringPiece cycle_length_node_name,
71                                      StringPiece block_length_node_name,
72                                      StringPiece num_parallel_calls_node_name,
73                                      StringPiece function_name,
74                                      StringPiece deterministic);
75 
76 // Creates a test NodeDef for ParallelMapDataset.
77 NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name,
78                             StringPiece num_parallel_calls_node_name,
79                             StringPiece function_name, bool sloppy);
80 
81 // Creates a test NodeDef for ParallelMapDatasetV2.
82 NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name,
83                               StringPiece num_parallel_calls_node_name,
84                               StringPiece function_name,
85                               StringPiece deterministic);
86 
87 // Creates a test NodeDef for ParseExampleDataset.
88 NodeDef MakeParseExampleNode(StringPiece name, StringPiece input_node_name,
89                              StringPiece num_parallel_calls_node_name,
90                              bool sloppy);
91 
92 // Creates a test NodeDef for ShuffleDatasetV2.
93 NodeDef MakeShuffleV2Node(StringPiece name, StringPiece input_node_name,
94                           StringPiece buffer_size_node_name,
95                           StringPiece seed_generator_node_name);
96 
97 // Creates a test NodeDef for TakeDataset.
98 NodeDef MakeTakeNode(StringPiece name, StringPiece input_node_name,
99                      StringPiece count_node_name);
100 
101 // Creates a test NodeDef for TensorSliceDataset.
102 NodeDef MakeTensorSliceNode(StringPiece name, StringPiece tensor_node_name,
103                             bool replicate_on_split);
104 
105 // Creates a test NodeDef for SkipDataset.
106 NodeDef MakeSkipNode(StringPiece name, StringPiece input_node_name,
107                      StringPiece count_node_name);
108 
109 // Creates a test NodeDef for ShardDataset.
110 NodeDef MakeShardNode(StringPiece name, StringPiece input_node_name,
111                       StringPiece num_shards_node_name,
112                       StringPiece index_node_name);
113 
114 // Creates a test NodeDef for PrefetchDataset.
115 NodeDef MakePrefetchNode(StringPiece name, StringPiece input_node_name,
116                          StringPiece buffer_size);
117 
118 }  // namespace graph_tests_utils
119 }  // namespace grappler
120 }  // namespace tensorflow
121 
122 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
123