xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/model_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/model.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "tensorflow/core/framework/cancellation.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/lib/gtl/cleanup.h"
26 #include "tensorflow/core/lib/monitoring/cell_reader.h"
27 #include "tensorflow/core/platform/stringprintf.h"
28 #include "tensorflow/core/platform/test.h"
29 
30 namespace tensorflow {
31 namespace data {
32 namespace model {
33 namespace {
34 
35 using ::tensorflow::monitoring::testing::CellReader;
36 using ::testing::AllOf;
37 using ::testing::HasSubstr;
38 
CountParametersOnNode(const string & node_name,const Model::ModelParameters & parameters)39 int64_t CountParametersOnNode(const string& node_name,
40                               const Model::ModelParameters& parameters) {
41   int64_t cnt = 0;
42   for (const auto& pair : parameters) {
43     if (pair.first == node_name) {
44       cnt++;
45     }
46   }
47   return cnt;
48 }
49 
50 class AsyncInterleaveManyTest
51     : public ::testing::TestWithParam<std::tuple<int64_t, double>> {};
52 
TEST_P(AsyncInterleaveManyTest,Model)53 TEST_P(AsyncInterleaveManyTest, Model) {
54   const int64_t parallelism = std::get<0>(GetParam());
55   const double input_time = std::get<1>(GetParam());
56   std::shared_ptr<Node> async_interleave_many =
57       model::MakeAsyncInterleaveManyNode(
58           {0, "async_interleave_many", nullptr},
59           {model::MakeParameter("parallelism",
60                                 std::make_shared<SharedState>(
61                                     /*value=*/parallelism, nullptr, nullptr),
62                                 /*min=*/1,
63                                 /*max=*/8),
64            model::MakeParameter(kCycleLength, nullptr,
65                                 /*min=*/1,
66                                 /*max=*/1)});
67   std::shared_ptr<Node> meta_source =
68       model::MakeSourceNode({1, "meta_source", async_interleave_many});
69   async_interleave_many->add_input(meta_source);
70   auto cleanup_meta = gtl::MakeCleanup([async_interleave_many, meta_source]() {
71     async_interleave_many->remove_input(meta_source);
72   });
73   std::shared_ptr<Node> source1 =
74       model::MakeSourceNode({2, "source1", async_interleave_many});
75   async_interleave_many->add_input(source1);
76   auto cleanup1 = gtl::MakeCleanup([async_interleave_many, source1]() {
77     async_interleave_many->remove_input(source1);
78   });
79   std::shared_ptr<Node> source2 =
80       model::MakeSourceNode({3, "source2", async_interleave_many});
81   async_interleave_many->add_input(source2);
82   auto cleanup2 = gtl::MakeCleanup([async_interleave_many, source2]() {
83     async_interleave_many->remove_input(source2);
84   });
85   Model::NodeValues input_times;
86   input_times[kModelInputTimeKey] = input_time;
87   EXPECT_EQ(async_interleave_many->TotalBufferedBytes(), 0);
88   EXPECT_EQ(async_interleave_many->TotalMaximumBufferedBytes(), 0);
89   async_interleave_many->record_buffer_event(110, 10);
90   EXPECT_EQ(async_interleave_many->TotalBufferedBytes(), 110);
91   EXPECT_EQ(async_interleave_many->TotalMaximumBufferedBytes(),
92             110 * parallelism / 10);
93   async_interleave_many->add_processing_time(100);
94   EXPECT_EQ(async_interleave_many->processing_time(), 100);
95   EXPECT_EQ(
96       async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
97       0);
98   EXPECT_EQ(async_interleave_many->OutputTime(&input_times, nullptr), 0);
99   async_interleave_many->record_element();
100   EXPECT_EQ(async_interleave_many->num_elements(), 1);
101   EXPECT_EQ(
102       async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
103       100);
104   EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr), 100);
105   EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
106   source1->add_processing_time(200);
107   source2->add_processing_time(300);
108   EXPECT_EQ(
109       async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
110       100);
111   EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr), 100);
112   EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
113   source1->record_element();
114   source2->record_element();
115   EXPECT_EQ(
116       async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
117       100 + 250);
118   EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr),
119             100 + 250 / parallelism);
120   EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
121   async_interleave_many->record_element();
122   EXPECT_EQ(
123       async_interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
124       50 + 250);
125   EXPECT_LE(async_interleave_many->OutputTime(&input_times, nullptr),
126             50 + 250 / parallelism);
127   EXPECT_GE(async_interleave_many->OutputTime(&input_times, nullptr), 0);
128 }
129 
130 INSTANTIATE_TEST_SUITE_P(Test, AsyncInterleaveManyTest,
131                          ::testing::Combine(::testing::Values(1, 2),
132                                             ::testing::Values(0, 50, 100,
133                                                               200)));
134 
135 class AsyncKnownRatioTest
136     : public ::testing::TestWithParam<std::tuple<int64_t, double, int64_t>> {};
137 
TEST_P(AsyncKnownRatioTest,Model)138 TEST_P(AsyncKnownRatioTest, Model) {
139   const int64_t parallelism = std::get<0>(GetParam());
140   const double input_time = std::get<1>(GetParam());
141   const int64_t num_inputs_per_output = std::get<2>(GetParam());
142   std::shared_ptr<Node> async_known_many = model::MakeAsyncKnownRatioNode(
143       {0, "async_known_many", nullptr}, num_inputs_per_output,
144       {model::MakeParameter("parallelism",
145                             std::make_shared<SharedState>(/*value=*/parallelism,
146                                                           nullptr, nullptr),
147                             /*min=*/1,
148                             /*max=*/16)});
149   std::shared_ptr<Node> source1 =
150       model::MakeSourceNode({1, "source1", async_known_many});
151   async_known_many->add_input(source1);
152   std::shared_ptr<Node> source2 =
153       model::MakeSourceNode({2, "source2", async_known_many});
154   async_known_many->add_input(source2);
155   Model::NodeValues input_times;
156   input_times[kModelInputTimeKey] = input_time;
157   EXPECT_EQ(async_known_many->TotalBufferedBytes(), 0);
158   EXPECT_EQ(async_known_many->TotalMaximumBufferedBytes(), 0);
159   async_known_many->record_buffer_event(110, 10);
160   EXPECT_EQ(async_known_many->TotalBufferedBytes(), 110);
161   EXPECT_EQ(async_known_many->TotalMaximumBufferedBytes(),
162             num_inputs_per_output == 0
163                 ? 110.0 * parallelism / 10
164                 : 110.0 * parallelism / 10 / num_inputs_per_output);
165   source1->add_processing_time(100);
166   EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
167             0);
168   EXPECT_EQ(async_known_many->OutputTime(&input_times, nullptr), 0);
169   source2->add_processing_time(200);
170   EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
171             0);
172   EXPECT_EQ(async_known_many->OutputTime(&input_times, nullptr), 0);
173   source1->record_element();
174   EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
175             num_inputs_per_output * 100);
176   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
177             num_inputs_per_output * 100);
178   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
179   source2->record_element();
180   EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
181             num_inputs_per_output * (100 + 200));
182   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
183             num_inputs_per_output * (100 + 200));
184   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
185   source1->record_element();
186   EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
187             num_inputs_per_output * (50 + 200));
188   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
189             num_inputs_per_output * (50 + 200));
190   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
191   source2->record_element();
192   EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
193             num_inputs_per_output * (50 + 100));
194   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
195             num_inputs_per_output * (50 + 100));
196   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
197   async_known_many->add_processing_time(128);
198   EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
199             num_inputs_per_output * (50 + 100));
200   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
201             num_inputs_per_output * (50 + 100));
202   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
203   async_known_many->record_element();
204   EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
205             num_inputs_per_output * (50 + 100) + 128);
206   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
207             num_inputs_per_output * (50 + 100) + 128 / parallelism);
208   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
209   async_known_many->record_element();
210   EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
211             num_inputs_per_output * (50 + 100) + 64);
212   EXPECT_LE(async_known_many->OutputTime(&input_times, nullptr),
213             num_inputs_per_output * (50 + 100) + 64 / parallelism);
214   EXPECT_GE(async_known_many->OutputTime(&input_times, nullptr), 0);
215 }
216 
217 INSTANTIATE_TEST_SUITE_P(Test, AsyncKnownRatioTest,
218                          ::testing::Combine(::testing::Values(1, 2, 4, 8),
219                                             ::testing::Values(0, 50, 100, 200),
220                                             ::testing::Values(0, 1, 2, 4)));
221 
TEST(InterleaveManyTest,Model)222 TEST(InterleaveManyTest, Model) {
223   auto parameter =
224       model::MakeParameter("cycle_length", nullptr, /*min=*/1, /*max=*/1);
225   std::shared_ptr<Node> interleave_many = model::MakeInterleaveManyNode(
226       {0, "interleave_many", nullptr},
227       {model::MakeParameter("cycle_length", nullptr, /*min=*/1, /*max=*/1)});
228   std::shared_ptr<Node> meta_source =
229       model::MakeSourceNode({1, "meta_source", interleave_many});
230   interleave_many->add_input(meta_source);
231   std::shared_ptr<Node> source1 =
232       model::MakeSourceNode({2, "source1", interleave_many});
233   interleave_many->add_input(source1);
234   std::shared_ptr<Node> source2 =
235       model::MakeSourceNode({3, "source2", interleave_many});
236   interleave_many->add_input(source2);
237   Model::NodeValues input_times;
238   input_times[kModelInputTimeKey] = 0.0;
239   interleave_many->add_processing_time(100);
240   EXPECT_EQ(interleave_many->processing_time(), 100);
241   EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
242             0);
243   EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 0);
244   interleave_many->record_element();
245   EXPECT_EQ(interleave_many->num_elements(), 1);
246   EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
247             100);
248   EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 100);
249   source1->add_processing_time(200);
250   source2->add_processing_time(300);
251   EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
252             100);
253   EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 100);
254   source1->record_element();
255   source2->record_element();
256   EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
257             350);
258   EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 350);
259   interleave_many->record_element();
260   EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
261             300);
262   EXPECT_EQ(interleave_many->OutputTime(&input_times, nullptr), 300);
263 }
264 
265 class KnownRatioTest : public ::testing::TestWithParam<int64_t> {};
266 
TEST_P(KnownRatioTest,Model)267 TEST_P(KnownRatioTest, Model) {
268   const int64_t num_inputs_per_output = GetParam();
269   std::shared_ptr<Node> known_many = model::MakeKnownRatioNode(
270       {0, "known_many", nullptr}, num_inputs_per_output);
271   std::shared_ptr<Node> source1 =
272       model::MakeSourceNode({1, "source1", known_many});
273   known_many->add_input(source1);
274   std::shared_ptr<Node> source2 =
275       model::MakeSourceNode({2, "source2", known_many});
276   known_many->add_input(source2);
277   Model::NodeValues input_times;
278   input_times[kModelInputTimeKey] = 0.0;
279   source1->add_processing_time(100);
280   EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
281   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr), 0);
282   source2->add_processing_time(200);
283   EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
284   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr), 0);
285   source1->record_element();
286   EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
287             num_inputs_per_output * 100);
288   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
289             num_inputs_per_output * 100);
290   source2->record_element();
291   EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
292             num_inputs_per_output * (100 + 200));
293   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
294             num_inputs_per_output * (100 + 200));
295   source1->record_element();
296   EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
297             num_inputs_per_output * (50 + 200));
298   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
299             num_inputs_per_output * (50 + 200));
300   source2->record_element();
301   EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
302             num_inputs_per_output * (50 + 100));
303   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
304             num_inputs_per_output * (50 + 100));
305   known_many->add_processing_time(128);
306   EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
307             num_inputs_per_output * (50 + 100));
308   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
309             num_inputs_per_output * (50 + 100));
310   known_many->record_element();
311   EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
312             num_inputs_per_output * (50 + 100) + 128);
313   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
314             num_inputs_per_output * (50 + 100) + 128);
315   known_many->record_element();
316   EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr),
317             num_inputs_per_output * (50 + 100) + 64);
318   EXPECT_EQ(known_many->OutputTime(&input_times, nullptr),
319             num_inputs_per_output * (50 + 100) + 64);
320 }
321 
322 INSTANTIATE_TEST_SUITE_P(Test, KnownRatioTest, ::testing::Values(0, 1, 2, 4));
323 
TEST(SourceTest,Model)324 TEST(SourceTest, Model) {
325   std::shared_ptr<Node> source = model::MakeSourceNode({0, "source", nullptr});
326   Model::NodeValues input_times;
327   input_times[kModelInputTimeKey] = 0.0;
328   source->add_processing_time(100);
329   EXPECT_EQ(source->processing_time(), 100);
330   EXPECT_EQ(source->TotalProcessingTime(/*processing_times=*/nullptr), 0);
331   EXPECT_EQ(source->OutputTime(&input_times, nullptr), 0);
332   source->record_element();
333   EXPECT_EQ(source->num_elements(), 1);
334   EXPECT_EQ(source->TotalProcessingTime(/*processing_times=*/nullptr), 100);
335   EXPECT_EQ(source->OutputTime(&input_times, nullptr), 100);
336   source->record_element();
337   EXPECT_EQ(source->num_elements(), 2);
338   EXPECT_EQ(source->TotalProcessingTime(/*processing_times=*/nullptr), 50);
339   EXPECT_EQ(source->OutputTime(&input_times, nullptr), 50);
340 }
341 
TEST(UnknownRatioTest,Model)342 TEST(UnknownRatioTest, Model) {
343   std::shared_ptr<Node> unknown_many =
344       model::MakeUnknownRatioNode({0, "unknown_many", nullptr});
345   std::shared_ptr<Node> source1 =
346       model::MakeSourceNode({1, "source1", unknown_many});
347   unknown_many->add_input(source1);
348   std::shared_ptr<Node> source2 =
349       model::MakeSourceNode({2, "source2", unknown_many});
350   unknown_many->add_input(source2);
351   Model::NodeValues input_times;
352   input_times[kModelInputTimeKey] = 0.0;
353   unknown_many->add_processing_time(100);
354   EXPECT_EQ(unknown_many->processing_time(), 100);
355   EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
356   EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 0);
357   unknown_many->record_element();
358   EXPECT_EQ(unknown_many->num_elements(), 1);
359   EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
360             100);
361   EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 100);
362   source1->add_processing_time(100);
363   source2->add_processing_time(200);
364   EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
365             100);
366   EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 100);
367   source1->record_element();
368   source2->record_element();
369   EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
370             400);
371   EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 400);
372   unknown_many->record_element();
373   EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
374             200);
375   EXPECT_EQ(unknown_many->OutputTime(&input_times, nullptr), 200);
376 }
377 
378 class AsyncUnknownRatioTest
379     : public ::testing::TestWithParam<std::tuple<int64_t, double>> {};
380 
TEST_P(AsyncUnknownRatioTest,Model)381 TEST_P(AsyncUnknownRatioTest, Model) {
382   const int64_t parallelism = std::get<0>(GetParam());
383   const double input_time = std::get<1>(GetParam());
384   std::shared_ptr<Node> async_unknown_many = model::MakeAsyncUnknownRatioNode(
385       {0, "async_unknown_many", nullptr},
386       {model::MakeParameter("parallelism",
387                             std::make_shared<SharedState>(/*value=*/parallelism,
388                                                           nullptr, nullptr),
389                             /*min=*/1,
390                             /*max=*/16)});
391   std::shared_ptr<Node> source1 =
392       model::MakeSourceNode({1, "source1", async_unknown_many});
393   async_unknown_many->add_input(source1);
394   std::shared_ptr<Node> source2 =
395       model::MakeSourceNode({2, "source2", async_unknown_many});
396   async_unknown_many->add_input(source2);
397   Model::NodeValues input_times;
398   input_times[kModelInputTimeKey] = input_time;
399   EXPECT_EQ(async_unknown_many->TotalBufferedBytes(), 0);
400   EXPECT_EQ(async_unknown_many->TotalMaximumBufferedBytes(), 0);
401   async_unknown_many->record_buffer_event(110, 10);
402   EXPECT_EQ(async_unknown_many->TotalBufferedBytes(), 110);
403   EXPECT_EQ(async_unknown_many->TotalMaximumBufferedBytes(),
404             110.0 * parallelism / 10);
405   source1->add_processing_time(100);
406   EXPECT_EQ(
407       async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
408   EXPECT_EQ(async_unknown_many->OutputTime(&input_times, nullptr), 0);
409   source2->add_processing_time(200);
410   EXPECT_EQ(
411       async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
412   EXPECT_EQ(async_unknown_many->OutputTime(&input_times, nullptr), 0);
413   source1->record_element();
414   EXPECT_EQ(
415       async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr), 0);
416   EXPECT_EQ(async_unknown_many->OutputTime(&input_times, nullptr), 0);
417   async_unknown_many->record_element();
418   // Estimated ratio is 1.
419   double ratio = 1.0;
420   EXPECT_EQ(
421       async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
422       ratio * 100);
423   EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr), 100);
424   EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr), 0);
425   source2->record_element();
426   EXPECT_EQ(
427       async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
428       ratio * (100 + 200));
429   EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
430             ratio * (100 + 200));
431   EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr), 0);
432   source2->record_element();
433   EXPECT_EQ(
434       async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
435       ratio * (100 + 100));
436   EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
437             ratio * (100 + 100));
438   EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr), 0);
439   source1->record_element();
440   // Estimated ratio is 2
441   ratio = 2.0;
442   EXPECT_EQ(
443       async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
444       ratio * (50 + 100));
445   EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
446             ratio * (50 + 100));
447   EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr), 0);
448   source2->record_element();
449   source2->record_element();
450   EXPECT_EQ(
451       async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
452       ratio * (50 + 50));
453   EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
454             ratio * (50 + 50));
455   EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr), 0);
456   async_unknown_many->add_processing_time(128);
457   EXPECT_EQ(
458       async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
459       ratio * (50 + 50) + 128);
460   EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
461             ratio * (50 + 50) + 128 / parallelism);
462   EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr),
463             128 / parallelism);
464   async_unknown_many->record_element();
465   // Estimated ratio is 1.0
466   ratio = 1.0;
467   EXPECT_EQ(
468       async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
469       ratio * (50 + 50) + 128 / 2);
470   EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
471             ratio * (50 + 50) + 128 / 2 / parallelism);
472   EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr),
473             128 / 2 / parallelism);
474   async_unknown_many->record_element();
475   // Estimated ratio is 2/3
476   ratio = 2.0 / 3.0;
477   EXPECT_FLOAT_EQ(
478       async_unknown_many->TotalProcessingTime(/*processing_times=*/nullptr),
479       ratio * (50 + 50) + 128 / 3.0);
480   EXPECT_LE(async_unknown_many->OutputTime(&input_times, nullptr),
481             ratio * (50 + 50) + 128 / 3.0 / parallelism);
482   EXPECT_GE(async_unknown_many->OutputTime(&input_times, nullptr),
483             128 / 3.0 / parallelism);
484 }
485 
486 INSTANTIATE_TEST_SUITE_P(Test, AsyncUnknownRatioTest,
487                          ::testing::Combine(::testing::Values(1, 2, 4, 8),
488                                             ::testing::Values(0, 50, 100,
489                                                               200)));
490 
TEST(UnknownTest,Model)491 TEST(UnknownTest, Model) {
492   std::shared_ptr<Node> unknown =
493       model::MakeUnknownNode({0, "unknown", nullptr});
494   std::shared_ptr<Node> source1 =
495       model::MakeSourceNode({1, "source1", unknown});
496   unknown->add_input(source1);
497   std::shared_ptr<Node> source2 =
498       model::MakeSourceNode({2, "source2", unknown});
499   unknown->add_input(source2);
500   Model::NodeValues input_times;
501   input_times[kModelInputTimeKey] = 0.0;
502   source1->add_processing_time(100);
503   EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 0);
504   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 0);
505   source2->add_processing_time(100);
506   EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 0);
507   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 0);
508   source1->record_element();
509   EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
510   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
511   source2->record_element();
512   EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 200);
513   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 200);
514   source1->record_element();
515   EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 150);
516   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 150);
517   source2->record_element();
518   EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
519   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
520   // Unknown node processing time should not affect its TotalProcessingTime() or
521   // OutputTime().
522   unknown->add_processing_time(100);
523   EXPECT_EQ(unknown->processing_time(), 100);
524   EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
525   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
526   // Unknown node number of elements should not affect its TotalProcessingTime()
527   // or OutputTime().
528   unknown->record_element();
529   EXPECT_EQ(unknown->num_elements(), 1);
530   EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 100);
531   EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
532 }
533 
TEST(BufferedBytesTest,Node)534 TEST(BufferedBytesTest, Node) {
535   std::shared_ptr<Node> node = model::MakeAsyncInterleaveManyNode(
536       {-1, "TestNode", nullptr},
537       {model::MakeParameter(
538            "parallelism",
539            std::make_shared<SharedState>(/*value=*/3, nullptr, nullptr),
540            /*min=*/1, /*max=*/7),
541        model::MakeParameter(kCycleLength, nullptr,
542                             /*min=*/1,
543                             /*max=*/1)});
544   EXPECT_EQ(node->id(), -1);
545   EXPECT_EQ(node->name(), "TestNode");
546   EXPECT_EQ(node->output(), nullptr);
547 
548   EXPECT_EQ(node->buffered_bytes(), 0);
549   EXPECT_EQ(node->buffered_elements(), 0);
550   EXPECT_EQ(node->TotalBufferedBytes(), 0);
551   EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0);
552 
553   node->record_buffer_event(20, 1);
554   EXPECT_EQ(node->buffered_bytes(), 20);
555   EXPECT_EQ(node->buffered_elements(), 1);
556   EXPECT_EQ(node->TotalBufferedBytes(), 20);
557   EXPECT_EQ(node->TotalMaximumBufferedBytes(), 60);
558 
559   node->record_buffer_event(10, 1);
560   EXPECT_EQ(node->buffered_bytes(), 30);
561   EXPECT_EQ(node->buffered_elements(), 2);
562   EXPECT_EQ(node->TotalBufferedBytes(), 30);
563   EXPECT_EQ(node->TotalMaximumBufferedBytes(), 45);
564 
565   node->record_buffer_event(18, 1);
566   EXPECT_EQ(node->buffered_bytes(), 48);
567   EXPECT_EQ(node->buffered_elements(), 3);
568   EXPECT_EQ(node->bytes_produced(), 0);
569   EXPECT_EQ(node->num_elements(), 0);
570   EXPECT_EQ(node->TotalBufferedBytes(), 48);
571   EXPECT_EQ(node->TotalMaximumBufferedBytes(), 48);
572 
573   node->record_buffer_event(-20, -1);
574   node->record_element();
575   node->record_bytes_produced(20);
576   EXPECT_EQ(node->buffered_bytes(), 28);
577   EXPECT_EQ(node->buffered_elements(), 2);
578   EXPECT_EQ(node->bytes_produced(), 20);
579   EXPECT_EQ(node->num_elements(), 1);
580   EXPECT_EQ(node->TotalBufferedBytes(), 28);
581   EXPECT_EQ(node->TotalMaximumBufferedBytes(), 51);
582 
583   node->record_buffer_event(-10, -1);
584   node->record_element();
585   node->record_bytes_produced(10);
586   EXPECT_EQ(node->buffered_bytes(), 18);
587   EXPECT_EQ(node->buffered_elements(), 1);
588   EXPECT_EQ(node->bytes_produced(), 30);
589   EXPECT_EQ(node->num_elements(), 2);
590   EXPECT_EQ(node->TotalBufferedBytes(), 18);
591   EXPECT_EQ(node->TotalMaximumBufferedBytes(), 49.5);
592 
593   EXPECT_EQ(node->processing_time(), 0);
594   node->record_start(1);
595   EXPECT_EQ(node->processing_time(), 0);
596   node->record_stop(41);
597   EXPECT_EQ(node->processing_time(), 40);
598   node->add_processing_time(2);
599   EXPECT_EQ(node->processing_time(), 42);
600 
601   std::shared_ptr<Node> input = model::MakeAsyncKnownRatioNode(
602       {0, "TestInput", node}, 2,
603       {model::MakeParameter("parallelism",
604                             std::make_shared<SharedState>(5, nullptr, nullptr),
605                             0, 6)});
606   EXPECT_EQ(input->output(), node.get());
607   EXPECT_EQ(node->inputs().size(), 0);
608   node->add_input(input);
609   EXPECT_EQ(node->inputs().size(), 1);
610   EXPECT_EQ(node->inputs().front(), input);
611 
612   input->record_buffer_event(28, 1);
613   EXPECT_EQ(node->bytes_consumed(), 0);
614   EXPECT_EQ(node->TotalBufferedBytes(), 46);
615   EXPECT_EQ(node->TotalMaximumBufferedBytes(), 119.5);
616 
617   input->record_buffer_event(-28, -1);
618   input->record_element();
619   input->record_bytes_produced(28);
620   node->record_bytes_consumed(28);
621   EXPECT_EQ(node->bytes_consumed(), 28);
622   EXPECT_EQ(node->TotalBufferedBytes(), 18);
623   EXPECT_EQ(node->TotalMaximumBufferedBytes(), 119.5);
624 
625   node->remove_input(input);
626   EXPECT_EQ(node->inputs().size(), 0);
627 }
628 
629 // Returns a weighted sum of a prior and the actual processing time.
weighted_processing_time(int64_t num_elements,double processing_time,double prior)630 double weighted_processing_time(int64_t num_elements, double processing_time,
631                                 double prior) {
632   if (num_elements < 30) {
633     double prior_weight = 1.0L / static_cast<double>(2 << num_elements);
634     return prior_weight * prior + (1.0L - prior_weight) * processing_time;
635   } else {
636     return processing_time;
637   }
638 }
639 
TEST(TestManyElements,Model)640 TEST(TestManyElements, Model) {
641   std::shared_ptr<Node> interleave_many = model::MakeInterleaveManyNode(
642       {0, "interleave_many", nullptr},
643       {model::MakeParameter("cycle_length", nullptr, /*min=*/1, /*max=*/1)});
644   std::shared_ptr<Node> source1 =
645       model::MakeSourceNode({1, "source1", interleave_many});
646   interleave_many->add_input(source1);
647   interleave_many->add_processing_time(100);
648   interleave_many->record_element();
649   source1->add_processing_time(200);
650   for (int i = 0; i < 100; i++) {
651     source1->record_element();
652   }
653   EXPECT_LE(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
654             (weighted_processing_time(100, 2, 0)) + 100);
655   EXPECT_GE(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr),
656             0);
657 }
658 
TEST(CollectAutotuneParametersWithElementsTest,Model)659 TEST(CollectAutotuneParametersWithElementsTest, Model) {
660   std::shared_ptr<Node> unknown =
661       model::MakeUnknownNode({0, "unknown", nullptr});
662   std::shared_ptr<Node> async_known_ratio = model::MakeAsyncKnownRatioNode(
663       {1, "source", unknown}, 2,
664       {model::MakeParameter("parallelism",
665                             std::make_shared<SharedState>(
666                                 /*value=*/model::kAutotune, nullptr, nullptr),
667                             /*min=*/1,
668                             /*max=*/5)});
669   async_known_ratio->record_element();
670   unknown->add_input(async_known_ratio);
671 
672   Model::ModelParameters parameters = unknown->CollectTunableParameters();
673 
674   EXPECT_EQ(CountParametersOnNode(unknown->long_name(), parameters), 0);
675   EXPECT_EQ(CountParametersOnNode(async_known_ratio->long_name(), parameters),
676             1);
677   EXPECT_EQ(parameters.size(), 1);
678 }
679 
TEST(DontCollectNonAutotuneParametersTest,Model)680 TEST(DontCollectNonAutotuneParametersTest, Model) {
681   std::shared_ptr<Node> unknown =
682       model::MakeUnknownNode({0, "unknown", nullptr});
683   std::shared_ptr<Node> async_known_ratio = model::MakeAsyncKnownRatioNode(
684       {1, "source", unknown}, 2,
685       {model::MakeParameter(
686           "parallelism",
687           std::make_shared<SharedState>(/*value=*/3, nullptr, nullptr),
688           /*min=*/1, /*max=*/5)});
689   async_known_ratio->record_element();
690   unknown->add_input(async_known_ratio);
691   Model::ModelParameters parameters = unknown->CollectTunableParameters();
692 
693   EXPECT_EQ(parameters.size(), 0);
694 }
695 
TEST(DontCollectAutotuneDisabledParametersTest,Model)696 TEST(DontCollectAutotuneDisabledParametersTest, Model) {
697   std::shared_ptr<Node> unknown =
698       model::MakeUnknownNode({0, "unknown", nullptr});
699   std::shared_ptr<Node> async_known_ratio = model::MakeAsyncKnownRatioNode(
700       {1, "source", unknown}, 2,
701       {model::MakeParameter("parallelism",
702                             std::make_shared<SharedState>(
703                                 /*value=*/model::kAutotune, nullptr, nullptr),
704                             /*min=*/1,
705                             /*max=*/5)});
706   async_known_ratio->record_element();
707   async_known_ratio->set_autotune(false);
708   unknown->add_input(async_known_ratio);
709   Model::ModelParameters parameters = unknown->CollectTunableParameters();
710 
711   EXPECT_EQ(parameters.size(), 0);
712 }
713 
TEST(DontCollectParametersWithoutElementsTest,Model)714 TEST(DontCollectParametersWithoutElementsTest, Model) {
715   std::shared_ptr<Node> unknown =
716       model::MakeUnknownNode({0, "unknown", nullptr});
717   std::shared_ptr<Node> async_known_ratio = model::MakeAsyncKnownRatioNode(
718       {1, "source", unknown}, 2,
719       {model::MakeParameter("parallelism",
720                             std::make_shared<SharedState>(
721                                 /*value=*/model::kAutotune, nullptr, nullptr),
722                             /*min=*/1,
723                             /*max=*/5)});
724   unknown->add_input(async_known_ratio);
725   Model::ModelParameters parameters = unknown->CollectTunableParameters();
726 
727   EXPECT_EQ(parameters.size(), 0);
728 }
729 
730 // Precision for comparison of the gradient and a relative output time change.
731 constexpr double kComparisonPrecision = 1e-1;
732 
733 // Parameter step for a relative output time change.
734 constexpr double kParameterStep = 1e-5;
735 
TEST(AsyncInterleaveManyGradientTest,Model)736 TEST(AsyncInterleaveManyGradientTest, Model) {
737   const double input_time = 100;
738   std::shared_ptr<Parameter> interleave_parameter =
739       model::MakeParameter("parallelism",
740                            std::make_shared<SharedState>(
741                                /*value=*/model::kAutotune, nullptr, nullptr),
742                            /*min=*/1, /*max=*/5);
743   std::shared_ptr<Node> async_interleave_many =
744       model::MakeAsyncInterleaveManyNode(
745           {0, "async_interleave_many", nullptr},
746           {interleave_parameter, model::MakeParameter("cycle_length", nullptr,
747                                                       /*min=*/1, /*max=*/1)});
748   std::shared_ptr<Node> meta_source =
749       model::MakeSourceNode({1, "meta_source", async_interleave_many});
750   async_interleave_many->add_input(meta_source);
751   auto cleanup_meta = gtl::MakeCleanup([async_interleave_many, meta_source]() {
752     async_interleave_many->remove_input(meta_source);
753   });
754   std::shared_ptr<Parameter> source1_parameter =
755       model::MakeParameter("parallelism",
756                            std::make_shared<SharedState>(
757                                /*value=*/model::kAutotune, nullptr, nullptr),
758                            /*min=*/1,
759                            /*max=*/7);
760   std::shared_ptr<Node> source1 = model::MakeAsyncInterleaveManyNode(
761       {2, "async_interleave_many", async_interleave_many},
762       {source1_parameter,
763        model::MakeParameter("cycle_length", nullptr, /*min=*/1, /*max=*/1)});
764   async_interleave_many->add_input(source1);
765   auto cleanup1 = gtl::MakeCleanup([async_interleave_many, source1]() {
766     async_interleave_many->remove_input(source1);
767   });
768   std::shared_ptr<Node> source2 =
769       model::MakeSourceNode({3, "source2", async_interleave_many});
770   async_interleave_many->add_input(source2);
771   auto cleanup2 = gtl::MakeCleanup([async_interleave_many, source2]() {
772     async_interleave_many->remove_input(source2);
773   });
774   Model::NodeValues input_times;
775   input_times[kModelInputTimeKey] = input_time;
776   async_interleave_many->record_element();
777   async_interleave_many->add_processing_time(100);
778   source1->record_element();
779   source1->add_processing_time(100);
780   source2->record_element();
781   source2->add_processing_time(300);
782 
783   interleave_parameter->value = 1;
784   source1_parameter->value = 1;
785 
786   // Test gradient of own parameters.
787   Model::ParameterGradients gradients;
788   double output_time =
789       async_interleave_many->OutputTime(&input_times, &gradients);
790   interleave_parameter->value += kParameterStep;
791   double new_output_time =
792       async_interleave_many->OutputTime(&input_times, nullptr);
793   EXPECT_NEAR(gradients[std::make_pair(async_interleave_many->long_name(),
794                                        interleave_parameter->name)],
795               (new_output_time - output_time) / kParameterStep,
796               kComparisonPrecision);
797 
798   // Test propagation of input's gradient.
799   interleave_parameter->value -= kParameterStep;
800   source1_parameter->value += kParameterStep;
801   new_output_time = async_interleave_many->OutputTime(&input_times, nullptr);
802   EXPECT_NEAR(
803       gradients[std::make_pair(source1->long_name(), source1_parameter->name)],
804       (new_output_time - output_time) / kParameterStep, kComparisonPrecision);
805 }
806 
807 class AsyncKnownRatioGradientTest : public ::testing::TestWithParam<string> {};
808 
TEST_P(AsyncKnownRatioGradientTest,Model)809 TEST_P(AsyncKnownRatioGradientTest, Model) {
810   const string parameter_name = GetParam();
811   const double input_time = 100;
812   const int64_t num_inputs_per_output = 2;
813 
814   std::shared_ptr<Parameter> known_parameter =
815       model::MakeParameter(parameter_name,
816                            std::make_shared<SharedState>(
817                                /*value=*/model::kAutotune, nullptr, nullptr),
818                            /*min=*/1,
819                            /*max=*/5);
820   std::shared_ptr<Node> async_known_many =
821       model::MakeAsyncKnownRatioNode({0, "async_known_many", nullptr},
822                                      num_inputs_per_output, {known_parameter});
823   std::shared_ptr<Parameter> source1_parameter =
824       model::MakeParameter(parameter_name,
825                            std::make_shared<SharedState>(
826                                /*value=*/model::kAutotune, nullptr, nullptr),
827                            /*min=*/1,
828                            /*max=*/7);
829   std::shared_ptr<Node> source1 = model::MakeAsyncKnownRatioNode(
830       {1, "source1", async_known_many}, num_inputs_per_output,
831       {source1_parameter});
832   async_known_many->add_input(source1);
833   std::shared_ptr<Node> source2 =
834       model::MakeSourceNode({2, "source2", async_known_many});
835   Model::NodeValues input_times;
836   input_times[kModelInputTimeKey] = input_time;
837   async_known_many->add_input(source2);
838   source1->record_element();
839   source1->add_processing_time(100);
840   source2->record_element();
841   source2->add_processing_time(100);
842   async_known_many->record_element();
843   async_known_many->add_processing_time(300);
844 
845   // Test gradient of own parameters.
846   Model::ParameterGradients gradients;
847   known_parameter->value = 1;
848   source1_parameter->value = 1;
849   double output_time = async_known_many->OutputTime(&input_times, &gradients);
850   known_parameter->value += kParameterStep;
851   double new_output_time = async_known_many->OutputTime(&input_times, nullptr);
852   EXPECT_NEAR(gradients[std::make_pair(async_known_many->long_name(),
853                                        known_parameter->name)],
854               (new_output_time - output_time) / kParameterStep,
855               kComparisonPrecision);
856 
857   // Test propagation of input's gradient.
858   known_parameter->value -= kParameterStep;
859   source1_parameter->value += kParameterStep;
860   new_output_time = async_known_many->OutputTime(&input_times, nullptr);
861   EXPECT_NEAR(
862       gradients[std::make_pair(source1->long_name(), source1_parameter->name)],
863       (new_output_time - output_time) / kParameterStep, kComparisonPrecision);
864 }
865 
866 INSTANTIATE_TEST_SUITE_P(Test, AsyncKnownRatioGradientTest,
867                          ::testing::Values("parallelism", "buffer_size"));
868 
TEST(InterleaveManyGradientTest,Model)869 TEST(InterleaveManyGradientTest, Model) {
870   const double input_time = 100;
871   const int64_t num_inputs_per_output = 2;
872   std::shared_ptr<Node> interleave_many = model::MakeInterleaveManyNode(
873       {0, "interleave_many", nullptr},
874       {model::MakeParameter("cycle_length", nullptr, /*min=*/1, /*max=*/1)});
875   std::shared_ptr<Parameter> known_parameter =
876       model::MakeParameter("parallelism",
877                            std::make_shared<SharedState>(
878                                /*value=*/model::kAutotune, nullptr, nullptr),
879                            /*min=*/1,
880                            /*max=*/5);
881   std::shared_ptr<Node> async_known_many =
882       model::MakeAsyncKnownRatioNode({1, "async_known_many", interleave_many},
883                                      num_inputs_per_output, {known_parameter});
884   std::shared_ptr<Node> source1 =
885       model::MakeSourceNode({2, "source1", interleave_many});
886   interleave_many->record_element();
887   interleave_many->add_processing_time(100);
888   interleave_many->add_input(source1);
889   interleave_many->add_input(async_known_many);
890   async_known_many->record_element();
891   async_known_many->add_processing_time(300);
892   Model::NodeValues input_times;
893   input_times[kModelInputTimeKey] = input_time;
894   Model::ParameterGradients gradients;
895   known_parameter->value = 1;
896   double output_time = interleave_many->OutputTime(&input_times, &gradients);
897   known_parameter->value += kParameterStep;
898   double new_output_time = interleave_many->OutputTime(&input_times, nullptr);
899   EXPECT_NEAR(gradients[std::make_pair(async_known_many->long_name(),
900                                        known_parameter->name)],
901               (new_output_time - output_time) / kParameterStep,
902               kComparisonPrecision);
903 }
904 
TEST(KnownRatioGradientTest,Model)905 TEST(KnownRatioGradientTest, Model) {
906   const double input_time = 100;
907   const int64_t num_inputs_per_output = 2;
908   std::shared_ptr<Node> known_many = model::MakeKnownRatioNode(
909       {0, "known_many", nullptr}, num_inputs_per_output);
910   std::shared_ptr<Parameter> known_parameter =
911       model::MakeParameter("parallelism",
912                            std::make_shared<SharedState>(
913                                /*value=*/model::kAutotune, nullptr, nullptr),
914                            /*min=*/1,
915                            /*max=*/5);
916   std::shared_ptr<Node> async_known_many =
917       model::MakeAsyncKnownRatioNode({1, "async_known_many", known_many},
918                                      num_inputs_per_output, {known_parameter});
919   known_many->record_element();
920   known_many->add_processing_time(100);
921   known_many->add_input(async_known_many);
922   async_known_many->record_element();
923   async_known_many->add_processing_time(300);
924   Model::NodeValues input_times;
925   input_times[kModelInputTimeKey] = input_time;
926   Model::ParameterGradients gradients;
927   known_parameter->value = 1;
928   double output_time = known_many->OutputTime(&input_times, &gradients);
929   known_parameter->value += kParameterStep;
930   double new_output_time = known_many->OutputTime(&input_times, nullptr);
931   EXPECT_NEAR(gradients[std::make_pair(async_known_many->long_name(),
932                                        known_parameter->name)],
933               (new_output_time - output_time) / kParameterStep,
934               kComparisonPrecision);
935 }
936 
TEST(UnknownRatioGradientTest,Model)937 TEST(UnknownRatioGradientTest, Model) {
938   const double input_time = 100;
939   const int64_t num_inputs_per_output = 2;
940   std::shared_ptr<Node> unknown_many =
941       model::MakeUnknownRatioNode({0, "unknown_many", nullptr});
942   std::shared_ptr<Parameter> known_parameter =
943       model::MakeParameter("parallelism",
944                            std::make_shared<SharedState>(
945                                /*value=*/model::kAutotune, nullptr, nullptr),
946                            /*min=*/1,
947                            /*max=*/5);
948   std::shared_ptr<Node> async_known_many =
949       model::MakeAsyncKnownRatioNode({1, "async_known_many", unknown_many},
950                                      num_inputs_per_output, {known_parameter});
951   unknown_many->record_element();
952   unknown_many->add_processing_time(100);
953   unknown_many->add_input(async_known_many);
954   async_known_many->record_element();
955   async_known_many->add_processing_time(300);
956   Model::NodeValues input_times;
957   input_times[kModelInputTimeKey] = input_time;
958   Model::ParameterGradients gradients;
959   known_parameter->value = 1;
960   double output_time = unknown_many->OutputTime(&input_times, &gradients);
961   known_parameter->value += kParameterStep;
962   double new_output_time = unknown_many->OutputTime(&input_times, nullptr);
963   EXPECT_NEAR(gradients[std::make_pair(async_known_many->long_name(),
964                                        known_parameter->name)],
965               (new_output_time - output_time) / kParameterStep,
966               kComparisonPrecision);
967 }
968 
TEST(UnknownGradientTest,Model)969 TEST(UnknownGradientTest, Model) {
970   const double input_time = 100;
971   const int64_t num_inputs_per_output = 2;
972   std::shared_ptr<Node> unknown =
973       model::MakeUnknownNode({0, "unknown", nullptr});
974   std::shared_ptr<Parameter> known_parameter =
975       model::MakeParameter("parallelism",
976                            std::make_shared<SharedState>(
977                                /*value=*/model::kAutotune, nullptr, nullptr),
978                            /*min=*/1,
979                            /*max=*/5);
980   std::shared_ptr<Node> async_known_many =
981       model::MakeAsyncKnownRatioNode({1, "async_known_many", unknown},
982                                      num_inputs_per_output, {known_parameter});
983   unknown->record_element();
984   unknown->add_processing_time(100);
985   unknown->add_input(async_known_many);
986   async_known_many->record_element();
987   async_known_many->add_processing_time(300);
988   Model::NodeValues input_times;
989   input_times[kModelInputTimeKey] = input_time;
990   Model::ParameterGradients gradients;
991   known_parameter->value = 1;
992   double output_time = unknown->OutputTime(&input_times, &gradients);
993   known_parameter->value += kParameterStep;
994   double new_output_time = unknown->OutputTime(&input_times, nullptr);
995   EXPECT_NEAR(gradients[std::make_pair(async_known_many->long_name(),
996                                        known_parameter->name)],
997               (new_output_time - output_time) / kParameterStep,
998               kComparisonPrecision);
999 }
1000 
TEST(SnapshotTest,Model)1001 TEST(SnapshotTest, Model) {
1002   std::shared_ptr<Node> root =
1003       model::MakeUnknownNode({0, std::to_string(0), nullptr});
1004   std::shared_ptr<Node> current = root;
1005 
1006   int64_t num_nodes = 20;
1007   for (int64_t i = 1; i < num_nodes; i++) {
1008     std::shared_ptr<Node> input =
1009         model::MakeUnknownNode({i, std::to_string(i), current});
1010     input->set_autotune(std::rand() % 2 == 1);
1011     current->add_input(input);
1012     current = input;
1013   }
1014 
1015   std::shared_ptr<Node> cloned_root = root->Snapshot();
1016   current = root;
1017   std::shared_ptr<Node> cloned_current = cloned_root;
1018 
1019   for (int64_t i = 0; i < num_nodes; i++) {
1020     EXPECT_EQ(current->id(), cloned_current->id());
1021     EXPECT_EQ(current->name(), cloned_current->name());
1022     EXPECT_EQ(current->autotune(), cloned_current->autotune());
1023     EXPECT_NE(current.get(), cloned_current.get());
1024 
1025     if (i > 0) {
1026       EXPECT_EQ(current->output()->long_name(),
1027                 cloned_current->output()->long_name());
1028       EXPECT_EQ(current->output()->autotune(),
1029                 cloned_current->output()->autotune());
1030       EXPECT_NE(current->output(), cloned_current->output());
1031     } else {
1032       EXPECT_EQ(current->output(), nullptr);
1033       EXPECT_EQ(cloned_current->output(), nullptr);
1034     }
1035 
1036     if (i < num_nodes - 1) {
1037       current = current->inputs().front();
1038       cloned_current = cloned_current->inputs().front();
1039     }
1040   }
1041 }
1042 
TEST(SaveModelTest,Model)1043 TEST(SaveModelTest, Model) {
1044   model::Model model;
1045   std::shared_ptr<Node> root = model::MakeUnknownNode({0, "unknown0", nullptr});
1046   model.AddNode([&root](model::Node::Args args) { return root; }, root->name(),
1047                 nullptr, &root);
1048   std::shared_ptr<Node> current = root;
1049 
1050   int64_t num_nodes = 20;
1051   for (int64_t i = 1; i < num_nodes; i++) {
1052     std::shared_ptr<Node> input;
1053     switch (i % 6) {
1054       case 0:
1055         input = model::MakeInterleaveManyNode(
1056             {i, "interleave_many" + std::to_string(i), current},
1057             {model::MakeParameter("cycle_length", nullptr, /*min=*/1,
1058                                   /*max=*/1)});
1059         break;
1060       case 1:
1061         input = model::MakeAsyncInterleaveManyNode(
1062             {i, "async_interleave_many", current},
1063             {model::MakeParameter(
1064                  "parallelism",
1065                  std::make_shared<SharedState>(
1066                      /*value=*/model::kAutotune, nullptr, nullptr),
1067                  /*min=*/1,
1068                  /*max=*/7),
1069              model::MakeParameter("cycle_length", nullptr, /*min=*/1,
1070                                   /*max=*/1)});
1071         break;
1072       case 2:
1073         input = model::MakeKnownRatioNode(
1074             {i, "known_many" + std::to_string(i), current}, 3);
1075         break;
1076       case 3:
1077         input = model::MakeAsyncKnownRatioNode(
1078             {i, "async_known_many", current}, 4,
1079             {model::MakeParameter(
1080                  "parallelism",
1081                  std::make_shared<SharedState>(
1082                      /*value=*/model::kAutotune, nullptr, nullptr),
1083                  /*min=*/1,
1084                  /*max=*/5),
1085              model::MakeParameter("cycle_length", nullptr, /*min=*/1,
1086                                   /*max=*/1)});
1087         break;
1088       case 4:
1089         input = model::MakeUnknownRatioNode(
1090             {i, "unknown_many" + std::to_string(i), current});
1091         break;
1092       default:
1093         input =
1094             model::MakeUnknownNode({i, "unknown" + std::to_string(i), current});
1095     }
1096     input->record_element();
1097     input->add_processing_time(i * 50);
1098     input->record_buffer_event(i * 33, i * 5);
1099     input->set_autotune(true);
1100     model.AddNode([&input](model::Node::Args args) { return input; },
1101                   input->name(), current, &input);
1102     current = input;
1103   }
1104 
1105   // Make Save->Load roundtrip.
1106   ModelProto::OptimizationParams optimization_params;
1107   optimization_params.set_algorithm(AutotuneAlgorithm::GRADIENT_DESCENT);
1108   optimization_params.set_cpu_budget(64);
1109   optimization_params.set_ram_budget(1024);
1110   optimization_params.set_model_input_time(43653.34534);
1111   TF_ASSERT_OK(model.Save("/tmp/autotune_model_test",
1112                           model.output()->Snapshot(), optimization_params));
1113 
1114   std::unique_ptr<model::Model> restored_model;
1115   ModelProto::OptimizationParams restored_optimization_params;
1116   TF_ASSERT_OK(model.Load("/tmp/autotune_model_test", &restored_model,
1117                           &restored_optimization_params));
1118 
1119   // Check optimization parameters.
1120   EXPECT_EQ(optimization_params.algorithm(),
1121             restored_optimization_params.algorithm());
1122   EXPECT_EQ(optimization_params.cpu_budget(),
1123             restored_optimization_params.cpu_budget());
1124   EXPECT_EQ(optimization_params.ram_budget(),
1125             restored_optimization_params.ram_budget());
1126   EXPECT_EQ(optimization_params.model_input_time(),
1127             restored_optimization_params.model_input_time());
1128 
1129   std::shared_ptr<Node> restored_root = restored_model->output();
1130   std::shared_ptr<Node> restored_current = restored_root;
1131   current = root;
1132   EXPECT_EQ(current->output(), nullptr);
1133   EXPECT_EQ(restored_current->output(), nullptr);
1134   while (!current->inputs().empty() && !restored_current->inputs().empty()) {
1135     EXPECT_EQ(current->id(), restored_current->id());
1136     EXPECT_EQ(current->name(), restored_current->name());
1137     EXPECT_EQ(current->autotune(), restored_current->autotune());
1138     Model::NodeValues input_times_actual, input_times_expected;
1139     input_times_actual.clear();
1140     input_times_expected.clear();
1141     EXPECT_EQ(current->OutputTime(&input_times_actual, nullptr),
1142               restored_current->OutputTime(&input_times_expected, nullptr));
1143     EXPECT_EQ(current->TotalBufferedBytes(),
1144               restored_current->TotalBufferedBytes());
1145     EXPECT_EQ(current->TotalMaximumBufferedBytes(),
1146               restored_current->TotalMaximumBufferedBytes());
1147     EXPECT_EQ(current->Ratio(), restored_current->Ratio());
1148     EXPECT_NE(current.get(), restored_current.get());
1149 
1150     current = current->inputs().front();
1151     restored_current = restored_current->inputs().front();
1152 
1153     EXPECT_EQ(current->output()->long_name(), current->output()->long_name());
1154     EXPECT_EQ(current->output()->autotune(),
1155               restored_current->output()->autotune());
1156     EXPECT_NE(current->output(), restored_current->output());
1157   }
1158   EXPECT_TRUE(current->inputs().empty());
1159   EXPECT_TRUE(restored_current->inputs().empty());
1160 }
1161 
1162 class ComputeWaitTimeTest
1163     : public ::testing::TestWithParam<std::tuple<double, double, double>> {};
1164 
TEST_P(ComputeWaitTimeTest,Model)1165 TEST_P(ComputeWaitTimeTest, Model) {
1166   const double producer_time = std::get<0>(GetParam());
1167   const double consumer_time = std::get<1>(GetParam());
1168   const double buffer_size = std::get<2>(GetParam());
1169 
1170   double producer_time_derivative = 0.0L;
1171   double consumer_time_derivative = 0.0L;
1172   double buffer_size_derivative = 0.0L;
1173 
1174   double wait_time = model::Node::ComputeWaitTime(
1175       producer_time, consumer_time, buffer_size, &producer_time_derivative,
1176       &consumer_time_derivative, &buffer_size_derivative);
1177 
1178   double new_wait_time = model::Node::ComputeWaitTime(
1179       producer_time + kParameterStep, consumer_time, buffer_size, nullptr,
1180       nullptr, nullptr);
1181   EXPECT_NEAR(producer_time_derivative,
1182               (new_wait_time - wait_time) / kParameterStep,
1183               kComparisonPrecision);
1184 
1185   if (producer_time >= kParameterStep) {
1186     new_wait_time = model::Node::ComputeWaitTime(producer_time - kParameterStep,
1187                                                  consumer_time, buffer_size,
1188                                                  nullptr, nullptr, nullptr);
1189     EXPECT_NEAR(producer_time_derivative,
1190                 (wait_time - new_wait_time) / kParameterStep,
1191                 kComparisonPrecision);
1192   }
1193 
1194   new_wait_time = model::Node::ComputeWaitTime(
1195       producer_time, consumer_time + kParameterStep, buffer_size, nullptr,
1196       nullptr, nullptr);
1197   EXPECT_NEAR(consumer_time_derivative,
1198               (new_wait_time - wait_time) / kParameterStep,
1199               kComparisonPrecision);
1200 
1201   if (consumer_time >= kParameterStep) {
1202     new_wait_time = model::Node::ComputeWaitTime(
1203         producer_time, consumer_time - kParameterStep, buffer_size, nullptr,
1204         nullptr, nullptr);
1205     EXPECT_NEAR(consumer_time_derivative,
1206                 (wait_time - new_wait_time) / kParameterStep,
1207                 kComparisonPrecision);
1208   }
1209 
1210   new_wait_time = model::Node::ComputeWaitTime(producer_time, consumer_time,
1211                                                buffer_size + kParameterStep,
1212                                                nullptr, nullptr, nullptr);
1213   EXPECT_NEAR(buffer_size_derivative,
1214               (new_wait_time - wait_time) / kParameterStep,
1215               kComparisonPrecision);
1216 
1217   if (buffer_size >= kParameterStep) {
1218     new_wait_time = model::Node::ComputeWaitTime(producer_time, consumer_time,
1219                                                  buffer_size - kParameterStep,
1220                                                  nullptr, nullptr, nullptr);
1221     EXPECT_NEAR(buffer_size_derivative,
1222                 (wait_time - new_wait_time) / kParameterStep,
1223                 kComparisonPrecision);
1224   }
1225 }
1226 
1227 INSTANTIATE_TEST_SUITE_P(
1228     Test, ComputeWaitTimeTest,
1229     ::testing::Combine(::testing::Values(0, 20, 40, 80, 100),
1230                        ::testing::Values(0, 20, 40, 80, 100),
1231                        ::testing::Values(0, 1, 2, 4, 10, 20, 40)));
1232 
1233 class SelfProcessingTimeTest : public ::testing::TestWithParam<int64_t> {};
1234 
TEST_P(SelfProcessingTimeTest,Model)1235 TEST_P(SelfProcessingTimeTest, Model) {
1236   const int64_t add_times = GetParam();
1237   std::shared_ptr<Node> source = model::MakeSourceNode({0, "source", nullptr});
1238   for (int i = 0; i < add_times; i++) {
1239     source->add_processing_time(i);
1240     source->record_element();
1241   }
1242   double self_processing_time =
1243       (add_times == 0 ? 0.0 : (static_cast<double>(add_times) - 1.0) / 2.0);
1244   EXPECT_EQ(source->SelfProcessingTime(), self_processing_time);
1245 }
1246 
1247 INSTANTIATE_TEST_SUITE_P(Test, SelfProcessingTimeTest,
1248                          ::testing::Values(0, 1, 2, 5, 10, 20, 40));
1249 
1250 class OptimizeZeroRamBudgetTest
1251     : public ::testing::TestWithParam<model::AutotuneAlgorithm> {};
1252 
TEST_P(OptimizeZeroRamBudgetTest,Model)1253 TEST_P(OptimizeZeroRamBudgetTest, Model) {
1254   const model::AutotuneAlgorithm algorithm = GetParam();
1255 
1256   std::shared_ptr<mutex> mutex1 = std::make_shared<mutex>();
1257   std::shared_ptr<condition_variable> cv1 =
1258       std::make_shared<condition_variable>();
1259   std::shared_ptr<Node> node1 = model::MakeAsyncKnownRatioNode(
1260       {1, "1", nullptr}, 2,
1261       {model::MakeParameter("parallelism",
1262                             std::make_shared<SharedState>(
1263                                 /*value=*/model::kAutotune, mutex1, cv1),
1264                             /*min=*/1, /*max=*/5)});
1265   node1->record_buffer_event(1, 1);
1266   node1->record_element();
1267 
1268   std::shared_ptr<mutex> mutex2 = std::make_shared<mutex>();
1269   std::shared_ptr<condition_variable> cv2 =
1270       std::make_shared<condition_variable>();
1271   std::shared_ptr<Node> node2 = model::MakeAsyncKnownRatioNode(
1272       {2, "2", node1}, 5,
1273       {model::MakeParameter("buffer_size",
1274                             std::make_shared<SharedState>(
1275                                 /*value=*/model::kAutotune, mutex2, cv2),
1276                             /*min=*/0, /*max=*/6)});
1277   node2->record_buffer_event(1, 1);
1278   node2->record_element();
1279 
1280   std::shared_ptr<mutex> mutex3 = std::make_shared<mutex>();
1281   std::shared_ptr<condition_variable> cv3 =
1282       std::make_shared<condition_variable>();
1283   std::shared_ptr<Node> node3 = model::MakeAsyncInterleaveManyNode(
1284       {3, "3", node2},
1285       {model::MakeParameter("parallelism",
1286                             std::make_shared<SharedState>(
1287                                 /*value=*/model::kAutotune, mutex3, cv3),
1288                             /*min=*/1, /*max=*/7),
1289        model::MakeParameter(kCycleLength, nullptr,
1290                             /*min=*/1,
1291                             /*max=*/1)});
1292   node3->record_buffer_event(1, 1);
1293   node3->record_element();
1294 
1295   EXPECT_EQ(node1->parameter_value("parallelism"), model::kAutotune);
1296   EXPECT_EQ(node2->parameter_value("buffer_size"), model::kAutotune);
1297   EXPECT_EQ(node3->parameter_value("parallelism"), model::kAutotune);
1298 
1299   model::Model model;
1300   model.AddNode([&node1](model::Node::Args args) { return node1; }, "1",
1301                 nullptr, &node1);
1302   model.AddNode([&node2](model::Node::Args args) { return node2; }, "2", node1,
1303                 &node2);
1304   model.AddNode([&node3](model::Node::Args args) { return node3; }, "3", node2,
1305                 &node3);
1306 
1307   CancellationManager cancellation_manager;
1308   model.Optimize(algorithm, 40, 0, 0, &cancellation_manager);
1309   EXPECT_EQ(node1->parameter_value("parallelism"), 1);
1310   EXPECT_EQ(node2->parameter_value("buffer_size"), 0);
1311   EXPECT_EQ(node3->parameter_value("parallelism"), 1);
1312 }
1313 
1314 INSTANTIATE_TEST_SUITE_P(Test, OptimizeZeroRamBudgetTest,
1315                          ::testing::Values(0, 1, 2, 3));
1316 
TEST(RecordTimeTest,RecordTimeTest)1317 TEST(RecordTimeTest, RecordTimeTest) {
1318   std::shared_ptr<Node> source = model::MakeSourceNode({});
1319   EXPECT_FALSE(source->is_recording());
1320   source->record_start(100);
1321   EXPECT_TRUE(source->is_recording());
1322   source->record_stop(200);
1323   EXPECT_FALSE(source->is_recording());
1324 }
1325 
TEST(ModelTest,ModelMetrics)1326 TEST(ModelTest, ModelMetrics) {
1327   CellReader<std::string> cell_reader("/tensorflow/data/model");
1328   model::Model model;
1329   std::shared_ptr<Node> root = model::MakeUnknownNode({0, "unknown0", nullptr});
1330   model.AddNode([&root](model::Node::Args args) { return root; }, root->name(),
1331                 nullptr, &root);
1332   std::string model_id = strings::StrCat(reinterpret_cast<uint64>(&model));
1333   EXPECT_THAT(cell_reader.Read(model_id),
1334               AllOf(HasSubstr("key: 0"), HasSubstr("name: \"unknown0\""),
1335                     HasSubstr("autotune: true")));
1336 }
1337 
1338 class ModelTimingTest : public ::testing::Test {
1339  public:
1340   // Builds a Model from its text proto.
BuildModelFromProto(const std::string & model_pbtxt)1341   void BuildModelFromProto(const std::string& model_pbtxt) {
1342     ModelProto model_proto;
1343     protobuf::TextFormat::ParseFromString(model_pbtxt, &model_proto);
1344     TF_CHECK_OK(Model::FromProto(model_proto, &model_));
1345     auto nodes = model_->output()->CollectNodes(
1346         TraversalOrder::BFS, [](const std::shared_ptr<Node>) { return true; });
1347     node_map_.clear();
1348     node_map_[model_->output()->id()] = model_->output().get();
1349     for (const auto& node : nodes) {
1350       node_map_[node->id()] = node.get();
1351     }
1352   }
1353 
1354   // Computes the timing given a Model text proto.
ComputeModelTiming(const std::string & model_pbtxt)1355   void ComputeModelTiming(const std::string& model_pbtxt) {
1356     BuildModelFromProto(model_pbtxt);
1357     model_timing_ = std::make_unique<ModelTiming>(model_->output());
1358   }
1359 
1360   // Gets the timing information of a node given its id.
GetNodeTiming(int64_t node_id) const1361   const ModelTiming::NodeTiming* GetNodeTiming(int64_t node_id) const {
1362     return model_timing_->GetTiming(node_map_.at(node_id));
1363   }
1364 
1365   // Gets the node given its id.
GetNode(int64_t node_id) const1366   const Node* GetNode(int64_t node_id) const { return node_map_.at(node_id); }
MutableGetNode(int64_t node_id) const1367   Node* MutableGetNode(int64_t node_id) const { return node_map_.at(node_id); }
1368 
1369  protected:
1370   std::unique_ptr<Model> model_;
1371   std::unique_ptr<ModelTiming> model_timing_;
1372   absl::flat_hash_map<int64_t, Node*> node_map_;
1373 };
1374 
TEST_F(ModelTimingTest,Interleave)1375 TEST_F(ModelTimingTest, Interleave) {
1376   ComputeModelTiming(R"pb(
1377     nodes: {
1378       key: 1
1379       value: {
1380         id: 1
1381         name: "Batch"
1382         autotune: true
1383         num_elements: 100
1384         processing_time: 1000
1385         node_class: KNOWN_RATIO
1386         ratio: 1
1387         inputs: 2
1388       }
1389     }
1390     nodes: {
1391       key: 2
1392       value: {
1393         id: 2
1394         name: "Interleave"
1395         autotune: true
1396         num_elements: 100
1397         processing_time: 1000
1398         node_class: INTERLEAVE_MANY
1399         inputs: 3
1400         inputs: 4
1401         parameters: { name: "cycle_length" value: 2 tunable: false }
1402       }
1403     }
1404     nodes: {
1405       key: 3
1406       value: {
1407         id: 3
1408         name: "Batch"
1409         autotune: true
1410         num_elements: 60
1411         processing_time: 1200
1412         node_class: KNOWN_RATIO
1413         ratio: 1
1414       }
1415     }
1416     nodes: {
1417       key: 4
1418       value: {
1419         id: 4
1420         name: "Batch"
1421         autotune: true
1422         num_elements: 40
1423         processing_time: 800
1424         node_class: KNOWN_RATIO
1425         ratio: 1
1426       }
1427     }
1428     output: 1
1429   )pb");
1430 
1431   EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/1)->pipeline_ratio);
1432   EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/2)->pipeline_ratio);
1433   EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/3)->pipeline_ratio);
1434   EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/4)->pipeline_ratio);
1435 
1436   EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->self_time_nsec);
1437   EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/2)->self_time_nsec);
1438   EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/3)->self_time_nsec);
1439   EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->self_time_nsec);
1440 
1441   EXPECT_DOUBLE_EQ(40, GetNodeTiming(/*node_id=*/1)->total_time_nsec);
1442   EXPECT_DOUBLE_EQ(30, GetNodeTiming(/*node_id=*/2)->total_time_nsec);
1443   EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/3)->total_time_nsec);
1444   EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->total_time_nsec);
1445 }
1446 
1447 class ParallelInterleaveTimingTest
1448     : public ModelTimingTest,
1449       public ::testing::WithParamInterface<
1450           std::tuple<int32_t, int32_t, int32_t>> {};
1451 
TEST_P(ParallelInterleaveTimingTest,ScenarioTest)1452 TEST_P(ParallelInterleaveTimingTest, ScenarioTest) {
1453   const int32_t parallelism = std::get<0>(GetParam());
1454   const int32_t deterministic = std::get<1>(GetParam());
1455   const int32_t cycle_length = std::get<2>(GetParam());
1456   ComputeModelTiming(strings::Printf(
1457       R"pb(
1458         nodes: {
1459           key: 1
1460           value: {
1461             id: 1
1462             name: "Batch"
1463             autotune: true
1464             num_elements: 100
1465             processing_time: 1000
1466             node_class: KNOWN_RATIO
1467             ratio: 1
1468             inputs: 2
1469           }
1470         }
1471         nodes: {
1472           key: 2
1473           value: {
1474             id: 2
1475             name: "ParallelInterleaveV4"
1476             autotune: true
1477             num_elements: 100
1478             processing_time: 2000
1479             node_class: ASYNC_INTERLEAVE_MANY
1480             inputs: 3
1481             inputs: 4
1482             inputs: 5
1483             inputs: 6
1484             inputs: 7
1485             parameters: {
1486               name: "parallelism"
1487               value: %d
1488               min: 1
1489               max: 10
1490               tunable: true
1491             }
1492             parameters: { name: "deterministic" value: %d tunable: false }
1493             parameters: { name: "cycle_length" value: %d tunable: false }
1494           }
1495         }
1496         nodes: {
1497           key: 3
1498           value: {
1499             id: 3
1500             name: "Batch"
1501             autotune: true
1502             num_elements: 60
1503             processing_time: 60
1504             node_class: KNOWN_RATIO
1505             ratio: 1
1506           }
1507         }
1508         nodes: {
1509           key: 4
1510           value: {
1511             id: 4
1512             name: "Batch"
1513             autotune: true
1514             num_elements: 60
1515             processing_time: 1200
1516             node_class: KNOWN_RATIO
1517             ratio: 1
1518           }
1519         }
1520         nodes: {
1521           key: 5
1522           value: {
1523             id: 5
1524             name: "Batch"
1525             autotune: true
1526             num_elements: 40
1527             processing_time: 1200
1528             node_class: KNOWN_RATIO
1529             ratio: 1
1530           }
1531         }
1532         nodes: {
1533           key: 6
1534           value: {
1535             id: 6
1536             name: "Batch"
1537             autotune: true
1538             num_elements: 60
1539             processing_time: 2400
1540             node_class: KNOWN_RATIO
1541             ratio: 1
1542           }
1543         }
1544         nodes: {
1545           key: 7
1546           value: {
1547             id: 7
1548             name: "Batch"
1549             autotune: false  # Marked as an inactive input
1550             num_elements: 40
1551             processing_time: 2000
1552             node_class: KNOWN_RATIO
1553             ratio: 1
1554           }
1555         }
1556         output: 1
1557       )pb",
1558       parallelism, deterministic, cycle_length));
1559 
1560   EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/1)->pipeline_ratio);
1561   EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/2)->pipeline_ratio);
1562   EXPECT_DOUBLE_EQ(1.0 / cycle_length,
1563                    GetNodeTiming(/*node_id=*/3)->pipeline_ratio);
1564   EXPECT_DOUBLE_EQ(1.0 / cycle_length,
1565                    GetNodeTiming(/*node_id=*/4)->pipeline_ratio);
1566   EXPECT_DOUBLE_EQ(1.0 / cycle_length,
1567                    GetNodeTiming(/*node_id=*/5)->pipeline_ratio);
1568   EXPECT_DOUBLE_EQ(1.0 / cycle_length,
1569                    GetNodeTiming(/*node_id=*/6)->pipeline_ratio);
1570   EXPECT_DOUBLE_EQ(0, GetNodeTiming(/*node_id=*/7)->pipeline_ratio);
1571 
1572   const double expected_self_time_1 = 1000.0 / 100.0;
1573   const double expected_self_time_2 = 2000.0 / 100.0 / parallelism;
1574   const double expected_self_time_3 = 60.0 / 60.0;
1575   const double expected_self_time_4 = 1200.0 / 60.0;
1576   const double expected_self_time_5 = 1200.0 / 40.0;
1577   const double expected_self_time_6 = 2400.0 / 60.0;
1578   const double expected_self_time_7 = 2000.0 / 40.0;
1579 
1580   EXPECT_DOUBLE_EQ(expected_self_time_1,
1581                    GetNodeTiming(/*node_id=*/1)->self_time_nsec);
1582   EXPECT_DOUBLE_EQ(expected_self_time_2,
1583                    GetNodeTiming(/*node_id=*/2)->self_time_nsec);
1584   EXPECT_DOUBLE_EQ(expected_self_time_3,
1585                    GetNodeTiming(/*node_id=*/3)->self_time_nsec);
1586   EXPECT_DOUBLE_EQ(expected_self_time_4,
1587                    GetNodeTiming(/*node_id=*/4)->self_time_nsec);
1588   EXPECT_DOUBLE_EQ(expected_self_time_5,
1589                    GetNodeTiming(/*node_id=*/5)->self_time_nsec);
1590   EXPECT_DOUBLE_EQ(expected_self_time_6,
1591                    GetNodeTiming(/*node_id=*/6)->self_time_nsec);
1592   EXPECT_DOUBLE_EQ(expected_self_time_7,
1593                    GetNodeTiming(/*node_id=*/7)->self_time_nsec);
1594 
1595   EXPECT_DOUBLE_EQ(expected_self_time_1,
1596                    GetNodeTiming(/*node_id=*/1)->total_time_nsec);
1597   EXPECT_DOUBLE_EQ(expected_self_time_3,
1598                    GetNodeTiming(/*node_id=*/3)->total_time_nsec);
1599   EXPECT_DOUBLE_EQ(expected_self_time_4,
1600                    GetNodeTiming(/*node_id=*/4)->total_time_nsec);
1601   EXPECT_DOUBLE_EQ(expected_self_time_5,
1602                    GetNodeTiming(/*node_id=*/5)->total_time_nsec);
1603   EXPECT_DOUBLE_EQ(expected_self_time_6,
1604                    GetNodeTiming(/*node_id=*/6)->total_time_nsec);
1605   EXPECT_DOUBLE_EQ(0, GetNodeTiming(/*node_id=*/7)->total_time_nsec);
1606 
1607   const double max_input_time = expected_self_time_6;
1608   double input_throughput = 1.0 / expected_self_time_4 +
1609                             1.0 / expected_self_time_5 +
1610                             1.0 / expected_self_time_6;
1611   const double active_inputs = 3.0;
1612   double expected_input_time;
1613   if (deterministic == 1) {
1614     expected_input_time = max_input_time / std::min(parallelism, cycle_length);
1615   } else {
1616     if (std::min(parallelism, cycle_length) < active_inputs) {
1617       input_throughput *= std::min(parallelism, cycle_length) / active_inputs;
1618     }
1619     expected_input_time = 1.0 / input_throughput;
1620   }
1621   EXPECT_DOUBLE_EQ(expected_input_time + expected_self_time_2,
1622                    GetNodeTiming(/*node_id=*/2)->total_time_nsec);
1623 }
1624 
1625 INSTANTIATE_TEST_SUITE_P(ParallelInterleaveTimingTest,
1626                          ParallelInterleaveTimingTest,
1627                          ::testing::Combine(::testing::Values(1, 2, 3),
1628                                             ::testing::Values(0, 1),
1629                                             ::testing::Values(1, 2, 3)));
1630 
TEST_F(ModelTimingTest,ParallelInterleave_Batch_ParallelMap)1631 TEST_F(ModelTimingTest, ParallelInterleave_Batch_ParallelMap) {
1632   ComputeModelTiming(R"pb(
1633     nodes: {
1634       key: 1
1635       value: {
1636         id: 1
1637         name: "Batch"
1638         autotune: true
1639         num_elements: 100
1640         processing_time: 1000
1641         node_class: KNOWN_RATIO
1642         ratio: 1
1643         inputs: 2
1644       }
1645     }
1646     nodes: {
1647       key: 2
1648       value: {
1649         id: 2
1650         name: "ParallelInterleaveV4"
1651         autotune: true
1652         num_elements: 100
1653         processing_time: 2000
1654         node_class: ASYNC_INTERLEAVE_MANY
1655         inputs: 3
1656         inputs: 4
1657         inputs: 5
1658         parameters: {
1659           name: "parallelism"
1660           value: 2
1661           min: 1
1662           max: 10
1663           tunable: true
1664         }
1665         parameters: { name: "cycle_length" value: 2 tunable: false }
1666         parameters: { name: "deterministic" value: 0 tunable: false }
1667       }
1668     }
1669     nodes: {
1670       key: 3
1671       value: {
1672         id: 3
1673         name: "Batch"
1674         autotune: true
1675         num_elements: 60
1676         processing_time: 60
1677         node_class: KNOWN_RATIO
1678         ratio: 1
1679       }
1680     }
1681     nodes: {
1682       key: 4
1683       value: {
1684         id: 4
1685         name: "Batch"
1686         autotune: true
1687         num_elements: 60
1688         processing_time: 1200
1689         node_class: KNOWN_RATIO
1690         ratio: 2
1691         inputs: 6
1692       }
1693     }
1694     nodes: {
1695       key: 5
1696       value: {
1697         id: 5
1698         name: "Batch"
1699         autotune: true
1700         num_elements: 40
1701         processing_time: 800
1702         node_class: KNOWN_RATIO
1703         ratio: 2
1704         inputs: 7
1705       }
1706     }
1707     nodes: {
1708       key: 6
1709       value: {
1710         id: 6
1711         name: "ParallelMapV2"
1712         autotune: true
1713         num_elements: 120
1714         processing_time: 2400
1715         node_class: ASYNC_KNOWN_RATIO
1716         ratio: 1
1717         parameters: {
1718           name: "parallelism"
1719           value: 2
1720           min: 1
1721           max: 16
1722           tunable: true
1723         }
1724       }
1725     }
1726     nodes: {
1727       key: 7
1728       value: {
1729         id: 7
1730         name: "ParallelMapV2"
1731         autotune: true
1732         num_elements: 120
1733         processing_time: 2400
1734         node_class: ASYNC_KNOWN_RATIO
1735         ratio: 1
1736         parameters: {
1737           name: "parallelism"
1738           value: 2
1739           min: 1
1740           max: 16
1741           tunable: true
1742         }
1743       }
1744     }
1745     output: 1
1746   )pb");
1747 
1748   EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/1)->pipeline_ratio);
1749   EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/2)->pipeline_ratio);
1750   EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/3)->pipeline_ratio);
1751   EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/4)->pipeline_ratio);
1752   EXPECT_DOUBLE_EQ(0.5, GetNodeTiming(/*node_id=*/5)->pipeline_ratio);
1753   EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/6)->pipeline_ratio);
1754   EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/7)->pipeline_ratio);
1755 
1756   EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->self_time_nsec);
1757   EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/2)->self_time_nsec);
1758   EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->self_time_nsec);
1759   EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->self_time_nsec);
1760   EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/5)->self_time_nsec);
1761   EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/6)->self_time_nsec);
1762   EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/7)->self_time_nsec);
1763 
1764   EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/1)->total_time_nsec);
1765   EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/2)->total_time_nsec);
1766   EXPECT_DOUBLE_EQ(1, GetNodeTiming(/*node_id=*/3)->total_time_nsec);
1767   EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/4)->total_time_nsec);
1768   EXPECT_DOUBLE_EQ(20, GetNodeTiming(/*node_id=*/5)->total_time_nsec);
1769   EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/6)->total_time_nsec);
1770   EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/7)->total_time_nsec);
1771 }
1772 
1773 class BufferSizeTest : public ::testing::Test {
1774  public:
ReadModel(const std::string & model_pbtxt)1775   void ReadModel(const std::string& model_pbtxt) {
1776     ModelProto model_proto;
1777     protobuf::TextFormat::ParseFromString(model_pbtxt, &model_proto);
1778     TF_CHECK_OK(Model::FromProto(model_proto, &model_));
1779     auto nodes = model_->output()->CollectNodes(
1780         TraversalOrder::BFS, [](const std::shared_ptr<Node>) { return true; });
1781     node_map_.clear();
1782     node_map_[model_->output()->id()] = model_->output();
1783     for (const auto& node : nodes) {
1784       node_map_[node->id()] = node;
1785     }
1786   }
1787 
1788   // Returns a node given its node id. If node id does not exist, it will fail.
GetNode(int64_t node_id) const1789   std::shared_ptr<Node> GetNode(int64_t node_id) const {
1790     return node_map_.at(node_id);
1791   }
1792 
1793  protected:
1794   std::unique_ptr<Model> model_;
1795   absl::flat_hash_map<int64_t, std::shared_ptr<Node>> node_map_;
1796 };
1797 
TEST_F(BufferSizeTest,OptimizeBuffers_PlentyOfMemory)1798 TEST_F(BufferSizeTest, OptimizeBuffers_PlentyOfMemory) {
1799   ReadModel(R"pb(
1800     nodes: {
1801       key: 1
1802       value: {
1803         id: 1
1804         name: "Prefetch"
1805         autotune: true
1806         bytes_produced: 10000
1807         num_elements: 100
1808         processing_time: 2000
1809         node_class: ASYNC_KNOWN_RATIO
1810         inputs: 2
1811         ratio: 1
1812         parameters: {
1813           name: "buffer_size"
1814           value: 3
1815           state_value: 3
1816           min: 1
1817           max: 10
1818           tunable: true
1819         }
1820       }
1821     }
1822     nodes: {
1823       key: 2
1824       value: {
1825         id: 2
1826         name: "Prefetch"
1827         autotune: true
1828         bytes_produced: 10000
1829         num_elements: 100
1830         processing_time: 2000
1831         node_class: ASYNC_KNOWN_RATIO
1832         inputs: 3
1833         ratio: 1
1834         parameters: {
1835           name: "buffer_size"
1836           value: 5
1837           state_value: 5
1838           min: 1
1839           max: 10
1840           tunable: true
1841         }
1842       }
1843     }
1844     nodes: {
1845       key: 3
1846       value: {
1847         id: 3
1848         name: "Prefetch"
1849         autotune: true
1850         bytes_produced: 10000
1851         num_elements: 100
1852         processing_time: 2000
1853         node_class: ASYNC_KNOWN_RATIO
1854         inputs: 4
1855         ratio: 1
1856         parameters: {
1857           name: "buffer_size"
1858           value: 5
1859           state_value: 5
1860           min: 1
1861           max: 10
1862           tunable: true
1863         }
1864       }
1865     }
1866     nodes: {
1867       key: 4
1868       value: {
1869         id: 4
1870         name: "Prefetch"
1871         autotune: true
1872         bytes_produced: 10000
1873         num_elements: 100
1874         processing_time: 2000
1875         node_class: ASYNC_KNOWN_RATIO
1876         inputs: 5
1877         ratio: 1
1878         parameters: {
1879           name: "buffer_size"
1880           value: 5
1881           state_value: 5
1882           min: 1
1883           max: 8
1884           tunable: true
1885         }
1886       }
1887     }
1888     nodes: {
1889       key: 5
1890       value: {
1891         id: 5
1892         name: "Prefetch"
1893         autotune: true
1894         bytes_produced: 10000
1895         num_elements: 100
1896         processing_time: 2000
1897         node_class: ASYNC_KNOWN_RATIO
1898         ratio: 1
1899         parameters: {
1900           name: "buffer_size"
1901           value: 8
1902           state_value: 8
1903           min: 1
1904           max: 8
1905           tunable: true
1906         }
1907       }
1908     }
1909     output: 1
1910   )pb");
1911 
1912   std::shared_ptr<Node> node_1 = GetNode(1);
1913   std::shared_ptr<Node> node_2 = GetNode(2);
1914   std::shared_ptr<Node> node_3 = GetNode(3);
1915   std::shared_ptr<Node> node_4 = GetNode(4);
1916   std::shared_ptr<Node> node_5 = GetNode(5);
1917   // Set node 1 low watermark to 1 and high watermark to 2. Expect that it is
1918   // downsized to 2.
1919   node_1->record_buffer_event(100, 1);
1920   node_1->record_buffer_event(100, 1);
1921   EXPECT_EQ(1, node_1->buffered_elements_low());
1922   EXPECT_EQ(2, node_1->buffered_elements_high());
1923   // Set node 2 low watermark to 1 and high watermark to 5. Expect that it is
1924   // not changed.
1925   node_2->record_buffer_event(100, 1);
1926   node_2->record_buffer_event(400, 4);
1927   node_2->record_buffer_event(-100, -1);
1928   EXPECT_EQ(1, node_2->buffered_elements_low());
1929   EXPECT_EQ(5, node_2->buffered_elements_high());
1930   // Set node 3 low watermark to 0 and high watermark to 5. Expect that it is
1931   // upsized to 10.
1932   node_3->record_buffer_event(100, 1);
1933   node_3->record_buffer_event(-100, -1);
1934   node_3->record_buffer_event(500, 5);
1935   node_3->record_buffer_event(-100, -1);
1936   EXPECT_EQ(0, node_3->buffered_elements_low());
1937   EXPECT_EQ(5, node_3->buffered_elements_high());
1938   // Set node 4 low watermark to 0 and high watermark to 5. Its max buffer size
1939   // is set to 8. Expect that it is upsized to 8.
1940   node_4->record_buffer_event(100, 1);
1941   node_4->record_buffer_event(-100, -1);
1942   node_4->record_buffer_event(500, 5);
1943   node_4->record_buffer_event(-100, -1);
1944   EXPECT_EQ(0, node_4->buffered_elements_low());
1945   EXPECT_EQ(5, node_4->buffered_elements_high());
1946   // Set node 5 low watermark to 1 and high watermark to 2. Its current buffer
1947   // size is set to 8. Expect that it is downsized to 8/2 rather than (2 - 1 + 1
1948   // = 3) because downsize is capped to half its size.
1949   node_5->record_buffer_event(100, 1);
1950   node_5->record_buffer_event(-100, 1);
1951   EXPECT_EQ(1, node_5->buffered_elements_low());
1952   EXPECT_EQ(2, node_5->buffered_elements_high());
1953 
1954   model_->OptimizeBuffers(node_1->Snapshot(), 10000);
1955 
1956   EXPECT_EQ(2, node_1->parameter_value(kBufferSize));
1957   EXPECT_EQ(5, node_2->parameter_value(kBufferSize));
1958   EXPECT_EQ(10, node_3->parameter_value(kBufferSize));
1959   EXPECT_EQ(8, node_4->parameter_value(kBufferSize));
1960   EXPECT_EQ(6, node_5->parameter_value(kBufferSize));
1961   EXPECT_EQ(2, node_1->buffered_elements_low());
1962   EXPECT_EQ(2, node_1->buffered_elements_high());
1963   EXPECT_EQ(4, node_2->buffered_elements_low());
1964   EXPECT_EQ(4, node_2->buffered_elements_high());
1965   EXPECT_EQ(4, node_3->buffered_elements_low());
1966   EXPECT_EQ(4, node_3->buffered_elements_high());
1967   EXPECT_EQ(4, node_4->buffered_elements_low());
1968   EXPECT_EQ(4, node_4->buffered_elements_high());
1969   EXPECT_EQ(2, node_5->buffered_elements_low());
1970   EXPECT_EQ(2, node_5->buffered_elements_high());
1971 }
1972 
TEST_F(BufferSizeTest,OptimizeBuffers_TightMemory)1973 TEST_F(BufferSizeTest, OptimizeBuffers_TightMemory) {
1974   ReadModel(R"pb(
1975     nodes: {
1976       key: 1
1977       value: {
1978         id: 1
1979         name: "Prefetch"
1980         autotune: true
1981         bytes_produced: 10000
1982         num_elements: 100
1983         processing_time: 2000
1984         node_class: ASYNC_KNOWN_RATIO
1985         inputs: 2
1986         ratio: 1
1987         parameters: {
1988           name: "buffer_size"
1989           value: 5
1990           state_value: 5
1991           min: 1
1992           max: 10
1993           tunable: true
1994         }
1995       }
1996     }
1997     nodes: {
1998       key: 2
1999       value: {
2000         id: 2
2001         name: "Prefetch"
2002         autotune: true
2003         bytes_produced: 10000
2004         num_elements: 100
2005         processing_time: 2000
2006         node_class: ASYNC_KNOWN_RATIO
2007         inputs: 3
2008         ratio: 1
2009         parameters: {
2010           name: "buffer_size"
2011           value: 5
2012           state_value: 5
2013           min: 1
2014           max: 10
2015           tunable: true
2016         }
2017       }
2018     }
2019     nodes: {
2020       key: 3
2021       value: {
2022         id: 3
2023         name: "Prefetch"
2024         autotune: true
2025         bytes_produced: 10000
2026         num_elements: 100
2027         processing_time: 2000
2028         node_class: ASYNC_KNOWN_RATIO
2029         inputs: 4
2030         ratio: 1
2031         parameters: {
2032           name: "buffer_size"
2033           value: 5
2034           state_value: 5
2035           min: 1
2036           max: 10
2037           tunable: true
2038         }
2039       }
2040     }
2041     nodes: {
2042       key: 4
2043       value: {
2044         id: 4
2045         name: "Prefetch"
2046         autotune: true
2047         bytes_produced: 10000
2048         num_elements: 100
2049         processing_time: 2000
2050         node_class: ASYNC_KNOWN_RATIO
2051         ratio: 1
2052         parameters: {
2053           name: "buffer_size"
2054           value: 5
2055           state_value: 5
2056           min: 1
2057           max: 8
2058           tunable: true
2059         }
2060       }
2061     }
2062     output: 1
2063   )pb");
2064 
2065   std::shared_ptr<Node> node_1 = GetNode(1);
2066   std::shared_ptr<Node> node_2 = GetNode(2);
2067   std::shared_ptr<Node> node_3 = GetNode(3);
2068   std::shared_ptr<Node> node_4 = GetNode(4);
2069   // Set low watermark to 0 and high watermark to 5 for all nodes.
2070   node_1->record_buffer_event(100, 1);
2071   node_1->record_buffer_event(-100, -1);
2072   node_1->record_buffer_event(500, 5);
2073   EXPECT_EQ(0, node_1->buffered_elements_low());
2074   EXPECT_EQ(5, node_1->buffered_elements_high());
2075   node_2->record_buffer_event(100, 1);
2076   node_2->record_buffer_event(100, 1);
2077   node_2->record_buffer_event(-100, -1);
2078   node_2->record_buffer_event(-100, -1);
2079   node_2->record_buffer_event(400, 4);
2080   node_2->record_buffer_event(100, 1);
2081   EXPECT_EQ(0, node_2->buffered_elements_low());
2082   EXPECT_EQ(5, node_2->buffered_elements_high());
2083   node_3->record_buffer_event(100, 1);
2084   node_3->record_buffer_event(-100, -1);
2085   node_3->record_buffer_event(500, 5);
2086   EXPECT_EQ(0, node_3->buffered_elements_low());
2087   EXPECT_EQ(5, node_3->buffered_elements_high());
2088   node_4->record_buffer_event(100, 1);
2089   node_4->record_buffer_event(-100, -1);
2090   node_4->record_buffer_event(100, 1);
2091   node_4->record_buffer_event(-100, -1);
2092   node_4->record_buffer_event(500, 5);
2093   node_4->record_buffer_event(-100, -1);
2094   EXPECT_EQ(0, node_4->buffered_elements_low());
2095   EXPECT_EQ(5, node_4->buffered_elements_high());
2096 
2097   model_->OptimizeBuffers(node_1->Snapshot(), 3000);
2098 
2099   EXPECT_DOUBLE_EQ(7.0, node_1->parameter_value(kBufferSize));
2100   EXPECT_DOUBLE_EQ(7.0, node_2->parameter_value(kBufferSize));
2101   EXPECT_DOUBLE_EQ(7.0, node_3->parameter_value(kBufferSize));
2102   EXPECT_DOUBLE_EQ(7.0, node_4->parameter_value(kBufferSize));
2103   EXPECT_EQ(5, node_1->buffered_elements_low());
2104   EXPECT_EQ(5, node_1->buffered_elements_high());
2105   EXPECT_EQ(5, node_2->buffered_elements_low());
2106   EXPECT_EQ(5, node_2->buffered_elements_high());
2107   EXPECT_EQ(5, node_3->buffered_elements_low());
2108   EXPECT_EQ(5, node_3->buffered_elements_high());
2109   EXPECT_EQ(4, node_4->buffered_elements_low());
2110   EXPECT_EQ(4, node_4->buffered_elements_high());
2111 }
2112 
TEST_F(ModelTimingTest,OptimizeStageBased_OneStage)2113 TEST_F(ModelTimingTest, OptimizeStageBased_OneStage) {
2114   BuildModelFromProto(R"pb(
2115     nodes: {
2116       key: 1
2117       value: {
2118         id: 1
2119         name: "ParallelMapV2"
2120         autotune: true
2121         num_elements: 100
2122         processing_time: 5000
2123         bytes_produced: 10000
2124         node_class: ASYNC_KNOWN_RATIO
2125         ratio: 1
2126         inputs: 2
2127         parameters: {
2128           name: "parallelism"
2129           value: 4
2130           min: 1
2131           max: 16
2132           tunable: true
2133         }
2134       }
2135     }
2136     nodes: {
2137       key: 2
2138       value: {
2139         id: 2
2140         name: "Map"
2141         autotune: true
2142         num_elements: 100
2143         processing_time: 3000
2144         node_class: KNOWN_RATIO
2145         ratio: 1
2146         inputs: 3
2147       }
2148     }
2149     nodes: {
2150       key: 3
2151       value: {
2152         id: 3
2153         name: "SSTable"
2154         autotune: true
2155         num_elements: 100
2156         processing_time: 1000
2157         node_class: KNOWN_RATIO
2158         ratio: 2
2159       }
2160     }
2161     output: 1
2162   )pb");
2163 
2164   CancellationManager cancellation_manager;
2165   model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 20, 1000, 50,
2166                    &cancellation_manager);
2167 
2168   EXPECT_EQ(5, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2169 }
2170 
TEST_F(ModelTimingTest,OptimizeStageBased_CappedByParameterMax)2171 TEST_F(ModelTimingTest, OptimizeStageBased_CappedByParameterMax) {
2172   BuildModelFromProto(R"pb(
2173     nodes: {
2174       key: 1
2175       value: {
2176         id: 1
2177         name: "ParallelMapV2"
2178         autotune: true
2179         num_elements: 100
2180         processing_time: 5000
2181         bytes_produced: 10000
2182         node_class: ASYNC_KNOWN_RATIO
2183         ratio: 1
2184         inputs: 2
2185         parameters: { name: "parallelism" value: 4 min: 1 max: 3 tunable: true }
2186       }
2187     }
2188     nodes: {
2189       key: 2
2190       value: {
2191         id: 2
2192         name: "Map"
2193         autotune: true
2194         num_elements: 100
2195         processing_time: 3000
2196         node_class: KNOWN_RATIO
2197         ratio: 1
2198         inputs: 3
2199       }
2200     }
2201     nodes: {
2202       key: 3
2203       value: {
2204         id: 3
2205         name: "SSTable"
2206         autotune: true
2207         num_elements: 100
2208         processing_time: 1000
2209         node_class: KNOWN_RATIO
2210         ratio: 2
2211       }
2212     }
2213     output: 1
2214   )pb");
2215 
2216   CancellationManager cancellation_manager;
2217   model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 20, 1000, 50,
2218                    &cancellation_manager);
2219 
2220   // The max value is set to 3. Otherwise, the expected parallelism value is 5.
2221   EXPECT_EQ(3, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2222 }
2223 
TEST_F(ModelTimingTest,OptimizeStageBased_TwoStages)2224 TEST_F(ModelTimingTest, OptimizeStageBased_TwoStages) {
2225   BuildModelFromProto(R"pb(
2226     nodes: {
2227       key: 1
2228       value: {
2229         id: 1
2230         name: "ParallelMapV2"
2231         autotune: true
2232         num_elements: 100
2233         processing_time: 25000
2234         bytes_produced: 10000
2235         node_class: ASYNC_KNOWN_RATIO
2236         ratio: 1
2237         inputs: 2
2238         parameters: {
2239           name: "parallelism"
2240           value: 4
2241           min: 1
2242           max: 16
2243           tunable: true
2244         }
2245       }
2246     }
2247     nodes: {
2248       key: 2
2249       value: {
2250         id: 2
2251         name: "ParallelMapV2"
2252         autotune: true
2253         num_elements: 100
2254         processing_time: 20000
2255         bytes_produced: 10000
2256         node_class: ASYNC_KNOWN_RATIO
2257         ratio: 1
2258         inputs: 3
2259         parameters: {
2260           name: "parallelism"
2261           value: 4
2262           min: 1
2263           max: 16
2264           tunable: true
2265         }
2266       }
2267     }
2268     nodes: {
2269       key: 3
2270       value: {
2271         id: 3
2272         name: "SSTable"
2273         autotune: true
2274         num_elements: 100
2275         processing_time: 1000
2276         node_class: KNOWN_RATIO
2277         ratio: 2
2278       }
2279     }
2280     output: 1
2281   )pb");
2282 
2283   CancellationManager cancellation_manager;
2284   model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 5, 1000, 50,
2285                    &cancellation_manager);
2286 
2287   EXPECT_EQ(5, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2288   EXPECT_EQ(5, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
2289 }
2290 
TEST_F(ModelTimingTest,OptimizeStageBased_TwoStages_RamBudgetExceeded)2291 TEST_F(ModelTimingTest, OptimizeStageBased_TwoStages_RamBudgetExceeded) {
2292   BuildModelFromProto(R"pb(
2293     nodes: {
2294       key: 1
2295       value: {
2296         id: 1
2297         name: "ParallelMapV2"
2298         autotune: true
2299         num_elements: 100
2300         processing_time: 25000
2301         bytes_produced: 10000
2302         node_class: ASYNC_KNOWN_RATIO
2303         ratio: 1
2304         inputs: 2
2305         parameters: {
2306           name: "parallelism"
2307           value: 4
2308           state_value: 4
2309           min: 1
2310           max: 16
2311           tunable: true
2312         }
2313       }
2314     }
2315     nodes: {
2316       key: 2
2317       value: {
2318         id: 2
2319         name: "ParallelMapV2"
2320         autotune: true
2321         num_elements: 100
2322         processing_time: 20000
2323         bytes_produced: 10000
2324         node_class: ASYNC_KNOWN_RATIO
2325         ratio: 1
2326         inputs: 3
2327         parameters: {
2328           name: "parallelism"
2329           value: 4
2330           state_value: 4
2331           min: 1
2332           max: 16
2333           tunable: true
2334         }
2335       }
2336     }
2337     nodes: {
2338       key: 3
2339       value: {
2340         id: 3
2341         name: "SSTable"
2342         autotune: true
2343         num_elements: 100
2344         processing_time: 1000
2345         node_class: KNOWN_RATIO
2346         ratio: 2
2347       }
2348     }
2349     output: 1
2350   )pb");
2351 
2352   CancellationManager cancellation_manager;
2353   // Not enough RAM, the original `parallelism` should not change.
2354   model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 10, 100, 0,
2355                    &cancellation_manager);
2356   EXPECT_EQ(4, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2357   EXPECT_EQ(4, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
2358   // Has enough RAM, the original `parallelism` should increase.
2359   model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 10, 100000, 0,
2360                    &cancellation_manager);
2361   EXPECT_EQ(12, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2362   EXPECT_EQ(16, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
2363   // Not enough RAM, the original `parallelism` should not change.
2364   model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 10, 100, 0,
2365                    &cancellation_manager);
2366   EXPECT_EQ(12, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2367   EXPECT_EQ(16, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
2368 }
2369 
TEST_F(ModelTimingTest,OptimizeStageBased_PipelineRatio)2370 TEST_F(ModelTimingTest, OptimizeStageBased_PipelineRatio) {
2371   BuildModelFromProto(R"pb(
2372     nodes: {
2373       key: 1
2374       value: {
2375         id: 1
2376         name: "ParallelBatch"
2377         autotune: true
2378         num_elements: 100
2379         processing_time: 5000
2380         bytes_produced: 10000
2381         node_class: ASYNC_KNOWN_RATIO
2382         ratio: 2
2383         inputs: 2
2384         parameters: {
2385           name: "parallelism"
2386           value: 4
2387           min: 1
2388           max: 16
2389           tunable: true
2390         }
2391       }
2392     }
2393     nodes: {
2394       key: 2
2395       value: {
2396         id: 2
2397         name: "Map"
2398         autotune: true
2399         num_elements: 100
2400         processing_time: 3000
2401         node_class: KNOWN_RATIO
2402         ratio: 1
2403         inputs: 3
2404       }
2405     }
2406     nodes: {
2407       key: 3
2408       value: {
2409         id: 3
2410         name: "SSTable"
2411         autotune: true
2412         num_elements: 100
2413         processing_time: 1000
2414         node_class: KNOWN_RATIO
2415         ratio: 2
2416       }
2417     }
2418     output: 1
2419   )pb");
2420 
2421   CancellationManager cancellation_manager;
2422   model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 20, 10000, 50,
2423                    &cancellation_manager);
2424 
2425   EXPECT_EQ(16, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
2426 }
2427 
TEST_F(ModelTimingTest,ComputeTargetTime)2428 TEST_F(ModelTimingTest, ComputeTargetTime) {
2429   model_ = std::make_unique<Model>();
2430 
2431   model_->RecordIteratorGapTime(10);
2432   model_->RecordIteratorGapTime(10);
2433   model_->RecordIteratorGapTime(10);
2434   model_->RecordIteratorGapTime(10);
2435   model_->RecordIteratorGapTime(10);
2436   model_->RecordIteratorGapTime(1000);
2437   // Gap times that are >= 10 seconds are always dropped.
2438   model_->RecordIteratorGapTime(10000000);
2439 
2440   EXPECT_DOUBLE_EQ(10, model_->ComputeTargetTimeNsec() * 1e-3);
2441 }
2442 
TEST_F(ModelTimingTest,ComputeTargetTime_NoOutlier)2443 TEST_F(ModelTimingTest, ComputeTargetTime_NoOutlier) {
2444   model_ = std::make_unique<Model>();
2445 
2446   model_->RecordIteratorGapTime(10);
2447   model_->RecordIteratorGapTime(10);
2448   model_->RecordIteratorGapTime(10);
2449   model_->RecordIteratorGapTime(10);
2450   model_->RecordIteratorGapTime(20);
2451   model_->RecordIteratorGapTime(20);
2452   model_->RecordIteratorGapTime(20);
2453   model_->RecordIteratorGapTime(20);
2454   // Gap times that are >= 10 seconds are always dropped.
2455   model_->RecordIteratorGapTime(10000000);
2456 
2457   EXPECT_DOUBLE_EQ(15.0, model_->ComputeTargetTimeNsec() * 1e-3);
2458 }
2459 
TEST_F(ModelTimingTest,ComputeTargetTime_TestWindow)2460 TEST_F(ModelTimingTest, ComputeTargetTime_TestWindow) {
2461   model_ = std::make_unique<Model>();
2462 
2463   // The window size is 100. Only the last 100 gap times are used to compute the
2464   // target time.
2465   for (int i = 0; i < 100; ++i) {
2466     model_->RecordIteratorGapTime(20);
2467   }
2468   for (int i = 0; i < 100; ++i) {
2469     model_->RecordIteratorGapTime(10);
2470   }
2471 
2472   EXPECT_DOUBLE_EQ(10.0, model_->ComputeTargetTimeNsec() * 1e-3);
2473 }
2474 
TEST_F(ModelTimingTest,SelfTime)2475 TEST_F(ModelTimingTest, SelfTime) {
2476   BuildModelFromProto(R"pb(
2477     nodes: {
2478       key: 1
2479       value: {
2480         id: 1
2481         name: "ParallelMapV2"
2482         autotune: true
2483         num_elements: 100
2484         processing_time: 20000
2485         node_class: ASYNC_KNOWN_RATIO
2486         ratio: 1
2487         inputs: 2
2488         parameters: {
2489           name: "parallelism"
2490           value: 2
2491           min: 1
2492           max: 16
2493           tunable: true
2494         }
2495       }
2496     }
2497     nodes: {
2498       key: 2
2499       value: {
2500         id: 2
2501         name: "SSTable"
2502         autotune: true
2503         num_elements: 100
2504         processing_time: 100000
2505         node_class: KNOWN_RATIO
2506         ratio: 1
2507       }
2508     }
2509     output: 1
2510   )pb");
2511 
2512   auto node_1 = MutableGetNode(/*node_id=*/1);
2513   EXPECT_DOUBLE_EQ(100, node_1->ComputeSelfTime());
2514   node_1->add_processing_time(400);
2515   node_1->record_element();
2516   EXPECT_DOUBLE_EQ(110, node_1->ComputeSelfTime());
2517   auto node_2 = MutableGetNode(/*node_id=*/2);
2518   EXPECT_DOUBLE_EQ(1000, node_2->ComputeSelfTime());
2519   node_2->add_processing_time(100);
2520   node_2->record_element();
2521   EXPECT_DOUBLE_EQ(910, node_2->ComputeSelfTime());
2522 }
2523 
2524 }  // namespace
2525 }  // namespace model
2526 }  // namespace data
2527 }  // namespace tensorflow
2528