1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h"
17 
18 #include "tensorflow/core/framework/attr_value_util.h"
19 #include "tensorflow/core/framework/function_testlib.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
23 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 namespace {
30 
OptimizeWithAutotuneBufferSizes(const GrapplerItem & item,GraphDef * output,bool autotune)31 Status OptimizeWithAutotuneBufferSizes(const GrapplerItem &item,
32                                        GraphDef *output, bool autotune) {
33   AutotuneBufferSizes optimizer;
34   RewriterConfig_CustomGraphOptimizer config;
35   if (autotune) {
36     (*config.mutable_parameter_map())["autotune"].set_s("true");
37   } else {
38     (*config.mutable_parameter_map())["autotune"].set_s("false");
39   }
40   TF_RETURN_IF_ERROR(optimizer.Init(&config));
41   return optimizer.Optimize(nullptr, item, output);
42 }
43 
44 class SimpleInject : public ::testing::TestWithParam<string> {};
45 
TEST_P(SimpleInject,AutotuneBufferSizesTest)46 TEST_P(SimpleInject, AutotuneBufferSizesTest) {
47   const string async_dataset = GetParam();
48   using test::function::NDef;
49   GrapplerItem item;
50   if (async_dataset == "map") {
51     item.graph = test::function::GDef(
52         {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
53          NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
54          NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
55          NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
56          NDef("num_parallel_calls", "Const", {},
57               {{"value", 1}, {"dtype", DT_INT32}}),
58          graph_tests_utils::MakeParallelMapNode(
59              "map", "range", "num_parallel_calls", "XTimesTwo",
60              /*sloppy=*/false)},
61         // FunctionLib
62         {
63             test::function::XTimesTwo(),
64         });
65   } else if (async_dataset == "interleave") {
66     item.graph = test::function::GDef(
67         {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
68          NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
69          NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
70          NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
71          NDef("cycle_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
72          NDef("block_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
73          NDef("num_parallel_calls", "Const", {},
74               {{"value", 1}, {"dtype", DT_INT32}}),
75          graph_tests_utils::MakeParallelInterleaveV2Node(
76              "interleave", "range", "cycle_length", "block_length",
77              "num_parallel_calls", "XTimesTwo", /*sloppy=*/false)},
78         // FunctionLib
79         {
80             test::function::XTimesTwo(),
81         });
82   } else if (async_dataset == "map_and_batch") {
83     item.graph = test::function::GDef(
84         {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
85          NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
86          NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
87          NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
88          NDef("batch_size", "Const", {}, {{"value", 32}, {"dtype", DT_INT64}}),
89          NDef("num_parallel_calls", "Const", {},
90               {{"value", 1}, {"dtype", DT_INT64}}),
91          NDef("drop_remainder", "Const", {},
92               {{"value", false}, {"dtype", DT_BOOL}}),
93          graph_tests_utils::MakeMapAndBatchNode(
94              "map_and_batch", "range", "batch_size", "num_parallel_calls",
95              "drop_remainder", "XTimesTwo")},
96         // FunctionLib
97         {
98             test::function::XTimesTwo(),
99         });
100   }
101 
102   GraphDef output;
103   TF_ASSERT_OK(OptimizeWithAutotuneBufferSizes(item, &output, true));
104 
105   EXPECT_TRUE(graph_utils::ContainsNodeWithOp("PrefetchDataset", output));
106   int index = graph_utils::FindGraphNodeWithOp("PrefetchDataset", output);
107   const NodeDef prefetch_node = output.node(index);
108   EXPECT_TRUE(prefetch_node.attr().find("legacy_autotune") ==
109               prefetch_node.attr().end());
110   EXPECT_EQ(prefetch_node.input_size(), 2);
111   NodeDef async_node = output.node(
112       graph_utils::FindGraphNodeWithName(prefetch_node.input(0), output));
113   EXPECT_EQ(async_node.name(), async_dataset);
114   NodeDef buffer_size_val = output.node(
115       graph_utils::FindGraphNodeWithName(prefetch_node.input(1), output));
116   EXPECT_EQ(buffer_size_val.attr().at("value").tensor().int64_val(0), -1);
117 }
118 
119 INSTANTIATE_TEST_SUITE_P(Test, SimpleInject,
120                          ::testing::Values("map", "interleave",
121                                            "map_and_batch"));
122 
123 class AutotuneSetting : public ::testing::TestWithParam<bool> {};
124 
TEST_P(AutotuneSetting,AutotuneBufferSizesTest)125 TEST_P(AutotuneSetting, AutotuneBufferSizesTest) {
126   const bool autotune = GetParam();
127 
128   using test::function::NDef;
129   GrapplerItem item;
130   item.graph = test::function::GDef(
131       {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
132        NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
133        NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
134        NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
135        NDef("num_parallel_calls", "Const", {},
136             {{"value", 1}, {"dtype", DT_INT32}}),
137        graph_tests_utils::MakeParallelMapNode("map", "range",
138                                               "num_parallel_calls", "XTimesTwo",
139                                               /*sloppy=*/false)},
140       // FunctionLib
141       {
142           test::function::XTimesTwo(),
143       });
144 
145   GraphDef output;
146   TF_ASSERT_OK(OptimizeWithAutotuneBufferSizes(item, &output, autotune));
147   EXPECT_EQ(graph_utils::ContainsNodeWithOp("PrefetchDataset", output),
148             autotune);
149 }
150 
151 class MultipleNodes
152     : public ::testing::TestWithParam<std::tuple<bool, int64_t>> {};
153 
TEST_P(MultipleNodes,AutotuneBufferSizesTest)154 TEST_P(MultipleNodes, AutotuneBufferSizesTest) {
155   const bool legacy_autotune = std::get<0>(GetParam());
156   const int64_t initial_buffer_size = std::get<1>(GetParam());
157 
158   GrapplerItem item;
159   MutableGraphView graph(&item.graph);
160 
161   NodeDef *start_val = graph_utils::AddScalarConstNode<int64_t>(0, &graph);
162   NodeDef *stop_val = graph_utils::AddScalarConstNode<int64_t>(10, &graph);
163   NodeDef *step_val = graph_utils::AddScalarConstNode<int64_t>(1, &graph);
164 
165   std::vector<string> range_inputs(3);
166   range_inputs[0] = start_val->name();
167   range_inputs[1] = stop_val->name();
168   range_inputs[2] = step_val->name();
169   std::vector<std::pair<string, AttrValue>> range_attrs;
170   NodeDef *range_node = graph_utils::AddNode("range", "RangeDataset",
171                                              range_inputs, range_attrs, &graph);
172 
173   NodeDef *parallelism_val =
174       graph_utils::AddScalarConstNode<int64_t>(1, &graph);
175   std::vector<string> map_inputs1(2);
176   map_inputs1[0] = range_node->name();
177   map_inputs1[1] = parallelism_val->name();
178   std::vector<std::pair<string, AttrValue>> map_attrs(4);
179   AttrValue attr_val;
180   SetAttrValue("value", &attr_val);
181   map_attrs[0] = std::make_pair("f", attr_val);
182   map_attrs[1] = std::make_pair("Targuments", attr_val);
183   map_attrs[2] = std::make_pair("output_types", attr_val);
184   map_attrs[3] = std::make_pair("output_shapes", attr_val);
185   NodeDef *map_node1 = graph_utils::AddNode("map1", "ParallelMapDatasetV2",
186                                             map_inputs1, map_attrs, &graph);
187 
188   NodeDef *buffer_size_val =
189       graph_utils::AddScalarConstNode<int64_t>(initial_buffer_size, &graph);
190   std::vector<string> prefetch_inputs(2);
191   prefetch_inputs[0] = map_node1->name();
192   prefetch_inputs[1] = buffer_size_val->name();
193   std::vector<std::pair<string, AttrValue>> prefetch_attrs(4);
194   AttrValue legacy_autotune_attr;
195   SetAttrValue(legacy_autotune, &legacy_autotune_attr);
196   AttrValue buffer_size_min_attr;
197   SetAttrValue(0, &buffer_size_min_attr);
198   prefetch_attrs[0] = std::make_pair("legacy_autotune", legacy_autotune_attr);
199   prefetch_attrs[1] = std::make_pair("buffer_size_min", buffer_size_min_attr);
200   prefetch_attrs[2] = std::make_pair("output_types", attr_val);
201   prefetch_attrs[3] = std::make_pair("output_shapes", attr_val);
202   NodeDef *prefetch_node = graph_utils::AddNode(
203       "prefetch", "PrefetchDataset", prefetch_inputs, prefetch_attrs, &graph);
204 
205   std::vector<string> map_inputs2(2);
206   map_inputs2[0] = prefetch_node->name();
207   map_inputs2[1] = parallelism_val->name();
208   NodeDef *map_node2 = graph_utils::AddNode("map2", "ParallelMapDatasetV2",
209                                             map_inputs2, map_attrs, &graph);
210 
211   std::vector<string> map_inputs3(1);
212   map_inputs3[0] = map_node2->name();
213   graph_utils::AddNode("map3", "MapDataset", map_inputs3, map_attrs, &graph);
214 
215   GraphDef output;
216   TF_ASSERT_OK(OptimizeWithAutotuneBufferSizes(item, &output, true));
217 
218   std::vector<int> prefetch_indices =
219       graph_utils::FindAllGraphNodesWithOp("PrefetchDataset", output);
220   EXPECT_EQ(prefetch_indices.size(), 2);
221 
222   NodeDef new_map_node3 =
223       output.node(graph_utils::FindGraphNodeWithName("map3", output));
224 
225   NodeDef new_prefetch_node2 = output.node(
226       graph_utils::FindGraphNodeWithName(new_map_node3.input(0), output));
227   EXPECT_EQ(new_prefetch_node2.op(), "PrefetchDataset");
228   EXPECT_EQ(new_prefetch_node2.input_size(), 2);
229   EXPECT_TRUE(new_prefetch_node2.attr().find("legacy_autotune") ==
230               new_prefetch_node2.attr().end());
231   EXPECT_TRUE(new_prefetch_node2.attr().find("buffer_size_min") ==
232               new_prefetch_node2.attr().end());
233   NodeDef new_buffer_size_val2 = output.node(
234       graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(1), output));
235   EXPECT_EQ(new_buffer_size_val2.attr().at("value").tensor().int64_val(0), -1);
236 
237   NodeDef new_map_node2 = output.node(
238       graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(0), output));
239   EXPECT_EQ(new_map_node2.name(), "map2");
240 
241   NodeDef new_prefetch_node1 = output.node(
242       graph_utils::FindGraphNodeWithName(new_map_node2.input(0), output));
243   EXPECT_EQ(new_prefetch_node1.op(), "PrefetchDataset");
244   EXPECT_EQ(new_prefetch_node1.input_size(), 2);
245   EXPECT_EQ(new_prefetch_node1.attr().at("legacy_autotune").b(),
246             legacy_autotune);
247   EXPECT_EQ(new_prefetch_node1.attr().at("buffer_size_min").i(),
248             (initial_buffer_size == -1 ? 0 : initial_buffer_size));
249   NodeDef new_buffer_size_val1 = output.node(
250       graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(1), output));
251   EXPECT_EQ(new_buffer_size_val1.attr().at("value").tensor().int64_val(0), -1);
252 
253   NodeDef new_map_node1 = output.node(
254       graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(0), output));
255   EXPECT_EQ(new_map_node1.name(), "map1");
256 }
257 
258 INSTANTIATE_TEST_SUITE_P(Test, MultipleNodes,
259                          ::testing::Combine(::testing::Values(true, false),
260                                             ::testing::Values(-1, 3)));
261 
262 INSTANTIATE_TEST_SUITE_P(Test, AutotuneSetting, ::testing::Values(false, true));
263 
264 }  // namespace
265 }  // namespace grappler
266 }  // namespace tensorflow
267