1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/model.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "tensorflow/core/framework/cancellation.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/lib/gtl/cleanup.h"
26 #include "tensorflow/core/lib/monitoring/cell_reader.h"
27 #include "tensorflow/core/platform/stringprintf.h"
28 #include "tensorflow/core/platform/test.h"
29
30 namespace tensorflow {
31 namespace data {
32 namespace model {
33 namespace {
34
35 using ::tensorflow::monitoring::testing::CellReader;
36 using ::testing::AllOf;
37 using ::testing::HasSubstr;
38
CountParametersOnNode(const string & node_name,const Model::ModelParameters & parameters)39 int64_t CountParametersOnNode(const string& node_name,
40 const Model::ModelParameters& parameters) {
41 int64_t cnt = 0;
42 for (const auto& pair : parameters) {
43 if (pair.first == node_name) {
44 cnt++;
45 }
46 }
47 return cnt;
48 }
49
50 class AsyncInterleaveManyTest
51 : public ::testing::TestWithParam<std::tuple<int64_t, double>> {};
52
TEST_P(AsyncInterleaveManyTest,Model)53 TEST_P(AsyncInterleaveManyTest, Model) {
54 const int64_t parallelism = std::get<0>(GetParam());
55 const double input_time = std::get<1>(GetParam());
56 std::shared_ptr<Node> async_interleave_many =
57 model::MakeAsyncInterleaveManyNode(
58 {0, "async_interleave_many", nullptr},
59 {model::MakeParameter("parallelism",
60 std::make_shared<SharedState>(
61 /*value=*/parallelism, nullptr, nullptr),
62 /*min=*/1,
63 /*max=*/8),
64 model::MakeParameter(kCycleLength, nullptr,
65 /*min=*/1,
66 /*max=*/1)});
67 std::shared_ptr<Node> meta_source =
68 model::MakeSourceNode({1, "meta_source", async_interleave_many});
69 async_interleave_many->add_input(meta_source);
70 auto cleanup_meta = gtl::MakeCleanup([async_interleave_many, meta_source]() {
71 async_interleave_many->remove_input(meta_source);
72 });
73 std::shared_ptr<Node> source1 =
74 model::MakeSourceNode({2, "source1", async_interleave_many});
75 async_interleave_many->add_input(source1);
76 auto cleanup1 = gtl::MakeCleanup([async_interleave_many, source1]() {
77 async_interleave_many->remove_input(source1);
78 });
79 std::shared_ptr<Node> source2 =
80 model::MakeSourceNode({3, "source2", async_interleave_many});
81 async_interleave_many->add_input(source2);
82 auto cleanup2 = gtl::MakeCleanup([async_interleave_many, source2]() {
83 async_interleave_many->remove_input(source2);
84 });
85 Model::NodeValues input_times;
86 input_times[kModelInputTimeKey] = input_time;
87 EXPECT_EQ(async_interleave_many->TotalBufferedBytes(), 0);
88 EXPECT_EQ(async_interleave_many->TotalMaximumBufferedBytes(), 0);
89 async_interleave_many->record_buffer_event(110, 10);
90 EXPECT_EQ(async_interleave_many->TotalBufferedBytes(), 110);
91 EXPECT_EQ(async_interleave_many->TotalMaximumBufferedBytes(),
92 110 * parallelism / 10);
93 async_interleave_many->add_processing_time(100);
94 EXPECT_EQ(async_interleave_many->processing_time(), 100);
95 EXPECT_EQ(
96 async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
97 0);
98 EXPECT_EQ(async_interleave_many->OutputTime(&input_times, nullptr), 0);
99 async_interleave_many->record_element();
100 EXPECT_EQ(async_interleave_many->num_elements(), 1);
101 EXPECT_EQ(
102 async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
103 100);
104 EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr), 100);
105 EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
106 source1->add_processing_time(200);
107 source2->add_processing_time(300);
108 EXPECT_EQ(
109 async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
110 100);
111 EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr), 100);
112 EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
113 source1->record_element();
114 source2->record_element();
115 EXPECT_EQ(
116 async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
117 100 + 250);
118 EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr),
119 100 + 250 / parallelism);
120 EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
121 async_interleave_many->record_element();
122 EXPECT_EQ(
123 async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
124 50 + 250);
125 EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr),
126 50 + 250 / parallelism);
127 EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
128 }
129
130 INSTANTIATE_TEST_SUITE_P(Test, AsyncInterleaveManyTest,
131 ::testing::Combine(::testing::Values(1, 2),
132 ::testing::Values(0, 50, 100,
133 200)));
134
135 class AsyncKnownRatioTest
136 : public ::testing::TestWithParam<std::tuple<int64_t, double, int64_t>> {};
137
TEST_P(AsyncKnownRatioTest,Model)138 TEST_P(AsyncKnownRatioTest, Model) {
139 const int64_t parallelism = std::get<0>(GetParam());
140 const double input_time = std::get<1>(GetParam());
141 const int64_t num_inputs_per_output = std::get<2>(GetParam());
142 std::shared_ptr<Node> async_known_many = model::MakeAsyncKnownRatioNode(
143 {0, "async_known_many", nullptr}, num_inputs_per_output,
144 {model::MakeParameter("parallelism",
145 std::make_shared<SharedState>(/*value=*/parallelism,
146 nullptr, nullptr),
147 /*min=*/1,
148 /*max=*/16)});
149 std::shared_ptr<Node> source1 =
150 model::MakeSourceNode({1, "source1", async_known_many});
151 async_known_many->add_input(source1);
152 std::shared_ptr<Node> source2 =
153 model::MakeSourceNode({2, "source2", async_known_many});
154 async_known_many->add_input(source2);
155 Model::NodeValues input_times;
156 input_times[kModelInputTimeKey] = input_time;
157 EXPECT_EQ(async_known_many->TotalBufferedBytes(), 0);
158 EXPECT_EQ(async_known_many->TotalMaximumBufferedBytes(), 0);
159 async_known_many->record_buffer_event(110, 10);
160 EXPECT_EQ(async_known_many->TotalBufferedBytes(), 110);
161 EXPECT_EQ(async_known_many->TotalMaximumBufferedBytes(),
162 num_inputs_per_output == 0
163 ? 110.0 * parallelism / 10
164 : 110.0 * parallelism / 10 / num_inputs_per_output);
165 source1->add_processing_time(100);
166 EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
167 0);
168 EXPECT_EQ(async_known_many->OutputTime(&input_times, nullptr), 0);
169 source2->add_processing_time(200);
170 EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
171 0);
172 EXPECT_EQ(async_known_many->OutputTime(&input_times, nullptr), 0);
173 source1->record_element();
174 EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
175 num_inputs_per_output * 100);
176 EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
177 num_inputs_per_output * 100);
178 EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
179 source2->record_element();
180 EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
181 num_inputs_per_output * (100 + 200));
182 EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
183 num_inputs_per_output * (100 + 200));
184 EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
185 source1->record_element();
186 EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
187 num_inputs_per_output * (50 + 200));
188 EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
189 num_inputs_per_output * (50 + 200));
190 EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
191 source2->record_element();
192 EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
193 num_inputs_per_output * (50 + 100));
194 EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
195 num_inputs_per_output * (50 + 100));
196 EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
197 async_known_many->add_processing_time(128);
198 EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
199 num_inputs_per_output * (50 + 100));
200 EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
201 num_inputs_per_output * (50 + 100));
202 EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
203 async_known_many->record_element();
204 EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
205 num_inputs_per_output * (50 + 100) + 128);
206 EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
207 num_inputs_per_output * (50 + 100) + 128 / parallelism);
208 EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
209 async_known_many->record_element();
210 EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
211 num_inputs_per_output * (50 + 100) + 64);
212 EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
213 num_inputs_per_output * (50 + 100) + 64 / parallelism);
214 EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
215 }
216
217 INSTANTIATE_TEST_SUITE_P(Test, AsyncKnownRatioTest,
218 ::testing::Combine(::testing::Values(1, 2, 4, 8),
219 ::testing::Values(0, 50, 100, 200),
220 ::testing::Values(0, 1, 2, 4)));
221
TEST(InterleaveManyTest,Model)222 TEST(InterleaveManyTest, Model) {
223 auto parameter =
224 model::MakeParameter("cycle_length", nullptr, /*min=*/1, /*max=*/1);
225 std::shared_ptr<Node> interleave_many = model::MakeInterleaveManyNode(
226 {0, "interleave_many", nullptr},
227 {model::MakeParameter("cycle_length", nullptr, /*min=*/1, /*max=*/1)});
228 std::shared_ptr<Node> meta_source =
229 model::MakeSourceNode({1, "meta_source", interleave_many});
230 interleave_many->add_input(meta_source);
231 std::shared_ptr<Node> source1 =
232 model::MakeSourceNode({2, "source1", interleave_many});
233 interleave_many->add_input(source1);
234 std::shared_ptr<Node> source2 =
235 model::MakeSourceNode({3, "source2", interleave_many});
236 interleave_many->add_input(source2);
237 Model::NodeValues input_times;
238 input_times[kModelInputTimeKey] = 0.0;
239 interleave_many->add_processing_time(100);
240 EXPECT_EQ(interleave_many->processing_time(), 100);
241 EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
242 0);
243 EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 0);
244 interleave_many->record_element();
245 EXPECT_EQ(interleave_many->num_elements(), 1);
246 EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
247 100);
248 EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 100);
249 source1->add_processing_time(200);
250 source2->add_processing_time(300);
251 EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
252 100);
253 EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 100);
254 source1->record_element();
255 source2->record_element();
256 EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
257 350);
258 EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 350);
259 interleave_many->record_element();
260 EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
261 300);
262 EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 300);
263 }
264
265 class KnownRatioTest : public ::testing::TestWithParam<int64_t> {};
266
TEST_P(KnownRatioTest,Model)267 TEST_P(KnownRatioTest, Model) {
268 const int64_t num_inputs_per_output = GetParam();
269 std::shared_ptr<Node> known_many = model::MakeKnownRatioNode(
270 {0, "known_many", nullptr}, num_inputs_per_output);
271 std::shared_ptr<Node> source1 =
272 model::MakeSourceNode({1, "source1", known_many});
273 known_many->add_input(source1);
274 std::shared_ptr<Node> source2 =
275 model::MakeSourceNode({2, "source2", known_many});
276 known_many->add_input(source2);
277 Model::NodeValues input_times;
278 input_times[kModelInputTimeKey] = 0.0;
279 source1->add_processing_time(100);
280 EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
281 EXPECT_EQ(known_many->OutputTime(&input_times, nullptr), 0);
282 source2->add_processing_time(200);
283 EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
284 EXPECT_EQ(known_many->OutputTime(&input_times, nullptr), 0);
285 source1->record_element();
286 EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
287 num_inputs_per_output * 100);
288 EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
289 num_inputs_per_output * 100);
290 source2->record_element();
291 EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
292 num_inputs_per_output * (100 + 200));
293 EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
294 num_inputs_per_output * (100 + 200));
295 source1->record_element();
296 EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
297 num_inputs_per_output * (50 + 200));
298 EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
299 num_inputs_per_output * (50 + 200));
300 source2->record_element();
301 EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
302 num_inputs_per_output * (50 + 100));
303 EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
304 num_inputs_per_output * (50 + 100));
305 known_many->add_processing_time(128);
306 EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
307 num_inputs_per_output * (50 + 100));
308 EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
309 num_inputs_per_output * (50 + 100));
310 known_many->record_element();
311 EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
312 num_inputs_per_output * (50 + 100) + 128);
313 EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
314 num_inputs_per_output * (50 + 100) + 128);
315 known_many->record_element();
316 EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
317 num_inputs_per_output * (50 + 100) + 64);
318 EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
319 num_inputs_per_output * (50 + 100) + 64);
320 }
321
322 INSTANTIATE_TEST_SUITE_P(Test, KnownRatioTest, ::testing::Values(0, 1, 2, 4));
323
TEST(SourceTest,Model)324 TEST(SourceTest, Model) {
325 std::shared_ptr<Node> source = model::MakeSourceNode({0, "source", nullptr});
326 Model::NodeValues input_times;
327 input_times[kModelInputTimeKey] = 0.0;
328 source->add_processing_time(100);
329 EXPECT_EQ(source->processing_time(), 100);
330 EXPECT_EQ(source->TotalProcessingTime(/*processing_times=*/nullptr), 0);
331 EXPECT_EQ(source->OutputTime(&input_times, nullptr), 0);
332 source->record_element();
333 EXPECT_EQ(source->num_elements(), 1);
334 EXPECT_EQ(source->TotalProcessingTime(/*processing_times=*/nullptr), 100);
335 EXPECT_EQ(source->OutputTime(&input_times, nullptr), 100);
336 source->record_element();
337 EXPECT_EQ(source->num_elements(), 2);
338 EXPECT_EQ(source->TotalProcessingTime(/*processing_times=*/nullptr), 50);
339 EXPECT_EQ(source->OutputTime(&input_times, nullptr), 50);
340 }
341
TEST(UnknownRatioTest,Model)342 TEST(UnknownRatioTest, Model) {
343 std::shared_ptr<Node> unknown_many =
344 model::MakeUnknownRatioNode({0, "unknown_many", nullptr});
345 std::shared_ptr<Node> source1 =
346 model::MakeSourceNode({1, "source1", unknown_many});
347 unknown_many->add_input(source1);
348 std::shared_ptr<Node> source2 =
349 model::MakeSourceNode({2, "source2", unknown_many});
350 unknown_many->add_input(source2);
351 Model::NodeValues input_times;
352 input_times[kModelInputTimeKey] = 0.0;
353 unknown_many->add_processing_time(100);
354 EXPECT_EQ(unknown_many->processing_time(), 100);
355 EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
356 EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 0);
357 unknown_many->record_element();
358 EXPECT_EQ(unknown_many->num_elements(), 1);
359 EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
360 100);
361 EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 100);
362 source1->add_processing_time(100);
363 source2->add_processing_time(200);
364 EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
365 100);
366 EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 100);
367 source1->record_element();
368 source2->record_element();
369 EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
370 400);
371 EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 400);
372 unknown_many->record_element();
373 EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
374 200);
375 EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 200);
376 }
377
378 class AsyncUnknownRatioTest
379 : public ::testing::TestWithParam<std::tuple<int64_t, double>> {};
380
TEST_P(AsyncUnknownRatioTest,Model)381 TEST_P(AsyncUnknownRatioTest, Model) {
382 const int64_t parallelism = std::get<0>(GetParam());
383 const double input_time = std::get<1>(GetParam());
384 std::shared_ptr<Node> async_unknown_many = model::MakeAsyncUnknownRatioNode(
385 {0, "async_unknown_many", nullptr},
386 {model::MakeParameter("parallelism",
387 std::make_shared<SharedState>(/*value=*/parallelism,
388 nullptr, nullptr),
389 /*min=*/1,
390 /*max=*/16)});
391 std::shared_ptr<Node> source1 =
392 model::MakeSourceNode({1, "source1", async_unknown_many});
393 async_unknown_many->add_input(source1);
394 std::shared_ptr<Node> source2 =
395 model::MakeSourceNode({2, "source2", async_unknown_many});
396 async_unknown_many->add_input(source2);
397 Model::NodeValues input_times;
398 input_times[kModelInputTimeKey] = input_time;
399 EXPECT_EQ(async_unknown_many->TotalBufferedBytes(), 0);
400 EXPECT_EQ(async_unknown_many->TotalMaximumBufferedBytes(), 0);
401 async_unknown_many->record_buffer_event(110, 10);
402 EXPECT_EQ(async_unknown_many->TotalBufferedBytes(), 110);
403 EXPECT_EQ(async_unknown_many->TotalMaximumBufferedBytes(),
404 110.0 * parallelism / 10);
405 source1->add_processing_time(100);
406 EXPECT_EQ(
407 async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
408 EXPECT_EQ(async_unknown_many->OutputTime(&input_times, nullptr), 0);
409 source2->add_processing_time(200);
410 EXPECT_EQ(
411 async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
412 EXPECT_EQ(async_unknown_many->OutputTime(&input_times, nullptr), 0);
413 source1->record_element();
414 EXPECT_EQ(
415 async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
416 EXPECT_EQ(async_unknown_many->OutputTime(&input_times, nullptr), 0);
417 async_unknown_many->record_element();
418 // Estimated ratio is 1.
419 double ratio = 1.0;
420 EXPECT_EQ(
421 async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
422 ratio * 100);
423 EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr), 100);
424 EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr), 0);
425 source2->record_element();
426 EXPECT_EQ(
427 async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
428 ratio * (100 + 200));
429 EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
430 ratio * (100 + 200));
431 EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr), 0);
432 source2->record_element();
433 EXPECT_EQ(
434 async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
435 ratio * (100 + 100));
436 EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
437 ratio * (100 + 100));
438 EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr), 0);
439 source1->record_element();
440 // Estimated ratio is 2
441 ratio = 2.0;
442 EXPECT_EQ(
443 async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
444 ratio * (50 + 100));
445 EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
446 ratio * (50 + 100));
447 EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr), 0);
448 source2->record_element();
449 source2->record_element();
450 EXPECT_EQ(
451 async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
452 ratio * (50 + 50));
453 EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
454 ratio * (50 + 50));
455 EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr), 0);
456 async_unknown_many->add_processing_time(128);
457 EXPECT_EQ(
458 async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
459 ratio * (50 + 50) + 128);
460 EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
461 ratio * (50 + 50) + 128 / parallelism);
462 EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr),
463 128 / parallelism);
464 async_unknown_many->record_element();
465 // Estimated ratio is 1.0
466 ratio = 1.0;
467 EXPECT_EQ(
468 async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
469 ratio * (50 + 50) + 128 / 2);
470 EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
471 ratio * (50 + 50) + 128 / 2 / parallelism);
472 EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr),
473 128 / 2 / parallelism);
474 async_unknown_many->record_element();
475 // Estimated ratio is 2/3
476 ratio = 2.0 / 3.0;
477 EXPECT_FLOAT_EQ(
478 async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
479 ratio * (50 + 50) + 128 / 3.0);
480 EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
481 ratio * (50 + 50) + 128 / 3.0 / parallelism);
482 EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr),
483 128 / 3.0 / parallelism);
484 }
485
486 INSTANTIATE_TEST_SUITE_P(Test, AsyncUnknownRatioTest,
487 ::testing::Combine(::testing::Values(1, 2, 4, 8),
488 ::testing::Values(0, 50, 100,
489 200)));
490
TEST(UnknownTest,Model)491 TEST(UnknownTest, Model) {
492 std::shared_ptr<Node> unknown =
493 model::MakeUnknownNode({0, "unknown", nullptr});
494 std::shared_ptr<Node> source1 =
495 model::MakeSourceNode({1, "source1", unknown});
496 unknown->add_input(source1);
497 std::shared_ptr<Node> source2 =
498 model::MakeSourceNode({2, "source2", unknown});
499 unknown->add_input(source2);
500 Model::NodeValues input_times;
501 input_times[kModelInputTimeKey] = 0.0;
502 source1->add_processing_time(100);
503 EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 0);
504 EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 0);
505 source2->add_processing_time(100);
506 EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 0);
507 EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 0);
508 source1->record_element();
509 EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
510 EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
511 source2->record_element();
512 EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 200);
513 EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 200);
514 source1->record_element();
515 EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 150);
516 EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 150);
517 source2->record_element();
518 EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
519 EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
520 // Unknown node processing time should not affect its TotalProcessingTime() or
521 // OutputTime().
522 unknown->add_processing_time(100);
523 EXPECT_EQ(unknown->processing_time(), 100);
524 EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
525 EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
526 // Unknown node number of elements should not affect its TotalProcessingTime()
527 // or OutputTime().
528 unknown->record_element();
529 EXPECT_EQ(unknown->num_elements(), 1);
530 EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
531 EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
532 }
533
TEST(BufferedBytesTest,Node)534 TEST(BufferedBytesTest, Node) {
535 std::shared_ptr<Node> node = model::MakeAsyncInterleaveManyNode(
536 {-1, "TestNode", nullptr},
537 {model::MakeParameter(
538 "parallelism",
539 std::make_shared<SharedState>(/*value=*/3, nullptr, nullptr),
540 /*min=*/1, /*max=*/7),
541 model::MakeParameter(kCycleLength, nullptr,
542 /*min=*/1,
543 /*max=*/1)});
544 EXPECT_EQ(node->id(), -1);
545 EXPECT_EQ(node->name(), "TestNode");
546 EXPECT_EQ(node->output(), nullptr);
547
548 EXPECT_EQ(node->buffered_bytes(), 0);
549 EXPECT_EQ(node->buffered_elements(), 0);
550 EXPECT_EQ(node->TotalBufferedBytes(), 0);
551 EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0);
552
553 node->record_buffer_event(20, 1);
554 EXPECT_EQ(node->buffered_bytes(), 20);
555 EXPECT_EQ(node->buffered_elements(), 1);
556 EXPECT_EQ(node->TotalBufferedBytes(), 20);
557 EXPECT_EQ(node->TotalMaximumBufferedBytes(), 60);
558
559 node->record_buffer_event(10, 1);
560 EXPECT_EQ(node->buffered_bytes(), 30);
561 EXPECT_EQ(node->buffered_elements(), 2);
562 EXPECT_EQ(node->TotalBufferedBytes(), 30);
563 EXPECT_EQ(node->TotalMaximumBufferedBytes(), 45);
564
565 node->record_buffer_event(18, 1);
566 EXPECT_EQ(node->buffered_bytes(), 48);
567 EXPECT_EQ(node->buffered_elements(), 3);
568 EXPECT_EQ(node->bytes_produced(), 0);
569 EXPECT_EQ(node->num_elements(), 0);
570 EXPECT_EQ(node->TotalBufferedBytes(), 48);
571 EXPECT_EQ(node->TotalMaximumBufferedBytes(), 48);
572
573 node->record_buffer_event(-20, -1);
574 node->record_element();
575 node->record_bytes_produced(20);
576 EXPECT_EQ(node->buffered_bytes(), 28);
577 EXPECT_EQ(node->buffered_elements(), 2);
578 EXPECT_EQ(node->bytes_produced(), 20);
579 EXPECT_EQ(node->num_elements(), 1);
580 EXPECT_EQ(node->TotalBufferedBytes(), 28);
581 EXPECT_EQ(node->TotalMaximumBufferedBytes(), 51);
582
583 node->record_buffer_event(-10, -1);
584 node->record_element();
585 node->record_bytes_produced(10);
586 EXPECT_EQ(node->buffered_bytes(), 18);
587 EXPECT_EQ(node->buffered_elements(), 1);
588 EXPECT_EQ(node->bytes_produced(), 30);
589 EXPECT_EQ(node->num_elements(), 2);
590 EXPECT_EQ(node->TotalBufferedBytes(), 18);
591 EXPECT_EQ(node->TotalMaximumBufferedBytes(), 49.5);
592
593 EXPECT_EQ(node->processing_time(), 0);
594 node->record_start(1);
595 EXPECT_EQ(node->processing_time(), 0);
596 node->record_stop(41);
597 EXPECT_EQ(node->processing_time(), 40);
598 node->add_processing_time(2);
599 EXPECT_EQ(node->processing_time(), 42);
600
601 std::shared_ptr<Node> input = model::MakeAsyncKnownRatioNode(
602 {0, "TestInput", node}, 2,
603 {model::MakeParameter("parallelism",
604 std::make_shared<SharedState>(5, nullptr, nullptr),
605 0, 6)});
606 EXPECT_EQ(input->output(), node.get());
607 EXPECT_EQ(node->inputs().size(), 0);
608 node->add_input(input);
609 EXPECT_EQ(node->inputs().size(), 1);
610 EXPECT_EQ(node->inputs().front(), input);
611
612 input->record_buffer_event(28, 1);
613 EXPECT_EQ(node->bytes_consumed(), 0);
614 EXPECT_EQ(node->TotalBufferedBytes(), 46);
615 EXPECT_EQ(node->TotalMaximumBufferedBytes(), 119.5);
616
617 input->record_buffer_event(-28, -1);
618 input->record_element();
619 input->record_bytes_produced(28);
620 node->record_bytes_consumed(28);
621 EXPECT_EQ(node->bytes_consumed(), 28);
622 EXPECT_EQ(node->TotalBufferedBytes(), 18);
623 EXPECT_EQ(node->TotalMaximumBufferedBytes(), 119.5);
624
625 node->remove_input(input);
626 EXPECT_EQ(node->inputs().size(), 0);
627 }
628
629 // Returns a weighted sum of a prior and the actual processing time.
weighted_processing_time(int64_t num_elements,double processing_time,double prior)630 double weighted_processing_time(int64_t num_elements, double processing_time,
631 double prior) {
632 if (num_elements < 30) {
633 double prior_weight = 1.0L / static_cast<double>(2 << num_elements);
634 return prior_weight * prior + (1.0L - prior_weight) * processing_time;
635 } else {
636 return processing_time;
637 }
638 }
639
TEST(TestManyElements,Model)640 TEST(TestManyElements, Model) {
641 std::shared_ptr<Node> interleave_many = model::MakeInterleaveManyNode(
642 {0, "interleave_many", nullptr},
643 {model::MakeParameter("cycle_length", nullptr, /*min=*/1, /*max=*/1)});
644 std::shared_ptr<Node> source1 =
645 model::MakeSourceNode({1, "source1", interleave_many});
646 interleave_many->add_input(source1);
647 interleave_many->add_processing_time(100);
648 interleave_many->record_element();
649 source1->add_processing_time(200);
650 for (int i = 0; i < 100; i++) {
651 source1->record_element();
652 }
653 EXPECT_LE(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
654 (weighted_processing_time(100, 2, 0)) + 100);
655 EXPECT_GE(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
656 0);
657 }
658
TEST(CollectAutotuneParametersWithElementsTest,Model)659 TEST(CollectAutotuneParametersWithElementsTest, Model) {
660 std::shared_ptr<Node> unknown =
661 model::MakeUnknownNode({0, "unknown", nullptr});
662 std::shared_ptr<Node> async_known_ratio = model::MakeAsyncKnownRatioNode(
663 {1, "source", unknown}, 2,
664 {model::MakeParameter("parallelism",
665 std::make_shared<SharedState>(
666 /*value=*/model::kAutotune, nullptr, nullptr),
667 /*min=*/1,
668 /*max=*/5)});
669 async_known_ratio->record_element();
670 unknown->add_input(async_known_ratio);
671
672 Model::ModelParameters parameters = unknown->CollectTunableParameters();
673
674 EXPECT_EQ(CountParametersOnNode(unknown->long_name(), parameters), 0);
675 EXPECT_EQ(CountParametersOnNode(async_known_ratio->long_name(), parameters),
676 1);
677 EXPECT_EQ(parameters.size(), 1);
678 }
679
TEST(DontCollectNonAutotuneParametersTest,Model)680 TEST(DontCollectNonAutotuneParametersTest, Model) {
681 std::shared_ptr<Node> unknown =
682 model::MakeUnknownNode({0, "unknown", nullptr});
683 std::shared_ptr<Node> async_known_ratio = model::MakeAsyncKnownRatioNode(
684 {1, "source", unknown}, 2,
685 {model::MakeParameter(
686 "parallelism",
687 std::make_shared<SharedState>(/*value=*/3, nullptr, nullptr),
688 /*min=*/1, /*max=*/5)});
689 async_known_ratio->record_element();
690 unknown->add_input(async_known_ratio);
691 Model::ModelParameters parameters = unknown->CollectTunableParameters();
692
693 EXPECT_EQ(parameters.size(), 0);
694 }
695
TEST(DontCollectAutotuneDisabledParametersTest,Model)696 TEST(DontCollectAutotuneDisabledParametersTest, Model) {
697 std::shared_ptr<Node> unknown =
698 model::MakeUnknownNode({0, "unknown", nullptr});
699 std::shared_ptr<Node> async_known_ratio = model::MakeAsyncKnownRatioNode(
700 {1, "source", unknown}, 2,
701 {model::MakeParameter("parallelism",
702 std::make_shared<SharedState>(
703 /*value=*/model::kAutotune, nullptr, nullptr),
704 /*min=*/1,
705 /*max=*/5)});
706 async_known_ratio->record_element();
707 async_known_ratio->set_autotune(false);
708 unknown->add_input(async_known_ratio);
709 Model::ModelParameters parameters = unknown->CollectTunableParameters();
710
711 EXPECT_EQ(parameters.size(), 0);
712 }
713
TEST(DontCollectParametersWithoutElementsTest,Model)714 TEST(DontCollectParametersWithoutElementsTest, Model) {
715 std::shared_ptr<Node> unknown =
716 model::MakeUnknownNode({0, "unknown", nullptr});
717 std::shared_ptr<Node> async_known_ratio = model::MakeAsyncKnownRatioNode(
718 {1, "source", unknown}, 2,
719 {model::MakeParameter("parallelism",
720 std::make_shared<SharedState>(
721 /*value=*/model::kAutotune, nullptr, nullptr),
722 /*min=*/1,
723 /*max=*/5)});
724 unknown->add_input(async_known_ratio);
725 Model::ModelParameters parameters = unknown->CollectTunableParameters();
726
727 EXPECT_EQ(parameters.size(), 0);
728 }
729
730 // Precision for comparison of the gradient and a relative output time change.
731 constexpr double kComparisonPrecision = 1e-1;
732
733 // Parameter step for a relative output time change.
734 constexpr double kParameterStep = 1e-5;
735
TEST(AsyncInterleaveManyGradientTest,Model)736 TEST(AsyncInterleaveManyGradientTest, Model) {
737 const double input_time = 100;
738 std::shared_ptr<Parameter> interleave_parameter =
739 model::MakeParameter("parallelism",
740 std::make_shared<SharedState>(
741 /*value=*/model::kAutotune, nullptr, nullptr),
742 /*min=*/1, /*max=*/5);
743 std::shared_ptr<Node> async_interleave_many =
744 model::MakeAsyncInterleaveManyNode(
745 {0, "async_interleave_many", nullptr},
746 {interleave_parameter, model::MakeParameter("cycle_length", nullptr,
747 /*min=*/1, /*max=*/1)});
748 std::shared_ptr<Node> meta_source =
749 model::MakeSourceNode({1, "meta_source", async_interleave_many});
750 async_interleave_many->add_input(meta_source);
751 auto cleanup_meta = gtl::MakeCleanup([async_interleave_many, meta_source]() {
752 async_interleave_many->remove_input(meta_source);
753 });
754 std::shared_ptr<Parameter> source1_parameter =
755 model::MakeParameter("parallelism",
756 std::make_shared<SharedState>(
757 /*value=*/model::kAutotune, nullptr, nullptr),
758 /*min=*/1,
759 /*max=*/7);
760 std::shared_ptr<Node> source1 = model::MakeAsyncInterleaveManyNode(
761 {2, "async_interleave_many", async_interleave_many},
762 {source1_parameter,
763 model::MakeParameter("cycle_length", nullptr, /*min=*/1, /*max=*/1)});
764 async_interleave_many->add_input(source1);
765 auto cleanup1 = gtl::MakeCleanup([async_interleave_many, source1]() {
766 async_interleave_many->remove_input(source1);
767 });
768 std::shared_ptr<Node> source2 =
769 model::MakeSourceNode({3, "source2", async_interleave_many});
770 async_interleave_many->add_input(source2);
771 auto cleanup2 = gtl::MakeCleanup([async_interleave_many, source2]() {
772 async_interleave_many->remove_input(source2);
773 });
774 Model::NodeValues input_times;
775 input_times[kModelInputTimeKey] = input_time;
776 async_interleave_many->record_element();
777 async_interleave_many->add_processing_time(100);
778 source1->record_element();
779 source1->add_processing_time(100);
780 source2->record_element();
781 source2->add_processing_time(300);
782
783 interleave_parameter->value = 1;
784 source1_parameter->value = 1;
785
786 // Test gradient of own parameters.
787 Model::ParameterGradients gradients;
788 double output_time =
789 async_interleave_many->OutputTime(&input_times, &gradients);
790 interleave_parameter->value += kParameterStep;
791 double new_output_time =
792 async_interleave_many->OutputTime(&input_times, nullptr);
793 EXPECT_NEAR(gradients[std::make_pair(async_interleave_many->long_name(),
794 interleave_parameter->name)],
795 (new_output_time - output_time) / kParameterStep,
796 kComparisonPrecision);
797
798 // Test propagation of input's gradient.
799 interleave_parameter->value -= kParameterStep;
800 source1_parameter->value += kParameterStep;
801 new_output_time = async_interleave_many->OutputTime(&input_times, nullptr);
802 EXPECT_NEAR(
803 gradients[std::make_pair(source1->long_name(), source1_parameter->name)],
804 (new_output_time - output_time) / kParameterStep, kComparisonPrecision);
805 }
806
807 class AsyncKnownRatioGradientTest : public ::testing::TestWithParam<string> {};
808
TEST_P(AsyncKnownRatioGradientTest,Model)809 TEST_P(AsyncKnownRatioGradientTest, Model) {
810 const string parameter_name = GetParam();
811 const double input_time = 100;
812 const int64_t num_inputs_per_output = 2;
813
814 std::shared_ptr<Parameter> known_parameter =
815 model::MakeParameter(parameter_name,
816 std::make_shared<SharedState>(
817 /*value=*/model::kAutotune, nullptr, nullptr),
818 /*min=*/1,
819 /*max=*/5);
820 std::shared_ptr<Node> async_known_many =
821 model::MakeAsyncKnownRatioNode({0, "async_known_many", nullptr},
822 num_inputs_per_output, {known_parameter});
823 std::shared_ptr<Parameter> source1_parameter =
824 model::MakeParameter(parameter_name,
825 std::make_shared<SharedState>(
826 /*value=*/model::kAutotune, nullptr, nullptr),
827 /*min=*/1,
828 /*max=*/7);
829 std::shared_ptr<Node> source1 = model::MakeAsyncKnownRatioNode(
830 {1, "source1", async_known_many}, num_inputs_per_output,
831 {source1_parameter});
832 async_known_many->add_input(source1);
833 std::shared_ptr<Node> source2 =
834 model::MakeSourceNode({2, "source2", async_known_many});
835 Model::NodeValues input_times;
836 input_times[kModelInputTimeKey] = input_time;
837 async_known_many->add_input(source2);
838 source1->record_element();
839 source1->add_processing_time(100);
840 source2->record_element();
841 source2->add_processing_time(100);
842 async_known_many->record_element();
843 async_known_many->add_processing_time(300);
844
845 // Test gradient of own parameters.
846 Model::ParameterGradients gradients;
847 known_parameter->value = 1;
848 source1_parameter->value = 1;
849 double output_time = async_known_many->OutputTime(&input_times, &gradients);
850 known_parameter->value += kParameterStep;
851 double new_output_time = async_known_many->OutputTime(&input_times, nullptr);
852 EXPECT_NEAR(gradients[std::make_pair(async_known_many->long_name(),
853 known_parameter->name)],
854 (new_output_time - output_time) / kParameterStep,
855 kComparisonPrecision);
856
857 // Test propagation of input's gradient.
858 known_parameter->value -= kParameterStep;
859 source1_parameter->value += kParameterStep;
860 new_output_time = async_known_many->OutputTime(&input_times, nullptr);
861 EXPECT_NEAR(
862 gradients[std::make_pair(source1->long_name(), source1_parameter->name)],
863 (new_output_time - output_time) / kParameterStep, kComparisonPrecision);
864 }
865
866 INSTANTIATE_TEST_SUITE_P(Test, AsyncKnownRatioGradientTest,
867 ::testing::Values("parallelism", "buffer_size"));
868
TEST(InterleaveManyGradientTest,Model)869 TEST(InterleaveManyGradientTest, Model) {
870 const double input_time = 100;
871 const int64_t num_inputs_per_output = 2;
872 std::shared_ptr<Node> interleave_many = model::MakeInterleaveManyNode(
873 {0, "interleave_many", nullptr},
874 {model::MakeParameter("cycle_length", nullptr, /*min=*/1, /*max=*/1)});
875 std::shared_ptr<Parameter> known_parameter =
876 model::MakeParameter("parallelism",
877 std::make_shared<SharedState>(
878 /*value=*/model::kAutotune, nullptr, nullptr),
879 /*min=*/1,
880 /*max=*/5);
881 std::shared_ptr<Node> async_known_many =
882 model::MakeAsyncKnownRatioNode({1, "async_known_many", interleave_many},
883 num_inputs_per_output, {known_parameter});
884 std::shared_ptr<Node> source1 =
885 model::MakeSourceNode({2, "source1", interleave_many});
886 interleave_many->record_element();
887 interleave_many->add_processing_time(100);
888 interleave_many->add_input(source1);
889 interleave_many->add_input(async_known_many);
890 async_known_many->record_element();
891 async_known_many->add_processing_time(300);
892 Model::NodeValues input_times;
893 input_times[kModelInputTimeKey] = input_time;
894 Model::ParameterGradients gradients;
895 known_parameter->value = 1;
896 double output_time = interleave_many->OutputTime(&input_times, &gradients);
897 known_parameter->value += kParameterStep;
898 double new_output_time = interleave_many->OutputTime(&input_times, nullptr);
899 EXPECT_NEAR(gradients[std::make_pair(async_known_many->long_name(),
900 known_parameter->name)],
901 (new_output_time - output_time) / kParameterStep,
902 kComparisonPrecision);
903 }
904
TEST(KnownRatioGradientTest,Model)905 TEST(KnownRatioGradientTest, Model) {
906 const double input_time = 100;
907 const int64_t num_inputs_per_output = 2;
908 std::shared_ptr<Node> known_many = model::MakeKnownRatioNode(
909 {0, "known_many", nullptr}, num_inputs_per_output);
910 std::shared_ptr<Parameter> known_parameter =
911 model::MakeParameter("parallelism",
912 std::make_shared<SharedState>(
913 /*value=*/model::kAutotune, nullptr, nullptr),
914 /*min=*/1,
915 /*max=*/5);
916 std::shared_ptr<Node> async_known_many =
917 model::MakeAsyncKnownRatioNode({1, "async_known_many", known_many},
918 num_inputs_per_output, {known_parameter});
919 known_many->record_element();
920 known_many->add_processing_time(100);
921 known_many->add_input(async_known_many);
922 async_known_many->record_element();
923 async_known_many->add_processing_time(300);
924 Model::NodeValues input_times;
925 input_times[kModelInputTimeKey] = input_time;
926 Model::ParameterGradients gradients;
927 known_parameter->value = 1;
928 double output_time = known_many->OutputTime(&input_times, &gradients);
929 known_parameter->value += kParameterStep;
930 double new_output_time = known_many->OutputTime(&input_times, nullptr);
931 EXPECT_NEAR(gradients[std::make_pair(async_known_many->long_name(),
932 known_parameter->name)],
933 (new_output_time - output_time) / kParameterStep,
934 kComparisonPrecision);
935 }
936
TEST(UnknownRatioGradientTest,Model)937 TEST(UnknownRatioGradientTest, Model) {
938 const double input_time = 100;
939 const int64_t num_inputs_per_output = 2;
940 std::shared_ptr<Node> unknown_many =
941 model::MakeUnknownRatioNode({0, "unknown_many", nullptr});
942 std::shared_ptr<Parameter> known_parameter =
943 model::MakeParameter("parallelism",
944 std::make_shared<SharedState>(
945 /*value=*/model::kAutotune, nullptr, nullptr),
946 /*min=*/1,
947 /*max=*/5);
948 std::shared_ptr<Node> async_known_many =
949 model::MakeAsyncKnownRatioNode({1, "async_known_many", unknown_many},
950 num_inputs_per_output, {known_parameter});
951 unknown_many->record_element();
952 unknown_many->add_processing_time(100);
953 unknown_many->add_input(async_known_many);
954 async_known_many->record_element();
955 async_known_many->add_processing_time(300);
956 Model::NodeValues input_times;
957 input_times[kModelInputTimeKey] = input_time;
958 Model::ParameterGradients gradients;
959 known_parameter->value = 1;
960 double output_time = unknown_many->OutputTime(&input_times, &gradients);
961 known_parameter->value += kParameterStep;
962 double new_output_time = unknown_many->OutputTime(&input_times, nullptr);
963 EXPECT_NEAR(gradients[std::make_pair(async_known_many->long_name(),
964 known_parameter->name)],
965 (new_output_time - output_time) / kParameterStep,
966 kComparisonPrecision);
967 }
968
TEST(UnknownGradientTest,Model)969 TEST(UnknownGradientTest, Model) {
970 const double input_time = 100;
971 const int64_t num_inputs_per_output = 2;
972 std::shared_ptr<Node> unknown =
973 model::MakeUnknownNode({0, "unknown", nullptr});
974 std::shared_ptr<Parameter> known_parameter =
975 model::MakeParameter("parallelism",
976 std::make_shared<SharedState>(
977 /*value=*/model::kAutotune, nullptr, nullptr),
978 /*min=*/1,
979 /*max=*/5);
980 std::shared_ptr<Node> async_known_many =
981 model::MakeAsyncKnownRatioNode({1, "async_known_many", unknown},
982 num_inputs_per_output, {known_parameter});
983 unknown->record_element();
984 unknown->add_processing_time(100);
985 unknown->add_input(async_known_many);
986 async_known_many->record_element();
987 async_known_many->add_processing_time(300);
988 Model::NodeValues input_times;
989 input_times[kModelInputTimeKey] = input_time;
990 Model::ParameterGradients gradients;
991 known_parameter->value = 1;
992 double output_time = unknown->OutputTime(&input_times, &gradients);
993 known_parameter->value += kParameterStep;
994 double new_output_time = unknown->OutputTime(&input_times, nullptr);
995 EXPECT_NEAR(gradients[std::make_pair(async_known_many->long_name(),
996 known_parameter->name)],
997 (new_output_time - output_time) / kParameterStep,
998 kComparisonPrecision);
999 }
1000
TEST(SnapshotTest,Model)1001 TEST(SnapshotTest, Model) {
1002 std::shared_ptr<Node> root =
1003 model::MakeUnknownNode({0, std::to_string(0), nullptr});
1004 std::shared_ptr<Node> current = root;
1005
1006 int64_t num_nodes = 20;
1007 for (int64_t i = 1; i < num_nodes; i++) {
1008 std::shared_ptr<Node> input =
1009 model::MakeUnknownNode({i, std::to_string(i), current});
1010 input->set_autotune(std::rand() % 2 == 1);
1011 current->add_input(input);
1012 current = input;
1013 }
1014
1015 std::shared_ptr<Node> cloned_root = root->Snapshot();
1016 current = root;
1017 std::shared_ptr<Node> cloned_current = cloned_root;
1018
1019 for (int64_t i = 0; i < num_nodes; i++) {
1020 EXPECT_EQ(current->id(), cloned_current->id());
1021 EXPECT_EQ(current->name(), cloned_current->name());
1022 EXPECT_EQ(current->autotune(), cloned_current->autotune());
1023 EXPECT_NE(current.get(), cloned_current.get());
1024
1025 if (i > 0) {
1026 EXPECT_EQ(current->output()->long_name(),
1027 cloned_current->output()->long_name());
1028 EXPECT_EQ(current->output()->autotune(),
1029 cloned_current->output()->autotune());
1030 EXPECT_NE(current->output(), cloned_current->output());
1031 } else {
1032 EXPECT_EQ(current->output(), nullptr);
1033 EXPECT_EQ(cloned_current->output(), nullptr);
1034 }
1035
1036 if (i < num_nodes - 1) {
1037 current = current->inputs().front();
1038 cloned_current = cloned_current->inputs().front();
1039 }
1040 }
1041 }
1042
TEST(SaveModelTest,Model)1043 TEST(SaveModelTest, Model) {
1044 model::Model model;
1045 std::shared_ptr<Node> root = model::MakeUnknownNode({0, "unknown0", nullptr});
1046 model.AddNode([&root](model::Node::Args args) { return root; }, root->name(),
1047 nullptr, &root);
1048 std::shared_ptr<Node> current = root;
1049
1050 int64_t num_nodes = 20;
1051 for (int64_t i = 1; i < num_nodes; i++) {
1052 std::shared_ptr<Node> input;
1053 switch (i % 6) {
1054 case 0:
1055 input = model::MakeInterleaveManyNode(
1056 {i, "interleave_many" + std::to_string(i), current},
1057 {model::MakeParameter("cycle_length", nullptr, /*min=*/1,
1058 /*max=*/1)});
1059 break;
1060 case 1:
1061 input = model::MakeAsyncInterleaveManyNode(
1062 {i, "async_interleave_many", current},
1063 {model::MakeParameter(
1064 "parallelism",
1065 std::make_shared<SharedState>(
1066 /*value=*/model::kAutotune, nullptr, nullptr),
1067 /*min=*/1,
1068 /*max=*/7),
1069 model::MakeParameter("cycle_length", nullptr, /*min=*/1,
1070 /*max=*/1)});
1071 break;
1072 case 2:
1073 input = model::MakeKnownRatioNode(
1074 {i, "known_many" + std::to_string(i), current}, 3);
1075 break;
1076 case 3:
1077 input = model::MakeAsyncKnownRatioNode(
1078 {i, "async_known_many", current}, 4,
1079 {model::MakeParameter(
1080 "parallelism",
1081 std::make_shared<SharedState>(
1082 /*value=*/model::kAutotune, nullptr, nullptr),
1083 /*min=*/1,
1084 /*max=*/5),
1085 model::MakeParameter("cycle_length", nullptr, /*min=*/1,
1086 /*max=*/1)});
1087 break;
1088 case 4:
1089 input = model::MakeUnknownRatioNode(
1090 {i, "unknown_many" + std::to_string(i), current});
1091 break;
1092 default:
1093 input =
1094 model::MakeUnknownNode({i, "unknown" + std::to_string(i), current});
1095 }
1096 input->record_element();
1097 input->add_processing_time(i * 50);
1098 input->record_buffer_event(i * 33, i * 5);
1099 input->set_autotune(true);
1100 model.AddNode([&input](model::Node::Args args) { return input; },
1101 input->name(), current, &input);
1102 current = input;
1103 }
1104
1105 // Make Save->Load roundtrip.
1106 ModelProto::OptimizationParams optimization_params;
1107 optimization_params.set_algorithm(AutotuneAlgorithm::GRADIENT_DESCENT);
1108 optimization_params.set_cpu_budget(64);
1109 optimization_params.set_ram_budget(1024);
1110 optimization_params.set_model_input_time(43653.34534);
1111 TF_ASSERT_OK(model.Save("/tmp/autotune_model_test",
1112 model.output()->Snapshot(), optimization_params));
1113
1114 std::unique_ptr<model::Model> restored_model;
1115 ModelProto::OptimizationParams restored_optimization_params;
1116 TF_ASSERT_OK(model.Load("/tmp/autotune_model_test", &restored_model,
1117 &restored_optimization_params));
1118
1119 // Check optimization parameters.
1120 EXPECT_EQ(optimization_params.algorithm(),
1121 restored_optimization_params.algorithm());
1122 EXPECT_EQ(optimization_params.cpu_budget(),
1123 restored_optimization_params.cpu_budget());
1124 EXPECT_EQ(optimization_params.ram_budget(),
1125 restored_optimization_params.ram_budget());
1126 EXPECT_EQ(optimization_params.model_input_time(),
1127 restored_optimization_params.model_input_time());
1128
1129 std::shared_ptr<Node> restored_root = restored_model->output();
1130 std::shared_ptr<Node> restored_current = restored_root;
1131 current = root;
1132 EXPECT_EQ(current->output(), nullptr);
1133 EXPECT_EQ(restored_current->output(), nullptr);
1134 while (!current->inputs().empty() && !restored_current->inputs().empty()) {
1135 EXPECT_EQ(current->id(), restored_current->id());
1136 EXPECT_EQ(current->name(), restored_current->name());
1137 EXPECT_EQ(current->autotune(), restored_current->autotune());
1138 Model::NodeValues input_times_actual, input_times_expected;
1139 input_times_actual.clear();
1140 input_times_expected.clear();
1141 EXPECT_EQ(current->OutputTime(&input_times_actual, nullptr),
1142 restored_current->OutputTime(&input_times_expected, nullptr));
1143 EXPECT_EQ(current->TotalBufferedBytes(),
1144 restored_current->TotalBufferedBytes());
1145 EXPECT_EQ(current->TotalMaximumBufferedBytes(),
1146 restored_current->TotalMaximumBufferedBytes());
1147 EXPECT_EQ(current->Ratio(), restored_current->Ratio());
1148 EXPECT_NE(current.get(), restored_current.get());
1149
1150 current = current->inputs().front();
1151 restored_current = restored_current->inputs().front();
1152
1153 EXPECT_EQ(current->output()->long_name(), current->output()->long_name());
1154 EXPECT_EQ(current->output()->autotune(),
1155 restored_current->output()->autotune());
1156 EXPECT_NE(current->output(), restored_current->output());
1157 }
1158 EXPECT_TRUE(current->inputs().empty());
1159 EXPECT_TRUE(restored_current->inputs().empty());
1160 }
1161
1162 class ComputeWaitTimeTest
1163 : public ::testing::TestWithParam<std::tuple<double, double, double>> {};
1164
TEST_P(ComputeWaitTimeTest,Model)1165 TEST_P(ComputeWaitTimeTest, Model) {
1166 const double producer_time = std::get<0>(GetParam());
1167 const double consumer_time = std::get<1>(GetParam());
1168 const double buffer_size = std::get<2>(GetParam());
1169
1170 double producer_time_derivative = 0.0L;
1171 double consumer_time_derivative = 0.0L;
1172 double buffer_size_derivative = 0.0L;
1173
1174 double wait_time = model::Node::ComputeWaitTime(
1175 producer_time, consumer_time, buffer_size, &producer_time_derivative,
1176 &consumer_time_derivative, &buffer_size_derivative);
1177
1178 double new_wait_time = model::Node::ComputeWaitTime(
1179 producer_time + kParameterStep, consumer_time, buffer_size, nullptr,
1180 nullptr, nullptr);
1181 EXPECT_NEAR(producer_time_derivative,
1182 (new_wait_time - wait_time) / kParameterStep,
1183 kComparisonPrecision);
1184
1185 if (producer_time >= kParameterStep) {
1186 new_wait_time = model::Node::ComputeWaitTime(producer_time - kParameterStep,
1187 consumer_time, buffer_size,
1188 nullptr, nullptr, nullptr);
1189 EXPECT_NEAR(producer_time_derivative,
1190 (wait_time - new_wait_time) / kParameterStep,
1191 kComparisonPrecision);
1192 }
1193
1194 new_wait_time = model::Node::ComputeWaitTime(
1195 producer_time, consumer_time + kParameterStep, buffer_size, nullptr,
1196 nullptr, nullptr);
1197 EXPECT_NEAR(consumer_time_derivative,
1198 (new_wait_time - wait_time) / kParameterStep,
1199 kComparisonPrecision);
1200
1201 if (consumer_time >= kParameterStep) {
1202 new_wait_time = model::Node::ComputeWaitTime(
1203 producer_time, consumer_time - kParameterStep, buffer_size, nullptr,
1204 nullptr, nullptr);
1205 EXPECT_NEAR(consumer_time_derivative,
1206 (wait_time - new_wait_time) / kParameterStep,
1207 kComparisonPrecision);
1208 }
1209
1210 new_wait_time = model::Node::ComputeWaitTime(producer_time, consumer_time,
1211 buffer_size + kParameterStep,
1212 nullptr, nullptr, nullptr);
1213 EXPECT_NEAR(buffer_size_derivative,
1214 (new_wait_time - wait_time) / kParameterStep,
1215 kComparisonPrecision);
1216
1217 if (buffer_size >= kParameterStep) {
1218 new_wait_time = model::Node::ComputeWaitTime(producer_time, consumer_time,
1219 buffer_size - kParameterStep,
1220 nullptr, nullptr, nullptr);
1221 EXPECT_NEAR(buffer_size_derivative,
1222 (wait_time - new_wait_time) / kParameterStep,
1223 kComparisonPrecision);
1224 }
1225 }
1226
1227 INSTANTIATE_TEST_SUITE_P(
1228 Test, ComputeWaitTimeTest,
1229 ::testing::Combine(::testing::Values(0, 20, 40, 80, 100),
1230 ::testing::Values(0, 20, 40, 80, 100),
1231 ::testing::Values(0, 1, 2, 4, 10, 20, 40)));
1232
1233 class SelfProcessingTimeTest : public ::testing::TestWithParam<int64_t> {};
1234
TEST_P(SelfProcessingTimeTest,Model)1235 TEST_P(SelfProcessingTimeTest, Model) {
1236 const int64_t add_times = GetParam();
1237 std::shared_ptr<Node> source = model::MakeSourceNode({0, "source", nullptr});
1238 for (int i = 0; i < add_times; i++) {
1239 source->add_processing_time(i);
1240 source->record_element();
1241 }
1242 double self_processing_time =
1243 (add_times == 0 ? 0.0 : (static_cast<double>(add_times) - 1.0) / 2.0);
1244 EXPECT_EQ(source->SelfProcessingTime(), self_processing_time);
1245 }
1246
1247 INSTANTIATE_TEST_SUITE_P(Test, SelfProcessingTimeTest,
1248 ::testing::Values(0, 1, 2, 5, 10, 20, 40));
1249
1250 class OptimizeZeroRamBudgetTest
1251 : public ::testing::TestWithParam<model::AutotuneAlgorithm> {};
1252
TEST_P(OptimizeZeroRamBudgetTest,Model)1253 TEST_P(OptimizeZeroRamBudgetTest, Model) {
1254 const model::AutotuneAlgorithm algorithm = GetParam();
1255
1256 std::shared_ptr<mutex> mutex1 = std::make_shared<mutex>();
1257 std::shared_ptr<condition_variable> cv1 =
1258 std::make_shared<condition_variable>();
1259 std::shared_ptr<Node> node1 = model::MakeAsyncKnownRatioNode(
1260 {1, "1", nullptr}, 2,
1261 {model::MakeParameter("parallelism",
1262 std::make_shared<SharedState>(
1263 /*value=*/model::kAutotune, mutex1, cv1),
1264 /*min=*/1, /*max=*/5)});
1265 node1->record_buffer_event(1, 1);
1266 node1->record_element();
1267
1268 std::shared_ptr<mutex> mutex2 = std::make_shared<mutex>();
1269 std::shared_ptr<condition_variable> cv2 =
1270 std::make_shared<condition_variable>();
1271 std::shared_ptr<Node> node2 = model::MakeAsyncKnownRatioNode(
1272 {2, "2", node1}, 5,
1273 {model::MakeParameter("buffer_size",
1274 std::make_shared<SharedState>(
1275 /*value=*/model::kAutotune, mutex2, cv2),
1276 /*min=*/0, /*max=*/6)});
1277 node2->record_buffer_event(1, 1);
1278 node2->record_element();
1279
1280 std::shared_ptr<mutex> mutex3 = std::make_shared<mutex>();
1281 std::shared_ptr<condition_variable> cv3 =
1282 std::make_shared<condition_variable>();
1283 std::shared_ptr<Node> node3 = model::MakeAsyncInterleaveManyNode(
1284 {3, "3", node2},
1285 {model::MakeParameter("parallelism",
1286 std::make_shared<SharedState>(
1287 /*value=*/model::kAutotune, mutex3, cv3),
1288 /*min=*/1, /*max=*/7),
1289 model::MakeParameter(kCycleLength, nullptr,
1290 /*min=*/1,
1291 /*max=*/1)});
1292 node3->record_buffer_event(1, 1);
1293 node3->record_element();
1294
1295 EXPECT_EQ(node1->parameter_value("parallelism"), model::kAutotune);
1296 EXPECT_EQ(node2->parameter_value("buffer_size"), model::kAutotune);
1297 EXPECT_EQ(node3->parameter_value("parallelism"), model::kAutotune);
1298
1299 model::Model model;
1300 model.AddNode([&node1](model::Node::Args args) { return node1; }, "1",
1301 nullptr, &node1);
1302 model.AddNode([&node2](model::Node::Args args) { return node2; }, "2", node1,
1303 &node2);
1304 model.AddNode([&node3](model::Node::Args args) { return node3; }, "3", node2,
1305 &node3);
1306
1307 CancellationManager cancellation_manager;
1308 model.Optimize(algorithm, 40, 0, 0, &cancellation_manager);
1309 EXPECT_EQ(node1->parameter_value("parallelism"), 1);
1310 EXPECT_EQ(node2->parameter_value("buffer_size"), 0);
1311 EXPECT_EQ(node3->parameter_value("parallelism"), 1);
1312 }
1313
1314 INSTANTIATE_TEST_SUITE_P(Test, OptimizeZeroRamBudgetTest,
1315 ::testing::Values(0, 1, 2, 3));
1316
TEST(RecordTimeTest,RecordTimeTest)1317 TEST(RecordTimeTest, RecordTimeTest) {
1318 std::shared_ptr<Node> source = model::MakeSourceNode({});
1319 EXPECT_FALSE(source->is_recording());
1320 source->record_start(100);
1321 EXPECT_TRUE(source->is_recording());
1322 source->record_stop(200);
1323 EXPECT_FALSE(source->is_recording());
1324 }
1325
TEST(ModelTest,ModelMetrics)1326 TEST(ModelTest, ModelMetrics) {
1327 CellReader<std::string> cell_reader("/tensorflow/data/model");
1328 model::Model model;
1329 std::shared_ptr<Node> root = model::MakeUnknownNode({0, "unknown0", nullptr});
1330 model.AddNode([&root](model::Node::Args args) { return root; }, root->name(),
1331 nullptr, &root);
1332 std::string model_id = strings::StrCat(reinterpret_cast<uint64>(&model));
1333 EXPECT_THAT(cell_reader.Read(model_id),
1334 AllOf(HasSubstr("key: 0"), HasSubstr("name: \"unknown0\""),
1335 HasSubstr("autotune: true")));
1336 }
1337
1338 class ModelTimingTest : public ::testing::Test {
1339 public:
1340 // Builds a Model from its text proto.
BuildModelFromProto(const std::string & model_pbtxt)1341 void BuildModelFromProto(const std::string& model_pbtxt) {
1342 ModelProto model_proto;
1343 protobuf::TextFormat::ParseFromString(model_pbtxt, &model_proto);
1344 TF_CHECK_OK(Model::FromProto(model_proto, &model_));
1345 auto nodes = model_->output()->CollectNodes(
1346 TraversalOrder::BFS, [](const std::shared_ptr<Node>) { return true; });
1347 node_map_.clear();
1348 node_map_[model_->output()->id()] = model_->output().get();
1349 for (const auto& node : nodes) {
1350 node_map_[node->id()] = node.get();
1351 }
1352 }
1353
1354 // Computes the timing given a Model text proto.
ComputeModelTiming(const std::string & model_pbtxt)1355 void ComputeModelTiming(const std::string& model_pbtxt) {
1356 BuildModelFromProto(model_pbtxt);
1357 model_timing_ = std::make_unique<ModelTiming>(model_->output());
1358 }
1359
1360 // Gets the timing information of a node given its id.
GetNodeTiming(int64_t node_id) const1361 const ModelTiming::NodeTiming* GetNodeTiming(int64_t node_id) const {
1362 return model_timing_->GetTiming(node_map_.at(node_id));
1363 }
1364
1365 // Gets the node given its id.
GetNode(int64_t node_id) const1366 const Node* GetNode(int64_t node_id) const { return node_map_.at(node_id); }
MutableGetNode(int64_t node_id) const1367 Node* MutableGetNode(int64_t node_id) const { return node_map_.at(node_id); }
1368
1369 protected:
1370 std::unique_ptr<Model> model_;
1371 std::unique_ptr<ModelTiming> model_timing_;
1372 absl::flat_hash_map<int64_t, Node*> node_map_;
1373 };
1374
TEST_F(ModelTimingTest,Interleave)1375 TEST_F(ModelTimingTest, Interleave) {
1376 ComputeModelTiming(R"pb(
1377 nodes: {
1378 key: 1
1379 value: {
1380 id: 1
1381 name: "Batch"
1382 autotune: true
1383 num_elements: 100
1384 processing_time: 1000
1385 node_class: KNOWN_RATIO
1386 ratio: 1
1387 inputs: 2
1388 }
1389 }
1390 nodes: {
1391 key: 2
1392 value: {
1393 id: 2
1394 name: "Interleave"
1395 autotune: true
1396 num_elements: 100
1397 processing_time: 1000
1398 node_class: INTERLEAVE_MANY
1399 inputs: 3
1400 inputs: 4
1401 parameters: { name: "cycle_length" value: 2 tunable: false }
1402 }
1403 }
1404 nodes: {
1405 key: 3
1406 value: {
1407 id: 3
1408 name: "Batch"
1409 autotune: true
1410 num_elements: 60
1411 processing_time: 1200
1412 node_class: KNOWN_RATIO
1413 ratio: 1
1414 }
1415 }
1416 nodes: {
1417 key: 4
1418 value: {
1419 id: 4
1420 name: "Batch"
1421 autotune: true
1422 num_elements: 40
1423 processing_time: 800
1424 node_class: KNOWN_RATIO
1425 ratio: 1
1426 }
1427 }
1428 output: 1
1429 )pb");
1430
1431 EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/1)->pipeline_ratio);
1432 EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/2)->pipeline_ratio);
1433 EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/3)->pipeline_ratio);
1434 EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/4)->pipeline_ratio);
1435
1436 EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->self_time_nsec);
1437 EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/2)->self_time_nsec);
1438 EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/3)->self_time_nsec);
1439 EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->self_time_nsec);
1440
1441 EXPECT_DOUBLE_EQ(40, GetNodeTiming(/*node_id=*/1)->total_time_nsec);
1442 EXPECT_DOUBLE_EQ(30, GetNodeTiming(/*node_id=*/2)->total_time_nsec);
1443 EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/3)->total_time_nsec);
1444 EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->total_time_nsec);
1445 }
1446
1447 class ParallelInterleaveTimingTest
1448 : public ModelTimingTest,
1449 public ::testing::WithParamInterface<
1450 std::tuple<int32_t, int32_t, int32_t>> {};
1451
TEST_P(ParallelInterleaveTimingTest,ScenarioTest)1452 TEST_P(ParallelInterleaveTimingTest, ScenarioTest) {
1453 const int32_t parallelism = std::get<0>(GetParam());
1454 const int32_t deterministic = std::get<1>(GetParam());
1455 const int32_t cycle_length = std::get<2>(GetParam());
1456 ComputeModelTiming(strings::Printf(
1457 R"pb(
1458 nodes: {
1459 key: 1
1460 value: {
1461 id: 1
1462 name: "Batch"
1463 autotune: true
1464 num_elements: 100
1465 processing_time: 1000
1466 node_class: KNOWN_RATIO
1467 ratio: 1
1468 inputs: 2
1469 }
1470 }
1471 nodes: {
1472 key: 2
1473 value: {
1474 id: 2
1475 name: "ParallelInterleaveV4"
1476 autotune: true
1477 num_elements: 100
1478 processing_time: 2000
1479 node_class: ASYNC_INTERLEAVE_MANY
1480 inputs: 3
1481 inputs: 4
1482 inputs: 5
1483 inputs: 6
1484 inputs: 7
1485 parameters: {
1486 name: "parallelism"
1487 value: %d
1488 min: 1
1489 max: 10
1490 tunable: true
1491 }
1492 parameters: { name: "deterministic" value: %d tunable: false }
1493 parameters: { name: "cycle_length" value: %d tunable: false }
1494 }
1495 }
1496 nodes: {
1497 key: 3
1498 value: {
1499 id: 3
1500 name: "Batch"
1501 autotune: true
1502 num_elements: 60
1503 processing_time: 60
1504 node_class: KNOWN_RATIO
1505 ratio: 1
1506 }
1507 }
1508 nodes: {
1509 key: 4
1510 value: {
1511 id: 4
1512 name: "Batch"
1513 autotune: true
1514 num_elements: 60
1515 processing_time: 1200
1516 node_class: KNOWN_RATIO
1517 ratio: 1
1518 }
1519 }
1520 nodes: {
1521 key: 5
1522 value: {
1523 id: 5
1524 name: "Batch"
1525 autotune: true
1526 num_elements: 40
1527 processing_time: 1200
1528 node_class: KNOWN_RATIO
1529 ratio: 1
1530 }
1531 }
1532 nodes: {
1533 key: 6
1534 value: {
1535 id: 6
1536 name: "Batch"
1537 autotune: true
1538 num_elements: 60
1539 processing_time: 2400
1540 node_class: KNOWN_RATIO
1541 ratio: 1
1542 }
1543 }
1544 nodes: {
1545 key: 7
1546 value: {
1547 id: 7
1548 name: "Batch"
1549 autotune: false # Marked as an inactive input
1550 num_elements: 40
1551 processing_time: 2000
1552 node_class: KNOWN_RATIO
1553 ratio: 1
1554 }
1555 }
1556 output: 1
1557 )pb",
1558 parallelism, deterministic, cycle_length));
1559
1560 EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/1)->pipeline_ratio);
1561 EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/2)->pipeline_ratio);
1562 EXPECT_DOUBLE_EQ(1.0 / cycle_length,
1563 GetNodeTiming(/*node_id=*/3)->pipeline_ratio);
1564 EXPECT_DOUBLE_EQ(1.0 / cycle_length,
1565 GetNodeTiming(/*node_id=*/4)->pipeline_ratio);
1566 EXPECT_DOUBLE_EQ(1.0 / cycle_length,
1567 GetNodeTiming(/*node_id=*/5)->pipeline_ratio);
1568 EXPECT_DOUBLE_EQ(1.0 / cycle_length,
1569 GetNodeTiming(/*node_id=*/6)->pipeline_ratio);
1570 EXPECT_DOUBLE_EQ(0, GetNodeTiming(/*node_id=*/7)->pipeline_ratio);
1571
1572 const double expected_self_time_1 = 1000.0 / 100.0;
1573 const double expected_self_time_2 = 2000.0 / 100.0 / parallelism;
1574 const double expected_self_time_3 = 60.0 / 60.0;
1575 const double expected_self_time_4 = 1200.0 / 60.0;
1576 const double expected_self_time_5 = 1200.0 / 40.0;
1577 const double expected_self_time_6 = 2400.0 / 60.0;
1578 const double expected_self_time_7 = 2000.0 / 40.0;
1579
1580 EXPECT_DOUBLE_EQ(expected_self_time_1,
1581 GetNodeTiming(/*node_id=*/1)->self_time_nsec);
1582 EXPECT_DOUBLE_EQ(expected_self_time_2,
1583 GetNodeTiming(/*node_id=*/2)->self_time_nsec);
1584 EXPECT_DOUBLE_EQ(expected_self_time_3,
1585 GetNodeTiming(/*node_id=*/3)->self_time_nsec);
1586 EXPECT_DOUBLE_EQ(expected_self_time_4,
1587 GetNodeTiming(/*node_id=*/4)->self_time_nsec);
1588 EXPECT_DOUBLE_EQ(expected_self_time_5,
1589 GetNodeTiming(/*node_id=*/5)->self_time_nsec);
1590 EXPECT_DOUBLE_EQ(expected_self_time_6,
1591 GetNodeTiming(/*node_id=*/6)->self_time_nsec);
1592 EXPECT_DOUBLE_EQ(expected_self_time_7,
1593 GetNodeTiming(/*node_id=*/7)->self_time_nsec);
1594
1595 EXPECT_DOUBLE_EQ(expected_self_time_1,
1596 GetNodeTiming(/*node_id=*/1)->total_time_nsec);
1597 EXPECT_DOUBLE_EQ(expected_self_time_3,
1598 GetNodeTiming(/*node_id=*/3)->total_time_nsec);
1599 EXPECT_DOUBLE_EQ(expected_self_time_4,
1600 GetNodeTiming(/*node_id=*/4)->total_time_nsec);
1601 EXPECT_DOUBLE_EQ(expected_self_time_5,
1602 GetNodeTiming(/*node_id=*/5)->total_time_nsec);
1603 EXPECT_DOUBLE_EQ(expected_self_time_6,
1604 GetNodeTiming(/*node_id=*/6)->total_time_nsec);
1605 EXPECT_DOUBLE_EQ(0, GetNodeTiming(/*node_id=*/7)->total_time_nsec);
1606
1607 const double max_input_time = expected_self_time_6;
1608 double input_throughput = 1.0 / expected_self_time_4 +
1609 1.0 / expected_self_time_5 +
1610 1.0 / expected_self_time_6;
1611 const double active_inputs = 3.0;
1612 double expected_input_time;
1613 if (deterministic == 1) {
1614 expected_input_time = max_input_time / std::min(parallelism, cycle_length);
1615 } else {
1616 if (std::min(parallelism, cycle_length) < active_inputs) {
1617 input_throughput *= std::min(parallelism, cycle_length) / active_inputs;
1618 }
1619 expected_input_time = 1.0 / input_throughput;
1620 }
1621 EXPECT_DOUBLE_EQ(expected_input_time + expected_self_time_2,
1622 GetNodeTiming(/*node_id=*/2)->total_time_nsec);
1623 }
1624
1625 INSTANTIATE_TEST_SUITE_P(ParallelInterleaveTimingTest,
1626 ParallelInterleaveTimingTest,
1627 ::testing::Combine(::testing::Values(1, 2, 3),
1628 ::testing::Values(0, 1),
1629 ::testing::Values(1, 2, 3)));
1630
TEST_F(ModelTimingTest,ParallelInterleave_Batch_ParallelMap)1631 TEST_F(ModelTimingTest, ParallelInterleave_Batch_ParallelMap) {
1632 ComputeModelTiming(R"pb(
1633 nodes: {
1634 key: 1
1635 value: {
1636 id: 1
1637 name: "Batch"
1638 autotune: true
1639 num_elements: 100
1640 processing_time: 1000
1641 node_class: KNOWN_RATIO
1642 ratio: 1
1643 inputs: 2
1644 }
1645 }
1646 nodes: {
1647 key: 2
1648 value: {
1649 id: 2
1650 name: "ParallelInterleaveV4"
1651 autotune: true
1652 num_elements: 100
1653 processing_time: 2000
1654 node_class: ASYNC_INTERLEAVE_MANY
1655 inputs: 3
1656 inputs: 4
1657 inputs: 5
1658 parameters: {
1659 name: "parallelism"
1660 value: 2
1661 min: 1
1662 max: 10
1663 tunable: true
1664 }
1665 parameters: { name: "cycle_length" value: 2 tunable: false }
1666 parameters: { name: "deterministic" value: 0 tunable: false }
1667 }
1668 }
1669 nodes: {
1670 key: 3
1671 value: {
1672 id: 3
1673 name: "Batch"
1674 autotune: true
1675 num_elements: 60
1676 processing_time: 60
1677 node_class: KNOWN_RATIO
1678 ratio: 1
1679 }
1680 }
1681 nodes: {
1682 key: 4
1683 value: {
1684 id: 4
1685 name: "Batch"
1686 autotune: true
1687 num_elements: 60
1688 processing_time: 1200
1689 node_class: KNOWN_RATIO
1690 ratio: 2
1691 inputs: 6
1692 }
1693 }
1694 nodes: {
1695 key: 5
1696 value: {
1697 id: 5
1698 name: "Batch"
1699 autotune: true
1700 num_elements: 40
1701 processing_time: 800
1702 node_class: KNOWN_RATIO
1703 ratio: 2
1704 inputs: 7
1705 }
1706 }
1707 nodes: {
1708 key: 6
1709 value: {
1710 id: 6
1711 name: "ParallelMapV2"
1712 autotune: true
1713 num_elements: 120
1714 processing_time: 2400
1715 node_class: ASYNC_KNOWN_RATIO
1716 ratio: 1
1717 parameters: {
1718 name: "parallelism"
1719 value: 2
1720 min: 1
1721 max: 16
1722 tunable: true
1723 }
1724 }
1725 }
1726 nodes: {
1727 key: 7
1728 value: {
1729 id: 7
1730 name: "ParallelMapV2"
1731 autotune: true
1732 num_elements: 120
1733 processing_time: 2400
1734 node_class: ASYNC_KNOWN_RATIO
1735 ratio: 1
1736 parameters: {
1737 name: "parallelism"
1738 value: 2
1739 min: 1
1740 max: 16
1741 tunable: true
1742 }
1743 }
1744 }
1745 output: 1
1746 )pb");
1747
1748 EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/1)->pipeline_ratio);
1749 EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/2)->pipeline_ratio);
1750 EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/3)->pipeline_ratio);
1751 EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/4)->pipeline_ratio);
1752 EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/5)->pipeline_ratio);
1753 EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/6)->pipeline_ratio);
1754 EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/7)->pipeline_ratio);
1755
1756 EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->self_time_nsec);
1757 EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/2)->self_time_nsec);
1758 EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->self_time_nsec);
1759 EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->self_time_nsec);
1760 EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/5)->self_time_nsec);
1761 EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/6)->self_time_nsec);
1762 EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/7)->self_time_nsec);
1763
1764 EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->total_time_nsec);
1765 EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/2)->total_time_nsec);
1766 EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->total_time_nsec);
1767 EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->total_time_nsec);
1768 EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/5)->total_time_nsec);
1769 EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/6)->total_time_nsec);
1770 EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/7)->total_time_nsec);
1771 }
1772
1773 class BufferSizeTest : public ::testing::Test {
1774 public:
ReadModel(const std::string & model_pbtxt)1775 void ReadModel(const std::string& model_pbtxt) {
1776 ModelProto model_proto;
1777 protobuf::TextFormat::ParseFromString(model_pbtxt, &model_proto);
1778 TF_CHECK_OK(Model::FromProto(model_proto, &model_));
1779 auto nodes = model_->output()->CollectNodes(
1780 TraversalOrder::BFS, [](const std::shared_ptr<Node>) { return true; });
1781 node_map_.clear();
1782 node_map_[model_->output()->id()] = model_->output();
1783 for (const auto& node : nodes) {
1784 node_map_[node->id()] = node;
1785 }
1786 }
1787
1788 // Returns a node given its node id. If node id does not exist, it will fail.
GetNode(int64_t node_id) const1789 std::shared_ptr<Node> GetNode(int64_t node_id) const {
1790 return node_map_.at(node_id);
1791 }
1792
1793 protected:
1794 std::unique_ptr<Model> model_;
1795 absl::flat_hash_map<int64_t, std::shared_ptr<Node>> node_map_;
1796 };
1797
TEST_F(BufferSizeTest,OptimizeBuffers_PlentyOfMemory)1798 TEST_F(BufferSizeTest, OptimizeBuffers_PlentyOfMemory) {
1799 ReadModel(R"pb(
1800 nodes: {
1801 key: 1
1802 value: {
1803 id: 1
1804 name: "Prefetch"
1805 autotune: true
1806 bytes_produced: 10000
1807 num_elements: 100
1808 processing_time: 2000
1809 node_class: ASYNC_KNOWN_RATIO
1810 inputs: 2
1811 ratio: 1
1812 parameters: {
1813 name: "buffer_size"
1814 value: 3
1815 state_value: 3
1816 min: 1
1817 max: 10
1818 tunable: true
1819 }
1820 }
1821 }
1822 nodes: {
1823 key: 2
1824 value: {
1825 id: 2
1826 name: "Prefetch"
1827 autotune: true
1828 bytes_produced: 10000
1829 num_elements: 100
1830 processing_time: 2000
1831 node_class: ASYNC_KNOWN_RATIO
1832 inputs: 3
1833 ratio: 1
1834 parameters: {
1835 name: "buffer_size"
1836 value: 5
1837 state_value: 5
1838 min: 1
1839 max: 10
1840 tunable: true
1841 }
1842 }
1843 }
1844 nodes: {
1845 key: 3
1846 value: {
1847 id: 3
1848 name: "Prefetch"
1849 autotune: true
1850 bytes_produced: 10000
1851 num_elements: 100
1852 processing_time: 2000
1853 node_class: ASYNC_KNOWN_RATIO
1854 inputs: 4
1855 ratio: 1
1856 parameters: {
1857 name: "buffer_size"
1858 value: 5
1859 state_value: 5
1860 min: 1
1861 max: 10
1862 tunable: true
1863 }
1864 }
1865 }
1866 nodes: {
1867 key: 4
1868 value: {
1869 id: 4
1870 name: "Prefetch"
1871 autotune: true
1872 bytes_produced: 10000
1873 num_elements: 100
1874 processing_time: 2000
1875 node_class: ASYNC_KNOWN_RATIO
1876 inputs: 5
1877 ratio: 1
1878 parameters: {
1879 name: "buffer_size"
1880 value: 5
1881 state_value: 5
1882 min: 1
1883 max: 8
1884 tunable: true
1885 }
1886 }
1887 }
1888 nodes: {
1889 key: 5
1890 value: {
1891 id: 5
1892 name: "Prefetch"
1893 autotune: true
1894 bytes_produced: 10000
1895 num_elements: 100
1896 processing_time: 2000
1897 node_class: ASYNC_KNOWN_RATIO
1898 ratio: 1
1899 parameters: {
1900 name: "buffer_size"
1901 value: 8
1902 state_value: 8
1903 min: 1
1904 max: 8
1905 tunable: true
1906 }
1907 }
1908 }
1909 output: 1
1910 )pb");
1911
1912 std::shared_ptr<Node> node_1 = GetNode(1);
1913 std::shared_ptr<Node> node_2 = GetNode(2);
1914 std::shared_ptr<Node> node_3 = GetNode(3);
1915 std::shared_ptr<Node> node_4 = GetNode(4);
1916 std::shared_ptr<Node> node_5 = GetNode(5);
1917 // Set node 1 low watermark to 1 and high watermark to 2. Expect that it is
1918 // downsized to 2.
1919 node_1->record_buffer_event(100, 1);
1920 node_1->record_buffer_event(100, 1);
1921 EXPECT_EQ(1, node_1->buffered_elements_low());
1922 EXPECT_EQ(2, node_1->buffered_elements_high());
1923 // Set node 2 low watermark to 1 and high watermark to 5. Expect that it is
1924 // not changed.
1925 node_2->record_buffer_event(100, 1);
1926 node_2->record_buffer_event(400, 4);
1927 node_2->record_buffer_event(-100, -1);
1928 EXPECT_EQ(1, node_2->buffered_elements_low());
1929 EXPECT_EQ(5, node_2->buffered_elements_high());
1930 // Set node 3 low watermark to 0 and high watermark to 5. Expect that it is
1931 // upsized to 10.
1932 node_3->record_buffer_event(100, 1);
1933 node_3->record_buffer_event(-100, -1);
1934 node_3->record_buffer_event(500, 5);
1935 node_3->record_buffer_event(-100, -1);
1936 EXPECT_EQ(0, node_3->buffered_elements_low());
1937 EXPECT_EQ(5, node_3->buffered_elements_high());
1938 // Set node 4 low watermark to 0 and high watermark to 5. Its max buffer size
1939 // is set to 8. Expect that it is upsized to 8.
1940 node_4->record_buffer_event(100, 1);
1941 node_4->record_buffer_event(-100, -1);
1942 node_4->record_buffer_event(500, 5);
1943 node_4->record_buffer_event(-100, -1);
1944 EXPECT_EQ(0, node_4->buffered_elements_low());
1945 EXPECT_EQ(5, node_4->buffered_elements_high());
1946 // Set node 5 low watermark to 1 and high watermark to 2. Its current buffer
1947 // size is set to 8. Expect that it is downsized to 8/2 rather than (2 - 1 + 1
1948 // = 3) because downsize is capped to half its size.
1949 node_5->record_buffer_event(100, 1);
1950 node_5->record_buffer_event(-100, 1);
1951 EXPECT_EQ(1, node_5->buffered_elements_low());
1952 EXPECT_EQ(2, node_5->buffered_elements_high());
1953
1954 model_->OptimizeBuffers(node_1->Snapshot(), 10000);
1955
1956 EXPECT_EQ(2, node_1->parameter_value(kBufferSize));
1957 EXPECT_EQ(5, node_2->parameter_value(kBufferSize));
1958 EXPECT_EQ(10, node_3->parameter_value(kBufferSize));
1959 EXPECT_EQ(8, node_4->parameter_value(kBufferSize));
1960 EXPECT_EQ(6, node_5->parameter_value(kBufferSize));
1961 EXPECT_EQ(2, node_1->buffered_elements_low());
1962 EXPECT_EQ(2, node_1->buffered_elements_high());
1963 EXPECT_EQ(4, node_2->buffered_elements_low());
1964 EXPECT_EQ(4, node_2->buffered_elements_high());
1965 EXPECT_EQ(4, node_3->buffered_elements_low());
1966 EXPECT_EQ(4, node_3->buffered_elements_high());
1967 EXPECT_EQ(4, node_4->buffered_elements_low());
1968 EXPECT_EQ(4, node_4->buffered_elements_high());
1969 EXPECT_EQ(2, node_5->buffered_elements_low());
1970 EXPECT_EQ(2, node_5->buffered_elements_high());
1971 }
1972
TEST_F(BufferSizeTest,OptimizeBuffers_TightMemory)1973 TEST_F(BufferSizeTest, OptimizeBuffers_TightMemory) {
1974 ReadModel(R"pb(
1975 nodes: {
1976 key: 1
1977 value: {
1978 id: 1
1979 name: "Prefetch"
1980 autotune: true
1981 bytes_produced: 10000
1982 num_elements: 100
1983 processing_time: 2000
1984 node_class: ASYNC_KNOWN_RATIO
1985 inputs: 2
1986 ratio: 1
1987 parameters: {
1988 name: "buffer_size"
1989 value: 5
1990 state_value: 5
1991 min: 1
1992 max: 10
1993 tunable: true
1994 }
1995 }
1996 }
1997 nodes: {
1998 key: 2
1999 value: {
2000 id: 2
2001 name: "Prefetch"
2002 autotune: true
2003 bytes_produced: 10000
2004 num_elements: 100
2005 processing_time: 2000
2006 node_class: ASYNC_KNOWN_RATIO
2007 inputs: 3
2008 ratio: 1
2009 parameters: {
2010 name: "buffer_size"
2011 value: 5
2012 state_value: 5
2013 min: 1
2014 max: 10
2015 tunable: true
2016 }
2017 }
2018 }
2019 nodes: {
2020 key: 3
2021 value: {
2022 id: 3
2023 name: "Prefetch"
2024 autotune: true
2025 bytes_produced: 10000
2026 num_elements: 100
2027 processing_time: 2000
2028 node_class: ASYNC_KNOWN_RATIO
2029 inputs: 4
2030 ratio: 1
2031 parameters: {
2032 name: "buffer_size"
2033 value: 5
2034 state_value: 5
2035 min: 1
2036 max: 10
2037 tunable: true
2038 }
2039 }
2040 }
2041 nodes: {
2042 key: 4
2043 value: {
2044 id: 4
2045 name: "Prefetch"
2046 autotune: true
2047 bytes_produced: 10000
2048 num_elements: 100
2049 processing_time: 2000
2050 node_class: ASYNC_KNOWN_RATIO
2051 ratio: 1
2052 parameters: {
2053 name: "buffer_size"
2054 value: 5
2055 state_value: 5
2056 min: 1
2057 max: 8
2058 tunable: true
2059 }
2060 }
2061 }
2062 output: 1
2063 )pb");
2064
2065 std::shared_ptr<Node> node_1 = GetNode(1);
2066 std::shared_ptr<Node> node_2 = GetNode(2);
2067 std::shared_ptr<Node> node_3 = GetNode(3);
2068 std::shared_ptr<Node> node_4 = GetNode(4);
2069 // Set low watermark to 0 and high watermark to 5 for all nodes.
2070 node_1->record_buffer_event(100, 1);
2071 node_1->record_buffer_event(-100, -1);
2072 node_1->record_buffer_event(500, 5);
2073 EXPECT_EQ(0, node_1->buffered_elements_low());
2074 EXPECT_EQ(5, node_1->buffered_elements_high());
2075 node_2->record_buffer_event(100, 1);
2076 node_2->record_buffer_event(100, 1);
2077 node_2->record_buffer_event(-100, -1);
2078 node_2->record_buffer_event(-100, -1);
2079 node_2->record_buffer_event(400, 4);
2080 node_2->record_buffer_event(100, 1);
2081 EXPECT_EQ(0, node_2->buffered_elements_low());
2082 EXPECT_EQ(5, node_2->buffered_elements_high());
2083 node_3->record_buffer_event(100, 1);
2084 node_3->record_buffer_event(-100, -1);
2085 node_3->record_buffer_event(500, 5);
2086 EXPECT_EQ(0, node_3->buffered_elements_low());
2087 EXPECT_EQ(5, node_3->buffered_elements_high());
2088 node_4->record_buffer_event(100, 1);
2089 node_4->record_buffer_event(-100, -1);
2090 node_4->record_buffer_event(100, 1);
2091 node_4->record_buffer_event(-100, -1);
2092 node_4->record_buffer_event(500, 5);
2093 node_4->record_buffer_event(-100, -1);
2094 EXPECT_EQ(0, node_4->buffered_elements_low());
2095 EXPECT_EQ(5, node_4->buffered_elements_high());
2096
2097 model_->OptimizeBuffers(node_1->Snapshot(), 3000);
2098
2099 EXPECT_DOUBLE_EQ(7.0, node_1->parameter_value(kBufferSize));
2100 EXPECT_DOUBLE_EQ(7.0, node_2->parameter_value(kBufferSize));
2101 EXPECT_DOUBLE_EQ(7.0, node_3->parameter_value(kBufferSize));
2102 EXPECT_DOUBLE_EQ(7.0, node_4->parameter_value(kBufferSize));
2103 EXPECT_EQ(5, node_1->buffered_elements_low());
2104 EXPECT_EQ(5, node_1->buffered_elements_high());
2105 EXPECT_EQ(5, node_2->buffered_elements_low());
2106 EXPECT_EQ(5, node_2->buffered_elements_high());
2107 EXPECT_EQ(5, node_3->buffered_elements_low());
2108 EXPECT_EQ(5, node_3->buffered_elements_high());
2109 EXPECT_EQ(4, node_4->buffered_elements_low());
2110 EXPECT_EQ(4, node_4->buffered_elements_high());
2111 }
2112
TEST_F(ModelTimingTest,OptimizeStageBased_OneStage)2113 TEST_F(ModelTimingTest, OptimizeStageBased_OneStage) {
2114 BuildModelFromProto(R"pb(
2115 nodes: {
2116 key: 1
2117 value: {
2118 id: 1
2119 name: "ParallelMapV2"
2120 autotune: true
2121 num_elements: 100
2122 processing_time: 5000
2123 bytes_produced: 10000
2124 node_class: ASYNC_KNOWN_RATIO
2125 ratio: 1
2126 inputs: 2
2127 parameters: {
2128 name: "parallelism"
2129 value: 4
2130 min: 1
2131 max: 16
2132 tunable: true
2133 }
2134 }
2135 }
2136 nodes: {
2137 key: 2
2138 value: {
2139 id: 2
2140 name: "Map"
2141 autotune: true
2142 num_elements: 100
2143 processing_time: 3000
2144 node_class: KNOWN_RATIO
2145 ratio: 1
2146 inputs: 3
2147 }
2148 }
2149 nodes: {
2150 key: 3
2151 value: {
2152 id: 3
2153 name: "SSTable"
2154 autotune: true
2155 num_elements: 100
2156 processing_time: 1000
2157 node_class: KNOWN_RATIO
2158 ratio: 2
2159 }
2160 }
2161 output: 1
2162 )pb");
2163
2164 CancellationManager cancellation_manager;
2165 model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 20, 1000, 50,
2166 &cancellation_manager);
2167
2168 EXPECT_EQ(5, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2169 }
2170
TEST_F(ModelTimingTest,OptimizeStageBased_CappedByParameterMax)2171 TEST_F(ModelTimingTest, OptimizeStageBased_CappedByParameterMax) {
2172 BuildModelFromProto(R"pb(
2173 nodes: {
2174 key: 1
2175 value: {
2176 id: 1
2177 name: "ParallelMapV2"
2178 autotune: true
2179 num_elements: 100
2180 processing_time: 5000
2181 bytes_produced: 10000
2182 node_class: ASYNC_KNOWN_RATIO
2183 ratio: 1
2184 inputs: 2
2185 parameters: { name: "parallelism" value: 4 min: 1 max: 3 tunable: true }
2186 }
2187 }
2188 nodes: {
2189 key: 2
2190 value: {
2191 id: 2
2192 name: "Map"
2193 autotune: true
2194 num_elements: 100
2195 processing_time: 3000
2196 node_class: KNOWN_RATIO
2197 ratio: 1
2198 inputs: 3
2199 }
2200 }
2201 nodes: {
2202 key: 3
2203 value: {
2204 id: 3
2205 name: "SSTable"
2206 autotune: true
2207 num_elements: 100
2208 processing_time: 1000
2209 node_class: KNOWN_RATIO
2210 ratio: 2
2211 }
2212 }
2213 output: 1
2214 )pb");
2215
2216 CancellationManager cancellation_manager;
2217 model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 20, 1000, 50,
2218 &cancellation_manager);
2219
2220 // The max value is set to 3. Otherwise, the expected parallelism value is 5.
2221 EXPECT_EQ(3, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2222 }
2223
TEST_F(ModelTimingTest,OptimizeStageBased_TwoStages)2224 TEST_F(ModelTimingTest, OptimizeStageBased_TwoStages) {
2225 BuildModelFromProto(R"pb(
2226 nodes: {
2227 key: 1
2228 value: {
2229 id: 1
2230 name: "ParallelMapV2"
2231 autotune: true
2232 num_elements: 100
2233 processing_time: 25000
2234 bytes_produced: 10000
2235 node_class: ASYNC_KNOWN_RATIO
2236 ratio: 1
2237 inputs: 2
2238 parameters: {
2239 name: "parallelism"
2240 value: 4
2241 min: 1
2242 max: 16
2243 tunable: true
2244 }
2245 }
2246 }
2247 nodes: {
2248 key: 2
2249 value: {
2250 id: 2
2251 name: "ParallelMapV2"
2252 autotune: true
2253 num_elements: 100
2254 processing_time: 20000
2255 bytes_produced: 10000
2256 node_class: ASYNC_KNOWN_RATIO
2257 ratio: 1
2258 inputs: 3
2259 parameters: {
2260 name: "parallelism"
2261 value: 4
2262 min: 1
2263 max: 16
2264 tunable: true
2265 }
2266 }
2267 }
2268 nodes: {
2269 key: 3
2270 value: {
2271 id: 3
2272 name: "SSTable"
2273 autotune: true
2274 num_elements: 100
2275 processing_time: 1000
2276 node_class: KNOWN_RATIO
2277 ratio: 2
2278 }
2279 }
2280 output: 1
2281 )pb");
2282
2283 CancellationManager cancellation_manager;
2284 model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 5, 1000, 50,
2285 &cancellation_manager);
2286
2287 EXPECT_EQ(5, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2288 EXPECT_EQ(5, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
2289 }
2290
TEST_F(ModelTimingTest,OptimizeStageBased_TwoStages_RamBudgetExceeded)2291 TEST_F(ModelTimingTest, OptimizeStageBased_TwoStages_RamBudgetExceeded) {
2292 BuildModelFromProto(R"pb(
2293 nodes: {
2294 key: 1
2295 value: {
2296 id: 1
2297 name: "ParallelMapV2"
2298 autotune: true
2299 num_elements: 100
2300 processing_time: 25000
2301 bytes_produced: 10000
2302 node_class: ASYNC_KNOWN_RATIO
2303 ratio: 1
2304 inputs: 2
2305 parameters: {
2306 name: "parallelism"
2307 value: 4
2308 state_value: 4
2309 min: 1
2310 max: 16
2311 tunable: true
2312 }
2313 }
2314 }
2315 nodes: {
2316 key: 2
2317 value: {
2318 id: 2
2319 name: "ParallelMapV2"
2320 autotune: true
2321 num_elements: 100
2322 processing_time: 20000
2323 bytes_produced: 10000
2324 node_class: ASYNC_KNOWN_RATIO
2325 ratio: 1
2326 inputs: 3
2327 parameters: {
2328 name: "parallelism"
2329 value: 4
2330 state_value: 4
2331 min: 1
2332 max: 16
2333 tunable: true
2334 }
2335 }
2336 }
2337 nodes: {
2338 key: 3
2339 value: {
2340 id: 3
2341 name: "SSTable"
2342 autotune: true
2343 num_elements: 100
2344 processing_time: 1000
2345 node_class: KNOWN_RATIO
2346 ratio: 2
2347 }
2348 }
2349 output: 1
2350 )pb");
2351
2352 CancellationManager cancellation_manager;
2353 // Not enough RAM, the original `parallelism` should not change.
2354 model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 10, 100, 0,
2355 &cancellation_manager);
2356 EXPECT_EQ(4, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2357 EXPECT_EQ(4, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
2358 // Has enough RAM, the original `parallelism` should increase.
2359 model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 10, 100000, 0,
2360 &cancellation_manager);
2361 EXPECT_EQ(12, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2362 EXPECT_EQ(16, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
2363 // Not enough RAM, the original `parallelism` should not change.
2364 model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 10, 100, 0,
2365 &cancellation_manager);
2366 EXPECT_EQ(12, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2367 EXPECT_EQ(16, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
2368 }
2369
TEST_F(ModelTimingTest,OptimizeStageBased_PipelineRatio)2370 TEST_F(ModelTimingTest, OptimizeStageBased_PipelineRatio) {
2371 BuildModelFromProto(R"pb(
2372 nodes: {
2373 key: 1
2374 value: {
2375 id: 1
2376 name: "ParallelBatch"
2377 autotune: true
2378 num_elements: 100
2379 processing_time: 5000
2380 bytes_produced: 10000
2381 node_class: ASYNC_KNOWN_RATIO
2382 ratio: 2
2383 inputs: 2
2384 parameters: {
2385 name: "parallelism"
2386 value: 4
2387 min: 1
2388 max: 16
2389 tunable: true
2390 }
2391 }
2392 }
2393 nodes: {
2394 key: 2
2395 value: {
2396 id: 2
2397 name: "Map"
2398 autotune: true
2399 num_elements: 100
2400 processing_time: 3000
2401 node_class: KNOWN_RATIO
2402 ratio: 1
2403 inputs: 3
2404 }
2405 }
2406 nodes: {
2407 key: 3
2408 value: {
2409 id: 3
2410 name: "SSTable"
2411 autotune: true
2412 num_elements: 100
2413 processing_time: 1000
2414 node_class: KNOWN_RATIO
2415 ratio: 2
2416 }
2417 }
2418 output: 1
2419 )pb");
2420
2421 CancellationManager cancellation_manager;
2422 model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 20, 10000, 50,
2423 &cancellation_manager);
2424
2425 EXPECT_EQ(16, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2426 }
2427
TEST_F(ModelTimingTest,ComputeTargetTime)2428 TEST_F(ModelTimingTest, ComputeTargetTime) {
2429 model_ = std::make_unique<Model>();
2430
2431 model_->RecordIteratorGapTime(10);
2432 model_->RecordIteratorGapTime(10);
2433 model_->RecordIteratorGapTime(10);
2434 model_->RecordIteratorGapTime(10);
2435 model_->RecordIteratorGapTime(10);
2436 model_->RecordIteratorGapTime(1000);
2437 // Gap times that are >= 10 seconds are always dropped.
2438 model_->RecordIteratorGapTime(10000000);
2439
2440 EXPECT_DOUBLE_EQ(10, model_->ComputeTargetTimeNsec() * 1e-3);
2441 }
2442
TEST_F(ModelTimingTest,ComputeTargetTime_NoOutlier)2443 TEST_F(ModelTimingTest, ComputeTargetTime_NoOutlier) {
2444 model_ = std::make_unique<Model>();
2445
2446 model_->RecordIteratorGapTime(10);
2447 model_->RecordIteratorGapTime(10);
2448 model_->RecordIteratorGapTime(10);
2449 model_->RecordIteratorGapTime(10);
2450 model_->RecordIteratorGapTime(20);
2451 model_->RecordIteratorGapTime(20);
2452 model_->RecordIteratorGapTime(20);
2453 model_->RecordIteratorGapTime(20);
2454 // Gap times that are >= 10 seconds are always dropped.
2455 model_->RecordIteratorGapTime(10000000);
2456
2457 EXPECT_DOUBLE_EQ(15.0, model_->ComputeTargetTimeNsec() * 1e-3);
2458 }
2459
TEST_F(ModelTimingTest,ComputeTargetTime_TestWindow)2460 TEST_F(ModelTimingTest, ComputeTargetTime_TestWindow) {
2461 model_ = std::make_unique<Model>();
2462
2463 // The window size is 100. Only the last 100 gap times are used to compute the
2464 // target time.
2465 for (int i = 0; i < 100; ++i) {
2466 model_->RecordIteratorGapTime(20);
2467 }
2468 for (int i = 0; i < 100; ++i) {
2469 model_->RecordIteratorGapTime(10);
2470 }
2471
2472 EXPECT_DOUBLE_EQ(10.0, model_->ComputeTargetTimeNsec() * 1e-3);
2473 }
2474
TEST_F(ModelTimingTest,SelfTime)2475 TEST_F(ModelTimingTest, SelfTime) {
2476 BuildModelFromProto(R"pb(
2477 nodes: {
2478 key: 1
2479 value: {
2480 id: 1
2481 name: "ParallelMapV2"
2482 autotune: true
2483 num_elements: 100
2484 processing_time: 20000
2485 node_class: ASYNC_KNOWN_RATIO
2486 ratio: 1
2487 inputs: 2
2488 parameters: {
2489 name: "parallelism"
2490 value: 2
2491 min: 1
2492 max: 16
2493 tunable: true
2494 }
2495 }
2496 }
2497 nodes: {
2498 key: 2
2499 value: {
2500 id: 2
2501 name: "SSTable"
2502 autotune: true
2503 num_elements: 100
2504 processing_time: 100000
2505 node_class: KNOWN_RATIO
2506 ratio: 1
2507 }
2508 }
2509 output: 1
2510 )pb");
2511
2512 auto node_1 = MutableGetNode(/*node_id=*/1);
2513 EXPECT_DOUBLE_EQ(100, node_1->ComputeSelfTime());
2514 node_1->add_processing_time(400);
2515 node_1->record_element();
2516 EXPECT_DOUBLE_EQ(110, node_1->ComputeSelfTime());
2517 auto node_2 = MutableGetNode(/*node_id=*/2);
2518 EXPECT_DOUBLE_EQ(1000, node_2->ComputeSelfTime());
2519 node_2->add_processing_time(100);
2520 node_2->record_element();
2521 EXPECT_DOUBLE_EQ(910, node_2->ComputeSelfTime());
2522 }
2523
2524 } // namespace
2525 } // namespace model
2526 } // namespace data
2527 } // namespace tensorflow
2528