1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <string>
21
22 #include "absl/strings/match.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/core/framework/allocation_description.pb.h"
25 #include "tensorflow/core/framework/tensor_description.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
28 #include "tensorflow/core/grappler/costs/graph_properties.h"
29 #include "tensorflow/core/grappler/costs/utils.h"
30 #include "tensorflow/core/grappler/costs/virtual_placer.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/platform/test.h"
33
34 namespace tensorflow {
35 namespace grappler {
36 namespace {
37
38 // Device names:
39 constexpr char kCPU0[] = "/job:localhost/replica:0/task:0/cpu:0";
40 constexpr char kCPU1[] = "/job:localhost/replica:0/task:0/cpu:1";
41 constexpr char kChannelFrom0To1[] = "Channel from CPU0 to CPU1";
42 constexpr char kChannelFrom1To0[] = "Channel from CPU1 to CPU0";
43 // Op names:
44 constexpr char kConv2D[] = "Conv2D";
45 constexpr char kSend[] = "_Send";
46 constexpr char kRecv[] = "_Recv";
47
48 class ReadyNodeManagerTest : public ::testing::Test {
49 protected:
ReadyNodeManagerTest()50 ReadyNodeManagerTest() {
51 // node1_ to node6_ on kCPU0, with time_ready in reverse_order.
52 NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_);
53 NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_);
54 NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_);
55 NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_);
56 NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_);
57 NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_);
58 }
59
NodeSetUp(const string & name,const string & op_name,const string & device_name,const uint64 time_ready,NodeDef * node)60 void NodeSetUp(const string& name, const string& op_name,
61 const string& device_name, const uint64 time_ready,
62 NodeDef* node) {
63 node->set_name(name);
64 node->set_op(op_name);
65 node->set_device(device_name);
66
67 node_states_[node] = NodeState();
68 node_states_[node].time_ready = time_ready;
69 node_states_[node].device_name = device_name;
70 }
71
72 NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
73 std::unordered_map<const NodeDef*, NodeState> node_states_;
74 };
75
76 // Tests that FIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeFIFOManager)77 TEST_F(ReadyNodeManagerTest, GetSingleNodeFIFOManager) {
78 FIFOManager manager = FIFOManager();
79 manager.AddNode(&node1_);
80 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
81 }
82
83 // Tests that FIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFIFOManager)84 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFIFOManager) {
85 FIFOManager manager = FIFOManager();
86 manager.AddNode(&node1_);
87
88 // Removes the only node in FIFOManager.
89 manager.RemoveCurrNode();
90 EXPECT_TRUE(manager.Empty());
91 }
92
93 // Tests that FIFOManager can remove multiple nodes and returns the current node
94 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFIFOManager)95 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFIFOManager) {
96 FIFOManager manager = FIFOManager();
97 manager.AddNode(&node1_);
98 manager.AddNode(&node2_);
99 manager.AddNode(&node3_);
100 manager.AddNode(&node4_);
101
102 // Keeps checking current node while removing nodes from manager.
103 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
104 manager.RemoveCurrNode();
105 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
106 manager.RemoveCurrNode();
107 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
108 manager.RemoveCurrNode();
109 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
110 manager.RemoveCurrNode();
111 EXPECT_TRUE(manager.Empty());
112 }
113
114 // Tests that FIFOManager can remove multiple nodes and add more nodes, still
115 // returning the current node in the right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleFIFOManager)116 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleFIFOManager) {
117 FIFOManager manager = FIFOManager();
118 manager.AddNode(&node1_);
119 manager.AddNode(&node2_);
120 manager.AddNode(&node3_);
121 manager.AddNode(&node4_);
122
123 // Keeps checking current node as nodes are removed and added.
124 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
125 manager.RemoveCurrNode();
126 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
127 manager.AddNode(&node5_);
128 // GetCurrNode() should return the same node even if some nodes are added,
129 // until RemoveCurrNode() is called.
130 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
131 manager.RemoveCurrNode();
132 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
133 manager.RemoveCurrNode();
134 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
135 manager.AddNode(&node6_);
136 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
137 manager.RemoveCurrNode();
138 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
139 manager.RemoveCurrNode();
140 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
141 manager.RemoveCurrNode();
142 EXPECT_TRUE(manager.Empty());
143 }
144
145 // Tests that LIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeLIFOManager)146 TEST_F(ReadyNodeManagerTest, GetSingleNodeLIFOManager) {
147 LIFOManager manager = LIFOManager();
148 manager.AddNode(&node1_);
149 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
150 }
151
152 // Tests that LIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeLIFOManager)153 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeLIFOManager) {
154 LIFOManager manager = LIFOManager();
155 manager.AddNode(&node1_);
156
157 // Removes the only node in LIFOManager.
158 manager.RemoveCurrNode();
159 EXPECT_TRUE(manager.Empty());
160 }
161
162 // Tests that LIFOManager can remove multiple nodes and returns the current node
163 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleLIFOManager)164 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleLIFOManager) {
165 LIFOManager manager = LIFOManager();
166 manager.AddNode(&node1_);
167 manager.AddNode(&node2_);
168 manager.AddNode(&node3_);
169 manager.AddNode(&node4_);
170
171 // Keeps checking current node while removing nodes from manager.
172 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
173 manager.RemoveCurrNode();
174 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
175 manager.RemoveCurrNode();
176 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
177 manager.RemoveCurrNode();
178 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
179 manager.RemoveCurrNode();
180 EXPECT_TRUE(manager.Empty());
181 }
182
183 // Tests that LIFOManager can remove multiple nodes (must be removing the
184 // current node) and add more nodes, still returning the current node in the
185 // right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleLIFOManager)186 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleLIFOManager) {
187 LIFOManager manager = LIFOManager();
188 manager.AddNode(&node1_);
189 manager.AddNode(&node2_);
190 manager.AddNode(&node3_);
191 manager.AddNode(&node4_);
192
193 // Keeps checking current node as nodes are removed and added.
194 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
195 manager.RemoveCurrNode();
196 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
197 manager.AddNode(&node5_);
198 // GetCurrNode() should return the same node even if some nodes are added,
199 // until RemoveCurrNode() is called.
200 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
201 manager.RemoveCurrNode();
202 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
203 manager.RemoveCurrNode();
204 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
205 manager.AddNode(&node6_);
206 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
207 manager.RemoveCurrNode();
208 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
209 manager.RemoveCurrNode();
210 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
211 manager.RemoveCurrNode();
212 EXPECT_TRUE(manager.Empty());
213 }
214
TEST_F(ReadyNodeManagerTest,MergeOrderInLIFOManager)215 TEST_F(ReadyNodeManagerTest, MergeOrderInLIFOManager) {
216 LIFOManager manager = LIFOManager();
217 node3_.set_op("Merge");
218 manager.AddNode(&node1_);
219 manager.AddNode(&node2_);
220 manager.AddNode(&node3_);
221 manager.AddNode(&node4_);
222
223 // Merge node (node3) will be scheduled at the end (even though it's added
224 // after nodde2).
225 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
226 manager.RemoveCurrNode();
227 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
228 manager.RemoveCurrNode();
229 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
230 manager.RemoveCurrNode();
231 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
232 manager.RemoveCurrNode();
233 }
234
TEST_F(ReadyNodeManagerTest,GetSingleNodeFirstReadyManager)235 TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) {
236 FirstReadyManager manager;
237 TF_EXPECT_OK(manager.Init(&node_states_));
238 manager.AddNode(&node1_);
239 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
240 }
241
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFirstReadyManager)242 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
243 FirstReadyManager manager;
244 TF_EXPECT_OK(manager.Init(&node_states_));
245 manager.AddNode(&node1_);
246 manager.RemoveCurrNode();
247 EXPECT_TRUE(manager.Empty());
248 }
249
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFirstReadyManager)250 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
251 FirstReadyManager manager;
252 TF_EXPECT_OK(manager.Init(&node_states_));
253 // Insert nodes in some random order.
254 manager.AddNode(&node2_);
255 manager.AddNode(&node1_);
256 manager.AddNode(&node4_);
257 manager.AddNode(&node5_);
258 manager.AddNode(&node3_);
259 manager.AddNode(&node6_);
260
261 // In whatever order we insert nodes, we get the same order based on nodes'
262 // time_ready.
263 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
264 manager.RemoveCurrNode();
265 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
266 manager.RemoveCurrNode();
267 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
268 manager.RemoveCurrNode();
269 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
270 manager.RemoveCurrNode();
271 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
272 manager.RemoveCurrNode();
273 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
274 manager.RemoveCurrNode();
275 EXPECT_TRUE(manager.Empty());
276 }
277
TEST_F(ReadyNodeManagerTest,GetCurrNodeFirstReadyManager)278 TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
279 FirstReadyManager manager;
280 TF_EXPECT_OK(manager.Init(&node_states_));
281
282 // Inserts nodes in some random order.
283 manager.AddNode(&node2_);
284 manager.AddNode(&node1_);
285 manager.AddNode(&node4_);
286 manager.AddNode(&node5_);
287 manager.AddNode(&node3_);
288 manager.AddNode(&node6_);
289
290 // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode()
291 // should return it.
292 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
293
294 // Now inserts a few other nodes, but their time_ready's are even smaller than
295 // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return
296 // the same node, Node6, in this case.
297 NodeDef node7;
298 NodeDef node8;
299 NodeDef node9;
300 NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7);
301 NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8);
302 NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9);
303
304 manager.AddNode(&node7);
305 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
306
307 manager.AddNode(&node8);
308 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
309
310 manager.RemoveCurrNode();
311 // Now Node6 is removed, and GetCurrNode() will return Node8.
312 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
313
314 // Again, AddNode shouldn't change GetCurrNode().
315 manager.AddNode(&node9);
316 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
317
318 manager.RemoveCurrNode();
319 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
320 manager.RemoveCurrNode();
321 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
322 manager.RemoveCurrNode();
323 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
324 manager.RemoveCurrNode();
325 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
326 manager.RemoveCurrNode();
327 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
328 manager.RemoveCurrNode();
329 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
330 manager.RemoveCurrNode();
331 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
332 manager.RemoveCurrNode();
333 EXPECT_TRUE(manager.Empty());
334 }
335
TEST_F(ReadyNodeManagerTest,DeterminismInFirstReadyManager)336 TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
337 FirstReadyManager manager1;
338 TF_EXPECT_OK(manager1.Init(&node_states_));
339 FirstReadyManager manager2;
340 TF_EXPECT_OK(manager2.Init(&node_states_));
341
342 // 6 nodes with same time_ready.
343 NodeDef node7;
344 NodeDef node8;
345 NodeDef node9;
346 NodeDef node10;
347 NodeDef node11;
348 NodeDef node12;
349 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
350 NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8);
351 NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9);
352 NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10);
353 NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11);
354 NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12);
355
356 // Adds the above 6 nodes to manager1.
357 manager1.AddNode(&node7);
358 manager1.AddNode(&node8);
359 manager1.AddNode(&node9);
360 manager1.AddNode(&node10);
361 manager1.AddNode(&node11);
362 manager1.AddNode(&node12);
363
364 // Adds the above 6 nodes to manager2, but in a different order.
365 manager2.AddNode(&node8);
366 manager2.AddNode(&node11);
367 manager2.AddNode(&node9);
368 manager2.AddNode(&node10);
369 manager2.AddNode(&node7);
370 manager2.AddNode(&node12);
371
372 // Expects both managers return the same nodes for deterministic node
373 // scheduling.
374 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
375 manager1.RemoveCurrNode();
376 manager2.RemoveCurrNode();
377
378 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
379 manager1.RemoveCurrNode();
380 manager2.RemoveCurrNode();
381
382 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
383 manager1.RemoveCurrNode();
384 manager2.RemoveCurrNode();
385
386 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
387 manager1.RemoveCurrNode();
388 manager2.RemoveCurrNode();
389
390 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
391 manager1.RemoveCurrNode();
392 manager2.RemoveCurrNode();
393
394 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
395 manager1.RemoveCurrNode();
396 manager2.RemoveCurrNode();
397
398 EXPECT_TRUE(manager1.Empty());
399 EXPECT_TRUE(manager2.Empty());
400 }
401
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultiplePriorityReadyManager)402 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultiplePriorityReadyManager) {
403 PriorityReadyManager manager;
404 TF_EXPECT_OK(manager.Init(&node_states_));
405
406 // Sets up node priorities.
407 std::unordered_map<string, int> node_priority = {
408 {"Node1", 1}, {"Node2", 2}, {"Node3", 2}, {"Node4", 4}, {"Node5", 5}};
409 TF_EXPECT_OK(manager.SetPriority(node_priority));
410
411 // Inserts nodes in some random order.
412 manager.AddNode(&node3_);
413 manager.AddNode(&node1_);
414 manager.AddNode(&node4_);
415 manager.AddNode(&node5_);
416 manager.AddNode(&node2_);
417 manager.AddNode(&node6_);
418
419 // Expects nodes scheduled based on priority.
420 // Node6 should default to lowest priority, since it is not found.
421 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
422 manager.RemoveCurrNode();
423 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
424 manager.RemoveCurrNode();
425 // Nodes 2 and 3 have equal priority and so should be scheduled ready-first.
426 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
427 manager.RemoveCurrNode();
428 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
429 manager.RemoveCurrNode();
430 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
431 manager.RemoveCurrNode();
432 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
433 manager.RemoveCurrNode();
434 EXPECT_TRUE(manager.Empty());
435 }
436
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeCompositeNodeManager)437 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
438 CompositeNodeManager manager;
439 TF_EXPECT_OK(manager.Init(&node_states_));
440 manager.AddNode(&node1_);
441 manager.RemoveCurrNode();
442 EXPECT_TRUE(manager.Empty());
443 }
444
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleCompositeNodeManager)445 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleCompositeNodeManager) {
446 CompositeNodeManager manager;
447 TF_EXPECT_OK(manager.Init(&node_states_));
448 manager.AddNode(&node1_);
449 manager.AddNode(&node2_);
450 manager.AddNode(&node3_);
451 manager.AddNode(&node4_);
452
453 // Keeps checking current node as nodes are removed and added.
454 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
455 manager.RemoveCurrNode();
456 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
457 manager.AddNode(&node5_);
458 // GetCurrNode() should return the same node even if some nodes are added,
459 // until RemoveCurrNode() is called.
460 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
461 manager.RemoveCurrNode();
462 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
463 manager.RemoveCurrNode();
464 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
465 manager.AddNode(&node6_);
466 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
467 manager.RemoveCurrNode();
468 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
469 manager.RemoveCurrNode();
470 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
471 manager.RemoveCurrNode();
472 EXPECT_TRUE(manager.Empty());
473 }
474
TEST_F(ReadyNodeManagerTest,MultiDeviceSendRecvCompositeNodeManager)475 TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvCompositeNodeManager) {
476 CompositeNodeManager manager;
477 TF_EXPECT_OK(manager.Init(&node_states_));
478 // Additional nodes on kCPU1.
479 NodeDef node7;
480 NodeDef node8;
481 NodeDef node9;
482 NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7);
483 NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8);
484 NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9);
485
486 // Send and Recv nodes.
487 NodeDef send1;
488 NodeDef send2;
489 NodeDef recv1;
490 NodeDef recv2;
491 NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1);
492 NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2);
493 NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1);
494 NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2);
495
496 // Inserts nodes.
497 manager.AddNode(&node1_);
498 manager.AddNode(&node2_);
499 manager.AddNode(&node3_);
500 manager.AddNode(&node4_);
501 manager.AddNode(&node5_);
502 manager.AddNode(&node6_);
503 manager.AddNode(&node7);
504 manager.AddNode(&node8);
505 manager.AddNode(&node9);
506 manager.AddNode(&send1);
507 manager.AddNode(&send2);
508 manager.AddNode(&recv1);
509 manager.AddNode(&recv2);
510
511 // On kCPU0; last one is node6_, on kCPU1: last one is node9;
512 // so choose one that has earliest time_ready among node6_, node9,
513 // Send1, Send2, Recv1, and Recv2.
514 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
515 manager.RemoveCurrNode();
516 // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node
517 // among node5_, node9, Send1, Send2, Recv1, and Recv2.
518 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
519 manager.RemoveCurrNode();
520 // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2.
521 EXPECT_EQ(manager.GetCurrNode()->name(), "Send1");
522 manager.RemoveCurrNode();
523 // Next, choose among node4_, node9, Sen2, Recv1, and Recv2.
524 EXPECT_EQ(manager.GetCurrNode()->name(), "Recv1");
525 manager.RemoveCurrNode();
526 // Next, choose among node4_, node9, Send2, and Recv2.
527 EXPECT_EQ(manager.GetCurrNode()->name(), "Recv2");
528 manager.RemoveCurrNode();
529 // Next, choose among node4_, node9, and Send2.
530 EXPECT_EQ(manager.GetCurrNode()->name(), "Send2");
531 manager.RemoveCurrNode();
532 // Next, choose between node4_, node9.
533 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
534 manager.RemoveCurrNode();
535 // Next, choose between node3_, node9.
536 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
537 manager.RemoveCurrNode();
538 // Next, choose between node3_, node8.
539 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
540 manager.RemoveCurrNode();
541 // Next, choose between node3_, node7.
542 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
543 manager.RemoveCurrNode();
544 // Then, just the nodes on kCPU1 -- LIFO.
545 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
546 manager.RemoveCurrNode();
547 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
548 manager.RemoveCurrNode();
549 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
550 manager.RemoveCurrNode();
551 EXPECT_TRUE(manager.Empty());
552 }
553
TEST_F(ReadyNodeManagerTest,DeterminismInCompositeNodeManager)554 TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) {
555 CompositeNodeManager manager;
556 TF_EXPECT_OK(manager.Init(&node_states_));
557 CompositeNodeManager manager2;
558 TF_EXPECT_OK(manager2.Init(&node_states_));
559
560 // 6 nodes with same time_ready.
561 NodeDef node7;
562 NodeDef node8;
563 NodeDef node9;
564 NodeDef node10;
565 NodeDef node11;
566 NodeDef node12;
567 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
568 NodeSetUp("Node8", kSend, kCPU0, 1000, &node8);
569 NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9);
570 NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10);
571 NodeSetUp("Node11", kRecv, kCPU0, 999, &node11);
572 NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12);
573
574 // Adds Nodes 7 to 9 to manager.
575 manager.AddNode(&node7);
576 manager.AddNode(&node8);
577 manager.AddNode(&node9);
578
579 // It should return _Send, Recv, and the other op order, when the candidate
580 // nodes have same time_ready.
581 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
582 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
583 manager.RemoveCurrNode();
584 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
585 EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
586 manager.RemoveCurrNode();
587 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
588 EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
589 manager.RemoveCurrNode();
590 EXPECT_TRUE(manager.Empty());
591
592 // Adds Nodes 7 to 9 to manager, but in a different order.
593 manager.AddNode(&node9);
594 manager.AddNode(&node8);
595 manager.AddNode(&node7);
596
597 // Expects same order (_Send, _Recv, and the other op), regardless of Add
598 // order.
599 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
600 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
601 manager.RemoveCurrNode();
602 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
603 EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
604 manager.RemoveCurrNode();
605 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
606 EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
607 manager.RemoveCurrNode();
608 EXPECT_TRUE(manager.Empty());
609
610 // Conv2D's time_ready < Send's time_ready; Expects Conv2D first.
611 manager.AddNode(&node8);
612 manager.AddNode(&node10);
613 EXPECT_EQ(manager.GetCurrNode()->name(), "Node10");
614 EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
615 manager.RemoveCurrNode();
616 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
617 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
618 manager.RemoveCurrNode();
619 EXPECT_TRUE(manager.Empty());
620
621 // Recv's time_ready < Send' time_ready; Expects Recv first.
622 manager.AddNode(&node11);
623 manager.AddNode(&node8);
624 EXPECT_EQ(manager.GetCurrNode()->name(), "Node11");
625 EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
626 manager.RemoveCurrNode();
627 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
628 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
629 manager.RemoveCurrNode();
630 EXPECT_TRUE(manager.Empty());
631
632 // Node7 and 12 are normal ops with the same time_ready, placed on different
633 // devices. These two nodes are added to manager and manager2, but in
634 // different orders; Expects GetCurrNode() returns the nodes in the same
635 // order.
636 manager.AddNode(&node7);
637 manager.AddNode(&node12);
638
639 manager2.AddNode(&node12);
640 manager2.AddNode(&node7);
641
642 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
643 manager.RemoveCurrNode();
644 manager2.RemoveCurrNode();
645 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
646 manager.RemoveCurrNode();
647 manager2.RemoveCurrNode();
648 EXPECT_TRUE(manager.Empty());
649 }
650
651 // Class for testing virtual scheduler.
652 class TestVirtualScheduler : public VirtualScheduler {
653 public:
TestVirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,ReadyNodeManager * ready_node_manager,Cluster * cluster)654 TestVirtualScheduler(const bool use_static_shapes,
655 const bool use_aggressive_shape_inference,
656 ReadyNodeManager* ready_node_manager, Cluster* cluster)
657 : VirtualScheduler(
658 use_static_shapes, use_aggressive_shape_inference, cluster,
659 ready_node_manager,
660 std::make_unique<VirtualPlacer>(cluster->GetDevices())) {
661 enable_mem_usage_tracking();
662 }
663
664 FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
665 FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
666 FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
667 FRIEND_TEST(VirtualSchedulerTest, Variable);
668 FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
669 };
670
671 class VirtualSchedulerTest : public ::testing::Test {
672 protected:
VirtualSchedulerTest()673 VirtualSchedulerTest() {
674 // Initializes cluster_ and scheduler_.
675 std::unordered_map<string, DeviceProperties> devices;
676
677 // Set some dummy CPU properties
678 DeviceProperties cpu_device = GetDummyCPUDevice();
679
680 // IMPORTANT: Device is not actually ever used in the test case since
681 // force_cpu_type is defaulted to "Haswell"
682 devices[kCPU0] = cpu_device;
683 devices[kCPU1] = cpu_device;
684 cluster_ = std::make_unique<VirtualCluster>(devices);
685 scheduler_ = std::make_unique<TestVirtualScheduler>(
686 /*use_static_shapes=*/true,
687 /*use_aggressive_shape_inference=*/true, &first_ready_manager_,
688 cluster_.get());
689 }
690
GetDummyCPUDevice()691 DeviceProperties GetDummyCPUDevice() {
692 // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
693 // - 8 Gflops
694 // - 2 GB/s
695 DeviceProperties cpu_device;
696 cpu_device.set_type("CPU");
697 cpu_device.set_frequency(4000);
698 cpu_device.set_num_cores(2);
699 cpu_device.set_bandwidth(2000000);
700 return cpu_device;
701 }
702
703 // Three Conv2Ds with only two in fetch nodes.
CreateGrapplerItemWithConv2Ds()704 void CreateGrapplerItemWithConv2Ds() {
705 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
706 auto x = ops::RandomUniform(
707 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
708 auto y = ops::RandomUniform(
709 s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
710 auto z = ops::RandomUniform(
711 s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
712 auto f = ops::RandomUniform(
713 s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
714 std::vector<int> strides = {1, 1, 1, 1};
715 auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
716 auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
717 auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
718
719 grappler_item_ = std::make_unique<GrapplerItem>();
720 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
721 grappler_item_->id = "test_conv2d_graph";
722 grappler_item_->fetch = {"c0", "c1"};
723
724 dependency_["c0"] = {"x", "f"};
725 dependency_["c1"] = {"y", "f"};
726 }
727
728 // A Conv2D with a variable.
CreateGrapplerItemWithConv2DAndVariable()729 void CreateGrapplerItemWithConv2DAndVariable() {
730 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
731 auto x = ops::RandomUniform(
732 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
733 auto f = ops::Variable(s.WithOpName("f"),
734 {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
735 std::vector<int> strides = {1, 1, 1, 1};
736 auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
737
738 grappler_item_ = std::make_unique<GrapplerItem>();
739 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
740 grappler_item_->id = "test_conv2d_var_graph";
741
742 grappler_item_->fetch = {"y"};
743
744 dependency_["y"] = {"x", "f"};
745 }
746
CreateGrapplerItemWithMatmulChain()747 void CreateGrapplerItemWithMatmulChain() {
748 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
749 // Add control dependencies to ensure tests do not rely on specific
750 // manager and the order remains consistent for the test.
751 auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT);
752 auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a),
753 {3200, 3200}, DT_FLOAT);
754 auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b),
755 {3200, 3200}, DT_FLOAT);
756 auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c),
757 {3200, 3200}, DT_FLOAT);
758 auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d),
759 {3200, 3200}, DT_FLOAT);
760
761 auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b);
762 auto abc = ops::MatMul(s.WithOpName("abc"), ab, c);
763 auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d);
764 auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e);
765
766 grappler_item_ = std::make_unique<GrapplerItem>();
767 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
768 grappler_item_->id = "test_matmul_sequence_graph";
769 grappler_item_->fetch = {"abcde"};
770
771 dependency_["ab"] = {"a", "b"};
772 dependency_["abc"] = {"ab", "c"};
773 dependency_["abcd"] = {"abc", "d"};
774 dependency_["abcde"] = {"abcd", "e"};
775 }
776
777 // AddN that takes 4 tensors with 10x10x10x10.
CreateGrapplerItemWithAddN()778 void CreateGrapplerItemWithAddN() {
779 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
780 auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT);
781 auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT);
782 auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT);
783 auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT);
784 OutputList input_tensors = {x, y, z, w};
785 auto add = ops::AddN(s.WithOpName("add"), input_tensors);
786 auto out = ops::Identity(s.WithOpName("out"), add);
787
788 grappler_item_ = std::make_unique<GrapplerItem>();
789 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
790 grappler_item_->id = "test_addn_graph";
791 grappler_item_->fetch = {"out"};
792
793 dependency_["out"] = {"x", "y", "z", "w", "add"};
794 }
795
796 // Graph with some placeholder feed nodes that are not in the fetch fan-in.
CreateGrapplerItemWithUnnecessaryPlaceholderNodes()797 void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() {
798 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
799 auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT);
800 auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT);
801
802 grappler_item_ = std::make_unique<GrapplerItem>();
803 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
804
805 grappler_item_->id = "test_extra_placeholders";
806 grappler_item_->fetch = {"x"};
807
808 // Grappler Item Builder puts all placeholder nodes into the feed
809 // list by default.
810 grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}};
811 }
812
813 // NoOp that takes 7 NoOps as control dependency.
CreateGrapplerItemWithControlDependency()814 void CreateGrapplerItemWithControlDependency() {
815 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
816 std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
817 std::vector<Operation> input_tensors;
818 for (const auto& input : input_noop_names) {
819 auto x = ops::NoOp(s.WithOpName(input));
820 input_tensors.push_back(x.operation);
821 }
822 auto out =
823 ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out"));
824
825 grappler_item_ = std::make_unique<GrapplerItem>();
826 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
827
828 grappler_item_->id = "test_control_dependency_graph";
829 grappler_item_->fetch = {"out"};
830
831 dependency_["out"] = input_noop_names;
832 }
833
CreateGrapplerItemWithAddFromOneTensor()834 void CreateGrapplerItemWithAddFromOneTensor() {
835 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
836 auto x = tensorflow::ops::RandomUniform(
837 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
838
839 auto y = tensorflow::ops::Add(s.WithOpName("y"), x, x);
840 Output fetch = ops::Identity(s.WithOpName("fetch"), y);
841
842 grappler_item_ = std::make_unique<GrapplerItem>();
843 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
844
845 grappler_item_->id = "test_add_from_one_tensor";
846 grappler_item_->fetch = {"fetch"};
847
848 dependency_["fetch"] = {"y"};
849 dependency_["y"] = {"x"};
850 }
851
CreateGrapplerItemWithSwitchMergeInput()852 void CreateGrapplerItemWithSwitchMergeInput() {
853 // sw = Switch(x, pred)
854 // a = Add(S:1, b)
855 // m = Merge(sw:0, a)
856 // y = Add(m, z)
857
858 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
859 auto x = ops::RandomUniform(
860 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
861 auto pred = ops::Const(s.WithOpName("pred"), false, {});
862 auto sw = ops::Switch(s.WithOpName("switch"), x, pred);
863 auto b = ops::RandomUniform(
864 s.WithOpName("b"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
865 auto a = ops::Add(s.WithOpName("a"), sw.output_true, b);
866 auto m = ops::Merge(s.WithOpName("m"), {sw.output_false, a.z});
867 auto z = ops::RandomUniform(
868 s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
869 auto y = ops::Add(s.WithOpName("y"), m.output, z);
870
871 grappler_item_ = std::make_unique<GrapplerItem>();
872 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
873
874 grappler_item_->id = "test_add_merge_switch";
875 grappler_item_->fetch = {"y"};
876
877 dependency_["y"] = {"m", "z"};
878 }
879
880 // FusedBN [an op with multiple outputs] with multiple consumers (including
881 // control dependency).
CreateGrapplerItemWithBatchNorm()882 void CreateGrapplerItemWithBatchNorm() {
883 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
884 auto x = ops::RandomUniform(
885 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
886 auto scale =
887 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
888 auto offset =
889 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
890 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
891 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
892
893 auto batch_norm = ops::FusedBatchNorm(
894 s.WithOpName("bn"), x, scale, offset, mean, var,
895 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
896 auto y = batch_norm.y;
897 auto batch_mean = batch_norm.batch_mean;
898 auto batch_var = batch_norm.batch_variance;
899
900 auto z1 = ops::Add(s.WithOpName("z1"), x, y);
901 auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var);
902 auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var);
903 std::vector<Operation> input_tensors = {
904 batch_mean.op(),
905 z1.z.op(),
906 z2.z.op(),
907 z3.z.op(),
908 };
909 auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4"));
910
911 grappler_item_ = std::make_unique<GrapplerItem>();
912 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
913
914 grappler_item_->id = "test_complex_dependency_graph";
915 grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
916
917 dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
918 dependency_["z1"] = {"x", "bn"};
919 dependency_["z2"] = {"bn"};
920 dependency_["z3"] = {"bn"};
921 dependency_["z4"] = {"bn"};
922 }
923
CreateGrapplerItemWithSendRecv()924 void CreateGrapplerItemWithSendRecv() {
925 const string gdef_ascii = R"EOF(
926 node {
927 name: "Const"
928 op: "Const"
929 device: "/job:localhost/replica:0/task:0/device:CPU:0"
930 attr {
931 key: "dtype"
932 value {
933 type: DT_FLOAT
934 }
935 }
936 attr {
937 key: "_output_shapes"
938 value {
939 list { shape {
940 dim { size: 128 }
941 dim { size: 32 }
942 }}}
943 }
944 attr {
945 key: "shape"
946 value {
947 list { shape {
948 dim { size: 128 }
949 dim { size: 32 }
950 }}}
951 }
952 attr {
953 key: "value"
954 value {
955 tensor {
956 dtype: DT_FLOAT
957 tensor_shape {
958 dim { size: 128 }
959 dim { size: 32 }
960 }
961 float_val: 3.1415
962 }
963 }
964 }
965 }
966 node {
967 name: "Send"
968 op: "_Send"
969 input: "Const"
970 device: "/job:localhost/replica:0/task:0/device:CPU:0"
971 attr {
972 key: "T"
973 value {
974 type: DT_FLOAT
975 }
976 }
977 attr {
978 key: "_output_shapes"
979 value {
980 list { shape {
981 dim { size: 128 }
982 dim { size: 32 }
983 }}}
984 }
985 attr {
986 key: "shape"
987 value {
988 list { shape {
989 dim { size: 128 }
990 dim { size: 32 }
991 }}}
992 }
993 attr {
994 key: "client_terminated"
995 value {
996 b: false
997 }
998 }
999 attr {
1000 key: "recv_device"
1001 value {
1002 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1003 }
1004 }
1005 attr {
1006 key: "send_device"
1007 value {
1008 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1009 }
1010 }
1011 attr {
1012 key: "send_device_incarnation"
1013 value {
1014 i: 0
1015 }
1016 }
1017 attr {
1018 key: "tensor_name"
1019 value {
1020 s: "test"
1021 }
1022 }
1023 }
1024 node {
1025 name: "Recv"
1026 op: "_Recv"
1027 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1028 attr {
1029 key: "client_terminated"
1030 value {
1031 b: false
1032 }
1033 }
1034 attr {
1035 key: "_output_shapes"
1036 value {
1037 list { shape {
1038 dim { size: 128 }
1039 dim { size: 32 }
1040 }}}
1041 }
1042 attr {
1043 key: "shape"
1044 value {
1045 list { shape {
1046 dim { size: 128 }
1047 dim { size: 32 }
1048 }}}
1049 }
1050 attr {
1051 key: "recv_device"
1052 value {
1053 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1054 }
1055 }
1056 attr {
1057 key: "send_device"
1058 value {
1059 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1060 }
1061 }
1062 attr {
1063 key: "send_device_incarnation"
1064 value {
1065 i: 0
1066 }
1067 }
1068 attr {
1069 key: "tensor_name"
1070 value {
1071 s: "test"
1072 }
1073 }
1074 attr {
1075 key: "tensor_type"
1076 value {
1077 type: DT_FLOAT
1078 }
1079 }
1080 }
1081 library {
1082 }
1083 versions {
1084 producer: 24
1085 }
1086 )EOF";
1087
1088 grappler_item_ = std::make_unique<GrapplerItem>();
1089
1090 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1091 &grappler_item_->graph));
1092 grappler_item_->id = "test_graph";
1093 grappler_item_->fetch = {"Recv"};
1094 }
1095
CreateGrapplerItemWithRecvWithoutSend()1096 void CreateGrapplerItemWithRecvWithoutSend() {
1097 const string gdef_ascii = R"EOF(
1098 node {
1099 name: "Recv"
1100 op: "_Recv"
1101 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1102 attr {
1103 key: "client_terminated"
1104 value {
1105 b: false
1106 }
1107 }
1108 attr {
1109 key: "recv_device"
1110 value {
1111 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1112 }
1113 }
1114 attr {
1115 key: "send_device"
1116 value {
1117 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1118 }
1119 }
1120 attr {
1121 key: "send_device_incarnation"
1122 value {
1123 i: 0
1124 }
1125 }
1126 attr {
1127 key: "tensor_name"
1128 value {
1129 s: "test"
1130 }
1131 }
1132 attr {
1133 key: "tensor_type"
1134 value {
1135 type: DT_FLOAT
1136 }
1137 }
1138 }
1139 library {
1140 }
1141 versions {
1142 producer: 24
1143 }
1144 )EOF";
1145
1146 grappler_item_ = std::make_unique<GrapplerItem>();
1147 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1148 &grappler_item_->graph));
1149 grappler_item_->id = "test_graph";
1150 grappler_item_->fetch = {"Recv"};
1151 }
1152
1153 // A simple while loop
CreateGrapplerItemWithLoop()1154 void CreateGrapplerItemWithLoop() {
1155 // Test graph produced in python using:
1156 /*
1157 with tf.Graph().as_default():
1158 i0 = tf.constant(0)
1159 m0 = tf.ones([2, 2])
1160 c = lambda i, m: i < 10
1161 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1162 r = tf.while_loop(
1163 c, b, loop_vars=[i0, m0],
1164 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1165 with open('/tmp/graph.pbtxt', 'w') as f:
1166 f.write(str(tf.get_default_graph().as_graph_def()))
1167 */
1168 const string gdef_ascii = R"EOF(
1169 node {
1170 name: "Const"
1171 op: "Const"
1172 attr {
1173 key: "dtype"
1174 value {
1175 type: DT_INT32
1176 }
1177 }
1178 attr {
1179 key: "value"
1180 value {
1181 tensor {
1182 dtype: DT_INT32
1183 tensor_shape {
1184 }
1185 int_val: 0
1186 }
1187 }
1188 }
1189 }
1190 node {
1191 name: "ones"
1192 op: "Const"
1193 attr {
1194 key: "dtype"
1195 value {
1196 type: DT_FLOAT
1197 }
1198 }
1199 attr {
1200 key: "value"
1201 value {
1202 tensor {
1203 dtype: DT_FLOAT
1204 tensor_shape {
1205 dim {
1206 size: 2
1207 }
1208 dim {
1209 size: 2
1210 }
1211 }
1212 float_val: 1.0
1213 }
1214 }
1215 }
1216 }
1217 node {
1218 name: "while/Enter"
1219 op: "Enter"
1220 input: "Const"
1221 attr {
1222 key: "T"
1223 value {
1224 type: DT_INT32
1225 }
1226 }
1227 attr {
1228 key: "frame_name"
1229 value {
1230 s: "while/while/"
1231 }
1232 }
1233 attr {
1234 key: "is_constant"
1235 value {
1236 b: false
1237 }
1238 }
1239 attr {
1240 key: "parallel_iterations"
1241 value {
1242 i: 10
1243 }
1244 }
1245 }
1246 node {
1247 name: "while/Enter_1"
1248 op: "Enter"
1249 input: "ones"
1250 attr {
1251 key: "T"
1252 value {
1253 type: DT_FLOAT
1254 }
1255 }
1256 attr {
1257 key: "frame_name"
1258 value {
1259 s: "while/while/"
1260 }
1261 }
1262 attr {
1263 key: "is_constant"
1264 value {
1265 b: false
1266 }
1267 }
1268 attr {
1269 key: "parallel_iterations"
1270 value {
1271 i: 10
1272 }
1273 }
1274 }
1275 node {
1276 name: "while/Merge"
1277 op: "Merge"
1278 input: "while/Enter"
1279 input: "while/NextIteration"
1280 attr {
1281 key: "N"
1282 value {
1283 i: 2
1284 }
1285 }
1286 attr {
1287 key: "T"
1288 value {
1289 type: DT_INT32
1290 }
1291 }
1292 }
1293 node {
1294 name: "while/Merge_1"
1295 op: "Merge"
1296 input: "while/Enter_1"
1297 input: "while/NextIteration_1"
1298 attr {
1299 key: "N"
1300 value {
1301 i: 2
1302 }
1303 }
1304 attr {
1305 key: "T"
1306 value {
1307 type: DT_FLOAT
1308 }
1309 }
1310 }
1311 node {
1312 name: "while/Less/y"
1313 op: "Const"
1314 input: "^while/Merge"
1315 attr {
1316 key: "dtype"
1317 value {
1318 type: DT_INT32
1319 }
1320 }
1321 attr {
1322 key: "value"
1323 value {
1324 tensor {
1325 dtype: DT_INT32
1326 tensor_shape {
1327 }
1328 int_val: 10
1329 }
1330 }
1331 }
1332 }
1333 node {
1334 name: "while/Less"
1335 op: "Less"
1336 input: "while/Merge"
1337 input: "while/Less/y"
1338 attr {
1339 key: "T"
1340 value {
1341 type: DT_INT32
1342 }
1343 }
1344 }
1345 node {
1346 name: "while/LoopCond"
1347 op: "LoopCond"
1348 input: "while/Less"
1349 }
1350 node {
1351 name: "while/Switch"
1352 op: "Switch"
1353 input: "while/Merge"
1354 input: "while/LoopCond"
1355 attr {
1356 key: "T"
1357 value {
1358 type: DT_INT32
1359 }
1360 }
1361 attr {
1362 key: "_class"
1363 value {
1364 list {
1365 s: "loc:@while/Merge"
1366 }
1367 }
1368 }
1369 }
1370 node {
1371 name: "while/Switch_1"
1372 op: "Switch"
1373 input: "while/Merge_1"
1374 input: "while/LoopCond"
1375 attr {
1376 key: "T"
1377 value {
1378 type: DT_FLOAT
1379 }
1380 }
1381 attr {
1382 key: "_class"
1383 value {
1384 list {
1385 s: "loc:@while/Merge_1"
1386 }
1387 }
1388 }
1389 }
1390 node {
1391 name: "while/Identity"
1392 op: "Identity"
1393 input: "while/Switch:1"
1394 attr {
1395 key: "T"
1396 value {
1397 type: DT_INT32
1398 }
1399 }
1400 }
1401 node {
1402 name: "while/Identity_1"
1403 op: "Identity"
1404 input: "while/Switch_1:1"
1405 attr {
1406 key: "T"
1407 value {
1408 type: DT_FLOAT
1409 }
1410 }
1411 }
1412 node {
1413 name: "while/add/y"
1414 op: "Const"
1415 input: "^while/Identity"
1416 attr {
1417 key: "dtype"
1418 value {
1419 type: DT_INT32
1420 }
1421 }
1422 attr {
1423 key: "value"
1424 value {
1425 tensor {
1426 dtype: DT_INT32
1427 tensor_shape {
1428 }
1429 int_val: 1
1430 }
1431 }
1432 }
1433 }
1434 node {
1435 name: "while/add"
1436 op: "Add"
1437 input: "while/Identity"
1438 input: "while/add/y"
1439 attr {
1440 key: "T"
1441 value {
1442 type: DT_INT32
1443 }
1444 }
1445 }
1446 node {
1447 name: "while/concat/axis"
1448 op: "Const"
1449 input: "^while/Identity"
1450 attr {
1451 key: "dtype"
1452 value {
1453 type: DT_INT32
1454 }
1455 }
1456 attr {
1457 key: "value"
1458 value {
1459 tensor {
1460 dtype: DT_INT32
1461 tensor_shape {
1462 }
1463 int_val: 0
1464 }
1465 }
1466 }
1467 }
1468 node {
1469 name: "while/concat"
1470 op: "ConcatV2"
1471 input: "while/Identity_1"
1472 input: "while/Identity_1"
1473 input: "while/concat/axis"
1474 attr {
1475 key: "N"
1476 value {
1477 i: 2
1478 }
1479 }
1480 attr {
1481 key: "T"
1482 value {
1483 type: DT_FLOAT
1484 }
1485 }
1486 attr {
1487 key: "Tidx"
1488 value {
1489 type: DT_INT32
1490 }
1491 }
1492 }
1493 node {
1494 name: "while/NextIteration"
1495 op: "NextIteration"
1496 input: "while/add"
1497 attr {
1498 key: "T"
1499 value {
1500 type: DT_INT32
1501 }
1502 }
1503 }
1504 node {
1505 name: "while/NextIteration_1"
1506 op: "NextIteration"
1507 input: "while/concat"
1508 attr {
1509 key: "T"
1510 value {
1511 type: DT_FLOAT
1512 }
1513 }
1514 }
1515 node {
1516 name: "while/Exit"
1517 op: "Exit"
1518 input: "while/Switch"
1519 attr {
1520 key: "T"
1521 value {
1522 type: DT_INT32
1523 }
1524 }
1525 }
1526 node {
1527 name: "while/Exit_1"
1528 op: "Exit"
1529 input: "while/Switch_1"
1530 attr {
1531 key: "T"
1532 value {
1533 type: DT_FLOAT
1534 }
1535 }
1536 }
1537 versions {
1538 producer: 21
1539 }
1540 )EOF";
1541
1542 grappler_item_ = std::make_unique<GrapplerItem>();
1543 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1544 &grappler_item_->graph));
1545 grappler_item_->id = "test_graph";
1546 grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
1547 }
1548
1549 // A simple while loop strengthened with Switch outputs xxx.
CreateGrapplerItemWithLoopAnnotated()1550 void CreateGrapplerItemWithLoopAnnotated() {
1551 // Test graph produced in python using:
1552 /*
1553 with tf.Graph().as_default():
1554 i0 = tf.constant(0)
1555 m0 = tf.ones([2, 2])
1556 c = lambda i, m: i < 10
1557 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1558 r = tf.while_loop(
1559 c, b, loop_vars=[i0, m0],
1560 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1561 with open('/tmp/graph.pbtxt', 'w') as f:
1562 f.write(str(tf.get_default_graph().as_graph_def()))
1563 */
1564 const string gdef_ascii = R"EOF(
1565 node {
1566 name: "Const"
1567 op: "Const"
1568 attr {
1569 key: "dtype"
1570 value {
1571 type: DT_INT32
1572 }
1573 }
1574 attr {
1575 key: "value"
1576 value {
1577 tensor {
1578 dtype: DT_INT32
1579 tensor_shape {
1580 }
1581 int_val: 0
1582 }
1583 }
1584 }
1585 attr {
1586 key: "_execution_count"
1587 value {
1588 i: 1
1589 }
1590 }
1591 }
1592 node {
1593 name: "ones"
1594 op: "Const"
1595 attr {
1596 key: "dtype"
1597 value {
1598 type: DT_FLOAT
1599 }
1600 }
1601 attr {
1602 key: "value"
1603 value {
1604 tensor {
1605 dtype: DT_FLOAT
1606 tensor_shape {
1607 dim {
1608 size: 2
1609 }
1610 dim {
1611 size: 2
1612 }
1613 }
1614 float_val: 1.0
1615 }
1616 }
1617 }
1618 attr {
1619 key: "_execution_count"
1620 value {
1621 i: 1
1622 }
1623 }
1624 }
1625 node {
1626 name: "while/Enter"
1627 op: "Enter"
1628 input: "Const"
1629 attr {
1630 key: "T"
1631 value {
1632 type: DT_INT32
1633 }
1634 }
1635 attr {
1636 key: "frame_name"
1637 value {
1638 s: "while/while/"
1639 }
1640 }
1641 attr {
1642 key: "is_constant"
1643 value {
1644 b: false
1645 }
1646 }
1647 attr {
1648 key: "parallel_iterations"
1649 value {
1650 i: 10
1651 }
1652 }
1653 attr {
1654 key: "_execution_count"
1655 value {
1656 i: 1
1657 }
1658 }
1659 }
1660 node {
1661 name: "while/Enter_1"
1662 op: "Enter"
1663 input: "ones"
1664 attr {
1665 key: "T"
1666 value {
1667 type: DT_FLOAT
1668 }
1669 }
1670 attr {
1671 key: "frame_name"
1672 value {
1673 s: "while/while/"
1674 }
1675 }
1676 attr {
1677 key: "is_constant"
1678 value {
1679 b: false
1680 }
1681 }
1682 attr {
1683 key: "parallel_iterations"
1684 value {
1685 i: 10
1686 }
1687 }
1688 attr {
1689 key: "_execution_count"
1690 value {
1691 i: 1
1692 }
1693 }
1694 }
1695 node {
1696 name: "while/Merge"
1697 op: "Merge"
1698 input: "while/Enter"
1699 input: "while/NextIteration"
1700 attr {
1701 key: "N"
1702 value {
1703 i: 2
1704 }
1705 }
1706 attr {
1707 key: "T"
1708 value {
1709 type: DT_INT32
1710 }
1711 }
1712 attr {
1713 key: "_execution_count"
1714 value {
1715 i: 10
1716 }
1717 }
1718 }
1719 node {
1720 name: "while/Merge_1"
1721 op: "Merge"
1722 input: "while/Enter_1"
1723 input: "while/NextIteration_1"
1724 attr {
1725 key: "N"
1726 value {
1727 i: 2
1728 }
1729 }
1730 attr {
1731 key: "T"
1732 value {
1733 type: DT_FLOAT
1734 }
1735 }
1736 attr {
1737 key: "_execution_count"
1738 value {
1739 i: 10
1740 }
1741 }
1742 }
1743 node {
1744 name: "while/Less/y"
1745 op: "Const"
1746 input: "^while/Merge"
1747 attr {
1748 key: "dtype"
1749 value {
1750 type: DT_INT32
1751 }
1752 }
1753 attr {
1754 key: "value"
1755 value {
1756 tensor {
1757 dtype: DT_INT32
1758 tensor_shape {
1759 }
1760 int_val: 10
1761 }
1762 }
1763 }
1764 attr {
1765 key: "_execution_count"
1766 value {
1767 i: 10
1768 }
1769 }
1770 }
1771 node {
1772 name: "while/Less"
1773 op: "Less"
1774 input: "while/Merge"
1775 input: "while/Less/y"
1776 attr {
1777 key: "T"
1778 value {
1779 type: DT_INT32
1780 }
1781 }
1782 attr {
1783 key: "_execution_count"
1784 value {
1785 i: 10
1786 }
1787 }
1788 }
1789 node {
1790 name: "while/LoopCond"
1791 op: "LoopCond"
1792 input: "while/Less"
1793 attr {
1794 key: "_execution_count"
1795 value {
1796 i: 10
1797 }
1798 }
1799 }
1800 node {
1801 name: "while/Switch"
1802 op: "Switch"
1803 input: "while/Merge"
1804 input: "while/LoopCond"
1805 attr {
1806 key: "T"
1807 value {
1808 type: DT_INT32
1809 }
1810 }
1811 attr {
1812 key: "_class"
1813 value {
1814 list {
1815 s: "loc:@while/Merge"
1816 }
1817 }
1818 }
1819 attr {
1820 key: "_execution_count"
1821 value {
1822 i: 11
1823 }
1824 }
1825 attr {
1826 key: "_output_slot_vector"
1827 value {
1828 list {
1829 i: 1
1830 i: 1
1831 i: 1
1832 i: 1
1833 i: 1
1834 i: 1
1835 i: 1
1836 i: 1
1837 i: 1
1838 i: 1
1839 i: 0
1840 }
1841 }
1842 }
1843 }
1844 node {
1845 name: "while/Switch_1"
1846 op: "Switch"
1847 input: "while/Merge_1"
1848 input: "while/LoopCond"
1849 attr {
1850 key: "T"
1851 value {
1852 type: DT_FLOAT
1853 }
1854 }
1855 attr {
1856 key: "_class"
1857 value {
1858 list {
1859 s: "loc:@while/Merge_1"
1860 }
1861 }
1862 }
1863 attr {
1864 key: "_execution_count"
1865 value {
1866 i: 11
1867 }
1868 }
1869 attr {
1870 key: "_output_slot_vector"
1871 value {
1872 list {
1873 i: 1
1874 i: 1
1875 i: 1
1876 i: 1
1877 i: 1
1878 i: 1
1879 i: 1
1880 i: 1
1881 i: 1
1882 i: 1
1883 i: 0
1884 }
1885 }
1886 }
1887 }
1888 node {
1889 name: "while/Identity"
1890 op: "Identity"
1891 input: "while/Switch:1"
1892 attr {
1893 key: "T"
1894 value {
1895 type: DT_INT32
1896 }
1897 }
1898 attr {
1899 key: "_execution_count"
1900 value {
1901 i: 10
1902 }
1903 }
1904 }
1905 node {
1906 name: "while/Identity_1"
1907 op: "Identity"
1908 input: "while/Switch_1:1"
1909 attr {
1910 key: "T"
1911 value {
1912 type: DT_FLOAT
1913 }
1914 }
1915 attr {
1916 key: "_execution_count"
1917 value {
1918 i: 10
1919 }
1920 }
1921 }
1922 node {
1923 name: "while/add/y"
1924 op: "Const"
1925 input: "^while/Identity"
1926 attr {
1927 key: "dtype"
1928 value {
1929 type: DT_INT32
1930 }
1931 }
1932 attr {
1933 key: "value"
1934 value {
1935 tensor {
1936 dtype: DT_INT32
1937 tensor_shape {
1938 }
1939 int_val: 1
1940 }
1941 }
1942 }
1943 attr {
1944 key: "_execution_count"
1945 value {
1946 i: 10
1947 }
1948 }
1949 }
1950 node {
1951 name: "while/add"
1952 op: "Add"
1953 input: "while/Identity"
1954 input: "while/add/y"
1955 attr {
1956 key: "T"
1957 value {
1958 type: DT_INT32
1959 }
1960 }
1961 attr {
1962 key: "_execution_count"
1963 value {
1964 i: 10
1965 }
1966 }
1967 }
1968 node {
1969 name: "while/concat/axis"
1970 op: "Const"
1971 input: "^while/Identity"
1972 attr {
1973 key: "dtype"
1974 value {
1975 type: DT_INT32
1976 }
1977 }
1978 attr {
1979 key: "value"
1980 value {
1981 tensor {
1982 dtype: DT_INT32
1983 tensor_shape {
1984 }
1985 int_val: 0
1986 }
1987 }
1988 }
1989 attr {
1990 key: "_execution_count"
1991 value {
1992 i: 10
1993 }
1994 }
1995 }
1996 node {
1997 name: "while/concat"
1998 op: "ConcatV2"
1999 input: "while/Identity_1"
2000 input: "while/Identity_1"
2001 input: "while/concat/axis"
2002 attr {
2003 key: "N"
2004 value {
2005 i: 2
2006 }
2007 }
2008 attr {
2009 key: "T"
2010 value {
2011 type: DT_FLOAT
2012 }
2013 }
2014 attr {
2015 key: "Tidx"
2016 value {
2017 type: DT_INT32
2018 }
2019 }
2020 attr {
2021 key: "_execution_count"
2022 value {
2023 i: 10
2024 }
2025 }
2026 }
2027 node {
2028 name: "while/NextIteration"
2029 op: "NextIteration"
2030 input: "while/add"
2031 attr {
2032 key: "T"
2033 value {
2034 type: DT_INT32
2035 }
2036 }
2037 attr {
2038 key: "_execution_count"
2039 value {
2040 i: 10
2041 }
2042 }
2043 }
2044 node {
2045 name: "while/NextIteration_1"
2046 op: "NextIteration"
2047 input: "while/concat"
2048 attr {
2049 key: "T"
2050 value {
2051 type: DT_FLOAT
2052 }
2053 }
2054 attr {
2055 key: "_execution_count"
2056 value {
2057 i: 10
2058 }
2059 }
2060 }
2061 node {
2062 name: "while/Exit"
2063 op: "Exit"
2064 input: "while/Switch"
2065 attr {
2066 key: "T"
2067 value {
2068 type: DT_INT32
2069 }
2070 }
2071 attr {
2072 key: "_execution_count"
2073 value {
2074 i: 1
2075 }
2076 }
2077 }
2078 node {
2079 name: "while/Exit_1"
2080 op: "Exit"
2081 input: "while/Switch_1"
2082 attr {
2083 key: "T"
2084 value {
2085 type: DT_FLOAT
2086 }
2087 }
2088 attr {
2089 key: "_execution_count"
2090 value {
2091 i: 1
2092 }
2093 }
2094 }
2095 versions {
2096 producer: 21
2097 }
2098 )EOF";
2099
2100 grappler_item_.reset(new GrapplerItem);
2101 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2102 &grappler_item_->graph));
2103 grappler_item_->id = "test_graph";
2104 grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
2105 }
2106
2107 // A simple condition graph.
CreateGrapplerItemWithCondition()2108 void CreateGrapplerItemWithCondition() {
2109 // Handcrafted test graph: a/Less -> Switch -> First/Second -> Merge.
2110 const string gdef_ascii = R"EOF(
2111 node {
2112 name: "a"
2113 op: "Const"
2114 attr {
2115 key: "dtype"
2116 value {
2117 type: DT_FLOAT
2118 }
2119 }
2120 attr {
2121 key: "value"
2122 value {
2123 tensor {
2124 dtype: DT_FLOAT
2125 tensor_shape {
2126 }
2127 float_val: 2.0
2128 }
2129 }
2130 }
2131 }
2132 node {
2133 name: "Less"
2134 op: "Const"
2135 attr {
2136 key: "dtype"
2137 value {
2138 type: DT_BOOL
2139 }
2140 }
2141 attr {
2142 key: "value"
2143 value {
2144 tensor {
2145 dtype: DT_BOOL
2146 tensor_shape {
2147 }
2148 tensor_content: "\001"
2149 }
2150 }
2151 }
2152 }
2153 node {
2154 name: "Switch"
2155 op: "Switch"
2156 input: "a"
2157 input: "Less"
2158 attr {
2159 key: "T"
2160 value {
2161 type: DT_FLOAT
2162 }
2163 }
2164 }
2165 node {
2166 name: "First"
2167 op: "Identity"
2168 input: "Switch"
2169 attr {
2170 key: "T"
2171 value {
2172 type: DT_FLOAT
2173 }
2174 }
2175 }
2176 node {
2177 name: "Second"
2178 op: "Identity"
2179 input: "Switch:1"
2180 attr {
2181 key: "T"
2182 value {
2183 type: DT_FLOAT
2184 }
2185 }
2186 }
2187 node {
2188 name: "Merge"
2189 op: "Merge"
2190 input: "First"
2191 input: "Second"
2192 attr {
2193 key: "N"
2194 value {
2195 i: 2
2196 }
2197 }
2198 attr {
2199 key: "T"
2200 value {
2201 type: DT_FLOAT
2202 }
2203 }
2204 }
2205 versions {
2206 producer: 27
2207 })EOF";
2208
2209 grappler_item_ = std::make_unique<GrapplerItem>();
2210 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2211 &grappler_item_->graph));
2212 grappler_item_->id = "test_graph";
2213 grappler_item_->fetch = {"Merge"};
2214 }
2215
2216 // Create a FusedBatchNorm op that has multiple output ports.
CreateGrapplerItemWithInterDeviceTransfers()2217 void CreateGrapplerItemWithInterDeviceTransfers() {
2218 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
2219
2220 // Create a FusedBatchNorm op that has multiple output ports.
2221 auto x = ops::RandomUniform(
2222 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
2223 auto scale =
2224 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
2225 auto offset =
2226 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
2227 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
2228 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
2229
2230 auto batch_norm = ops::FusedBatchNorm(
2231 s.WithOpName("bn"), x, scale, offset, mean, var,
2232 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
2233 auto y = batch_norm.y;
2234 auto batch_mean = batch_norm.batch_mean;
2235 auto batch_var = batch_norm.batch_variance;
2236 // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
2237 auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
2238 auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
2239 // batch_mean1 and batch_var1 take different output ports, so each will
2240 // initiate Send/Recv.
2241 auto batch_mean1 = ops::Identity(
2242 s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
2243 auto batch_var1 =
2244 ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
2245 // This is control dependency.
2246 auto control_dep = ops::NoOp(s.WithOpName("control_dep")
2247 .WithControlDependencies(y)
2248 .WithDevice(kCPU1));
2249
2250 grappler_item_ = std::make_unique<GrapplerItem>();
2251 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
2252 grappler_item_->id = "test_conv2d_graph";
2253 grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
2254 "control_dep"};
2255
2256 dependency_["bn"] = {"x", "mean", "var"};
2257 dependency_["y1"] = {"bn"};
2258 dependency_["y2"] = {"bn"};
2259 dependency_["batch_mean1"] = {"bn"};
2260 dependency_["batch_var1"] = {"bn"};
2261 dependency_["control_dep"] = {"bn"};
2262 }
2263
2264 // Call this after creating grappler_item_ and setting up dependency_.
InitScheduler()2265 void InitScheduler() { TF_ASSERT_OK(scheduler_->Init(grappler_item_.get())); }
2266
2267 // Returns cost based on op.
SimplePredictCosts(const OpContext & op_context) const2268 Costs SimplePredictCosts(const OpContext& op_context) const {
2269 Costs c;
2270 int64_t exec_cost = 0;
2271 if (op_context.op_info.op() == "MatMul") {
2272 exec_cost = 2000000000;
2273 } else if (op_context.op_info.op() == "RandomUniform") {
2274 exec_cost = 1000000000;
2275 } else {
2276 exec_cost = 1000;
2277 }
2278 c.execution_time = Costs::NanoSeconds(exec_cost);
2279 return c;
2280 }
2281
2282 // Call this after init scheduler_. Scheduler stops after executing
2283 // target_node.
RunScheduler(const string & target_node)2284 std::unordered_map<string, OpContext> RunScheduler(
2285 const string& target_node) {
2286 std::unordered_map<string, OpContext> ops_executed;
2287 bool more_nodes = true;
2288 do {
2289 OpContext op_context = scheduler_->GetCurrNode();
2290 ops_executed[op_context.name] = op_context;
2291 std::cout << op_context.name << std::endl;
2292
2293 Costs node_costs = SimplePredictCosts(op_context);
2294
2295 // Check scheduling order.
2296 auto it = dependency_.find(op_context.name);
2297 if (it != dependency_.end()) {
2298 for (const auto& preceding_node : it->second) {
2299 EXPECT_GT(ops_executed.count(preceding_node), 0);
2300 }
2301 }
2302 more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
2303
2304 if (op_context.name == target_node) {
2305 // Scheduler has the state after executing the target node.
2306 break;
2307 }
2308 } while (more_nodes);
2309 return ops_executed;
2310 }
2311
2312 // Helper method for validating a vector.
2313 template <typename T>
ExpectVectorEq(const std::vector<T> & expected,const std::vector<T> & test_elements)2314 void ExpectVectorEq(const std::vector<T>& expected,
2315 const std::vector<T>& test_elements) {
2316 // Set of expected elements for an easy comparison.
2317 std::set<T> expected_set(expected.begin(), expected.end());
2318 for (const auto& element : test_elements) {
2319 EXPECT_GT(expected_set.count(element), 0);
2320 }
2321 EXPECT_EQ(expected.size(), test_elements.size());
2322 }
2323
2324 // Helper method that checks the name of nodes.
ValidateNodeDefs(const std::vector<string> & expected,const std::vector<const NodeDef * > & node_defs)2325 void ValidateNodeDefs(const std::vector<string>& expected,
2326 const std::vector<const NodeDef*>& node_defs) {
2327 std::vector<string> node_names;
2328 std::transform(node_defs.begin(), node_defs.end(),
2329 std::back_inserter(node_names),
2330 [](const NodeDef* node) { return node->name(); });
2331 ExpectVectorEq(expected, node_names);
2332 }
2333
2334 // Helper method for validating a set.
2335 template <typename T>
ExpectSetEq(const std::set<T> & expected,const std::set<T> & test_elements)2336 void ExpectSetEq(const std::set<T>& expected,
2337 const std::set<T>& test_elements) {
2338 for (const auto& element : test_elements) {
2339 EXPECT_GT(expected.count(element), 0);
2340 }
2341 EXPECT_EQ(expected.size(), test_elements.size());
2342 }
2343
2344 // Helper method for validating an unordered map.
2345 template <typename T, typename U>
ExpectUnorderedMapEq(const std::unordered_map<T,U> & expected,const std::unordered_map<T,U> & test_map)2346 void ExpectUnorderedMapEq(const std::unordered_map<T, U>& expected,
2347 const std::unordered_map<T, U>& test_map) {
2348 EXPECT_EQ(expected.size(), test_map.size());
2349 for (const auto& key_val : expected) {
2350 EXPECT_GT(test_map.count(key_val.first), 0);
2351 EXPECT_EQ(test_map.at(key_val.first), key_val.second);
2352 }
2353 }
2354
2355 // Helper method that checks name - port pairs.
ValidateMemoryUsageSnapshot(const std::vector<string> & expected_names,const int port_num_expected,const std::unordered_set<std::pair<const NodeDef *,int>,DeviceState::NodePairHash> & mem_usage_snapshot)2356 void ValidateMemoryUsageSnapshot(
2357 const std::vector<string>& expected_names, const int port_num_expected,
2358 const std::unordered_set<std::pair<const NodeDef*, int>,
2359 DeviceState::NodePairHash>& mem_usage_snapshot) {
2360 std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
2361 std::transform(
2362 mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
2363 std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
2364 [](const std::pair<const NodeDef*, int>& node_port) {
2365 return std::make_pair(node_port.first->name(), node_port.second);
2366 });
2367 std::set<std::pair<string, int>> expected;
2368 std::transform(expected_names.begin(), expected_names.end(),
2369 std::inserter(expected, expected.begin()),
2370 [port_num_expected](const string& name) {
2371 return std::make_pair(name, port_num_expected);
2372 });
2373 ExpectSetEq(expected, nodes_at_peak_mem_usage);
2374 }
2375
2376 // Helper method for checking nodes dependency.
ValidateDependencyChain(const std::unordered_map<string,int64_t> & start_times,const std::vector<string> & nodes_in_dependency_order)2377 void ValidateDependencyChain(
2378 const std::unordered_map<string, int64_t>& start_times,
2379 const std::vector<string>& nodes_in_dependency_order) {
2380 int64_t prev_node_time = -1;
2381 for (const auto& node : nodes_in_dependency_order) {
2382 int64_t curr_node_time = start_times.at(node);
2383 EXPECT_GE(curr_node_time, prev_node_time);
2384 prev_node_time = curr_node_time;
2385 }
2386 }
2387
2388 // cluster_ and scheduler_ are initialized in the c'tor.
2389 std::unique_ptr<VirtualCluster> cluster_;
2390 std::unique_ptr<TestVirtualScheduler> scheduler_;
2391 FirstReadyManager first_ready_manager_;
2392 CompositeNodeManager composite_node_manager_;
2393
2394 // grappler_item_ will be initialized differently for each test case.
2395 std::unique_ptr<GrapplerItem> grappler_item_;
2396 // Node name -> its preceding nodes map for testing scheduling order.
2397 std::unordered_map<string, std::vector<string>> dependency_;
2398
2399 // Shared params for Conv2D related graphs:
2400 const int batch_size_ = 4;
2401 const int width_ = 10;
2402 const int height_ = 10;
2403 const int depth_in_ = 8;
2404 const int kernel_ = 3;
2405 const int depth_out_ = 16;
2406 };
2407
2408 // Create small graph, run predict costs on it, make sure the costs from the
2409 // summary match the hand-calculated costs.
TEST_F(VirtualSchedulerTest,SummaryCostTest)2410 TEST_F(VirtualSchedulerTest, SummaryCostTest) {
2411 // Run matmul test.
2412 CreateGrapplerItemWithMatmulChain();
2413 InitScheduler();
2414 auto ops_executed = RunScheduler("");
2415 Costs c = scheduler_->Summary();
2416
2417 // RandomUniform - 5 * 1s
2418 // Matmuls - 4 * 2s = 8
2419 // Misc - 5 * 1us
2420 // Total: 13000005
2421 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2422 EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2423 EXPECT_FALSE(c.inaccurate);
2424 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2425 }
2426
2427 // Like the above SummaryCostTest, but makes sure the stepstats timeline is
2428 // correct.
TEST_F(VirtualSchedulerTest,SummaryCostStepStatsTest)2429 TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
2430 // Run matmul test.
2431 CreateGrapplerItemWithMatmulChain();
2432 InitScheduler();
2433 auto ops_executed = RunScheduler("");
2434 RunMetadata metadata;
2435 Costs c = scheduler_->Summary(&metadata);
2436 StepStats stepstats = metadata.step_stats();
2437 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2438 EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2439 EXPECT_FALSE(c.inaccurate);
2440 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2441
2442 // Should only be 1 device!
2443 EXPECT_EQ(1, stepstats.dev_stats().size());
2444
2445 // Create a map of op name -> start and end times (micros).
2446 std::map<string, std::pair<int64_t, int64_t>> start_end_times;
2447 for (const auto& device_step_stats : stepstats.dev_stats()) {
2448 for (const auto& stats : device_step_stats.node_stats()) {
2449 int64_t start = stats.all_start_micros();
2450 int64_t end = start + stats.all_end_rel_micros();
2451 start_end_times[stats.node_name()] =
2452 std::pair<int64_t, int64_t>(start, end);
2453
2454 // Make sure that the output properties are correct for
2455 // MatMul and RandomUniform operations.
2456 // We only check for dtype, and shape (excluding alloc)
2457 // since alloc is not set by the virtual scheduler.
2458 if (stats.timeline_label() == "MatMul" ||
2459 stats.timeline_label() == "RandomUniform") {
2460 EXPECT_EQ(1, stats.output().size());
2461 for (const auto& output : stats.output()) {
2462 EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype());
2463 EXPECT_EQ(2, output.tensor_description().shape().dim().size());
2464 for (const auto& dim : output.tensor_description().shape().dim()) {
2465 EXPECT_EQ(3200, dim.size());
2466 }
2467 }
2468 }
2469 }
2470 }
2471
2472 // The base start_time is the time to compute RandomUniforms
2473 int64_t cur_time = static_cast<int64_t>(5000005);
2474 // The increment is the execution time of one matmul. See
2475 // CreateGrapplerItemWithMatmulChain for details.
2476 int64_t increment = static_cast<int64_t>(2000000);
2477 auto op_names = {"ab", "abc", "abcd", "abcde"};
2478 for (const auto& op_name : op_names) {
2479 int64_t actual_start = start_end_times[op_name].first;
2480 int64_t actual_end = start_end_times[op_name].second;
2481 int64_t expected_start = cur_time;
2482 int64_t expected_end = cur_time + increment;
2483 EXPECT_EQ(expected_start, actual_start);
2484 EXPECT_EQ(expected_end, actual_end);
2485 cur_time += increment;
2486 }
2487 }
2488
TEST_F(VirtualSchedulerTest,InitAndBasicScheduling)2489 TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
2490 // Init.
2491 CreateGrapplerItemWithConv2Ds();
2492 InitScheduler();
2493
2494 // Run the scheduler.
2495 auto ops_executed = RunScheduler(""); // Run all the nodes.
2496
2497 // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
2498 // executed.
2499 EXPECT_EQ(8, ops_executed.size());
2500
2501 // x, y, f, c0, and c1 should be in the ops executed.
2502 EXPECT_GT(ops_executed.count("x"), 0);
2503 EXPECT_GT(ops_executed.count("y"), 0);
2504 EXPECT_GT(ops_executed.count("f"), 0);
2505 EXPECT_GT(ops_executed.count("c0"), 0);
2506 EXPECT_GT(ops_executed.count("c1"), 0);
2507
2508 // z and c2 shouldn't be part of it.
2509 EXPECT_EQ(ops_executed.count("z"), 0);
2510 EXPECT_EQ(ops_executed.count("c2"), 0);
2511
2512 // Check input / output properties.
2513 EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size());
2514 EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size());
2515 EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size());
2516 EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
2517 EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
2518 }
2519
TEST_F(VirtualSchedulerTest,MemoryUsage)2520 TEST_F(VirtualSchedulerTest, MemoryUsage) {
2521 // Init.
2522 CreateGrapplerItemWithAddN();
2523 InitScheduler();
2524
2525 // Run the scheduler.
2526 RunScheduler("");
2527
2528 const auto* device_states = scheduler_->GetDeviceStates();
2529 const auto& cpu_state = device_states->at(kCPU0);
2530
2531 // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
2532 // is 4 x the input tensor size while executing the out node.
2533 int64_t one_input_node_size = 4 * 10 * 10 * 10 * 10;
2534 const std::vector<string> expected_names = {"x", "y", "z", "w", "add"};
2535 EXPECT_EQ(expected_names.size() * one_input_node_size,
2536 cpu_state.max_memory_usage);
2537 ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
2538 cpu_state.mem_usage_snapshot_at_peak);
2539
2540 // Total 10 nodes: Four const, x, y, z, w, add, out.
2541 ASSERT_EQ(cpu_state.temporary_memory_usage_trace.size(), 10);
2542 const std::pair<std::string, int64_t>& x_usage =
2543 cpu_state.temporary_memory_usage_trace.at(4);
2544 EXPECT_EQ(x_usage.first, "x");
2545 EXPECT_EQ(x_usage.second, one_input_node_size);
2546 const std::pair<std::string, int64_t>& add_usage =
2547 cpu_state.temporary_memory_usage_trace.at(8);
2548 EXPECT_EQ(add_usage.first, "add");
2549 EXPECT_EQ(add_usage.second, 5 * one_input_node_size);
2550 const std::pair<std::string, int64_t>& out_usage =
2551 cpu_state.temporary_memory_usage_trace.at(9);
2552 EXPECT_EQ(out_usage.first, "out");
2553 EXPECT_EQ(out_usage.second, one_input_node_size);
2554 ExpectUnorderedMapEq(
2555 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 64)},
2556 scheduler_->GetPersistentMemoryUsage());
2557 ExpectUnorderedMapEq(
2558 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 200000)},
2559 scheduler_->GetPeakMemoryUsage());
2560 }
2561
TEST_F(VirtualSchedulerTest,MemoryUsageForStreamingOps)2562 TEST_F(VirtualSchedulerTest, MemoryUsageForStreamingOps) {
2563 // Init.
2564 CreateGrapplerItemWithAddN();
2565 auto& graph = grappler_item_->graph;
2566 // Nodes add and out are placed on CPU1.
2567 // Nodes x, y are allocate in memory, while Nodes z and w are streaming nodes.
2568 for (auto& node : *graph.mutable_node()) {
2569 if (node.name() == "out" || node.name() == "add") {
2570 node.set_device(kCPU1);
2571 }
2572 if (node.name() == "z" || node.name() == "w")
2573 (*node.mutable_attr())[kStreaming].mutable_list()->add_b(true);
2574 }
2575
2576 InitScheduler();
2577
2578 // Run the scheduler.
2579 auto ops_executed = RunScheduler("");
2580
2581 const auto* device_states = scheduler_->GetDeviceStates();
2582 const auto& cpu_state_0 = device_states->at(kCPU0);
2583 const auto& cpu_state_1 = device_states->at(kCPU1);
2584 // All tensors are of the same size, 10 x 10 x 10 x 10.
2585 int64_t one_input_node_size = 4 * 10 * 10 * 10 * 10;
2586 const std::vector<string> cpu_0_expected_tensors = {"x", "y"};
2587 const std::vector<string> cpu_1_expected_tensors = {"x", "y", "add"};
2588 EXPECT_EQ(cpu_0_expected_tensors.size() * one_input_node_size,
2589 cpu_state_0.max_memory_usage);
2590 EXPECT_EQ(cpu_1_expected_tensors.size() * one_input_node_size,
2591 cpu_state_1.max_memory_usage);
2592 // After the graph is executed, at the end, memory usage for the device
2593 // should be zero.
2594 EXPECT_EQ(cpu_state_0.memory_usage, 0);
2595 EXPECT_EQ(cpu_state_1.memory_usage, 0);
2596 }
2597
TEST_F(VirtualSchedulerTest,MemoryUsageWithExecutionCount)2598 TEST_F(VirtualSchedulerTest, MemoryUsageWithExecutionCount) {
2599 // Init.
2600 CreateGrapplerItemWithAddN();
2601 auto& graph = grappler_item_->graph;
2602 // Repeat execution for each node.
2603 for (auto& node : *graph.mutable_node()) {
2604 (*node.mutable_attr())[kExecutionCount].set_i(10000);
2605 }
2606
2607 InitScheduler();
2608
2609 // Run the scheduler.
2610 auto ops_executed = RunScheduler("");
2611
2612 const auto* device_states = scheduler_->GetDeviceStates();
2613 const auto& cpu_state_0 = device_states->at(kCPU0);
2614 // All tensors are of the same size, 10 x 10 x 10 x 10.
2615 int64_t one_input_node_size = 4 * 10 * 10 * 10 * 10;
2616 const std::vector<string> expected_names = {"x", "y", "z", "w", "add"};
2617 // Max memory usage does not rely on the number of executions.
2618 EXPECT_EQ(expected_names.size() * one_input_node_size,
2619 cpu_state_0.max_memory_usage);
2620 // After the graph is executed, at the end, memory usage for the device
2621 // should be zero.
2622 EXPECT_EQ(cpu_state_0.memory_usage, 0);
2623
2624 Costs c = scheduler_->Summary();
2625 EXPECT_EQ(64, c.persistent_memory);
2626 EXPECT_EQ(200000, c.temporary_memory);
2627 EXPECT_EQ(200064, c.max_memory);
2628 }
2629
TEST_F(VirtualSchedulerTest,UnnecessaryFeedNodes)2630 TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) {
2631 CreateGrapplerItemWithUnnecessaryPlaceholderNodes();
2632 InitScheduler();
2633
2634 // Test that scheduler can run graphs with extra unnecessary feed nodes.
2635 auto ops_executed = RunScheduler("");
2636 ASSERT_EQ(1, ops_executed.size());
2637 ASSERT_EQ(ops_executed.count("x"), 1);
2638 }
2639
TEST_F(VirtualSchedulerTest,ControlDependency)2640 TEST_F(VirtualSchedulerTest, ControlDependency) {
2641 // Init.
2642 CreateGrapplerItemWithControlDependency();
2643 InitScheduler();
2644
2645 // Run the scheduler.
2646 RunScheduler("");
2647
2648 const auto* device_states = scheduler_->GetDeviceStates();
2649 const auto& cpu_state = device_states->at(kCPU0);
2650
2651 // The graph has a NoOp that takes control dependency from 7 NoOps. The peak
2652 // memory usage is when executing the final NoOp.
2653 int64_t one_input_node_size = 4; // control dependency
2654 const std::vector<string> expected_names = {"x", "y", "z", "w",
2655 "u", "v", "t"};
2656 EXPECT_EQ(expected_names.size() * one_input_node_size,
2657 cpu_state.max_memory_usage);
2658 ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
2659 cpu_state.mem_usage_snapshot_at_peak);
2660 ExpectUnorderedMapEq(
2661 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 0)},
2662 scheduler_->GetPersistentMemoryUsage());
2663 ExpectUnorderedMapEq(
2664 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 28)},
2665 scheduler_->GetPeakMemoryUsage());
2666 }
2667
TEST_F(VirtualSchedulerTest,ComplexDependency)2668 TEST_F(VirtualSchedulerTest, ComplexDependency) {
2669 // Init.
2670 CreateGrapplerItemWithBatchNorm();
2671 InitScheduler();
2672
2673 // Run the scheduler.
2674 RunScheduler("bn");
2675
2676 const auto& device_states = scheduler_->GetDeviceStates();
2677 const auto& cpu_state = device_states->at(kCPU0);
2678
2679 // The graph is
2680 // bn = FusedBatchNorm(x, scale, offset, mean, var)
2681 // z1 = bn.y + x
2682 // z2 = bn.var + bn.var
2683 // z3 = bn.var + bn.var
2684 // z4 = control dependency from bn.
2685 // Note that bn.mean doesn't have any consumer.
2686 const int x_size = batch_size_ * width_ * height_ * depth_in_;
2687 int64_t expected_size =
2688 4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
2689 1 /* control dependency */);
2690 EXPECT_EQ(expected_size, cpu_state.memory_usage);
2691
2692 // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0.
2693 std::set<std::pair<string, int>> nodes_in_memory;
2694 std::transform(
2695 cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
2696 std::inserter(nodes_in_memory, nodes_in_memory.begin()),
2697 [](const std::pair<const NodeDef*, int>& node_port) {
2698 return std::make_pair(node_port.first->name(), node_port.second);
2699 });
2700 std::set<std::pair<string, int>> expected = {
2701 std::make_pair("bn", -1),
2702 std::make_pair("bn", 0),
2703 std::make_pair("bn", 2),
2704 std::make_pair("x", 0),
2705 };
2706 ExpectSetEq(expected, nodes_in_memory);
2707
2708 const auto* node_states = scheduler_->GetNodeStates();
2709 const NodeState* bn_node = nullptr;
2710 const NodeState* x_node = nullptr;
2711 for (const auto& nodedef_node_state : *node_states) {
2712 const NodeDef* node = nodedef_node_state.first;
2713 const NodeState& node_state = nodedef_node_state.second;
2714 if (node->name() == "bn") {
2715 bn_node = &node_state;
2716 }
2717 if (node->name() == "x") {
2718 x_node = &node_state;
2719 }
2720 }
2721 CHECK_NOTNULL(bn_node);
2722 CHECK_NOTNULL(x_node);
2723
2724 ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
2725 ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
2726 ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
2727 // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
2728 ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
2729 }
2730
TEST_F(VirtualSchedulerTest,Variable)2731 TEST_F(VirtualSchedulerTest, Variable) {
2732 // Init.
2733 CreateGrapplerItemWithConv2DAndVariable();
2734 InitScheduler();
2735
2736 // Run the scheduler.
2737 RunScheduler("");
2738
2739 const auto* device_states = scheduler_->GetDeviceStates();
2740 const auto& cpu_state = device_states->at(kCPU0);
2741
2742 // There is one Conv2D that takes x and f, but f is variable, so it should be
2743 // in persistent nodes.
2744 ValidateMemoryUsageSnapshot({"f", "Const/Const"}, /*port_num_expected=*/0,
2745 cpu_state.persistent_nodes);
2746 // Only x in peak memory usage snapshot.
2747 ValidateMemoryUsageSnapshot({"x"}, /*port_num_expected=*/0,
2748 cpu_state.mem_usage_snapshot_at_peak);
2749 ExpectUnorderedMapEq(
2750 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 4624)},
2751 scheduler_->GetPersistentMemoryUsage());
2752 ExpectUnorderedMapEq(
2753 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 12800)},
2754 scheduler_->GetPeakMemoryUsage());
2755 }
2756
TEST_F(VirtualSchedulerTest,WhileLoop)2757 TEST_F(VirtualSchedulerTest, WhileLoop) {
2758 // Init.
2759 CreateGrapplerItemWithLoop();
2760 InitScheduler();
2761
2762 // Run the scheduler.
2763 RunScheduler("");
2764
2765 // Check the timeline
2766 RunMetadata metadata;
2767 scheduler_->Summary(&metadata);
2768
2769 // Nodes in topological order:
2770 // * const, ones
2771 // * while/Enter, while/Enter_1
2772 // * while/Merge, while/Merge_1
2773 // * while/Less/y
2774 // * while/Less
2775 // * while/LoopCond
2776 // * while/Switch, while/Switch_1
2777 // * while/Identity, while/Identity_1, while/Exit, while/Exit_1
2778 // * while/add/y, while/concat/axis
2779 // * while/add, while/concat
2780 // * while/NextIteration, while/NextIteration_1
2781
2782 int num_next_iteration = 0;
2783 int num_next_iteration_1 = 0;
2784 int num_exit = 0;
2785 int num_exit_1 = 0;
2786 int64_t next_iter_start_micro;
2787 int64_t next_iter_1_start_micro;
2788 int64_t exit_start_micro;
2789 int64_t exit_1_start_micro;
2790
2791 std::unordered_map<string, int64_t> start_times;
2792 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2793 for (const auto& stats : device_step_stats.node_stats()) {
2794 start_times[stats.node_name()] = stats.all_start_micros();
2795 if (stats.node_name() == "while/NextIteration") {
2796 ++num_next_iteration;
2797 next_iter_start_micro = stats.all_start_micros();
2798 } else if (stats.node_name() == "while/NextIteration_1") {
2799 ++num_next_iteration_1;
2800 next_iter_1_start_micro = stats.all_start_micros();
2801 } else if (stats.node_name() == "while/Exit") {
2802 ++num_exit;
2803 exit_start_micro = stats.all_start_micros();
2804 } else if (stats.node_name() == "while/Exit_1") {
2805 ++num_exit_1;
2806 exit_1_start_micro = stats.all_start_micros();
2807 }
2808 }
2809 }
2810
2811 // Make sure we went though the body of the loop once, and that the output of
2812 // the loop was scheduled as well.
2813 EXPECT_EQ(1, num_next_iteration);
2814 EXPECT_EQ(1, num_next_iteration_1);
2815 EXPECT_EQ(1, num_exit);
2816 EXPECT_EQ(1, num_exit_1);
2817
2818 // Start times of while/NextIteration and while/NextIteration_1 should be
2819 // different, so should be those of while/Exit and while/Exit_1.
2820 EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
2821 EXPECT_NE(exit_start_micro, exit_1_start_micro);
2822
2823 // Check dependency among the nodes; no matter what scheduling mechanism we
2824 // use, the scheduled ops should follow these dependency chains.
2825 // Note that currently, VirtualScheduler executes while/Merge twice; hence,
2826 // we're not testing dependency chains related to while/Merge.
2827 // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the
2828 // order of Enter, Merge, ...loop condition ..., ... loop body ...,
2829 // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency
2830 // chaining test w/ Merge nodes.
2831 ValidateDependencyChain(
2832 start_times,
2833 {"Const", "while/Enter", // "while/Merge",
2834 "while/Less/y", "while/Less", "while/LoopCond", "while/Switch",
2835 "while/Identity", "while/add/y", "while/add", "while/NextIteration"});
2836 // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"});
2837 ValidateDependencyChain(start_times,
2838 {"ones", "while/Enter_1", // "while/Merge_1",
2839 "while/Switch_1", "while/Identity_1", "while/concat",
2840 "while/NextIteration_1"});
2841 ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"});
2842 ValidateDependencyChain(
2843 start_times, {"while/Identity", "while/concat/axis", "while/concat"});
2844 ValidateDependencyChain(start_times, {"while/Identity", "while/add"});
2845 ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"});
2846 }
2847
TEST_F(VirtualSchedulerTest,AnnotatedWhileLoop)2848 TEST_F(VirtualSchedulerTest, AnnotatedWhileLoop) {
2849 {
2850 // Init.
2851 CreateGrapplerItemWithLoop();
2852 InitScheduler();
2853
2854 // Runs the scheduler.
2855 RunScheduler("");
2856 Costs c = scheduler_->Summary();
2857
2858 EXPECT_EQ(23, c.execution_time.asMicroSeconds().count());
2859 // Both while/Merge and while/Merge_1 are scheduled twice.
2860 EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2861 EXPECT_FALSE(c.inaccurate);
2862 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2863 }
2864
2865 {
2866 // Init.
2867 CreateGrapplerItemWithLoopAnnotated();
2868 InitScheduler();
2869
2870 // Runs the scheduler.
2871 RunScheduler("");
2872 Costs c = scheduler_->Summary();
2873
2874 // The costs for Merge is accumulated twice for execution_count times, but
2875 // since Merge's cost is minimal, we keep this behavior here.
2876 EXPECT_EQ(178, c.execution_time.asMicroSeconds().count());
2877 // Both while/Merge and while/Merge_1 are scheduled twice.
2878 EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2879 EXPECT_FALSE(c.inaccurate);
2880 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2881 }
2882 }
2883
TEST_F(VirtualSchedulerTest,Condition)2884 TEST_F(VirtualSchedulerTest, Condition) {
2885 // Without annotation.
2886 {
2887 // Inits.
2888 CreateGrapplerItemWithCondition();
2889 InitScheduler();
2890
2891 // Runs the scheduler.
2892 RunScheduler("");
2893 RunMetadata metadata;
2894 Costs c = scheduler_->Summary(&metadata);
2895
2896 // Nodes in topological order: a/Less, Switch, First/Second, Merge.
2897 int num_a = 0;
2898 int num_less = 0;
2899 int num_switch = 0;
2900 int num_first = 0;
2901 int num_second = 0;
2902 int num_merge = 0;
2903
2904 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2905 for (const auto& stats : device_step_stats.node_stats()) {
2906 if (stats.node_name() == "a") {
2907 ++num_a;
2908 } else if (stats.node_name() == "Less") {
2909 ++num_less;
2910 } else if (stats.node_name() == "Switch") {
2911 ++num_switch;
2912 } else if (stats.node_name() == "First") {
2913 ++num_first;
2914 } else if (stats.node_name() == "Second") {
2915 ++num_second;
2916 } else if (stats.node_name() == "Merge") {
2917 ++num_merge;
2918 }
2919 }
2920 }
2921
2922 EXPECT_EQ(1, num_a);
2923 EXPECT_EQ(1, num_less);
2924 EXPECT_EQ(1, num_switch);
2925 EXPECT_EQ(1, num_first);
2926 EXPECT_EQ(1, num_second);
2927 EXPECT_EQ(2, num_merge);
2928
2929 EXPECT_EQ(7, c.execution_time.asMicroSeconds().count());
2930 // Merge is executed twice.
2931 EXPECT_EQ(grappler_item_->graph.node_size() + 1, c.num_ops_total);
2932 EXPECT_FALSE(c.inaccurate);
2933 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2934 }
2935
2936 // With annotation.
2937 {
2938 // Inits.
2939 CreateGrapplerItemWithCondition();
2940
2941 // Annotates the Switch node.
2942 for (auto& node : *grappler_item_->graph.mutable_node()) {
2943 if (node.name() == "Switch") {
2944 AttrValue attr_output_info;
2945 // Adds one output slot 0 so that Second shouldn't be executed.
2946 (*attr_output_info.mutable_list()).add_i(0);
2947 AddNodeAttr(kOutputSlots, attr_output_info, &node);
2948 }
2949 }
2950
2951 InitScheduler();
2952
2953 // Runs the scheduler.
2954 RunScheduler("");
2955 RunMetadata metadata;
2956 Costs c = scheduler_->Summary(&metadata);
2957
2958 // Nodes in topological order: a/Less, Switch, Merge
2959 int num_a = 0;
2960 int num_less = 0;
2961 int num_switch = 0;
2962 int num_first = 0;
2963 int num_second = 0;
2964 int num_merge = 0;
2965
2966 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2967 for (const auto& stats : device_step_stats.node_stats()) {
2968 if (stats.node_name() == "a") {
2969 ++num_a;
2970 } else if (stats.node_name() == "Less") {
2971 ++num_less;
2972 } else if (stats.node_name() == "Switch") {
2973 ++num_switch;
2974 } else if (stats.node_name() == "First") {
2975 ++num_first;
2976 } else if (stats.node_name() == "Second") {
2977 ++num_second;
2978 } else if (stats.node_name() == "Merge") {
2979 ++num_merge;
2980 }
2981 }
2982 }
2983
2984 EXPECT_EQ(1, num_a);
2985 EXPECT_EQ(1, num_less);
2986 EXPECT_EQ(1, num_switch);
2987 EXPECT_EQ(1, num_first);
2988 EXPECT_EQ(0, num_second);
2989 EXPECT_EQ(1, num_merge);
2990
2991 EXPECT_EQ(5, c.execution_time.asMicroSeconds().count());
2992 // Second is not executed.
2993 EXPECT_EQ(grappler_item_->graph.node_size() - 1, c.num_ops_total);
2994 EXPECT_FALSE(c.inaccurate);
2995 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2996 }
2997 }
2998
TEST_F(VirtualSchedulerTest,InterDeviceTransfer)2999 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
3000 // Init.
3001 CreateGrapplerItemWithInterDeviceTransfers();
3002 InitScheduler();
3003
3004 // Run the scheduler.
3005 auto ops_executed = RunScheduler("");
3006
3007 // Helper lambda to extract port num from _Send and _Recv op name.
3008 auto get_port_num = [](const string& name) -> int {
3009 if (absl::StrContains(name, "bn_0")) {
3010 return 0;
3011 } else if (absl::StrContains(name, "bn_1")) {
3012 return 1;
3013 } else if (absl::StrContains(name, "bn_2")) {
3014 return 2;
3015 } else if (absl::StrContains(name, "bn_minus1")) {
3016 return -1;
3017 }
3018 return -999;
3019 };
3020
3021 // Reorganize ops_executed for further testing.
3022 std::unordered_map<string, int> op_count;
3023 std::unordered_map<int, string> recv_op_names;
3024 std::unordered_map<int, string> send_op_names;
3025 for (const auto& x : ops_executed) {
3026 const auto& name = x.first;
3027 const auto& node_info = x.second;
3028 const auto& op = node_info.op_info.op();
3029 if (op == kRecv) {
3030 recv_op_names[get_port_num(name)] = name;
3031 } else if (op == kSend) {
3032 send_op_names[get_port_num(name)] = name;
3033 }
3034 op_count[op]++;
3035 }
3036
3037 // Same number of _Send and _Recv.
3038 EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv));
3039
3040 // Expect 3 Send and Recvs each: port 0, 1, and, 2.
3041 // Control dependency bypasses the channel.
3042 EXPECT_EQ(op_count.at(kRecv), 3);
3043 EXPECT_EQ(op_count.at(kSend), 3);
3044
3045 // Helper lambda for extracting output Tensor size.
3046 auto get_output_size = [this, ops_executed](const string& name) -> int64 {
3047 const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
3048 std::vector<OpInfo::TensorProperties> output_properties;
3049 for (const auto& output_property : output_properties_) {
3050 output_properties.push_back(output_property);
3051 }
3052 return CalculateOutputSize(output_properties, 0);
3053 };
3054
3055 // Validate transfer size.
3056 // Batchnorm output y is 4D vector: batch x width x width x depth.
3057 int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
3058 EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
3059 EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
3060 // Mean and vars are 1-D vector with size depth_in_.
3061 EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
3062 EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
3063 EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
3064 EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
3065 }
3066
TEST_F(VirtualSchedulerTest,GraphWithSendRecv)3067 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
3068 // Init.
3069 CreateGrapplerItemWithSendRecv();
3070 InitScheduler();
3071
3072 // Run the scheduler.
3073 auto ops_executed = RunScheduler("");
3074
3075 EXPECT_GT(ops_executed.count("Const"), 0);
3076 EXPECT_GT(ops_executed.count("Send"), 0);
3077 EXPECT_GT(ops_executed.count("Recv"), 0);
3078 }
3079
TEST_F(VirtualSchedulerTest,GraphWithSendRecvDifferentDevice)3080 TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
3081 // Init.
3082 CreateGrapplerItemWithSendRecv();
3083 // Change Recv node's device so that Send and Recv are placed on different
3084 // devices.
3085 auto& graph = grappler_item_->graph;
3086 const string recv_device = kCPU1;
3087 for (int i = 0; i < graph.node_size(); i++) {
3088 auto* node = graph.mutable_node(i);
3089 if (node->name() == "Recv") {
3090 node->set_device(recv_device);
3091 auto* attr = node->mutable_attr();
3092 (*attr)["recv_device"].set_s(recv_device);
3093 } else if (node->name() == "Send") {
3094 auto* attr = node->mutable_attr();
3095 (*attr)["recv_device"].set_s(recv_device);
3096 }
3097 }
3098 InitScheduler();
3099
3100 // Run the scheduler.
3101 auto ops_executed = RunScheduler("");
3102
3103 // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
3104 EXPECT_GT(ops_executed.count("Const"), 0);
3105 EXPECT_GT(ops_executed.count("Send"), 0);
3106 EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
3107 "task_0/cpu_0_to_/job_localhost"
3108 "/replica_0/task_0/cpu_1"),
3109 0);
3110 EXPECT_GT(ops_executed.count(
3111 "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
3112 0);
3113 EXPECT_GT(ops_executed.count("Recv"), 0);
3114 }
3115
TEST_F(VirtualSchedulerTest,GraphWihtOnlyRecv)3116 TEST_F(VirtualSchedulerTest, GraphWihtOnlyRecv) {
3117 // Init.
3118 CreateGrapplerItemWithRecvWithoutSend();
3119 InitScheduler();
3120
3121 // Run the scheduler.
3122 auto ops_executed = RunScheduler("");
3123
3124 // Recv without Send will be treated as initially ready node.
3125 EXPECT_GT(ops_executed.count("Recv"), 0);
3126 }
3127
TEST_F(VirtualSchedulerTest,AddMergeSwitch)3128 TEST_F(VirtualSchedulerTest, AddMergeSwitch) {
3129 // Override scheduler_ with CompositeNodeManager.
3130 scheduler_ = std::make_unique<TestVirtualScheduler>(
3131 /*use_static_shapes=*/true,
3132 /*use_aggressive_shape_inference=*/true, &composite_node_manager_,
3133 cluster_.get());
3134 CreateGrapplerItemWithSwitchMergeInput();
3135 InitScheduler();
3136
3137 // pred --+ z --+
3138 // | |
3139 // V V
3140 // x -> Switch --------> Merge ---> Add --> y
3141 // | ^
3142 // | |
3143 // +-----> Add -----+
3144 // ^
3145 // |
3146 // b --------------+
3147
3148 // Run the scheduler. The current VirtualScheduler, w/o annotation, triggers
3149 // both outputs of Switch; then Merge (as long as one input is ready, it's z
3150 // is ready, if we just use num_inputs_ready counter, the final Add becomes
3151 // ready. possible to skip scheduling z. (Need to use CompositeNodeManager
3152 // to test this case).
3153 auto ops_executed = RunScheduler("");
3154
3155 EXPECT_GT(ops_executed.count("z"), 0);
3156 }
3157
TEST_F(VirtualSchedulerTest,AddFromOneTensor)3158 TEST_F(VirtualSchedulerTest, AddFromOneTensor) {
3159 CreateGrapplerItemWithAddFromOneTensor();
3160 InitScheduler();
3161
3162 // x -+----> Add --> y
3163 // | ^
3164 // | |
3165 // +-------+
3166
3167 // Run the scheduler.
3168 auto ops_executed = RunScheduler("");
3169 EXPECT_GT(ops_executed.count("y"), 0);
3170 EXPECT_GT(ops_executed.count("x"), 0);
3171 }
3172
TEST_F(VirtualSchedulerTest,TestNodeCostOutputTensorSize)3173 TEST_F(VirtualSchedulerTest, TestNodeCostOutputTensorSize) {
3174 // Create a schedule with more than 2 ops to be executed.
3175 CreateGrapplerItemWithMatmulChain();
3176 InitScheduler();
3177 RunScheduler("ab");
3178
3179 int32_t persistent_memory_before =
3180 scheduler_->GetPersistentMemoryUsage().at(kCPU0);
3181
3182 auto* device_states = scheduler_->GetDeviceStates();
3183 int32_t memory_usage = device_states->at(kCPU0).memory_usage;
3184
3185 // Set temporary/persistent memory to some values for the first node cost.
3186 Costs node_costs = Costs::ZeroCosts(false);
3187
3188 const int32_t node_one_cost = 12345;
3189 const int32_t node_two_cost = 98765;
3190
3191 const int32_t input_size = 4 * 3200 * 3200;
3192
3193 node_costs.persistent_memory = node_one_cost;
3194 node_costs.temporary_memory = 0;
3195 node_costs.output_tensor_size_bytes = {{0, node_one_cost}};
3196 node_costs.persistent_output_ports = {0};
3197
3198 // Mark first node executed and check device state.
3199 scheduler_->MarkCurrNodeExecuted(node_costs);
3200 device_states = scheduler_->GetDeviceStates();
3201 const auto& cpu_state_0 = device_states->at(kCPU0);
3202
3203 // The expected memory usage is previous memory usage minus the size
3204 // of the two inputs for the multiply operation.
3205 memory_usage -= 2 * input_size;
3206 EXPECT_EQ(cpu_state_0.memory_usage, memory_usage);
3207
3208 int64_t persistent_memory = node_one_cost + persistent_memory_before;
3209 EXPECT_EQ(scheduler_->GetPersistentMemoryUsage().at(kCPU0),
3210 persistent_memory);
3211
3212 // Set second node costs to temporary memory.
3213 node_costs = Costs::ZeroCosts(false);
3214 node_costs.persistent_memory = 0;
3215 node_costs.temporary_memory = node_two_cost;
3216 node_costs.output_tensor_size_bytes = {{0, node_two_cost}};
3217
3218 scheduler_->MarkCurrNodeExecuted(node_costs);
3219 device_states = scheduler_->GetDeviceStates();
3220 const auto& cpu_state_1 = device_states->at(kCPU0);
3221
3222 // Again we remove the inputs from memory usage. The output of the previous
3223 // operation is not subtracted because it is set as persistent memory.
3224 memory_usage += node_two_cost - input_size;
3225 EXPECT_EQ(cpu_state_1.memory_usage, memory_usage);
3226 EXPECT_EQ(scheduler_->GetPersistentMemoryUsage().at(kCPU0),
3227 persistent_memory);
3228
3229 // Finish off the schedule to test if the Summary counts persistent memory
3230 // correctly.
3231 bool more_nodes = true;
3232 do {
3233 OpContext op_context = scheduler_->GetCurrNode();
3234 node_costs = SimplePredictCosts(op_context);
3235 more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
3236 } while (more_nodes);
3237
3238 RunMetadata metadata;
3239 Costs final_cost = scheduler_->Summary(&metadata);
3240
3241 EXPECT_EQ(final_cost.persistent_memory, persistent_memory);
3242
3243 // Since we adjusted the node costs, we expect the requested and allocated
3244 // memory are not equal for nodes "abc" and "abcd" .
3245 StepStats stepstats = metadata.step_stats();
3246 for (const auto& device_step_stats : stepstats.dev_stats()) {
3247 for (const auto& stats : device_step_stats.node_stats()) {
3248 const auto& allocation_description =
3249 stats.output().at(0).tensor_description().allocation_description();
3250 if (stats.node_name() == "abc") {
3251 EXPECT_NE(allocation_description.allocated_bytes(),
3252 allocation_description.requested_bytes());
3253 const auto& mem_stats = stats.memory_stats();
3254 EXPECT_EQ(mem_stats.persistent_memory_size(), node_one_cost);
3255 } else if (stats.node_name() == "abcd") {
3256 EXPECT_NE(allocation_description.allocated_bytes(),
3257 allocation_description.requested_bytes());
3258 } else {
3259 EXPECT_EQ(allocation_description.allocated_bytes(),
3260 allocation_description.requested_bytes());
3261 }
3262 }
3263 }
3264 }
3265
3266 } // namespace
3267 } // end namespace grappler
3268 } // end namespace tensorflow
3269