1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h"
17
18 #include "tensorflow/core/framework/attr_value_util.h"
19 #include "tensorflow/core/framework/function_testlib.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
23 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace grappler {
29 namespace {
30
OptimizeWithAutotuneBufferSizes(const GrapplerItem & item,GraphDef * output,bool autotune)31 Status OptimizeWithAutotuneBufferSizes(const GrapplerItem &item,
32 GraphDef *output, bool autotune) {
33 AutotuneBufferSizes optimizer;
34 RewriterConfig_CustomGraphOptimizer config;
35 if (autotune) {
36 (*config.mutable_parameter_map())["autotune"].set_s("true");
37 } else {
38 (*config.mutable_parameter_map())["autotune"].set_s("false");
39 }
40 TF_RETURN_IF_ERROR(optimizer.Init(&config));
41 return optimizer.Optimize(nullptr, item, output);
42 }
43
44 class SimpleInject : public ::testing::TestWithParam<string> {};
45
TEST_P(SimpleInject,AutotuneBufferSizesTest)46 TEST_P(SimpleInject, AutotuneBufferSizesTest) {
47 const string async_dataset = GetParam();
48 using test::function::NDef;
49 GrapplerItem item;
50 if (async_dataset == "map") {
51 item.graph = test::function::GDef(
52 {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
53 NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
54 NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
55 NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
56 NDef("num_parallel_calls", "Const", {},
57 {{"value", 1}, {"dtype", DT_INT32}}),
58 graph_tests_utils::MakeParallelMapNode(
59 "map", "range", "num_parallel_calls", "XTimesTwo",
60 /*sloppy=*/false)},
61 // FunctionLib
62 {
63 test::function::XTimesTwo(),
64 });
65 } else if (async_dataset == "interleave") {
66 item.graph = test::function::GDef(
67 {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
68 NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
69 NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
70 NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
71 NDef("cycle_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
72 NDef("block_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
73 NDef("num_parallel_calls", "Const", {},
74 {{"value", 1}, {"dtype", DT_INT32}}),
75 graph_tests_utils::MakeParallelInterleaveV2Node(
76 "interleave", "range", "cycle_length", "block_length",
77 "num_parallel_calls", "XTimesTwo", /*sloppy=*/false)},
78 // FunctionLib
79 {
80 test::function::XTimesTwo(),
81 });
82 } else if (async_dataset == "map_and_batch") {
83 item.graph = test::function::GDef(
84 {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
85 NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
86 NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
87 NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
88 NDef("batch_size", "Const", {}, {{"value", 32}, {"dtype", DT_INT64}}),
89 NDef("num_parallel_calls", "Const", {},
90 {{"value", 1}, {"dtype", DT_INT64}}),
91 NDef("drop_remainder", "Const", {},
92 {{"value", false}, {"dtype", DT_BOOL}}),
93 graph_tests_utils::MakeMapAndBatchNode(
94 "map_and_batch", "range", "batch_size", "num_parallel_calls",
95 "drop_remainder", "XTimesTwo")},
96 // FunctionLib
97 {
98 test::function::XTimesTwo(),
99 });
100 }
101
102 GraphDef output;
103 TF_ASSERT_OK(OptimizeWithAutotuneBufferSizes(item, &output, true));
104
105 EXPECT_TRUE(graph_utils::ContainsNodeWithOp("PrefetchDataset", output));
106 int index = graph_utils::FindGraphNodeWithOp("PrefetchDataset", output);
107 const NodeDef prefetch_node = output.node(index);
108 EXPECT_TRUE(prefetch_node.attr().find("legacy_autotune") ==
109 prefetch_node.attr().end());
110 EXPECT_EQ(prefetch_node.input_size(), 2);
111 NodeDef async_node = output.node(
112 graph_utils::FindGraphNodeWithName(prefetch_node.input(0), output));
113 EXPECT_EQ(async_node.name(), async_dataset);
114 NodeDef buffer_size_val = output.node(
115 graph_utils::FindGraphNodeWithName(prefetch_node.input(1), output));
116 EXPECT_EQ(buffer_size_val.attr().at("value").tensor().int64_val(0), -1);
117 }
118
119 INSTANTIATE_TEST_SUITE_P(Test, SimpleInject,
120 ::testing::Values("map", "interleave",
121 "map_and_batch"));
122
123 class AutotuneSetting : public ::testing::TestWithParam<bool> {};
124
TEST_P(AutotuneSetting,AutotuneBufferSizesTest)125 TEST_P(AutotuneSetting, AutotuneBufferSizesTest) {
126 const bool autotune = GetParam();
127
128 using test::function::NDef;
129 GrapplerItem item;
130 item.graph = test::function::GDef(
131 {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
132 NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
133 NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
134 NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
135 NDef("num_parallel_calls", "Const", {},
136 {{"value", 1}, {"dtype", DT_INT32}}),
137 graph_tests_utils::MakeParallelMapNode("map", "range",
138 "num_parallel_calls", "XTimesTwo",
139 /*sloppy=*/false)},
140 // FunctionLib
141 {
142 test::function::XTimesTwo(),
143 });
144
145 GraphDef output;
146 TF_ASSERT_OK(OptimizeWithAutotuneBufferSizes(item, &output, autotune));
147 EXPECT_EQ(graph_utils::ContainsNodeWithOp("PrefetchDataset", output),
148 autotune);
149 }
150
151 class MultipleNodes
152 : public ::testing::TestWithParam<std::tuple<bool, int64_t>> {};
153
TEST_P(MultipleNodes,AutotuneBufferSizesTest)154 TEST_P(MultipleNodes, AutotuneBufferSizesTest) {
155 const bool legacy_autotune = std::get<0>(GetParam());
156 const int64_t initial_buffer_size = std::get<1>(GetParam());
157
158 GrapplerItem item;
159 MutableGraphView graph(&item.graph);
160
161 NodeDef *start_val = graph_utils::AddScalarConstNode<int64_t>(0, &graph);
162 NodeDef *stop_val = graph_utils::AddScalarConstNode<int64_t>(10, &graph);
163 NodeDef *step_val = graph_utils::AddScalarConstNode<int64_t>(1, &graph);
164
165 std::vector<string> range_inputs(3);
166 range_inputs[0] = start_val->name();
167 range_inputs[1] = stop_val->name();
168 range_inputs[2] = step_val->name();
169 std::vector<std::pair<string, AttrValue>> range_attrs;
170 NodeDef *range_node = graph_utils::AddNode("range", "RangeDataset",
171 range_inputs, range_attrs, &graph);
172
173 NodeDef *parallelism_val =
174 graph_utils::AddScalarConstNode<int64_t>(1, &graph);
175 std::vector<string> map_inputs1(2);
176 map_inputs1[0] = range_node->name();
177 map_inputs1[1] = parallelism_val->name();
178 std::vector<std::pair<string, AttrValue>> map_attrs(4);
179 AttrValue attr_val;
180 SetAttrValue("value", &attr_val);
181 map_attrs[0] = std::make_pair("f", attr_val);
182 map_attrs[1] = std::make_pair("Targuments", attr_val);
183 map_attrs[2] = std::make_pair("output_types", attr_val);
184 map_attrs[3] = std::make_pair("output_shapes", attr_val);
185 NodeDef *map_node1 = graph_utils::AddNode("map1", "ParallelMapDatasetV2",
186 map_inputs1, map_attrs, &graph);
187
188 NodeDef *buffer_size_val =
189 graph_utils::AddScalarConstNode<int64_t>(initial_buffer_size, &graph);
190 std::vector<string> prefetch_inputs(2);
191 prefetch_inputs[0] = map_node1->name();
192 prefetch_inputs[1] = buffer_size_val->name();
193 std::vector<std::pair<string, AttrValue>> prefetch_attrs(4);
194 AttrValue legacy_autotune_attr;
195 SetAttrValue(legacy_autotune, &legacy_autotune_attr);
196 AttrValue buffer_size_min_attr;
197 SetAttrValue(0, &buffer_size_min_attr);
198 prefetch_attrs[0] = std::make_pair("legacy_autotune", legacy_autotune_attr);
199 prefetch_attrs[1] = std::make_pair("buffer_size_min", buffer_size_min_attr);
200 prefetch_attrs[2] = std::make_pair("output_types", attr_val);
201 prefetch_attrs[3] = std::make_pair("output_shapes", attr_val);
202 NodeDef *prefetch_node = graph_utils::AddNode(
203 "prefetch", "PrefetchDataset", prefetch_inputs, prefetch_attrs, &graph);
204
205 std::vector<string> map_inputs2(2);
206 map_inputs2[0] = prefetch_node->name();
207 map_inputs2[1] = parallelism_val->name();
208 NodeDef *map_node2 = graph_utils::AddNode("map2", "ParallelMapDatasetV2",
209 map_inputs2, map_attrs, &graph);
210
211 std::vector<string> map_inputs3(1);
212 map_inputs3[0] = map_node2->name();
213 graph_utils::AddNode("map3", "MapDataset", map_inputs3, map_attrs, &graph);
214
215 GraphDef output;
216 TF_ASSERT_OK(OptimizeWithAutotuneBufferSizes(item, &output, true));
217
218 std::vector<int> prefetch_indices =
219 graph_utils::FindAllGraphNodesWithOp("PrefetchDataset", output);
220 EXPECT_EQ(prefetch_indices.size(), 2);
221
222 NodeDef new_map_node3 =
223 output.node(graph_utils::FindGraphNodeWithName("map3", output));
224
225 NodeDef new_prefetch_node2 = output.node(
226 graph_utils::FindGraphNodeWithName(new_map_node3.input(0), output));
227 EXPECT_EQ(new_prefetch_node2.op(), "PrefetchDataset");
228 EXPECT_EQ(new_prefetch_node2.input_size(), 2);
229 EXPECT_TRUE(new_prefetch_node2.attr().find("legacy_autotune") ==
230 new_prefetch_node2.attr().end());
231 EXPECT_TRUE(new_prefetch_node2.attr().find("buffer_size_min") ==
232 new_prefetch_node2.attr().end());
233 NodeDef new_buffer_size_val2 = output.node(
234 graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(1), output));
235 EXPECT_EQ(new_buffer_size_val2.attr().at("value").tensor().int64_val(0), -1);
236
237 NodeDef new_map_node2 = output.node(
238 graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(0), output));
239 EXPECT_EQ(new_map_node2.name(), "map2");
240
241 NodeDef new_prefetch_node1 = output.node(
242 graph_utils::FindGraphNodeWithName(new_map_node2.input(0), output));
243 EXPECT_EQ(new_prefetch_node1.op(), "PrefetchDataset");
244 EXPECT_EQ(new_prefetch_node1.input_size(), 2);
245 EXPECT_EQ(new_prefetch_node1.attr().at("legacy_autotune").b(),
246 legacy_autotune);
247 EXPECT_EQ(new_prefetch_node1.attr().at("buffer_size_min").i(),
248 (initial_buffer_size == -1 ? 0 : initial_buffer_size));
249 NodeDef new_buffer_size_val1 = output.node(
250 graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(1), output));
251 EXPECT_EQ(new_buffer_size_val1.attr().at("value").tensor().int64_val(0), -1);
252
253 NodeDef new_map_node1 = output.node(
254 graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(0), output));
255 EXPECT_EQ(new_map_node1.name(), "map1");
256 }
257
258 INSTANTIATE_TEST_SUITE_P(Test, MultipleNodes,
259 ::testing::Combine(::testing::Values(true, false),
260 ::testing::Values(-1, 3)));
261
262 INSTANTIATE_TEST_SUITE_P(Test, AutotuneSetting, ::testing::Values(false, true));
263
264 } // namespace
265 } // namespace grappler
266 } // namespace tensorflow
267