xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/costs/virtual_scheduler_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <string>
21 
22 #include "absl/strings/match.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/core/framework/allocation_description.pb.h"
25 #include "tensorflow/core/framework/tensor_description.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
28 #include "tensorflow/core/grappler/costs/graph_properties.h"
29 #include "tensorflow/core/grappler/costs/utils.h"
30 #include "tensorflow/core/grappler/costs/virtual_placer.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/platform/test.h"
33 
34 namespace tensorflow {
35 namespace grappler {
36 namespace {
37 
38 // Device names:
39 constexpr char kCPU0[] = "/job:localhost/replica:0/task:0/cpu:0";
40 constexpr char kCPU1[] = "/job:localhost/replica:0/task:0/cpu:1";
41 constexpr char kChannelFrom0To1[] = "Channel from CPU0 to CPU1";
42 constexpr char kChannelFrom1To0[] = "Channel from CPU1 to CPU0";
43 // Op names:
44 constexpr char kConv2D[] = "Conv2D";
45 constexpr char kSend[] = "_Send";
46 constexpr char kRecv[] = "_Recv";
47 
48 class ReadyNodeManagerTest : public ::testing::Test {
49  protected:
ReadyNodeManagerTest()50   ReadyNodeManagerTest() {
51     // node1_ to node6_ on kCPU0, with time_ready in reverse_order.
52     NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_);
53     NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_);
54     NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_);
55     NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_);
56     NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_);
57     NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_);
58   }
59 
NodeSetUp(const string & name,const string & op_name,const string & device_name,const uint64 time_ready,NodeDef * node)60   void NodeSetUp(const string& name, const string& op_name,
61                  const string& device_name, const uint64 time_ready,
62                  NodeDef* node) {
63     node->set_name(name);
64     node->set_op(op_name);
65     node->set_device(device_name);
66 
67     node_states_[node] = NodeState();
68     node_states_[node].time_ready = time_ready;
69     node_states_[node].device_name = device_name;
70   }
71 
72   NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
73   std::unordered_map<const NodeDef*, NodeState> node_states_;
74 };
75 
76 // Tests that FIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeFIFOManager)77 TEST_F(ReadyNodeManagerTest, GetSingleNodeFIFOManager) {
78   FIFOManager manager = FIFOManager();
79   manager.AddNode(&node1_);
80   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
81 }
82 
83 // Tests that FIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFIFOManager)84 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFIFOManager) {
85   FIFOManager manager = FIFOManager();
86   manager.AddNode(&node1_);
87 
88   // Removes the only node in FIFOManager.
89   manager.RemoveCurrNode();
90   EXPECT_TRUE(manager.Empty());
91 }
92 
93 // Tests that FIFOManager can remove multiple nodes and returns the current node
94 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFIFOManager)95 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFIFOManager) {
96   FIFOManager manager = FIFOManager();
97   manager.AddNode(&node1_);
98   manager.AddNode(&node2_);
99   manager.AddNode(&node3_);
100   manager.AddNode(&node4_);
101 
102   // Keeps checking current node while removing nodes from manager.
103   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
104   manager.RemoveCurrNode();
105   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
106   manager.RemoveCurrNode();
107   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
108   manager.RemoveCurrNode();
109   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
110   manager.RemoveCurrNode();
111   EXPECT_TRUE(manager.Empty());
112 }
113 
114 // Tests that FIFOManager can remove multiple nodes and add more nodes, still
115 // returning the current node in the right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleFIFOManager)116 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleFIFOManager) {
117   FIFOManager manager = FIFOManager();
118   manager.AddNode(&node1_);
119   manager.AddNode(&node2_);
120   manager.AddNode(&node3_);
121   manager.AddNode(&node4_);
122 
123   // Keeps checking current node as nodes are removed and added.
124   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
125   manager.RemoveCurrNode();
126   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
127   manager.AddNode(&node5_);
128   // GetCurrNode() should return the same node even if some nodes are added,
129   // until RemoveCurrNode() is called.
130   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
131   manager.RemoveCurrNode();
132   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
133   manager.RemoveCurrNode();
134   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
135   manager.AddNode(&node6_);
136   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
137   manager.RemoveCurrNode();
138   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
139   manager.RemoveCurrNode();
140   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
141   manager.RemoveCurrNode();
142   EXPECT_TRUE(manager.Empty());
143 }
144 
145 // Tests that LIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeLIFOManager)146 TEST_F(ReadyNodeManagerTest, GetSingleNodeLIFOManager) {
147   LIFOManager manager = LIFOManager();
148   manager.AddNode(&node1_);
149   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
150 }
151 
152 // Tests that LIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeLIFOManager)153 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeLIFOManager) {
154   LIFOManager manager = LIFOManager();
155   manager.AddNode(&node1_);
156 
157   // Removes the only node in LIFOManager.
158   manager.RemoveCurrNode();
159   EXPECT_TRUE(manager.Empty());
160 }
161 
162 // Tests that LIFOManager can remove multiple nodes and returns the current node
163 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleLIFOManager)164 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleLIFOManager) {
165   LIFOManager manager = LIFOManager();
166   manager.AddNode(&node1_);
167   manager.AddNode(&node2_);
168   manager.AddNode(&node3_);
169   manager.AddNode(&node4_);
170 
171   // Keeps checking current node while removing nodes from manager.
172   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
173   manager.RemoveCurrNode();
174   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
175   manager.RemoveCurrNode();
176   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
177   manager.RemoveCurrNode();
178   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
179   manager.RemoveCurrNode();
180   EXPECT_TRUE(manager.Empty());
181 }
182 
183 // Tests that LIFOManager can remove multiple nodes (must be removing the
184 // current node) and add more nodes, still returning the current node in the
185 // right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleLIFOManager)186 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleLIFOManager) {
187   LIFOManager manager = LIFOManager();
188   manager.AddNode(&node1_);
189   manager.AddNode(&node2_);
190   manager.AddNode(&node3_);
191   manager.AddNode(&node4_);
192 
193   // Keeps checking current node as nodes are removed and added.
194   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
195   manager.RemoveCurrNode();
196   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
197   manager.AddNode(&node5_);
198   // GetCurrNode()  should return the same node even if some nodes are added,
199   // until RemoveCurrNode() is called.
200   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
201   manager.RemoveCurrNode();
202   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
203   manager.RemoveCurrNode();
204   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
205   manager.AddNode(&node6_);
206   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
207   manager.RemoveCurrNode();
208   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
209   manager.RemoveCurrNode();
210   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
211   manager.RemoveCurrNode();
212   EXPECT_TRUE(manager.Empty());
213 }
214 
TEST_F(ReadyNodeManagerTest,MergeOrderInLIFOManager)215 TEST_F(ReadyNodeManagerTest, MergeOrderInLIFOManager) {
216   LIFOManager manager = LIFOManager();
217   node3_.set_op("Merge");
218   manager.AddNode(&node1_);
219   manager.AddNode(&node2_);
220   manager.AddNode(&node3_);
221   manager.AddNode(&node4_);
222 
223   // Merge node (node3) will be scheduled at the end (even though it's added
224   // after nodde2).
225   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
226   manager.RemoveCurrNode();
227   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
228   manager.RemoveCurrNode();
229   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
230   manager.RemoveCurrNode();
231   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
232   manager.RemoveCurrNode();
233 }
234 
TEST_F(ReadyNodeManagerTest,GetSingleNodeFirstReadyManager)235 TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) {
236   FirstReadyManager manager;
237   TF_EXPECT_OK(manager.Init(&node_states_));
238   manager.AddNode(&node1_);
239   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
240 }
241 
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFirstReadyManager)242 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
243   FirstReadyManager manager;
244   TF_EXPECT_OK(manager.Init(&node_states_));
245   manager.AddNode(&node1_);
246   manager.RemoveCurrNode();
247   EXPECT_TRUE(manager.Empty());
248 }
249 
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFirstReadyManager)250 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
251   FirstReadyManager manager;
252   TF_EXPECT_OK(manager.Init(&node_states_));
253   // Insert nodes in some random order.
254   manager.AddNode(&node2_);
255   manager.AddNode(&node1_);
256   manager.AddNode(&node4_);
257   manager.AddNode(&node5_);
258   manager.AddNode(&node3_);
259   manager.AddNode(&node6_);
260 
261   // In whatever order we insert nodes, we get the same order based on nodes'
262   // time_ready.
263   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
264   manager.RemoveCurrNode();
265   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
266   manager.RemoveCurrNode();
267   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
268   manager.RemoveCurrNode();
269   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
270   manager.RemoveCurrNode();
271   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
272   manager.RemoveCurrNode();
273   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
274   manager.RemoveCurrNode();
275   EXPECT_TRUE(manager.Empty());
276 }
277 
TEST_F(ReadyNodeManagerTest,GetCurrNodeFirstReadyManager)278 TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
279   FirstReadyManager manager;
280   TF_EXPECT_OK(manager.Init(&node_states_));
281 
282   // Inserts nodes in some random order.
283   manager.AddNode(&node2_);
284   manager.AddNode(&node1_);
285   manager.AddNode(&node4_);
286   manager.AddNode(&node5_);
287   manager.AddNode(&node3_);
288   manager.AddNode(&node6_);
289 
290   // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode()
291   // should return it.
292   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
293 
294   // Now inserts a few other nodes, but their time_ready's are even smaller than
295   // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return
296   // the same node, Node6, in this case.
297   NodeDef node7;
298   NodeDef node8;
299   NodeDef node9;
300   NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7);
301   NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8);
302   NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9);
303 
304   manager.AddNode(&node7);
305   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
306 
307   manager.AddNode(&node8);
308   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
309 
310   manager.RemoveCurrNode();
311   // Now Node6 is removed, and GetCurrNode() will return Node8.
312   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
313 
314   // Again, AddNode shouldn't change GetCurrNode().
315   manager.AddNode(&node9);
316   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
317 
318   manager.RemoveCurrNode();
319   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
320   manager.RemoveCurrNode();
321   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
322   manager.RemoveCurrNode();
323   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
324   manager.RemoveCurrNode();
325   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
326   manager.RemoveCurrNode();
327   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
328   manager.RemoveCurrNode();
329   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
330   manager.RemoveCurrNode();
331   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
332   manager.RemoveCurrNode();
333   EXPECT_TRUE(manager.Empty());
334 }
335 
TEST_F(ReadyNodeManagerTest,DeterminismInFirstReadyManager)336 TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
337   FirstReadyManager manager1;
338   TF_EXPECT_OK(manager1.Init(&node_states_));
339   FirstReadyManager manager2;
340   TF_EXPECT_OK(manager2.Init(&node_states_));
341 
342   // 6 nodes with same time_ready.
343   NodeDef node7;
344   NodeDef node8;
345   NodeDef node9;
346   NodeDef node10;
347   NodeDef node11;
348   NodeDef node12;
349   NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
350   NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8);
351   NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9);
352   NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10);
353   NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11);
354   NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12);
355 
356   // Adds the above 6 nodes to manager1.
357   manager1.AddNode(&node7);
358   manager1.AddNode(&node8);
359   manager1.AddNode(&node9);
360   manager1.AddNode(&node10);
361   manager1.AddNode(&node11);
362   manager1.AddNode(&node12);
363 
364   // Adds the above 6 nodes to manager2, but in a different order.
365   manager2.AddNode(&node8);
366   manager2.AddNode(&node11);
367   manager2.AddNode(&node9);
368   manager2.AddNode(&node10);
369   manager2.AddNode(&node7);
370   manager2.AddNode(&node12);
371 
372   // Expects both managers return the same nodes for deterministic node
373   // scheduling.
374   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
375   manager1.RemoveCurrNode();
376   manager2.RemoveCurrNode();
377 
378   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
379   manager1.RemoveCurrNode();
380   manager2.RemoveCurrNode();
381 
382   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
383   manager1.RemoveCurrNode();
384   manager2.RemoveCurrNode();
385 
386   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
387   manager1.RemoveCurrNode();
388   manager2.RemoveCurrNode();
389 
390   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
391   manager1.RemoveCurrNode();
392   manager2.RemoveCurrNode();
393 
394   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
395   manager1.RemoveCurrNode();
396   manager2.RemoveCurrNode();
397 
398   EXPECT_TRUE(manager1.Empty());
399   EXPECT_TRUE(manager2.Empty());
400 }
401 
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultiplePriorityReadyManager)402 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultiplePriorityReadyManager) {
403   PriorityReadyManager manager;
404   TF_EXPECT_OK(manager.Init(&node_states_));
405 
406   // Sets up node priorities.
407   std::unordered_map<string, int> node_priority = {
408       {"Node1", 1}, {"Node2", 2}, {"Node3", 2}, {"Node4", 4}, {"Node5", 5}};
409   TF_EXPECT_OK(manager.SetPriority(node_priority));
410 
411   // Inserts nodes in some random order.
412   manager.AddNode(&node3_);
413   manager.AddNode(&node1_);
414   manager.AddNode(&node4_);
415   manager.AddNode(&node5_);
416   manager.AddNode(&node2_);
417   manager.AddNode(&node6_);
418 
419   // Expects nodes scheduled based on priority.
420   // Node6 should default to lowest priority, since it is not found.
421   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
422   manager.RemoveCurrNode();
423   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
424   manager.RemoveCurrNode();
425   // Nodes 2 and 3 have equal priority and so should be scheduled ready-first.
426   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
427   manager.RemoveCurrNode();
428   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
429   manager.RemoveCurrNode();
430   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
431   manager.RemoveCurrNode();
432   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
433   manager.RemoveCurrNode();
434   EXPECT_TRUE(manager.Empty());
435 }
436 
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeCompositeNodeManager)437 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
438   CompositeNodeManager manager;
439   TF_EXPECT_OK(manager.Init(&node_states_));
440   manager.AddNode(&node1_);
441   manager.RemoveCurrNode();
442   EXPECT_TRUE(manager.Empty());
443 }
444 
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleCompositeNodeManager)445 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleCompositeNodeManager) {
446   CompositeNodeManager manager;
447   TF_EXPECT_OK(manager.Init(&node_states_));
448   manager.AddNode(&node1_);
449   manager.AddNode(&node2_);
450   manager.AddNode(&node3_);
451   manager.AddNode(&node4_);
452 
453   // Keeps checking current node as nodes are removed and added.
454   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
455   manager.RemoveCurrNode();
456   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
457   manager.AddNode(&node5_);
458   // GetCurrNode()  should return the same node even if some nodes are added,
459   // until RemoveCurrNode() is called.
460   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
461   manager.RemoveCurrNode();
462   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
463   manager.RemoveCurrNode();
464   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
465   manager.AddNode(&node6_);
466   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
467   manager.RemoveCurrNode();
468   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
469   manager.RemoveCurrNode();
470   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
471   manager.RemoveCurrNode();
472   EXPECT_TRUE(manager.Empty());
473 }
474 
TEST_F(ReadyNodeManagerTest,MultiDeviceSendRecvCompositeNodeManager)475 TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvCompositeNodeManager) {
476   CompositeNodeManager manager;
477   TF_EXPECT_OK(manager.Init(&node_states_));
478   // Additional nodes on kCPU1.
479   NodeDef node7;
480   NodeDef node8;
481   NodeDef node9;
482   NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7);
483   NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8);
484   NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9);
485 
486   // Send and Recv nodes.
487   NodeDef send1;
488   NodeDef send2;
489   NodeDef recv1;
490   NodeDef recv2;
491   NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1);
492   NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2);
493   NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1);
494   NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2);
495 
496   // Inserts nodes.
497   manager.AddNode(&node1_);
498   manager.AddNode(&node2_);
499   manager.AddNode(&node3_);
500   manager.AddNode(&node4_);
501   manager.AddNode(&node5_);
502   manager.AddNode(&node6_);
503   manager.AddNode(&node7);
504   manager.AddNode(&node8);
505   manager.AddNode(&node9);
506   manager.AddNode(&send1);
507   manager.AddNode(&send2);
508   manager.AddNode(&recv1);
509   manager.AddNode(&recv2);
510 
511   // On kCPU0; last one is node6_, on kCPU1: last one is node9;
512   // so choose one that has earliest time_ready among node6_, node9,
513   // Send1, Send2, Recv1, and Recv2.
514   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
515   manager.RemoveCurrNode();
516   // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node
517   // among node5_, node9, Send1, Send2, Recv1, and Recv2.
518   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
519   manager.RemoveCurrNode();
520   // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2.
521   EXPECT_EQ(manager.GetCurrNode()->name(), "Send1");
522   manager.RemoveCurrNode();
523   // Next, choose among node4_, node9, Sen2, Recv1, and Recv2.
524   EXPECT_EQ(manager.GetCurrNode()->name(), "Recv1");
525   manager.RemoveCurrNode();
526   // Next, choose among node4_, node9, Send2, and Recv2.
527   EXPECT_EQ(manager.GetCurrNode()->name(), "Recv2");
528   manager.RemoveCurrNode();
529   // Next, choose among node4_, node9, and Send2.
530   EXPECT_EQ(manager.GetCurrNode()->name(), "Send2");
531   manager.RemoveCurrNode();
532   // Next, choose between node4_, node9.
533   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
534   manager.RemoveCurrNode();
535   // Next, choose between node3_, node9.
536   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
537   manager.RemoveCurrNode();
538   // Next, choose between node3_, node8.
539   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
540   manager.RemoveCurrNode();
541   // Next, choose between node3_, node7.
542   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
543   manager.RemoveCurrNode();
544   // Then, just the nodes on kCPU1 -- LIFO.
545   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
546   manager.RemoveCurrNode();
547   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
548   manager.RemoveCurrNode();
549   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
550   manager.RemoveCurrNode();
551   EXPECT_TRUE(manager.Empty());
552 }
553 
TEST_F(ReadyNodeManagerTest,DeterminismInCompositeNodeManager)554 TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) {
555   CompositeNodeManager manager;
556   TF_EXPECT_OK(manager.Init(&node_states_));
557   CompositeNodeManager manager2;
558   TF_EXPECT_OK(manager2.Init(&node_states_));
559 
560   // 6 nodes with same time_ready.
561   NodeDef node7;
562   NodeDef node8;
563   NodeDef node9;
564   NodeDef node10;
565   NodeDef node11;
566   NodeDef node12;
567   NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
568   NodeSetUp("Node8", kSend, kCPU0, 1000, &node8);
569   NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9);
570   NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10);
571   NodeSetUp("Node11", kRecv, kCPU0, 999, &node11);
572   NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12);
573 
574   // Adds Nodes 7 to 9 to manager.
575   manager.AddNode(&node7);
576   manager.AddNode(&node8);
577   manager.AddNode(&node9);
578 
579   // It should return _Send, Recv, and the other op order, when the candidate
580   // nodes have same time_ready.
581   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
582   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
583   manager.RemoveCurrNode();
584   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
585   EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
586   manager.RemoveCurrNode();
587   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
588   EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
589   manager.RemoveCurrNode();
590   EXPECT_TRUE(manager.Empty());
591 
592   // Adds Nodes 7 to 9 to manager, but in a different order.
593   manager.AddNode(&node9);
594   manager.AddNode(&node8);
595   manager.AddNode(&node7);
596 
597   // Expects same order (_Send, _Recv, and the other op), regardless of Add
598   // order.
599   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
600   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
601   manager.RemoveCurrNode();
602   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
603   EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
604   manager.RemoveCurrNode();
605   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
606   EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
607   manager.RemoveCurrNode();
608   EXPECT_TRUE(manager.Empty());
609 
610   // Conv2D's time_ready < Send's time_ready; Expects Conv2D first.
611   manager.AddNode(&node8);
612   manager.AddNode(&node10);
613   EXPECT_EQ(manager.GetCurrNode()->name(), "Node10");
614   EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
615   manager.RemoveCurrNode();
616   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
617   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
618   manager.RemoveCurrNode();
619   EXPECT_TRUE(manager.Empty());
620 
621   // Recv's time_ready < Send' time_ready; Expects Recv first.
622   manager.AddNode(&node11);
623   manager.AddNode(&node8);
624   EXPECT_EQ(manager.GetCurrNode()->name(), "Node11");
625   EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
626   manager.RemoveCurrNode();
627   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
628   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
629   manager.RemoveCurrNode();
630   EXPECT_TRUE(manager.Empty());
631 
632   // Node7 and 12 are normal ops with the same time_ready, placed on different
633   // devices. These two nodes are added to manager and manager2, but in
634   // different orders; Expects GetCurrNode() returns the nodes in the same
635   // order.
636   manager.AddNode(&node7);
637   manager.AddNode(&node12);
638 
639   manager2.AddNode(&node12);
640   manager2.AddNode(&node7);
641 
642   EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
643   manager.RemoveCurrNode();
644   manager2.RemoveCurrNode();
645   EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
646   manager.RemoveCurrNode();
647   manager2.RemoveCurrNode();
648   EXPECT_TRUE(manager.Empty());
649 }
650 
651 // Class for testing virtual scheduler.
652 class TestVirtualScheduler : public VirtualScheduler {
653  public:
TestVirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,ReadyNodeManager * ready_node_manager,Cluster * cluster)654   TestVirtualScheduler(const bool use_static_shapes,
655                        const bool use_aggressive_shape_inference,
656                        ReadyNodeManager* ready_node_manager, Cluster* cluster)
657       : VirtualScheduler(
658             use_static_shapes, use_aggressive_shape_inference, cluster,
659             ready_node_manager,
660             std::make_unique<VirtualPlacer>(cluster->GetDevices())) {
661     enable_mem_usage_tracking();
662   }
663 
664   FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
665   FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
666   FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
667   FRIEND_TEST(VirtualSchedulerTest, Variable);
668   FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
669 };
670 
671 class VirtualSchedulerTest : public ::testing::Test {
672  protected:
VirtualSchedulerTest()673   VirtualSchedulerTest() {
674     // Initializes cluster_ and scheduler_.
675     std::unordered_map<string, DeviceProperties> devices;
676 
677     // Set some dummy CPU properties
678     DeviceProperties cpu_device = GetDummyCPUDevice();
679 
680     // IMPORTANT: Device is not actually ever used in the test case since
681     // force_cpu_type is defaulted to "Haswell"
682     devices[kCPU0] = cpu_device;
683     devices[kCPU1] = cpu_device;
684     cluster_ = std::make_unique<VirtualCluster>(devices);
685     scheduler_ = std::make_unique<TestVirtualScheduler>(
686         /*use_static_shapes=*/true,
687         /*use_aggressive_shape_inference=*/true, &first_ready_manager_,
688         cluster_.get());
689   }
690 
GetDummyCPUDevice()691   DeviceProperties GetDummyCPUDevice() {
692     // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
693     // - 8 Gflops
694     // - 2 GB/s
695     DeviceProperties cpu_device;
696     cpu_device.set_type("CPU");
697     cpu_device.set_frequency(4000);
698     cpu_device.set_num_cores(2);
699     cpu_device.set_bandwidth(2000000);
700     return cpu_device;
701   }
702 
703   // Three Conv2Ds with only two in fetch nodes.
CreateGrapplerItemWithConv2Ds()704   void CreateGrapplerItemWithConv2Ds() {
705     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
706     auto x = ops::RandomUniform(
707         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
708     auto y = ops::RandomUniform(
709         s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
710     auto z = ops::RandomUniform(
711         s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
712     auto f = ops::RandomUniform(
713         s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
714     std::vector<int> strides = {1, 1, 1, 1};
715     auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
716     auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
717     auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
718 
719     grappler_item_ = std::make_unique<GrapplerItem>();
720     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
721     grappler_item_->id = "test_conv2d_graph";
722     grappler_item_->fetch = {"c0", "c1"};
723 
724     dependency_["c0"] = {"x", "f"};
725     dependency_["c1"] = {"y", "f"};
726   }
727 
728   // A Conv2D with a variable.
CreateGrapplerItemWithConv2DAndVariable()729   void CreateGrapplerItemWithConv2DAndVariable() {
730     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
731     auto x = ops::RandomUniform(
732         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
733     auto f = ops::Variable(s.WithOpName("f"),
734                            {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
735     std::vector<int> strides = {1, 1, 1, 1};
736     auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
737 
738     grappler_item_ = std::make_unique<GrapplerItem>();
739     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
740     grappler_item_->id = "test_conv2d_var_graph";
741 
742     grappler_item_->fetch = {"y"};
743 
744     dependency_["y"] = {"x", "f"};
745   }
746 
CreateGrapplerItemWithMatmulChain()747   void CreateGrapplerItemWithMatmulChain() {
748     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
749     // Add control dependencies to ensure tests do not rely on specific
750     // manager and the order remains consistent for the test.
751     auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT);
752     auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a),
753                                 {3200, 3200}, DT_FLOAT);
754     auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b),
755                                 {3200, 3200}, DT_FLOAT);
756     auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c),
757                                 {3200, 3200}, DT_FLOAT);
758     auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d),
759                                 {3200, 3200}, DT_FLOAT);
760 
761     auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b);
762     auto abc = ops::MatMul(s.WithOpName("abc"), ab, c);
763     auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d);
764     auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e);
765 
766     grappler_item_ = std::make_unique<GrapplerItem>();
767     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
768     grappler_item_->id = "test_matmul_sequence_graph";
769     grappler_item_->fetch = {"abcde"};
770 
771     dependency_["ab"] = {"a", "b"};
772     dependency_["abc"] = {"ab", "c"};
773     dependency_["abcd"] = {"abc", "d"};
774     dependency_["abcde"] = {"abcd", "e"};
775   }
776 
777   // AddN that takes 4 tensors with 10x10x10x10.
CreateGrapplerItemWithAddN()778   void CreateGrapplerItemWithAddN() {
779     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
780     auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT);
781     auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT);
782     auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT);
783     auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT);
784     OutputList input_tensors = {x, y, z, w};
785     auto add = ops::AddN(s.WithOpName("add"), input_tensors);
786     auto out = ops::Identity(s.WithOpName("out"), add);
787 
788     grappler_item_ = std::make_unique<GrapplerItem>();
789     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
790     grappler_item_->id = "test_addn_graph";
791     grappler_item_->fetch = {"out"};
792 
793     dependency_["out"] = {"x", "y", "z", "w", "add"};
794   }
795 
796   // Graph with some placeholder feed nodes that are not in the fetch fan-in.
CreateGrapplerItemWithUnnecessaryPlaceholderNodes()797   void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() {
798     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
799     auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT);
800     auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT);
801 
802     grappler_item_ = std::make_unique<GrapplerItem>();
803     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
804 
805     grappler_item_->id = "test_extra_placeholders";
806     grappler_item_->fetch = {"x"};
807 
808     // Grappler Item Builder puts all placeholder nodes into the feed
809     // list by default.
810     grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}};
811   }
812 
813   // NoOp that takes 7 NoOps as control dependency.
CreateGrapplerItemWithControlDependency()814   void CreateGrapplerItemWithControlDependency() {
815     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
816     std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
817     std::vector<Operation> input_tensors;
818     for (const auto& input : input_noop_names) {
819       auto x = ops::NoOp(s.WithOpName(input));
820       input_tensors.push_back(x.operation);
821     }
822     auto out =
823         ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out"));
824 
825     grappler_item_ = std::make_unique<GrapplerItem>();
826     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
827 
828     grappler_item_->id = "test_control_dependency_graph";
829     grappler_item_->fetch = {"out"};
830 
831     dependency_["out"] = input_noop_names;
832   }
833 
CreateGrapplerItemWithAddFromOneTensor()834   void CreateGrapplerItemWithAddFromOneTensor() {
835     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
836     auto x = tensorflow::ops::RandomUniform(
837         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
838 
839     auto y = tensorflow::ops::Add(s.WithOpName("y"), x, x);
840     Output fetch = ops::Identity(s.WithOpName("fetch"), y);
841 
842     grappler_item_ = std::make_unique<GrapplerItem>();
843     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
844 
845     grappler_item_->id = "test_add_from_one_tensor";
846     grappler_item_->fetch = {"fetch"};
847 
848     dependency_["fetch"] = {"y"};
849     dependency_["y"] = {"x"};
850   }
851 
CreateGrapplerItemWithSwitchMergeInput()852   void CreateGrapplerItemWithSwitchMergeInput() {
853     // sw = Switch(x, pred)
854     // a = Add(S:1, b)
855     // m = Merge(sw:0, a)
856     // y = Add(m, z)
857 
858     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
859     auto x = ops::RandomUniform(
860         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
861     auto pred = ops::Const(s.WithOpName("pred"), false, {});
862     auto sw = ops::Switch(s.WithOpName("switch"), x, pred);
863     auto b = ops::RandomUniform(
864         s.WithOpName("b"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
865     auto a = ops::Add(s.WithOpName("a"), sw.output_true, b);
866     auto m = ops::Merge(s.WithOpName("m"), {sw.output_false, a.z});
867     auto z = ops::RandomUniform(
868         s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
869     auto y = ops::Add(s.WithOpName("y"), m.output, z);
870 
871     grappler_item_ = std::make_unique<GrapplerItem>();
872     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
873 
874     grappler_item_->id = "test_add_merge_switch";
875     grappler_item_->fetch = {"y"};
876 
877     dependency_["y"] = {"m", "z"};
878   }
879 
880   // FusedBN [an op with multiple outputs] with multiple consumers (including
881   // control dependency).
CreateGrapplerItemWithBatchNorm()882   void CreateGrapplerItemWithBatchNorm() {
883     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
884     auto x = ops::RandomUniform(
885         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
886     auto scale =
887         ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
888     auto offset =
889         ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
890     auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
891     auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
892 
893     auto batch_norm = ops::FusedBatchNorm(
894         s.WithOpName("bn"), x, scale, offset, mean, var,
895         ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
896     auto y = batch_norm.y;
897     auto batch_mean = batch_norm.batch_mean;
898     auto batch_var = batch_norm.batch_variance;
899 
900     auto z1 = ops::Add(s.WithOpName("z1"), x, y);
901     auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var);
902     auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var);
903     std::vector<Operation> input_tensors = {
904         batch_mean.op(),
905         z1.z.op(),
906         z2.z.op(),
907         z3.z.op(),
908     };
909     auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4"));
910 
911     grappler_item_ = std::make_unique<GrapplerItem>();
912     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
913 
914     grappler_item_->id = "test_complex_dependency_graph";
915     grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
916 
917     dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
918     dependency_["z1"] = {"x", "bn"};
919     dependency_["z2"] = {"bn"};
920     dependency_["z3"] = {"bn"};
921     dependency_["z4"] = {"bn"};
922   }
923 
CreateGrapplerItemWithSendRecv()924   void CreateGrapplerItemWithSendRecv() {
925     const string gdef_ascii = R"EOF(
926 node {
927   name: "Const"
928   op: "Const"
929   device: "/job:localhost/replica:0/task:0/device:CPU:0"
930   attr {
931     key: "dtype"
932     value {
933       type: DT_FLOAT
934     }
935   }
936   attr {
937     key: "_output_shapes"
938     value {
939       list { shape {
940         dim { size: 128 }
941         dim { size: 32 }
942       }}}
943   }
944   attr {
945     key: "shape"
946     value {
947       list { shape {
948         dim { size: 128 }
949         dim { size: 32 }
950       }}}
951   }
952   attr {
953     key: "value"
954     value {
955       tensor {
956         dtype: DT_FLOAT
957         tensor_shape {
958           dim { size: 128 }
959           dim { size: 32 }
960         }
961         float_val: 3.1415
962       }
963     }
964   }
965 }
966 node {
967   name: "Send"
968   op: "_Send"
969   input: "Const"
970   device: "/job:localhost/replica:0/task:0/device:CPU:0"
971   attr {
972     key: "T"
973     value {
974       type: DT_FLOAT
975     }
976   }
977   attr {
978     key: "_output_shapes"
979     value {
980       list { shape {
981         dim { size: 128 }
982         dim { size: 32 }
983       }}}
984   }
985   attr {
986     key: "shape"
987     value {
988       list { shape {
989         dim { size: 128 }
990         dim { size: 32 }
991       }}}
992   }
993   attr {
994     key: "client_terminated"
995     value {
996       b: false
997     }
998   }
999   attr {
1000     key: "recv_device"
1001     value {
1002       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1003     }
1004   }
1005   attr {
1006     key: "send_device"
1007     value {
1008       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1009     }
1010   }
1011   attr {
1012     key: "send_device_incarnation"
1013     value {
1014       i: 0
1015     }
1016   }
1017   attr {
1018     key: "tensor_name"
1019     value {
1020       s: "test"
1021     }
1022   }
1023 }
1024 node {
1025   name: "Recv"
1026   op: "_Recv"
1027   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1028   attr {
1029     key: "client_terminated"
1030     value {
1031       b: false
1032     }
1033   }
1034   attr {
1035     key: "_output_shapes"
1036     value {
1037       list { shape {
1038         dim { size: 128 }
1039         dim { size: 32 }
1040       }}}
1041   }
1042   attr {
1043     key: "shape"
1044     value {
1045       list { shape {
1046         dim { size: 128 }
1047         dim { size: 32 }
1048       }}}
1049   }
1050   attr {
1051     key: "recv_device"
1052     value {
1053       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1054     }
1055   }
1056   attr {
1057     key: "send_device"
1058     value {
1059       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1060     }
1061   }
1062   attr {
1063     key: "send_device_incarnation"
1064     value {
1065       i: 0
1066     }
1067   }
1068   attr {
1069     key: "tensor_name"
1070     value {
1071       s: "test"
1072     }
1073   }
1074   attr {
1075     key: "tensor_type"
1076     value {
1077       type: DT_FLOAT
1078     }
1079   }
1080 }
1081 library {
1082 }
1083 versions {
1084   producer: 24
1085 }
1086     )EOF";
1087 
1088     grappler_item_ = std::make_unique<GrapplerItem>();
1089 
1090     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1091                                                 &grappler_item_->graph));
1092     grappler_item_->id = "test_graph";
1093     grappler_item_->fetch = {"Recv"};
1094   }
1095 
CreateGrapplerItemWithRecvWithoutSend()1096   void CreateGrapplerItemWithRecvWithoutSend() {
1097     const string gdef_ascii = R"EOF(
1098 node {
1099   name: "Recv"
1100   op: "_Recv"
1101   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1102   attr {
1103     key: "client_terminated"
1104     value {
1105       b: false
1106     }
1107   }
1108   attr {
1109     key: "recv_device"
1110     value {
1111       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1112     }
1113   }
1114   attr {
1115     key: "send_device"
1116     value {
1117       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1118     }
1119   }
1120   attr {
1121     key: "send_device_incarnation"
1122     value {
1123       i: 0
1124     }
1125   }
1126   attr {
1127     key: "tensor_name"
1128     value {
1129       s: "test"
1130     }
1131   }
1132   attr {
1133     key: "tensor_type"
1134     value {
1135       type: DT_FLOAT
1136     }
1137   }
1138 }
1139 library {
1140 }
1141 versions {
1142   producer: 24
1143 }
1144     )EOF";
1145 
1146     grappler_item_ = std::make_unique<GrapplerItem>();
1147     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1148                                                 &grappler_item_->graph));
1149     grappler_item_->id = "test_graph";
1150     grappler_item_->fetch = {"Recv"};
1151   }
1152 
1153   // A simple while loop
CreateGrapplerItemWithLoop()1154   void CreateGrapplerItemWithLoop() {
1155     // Test graph produced in python using:
1156     /*
1157       with tf.Graph().as_default():
1158       i0 = tf.constant(0)
1159       m0 = tf.ones([2, 2])
1160       c = lambda i, m: i < 10
1161       b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1162       r = tf.while_loop(
1163       c, b, loop_vars=[i0, m0],
1164       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1165       with open('/tmp/graph.pbtxt', 'w') as f:
1166       f.write(str(tf.get_default_graph().as_graph_def()))
1167     */
1168     const string gdef_ascii = R"EOF(
1169 node {
1170   name: "Const"
1171   op: "Const"
1172   attr {
1173     key: "dtype"
1174     value {
1175       type: DT_INT32
1176     }
1177   }
1178   attr {
1179     key: "value"
1180     value {
1181       tensor {
1182         dtype: DT_INT32
1183         tensor_shape {
1184         }
1185         int_val: 0
1186       }
1187     }
1188   }
1189 }
1190 node {
1191   name: "ones"
1192   op: "Const"
1193   attr {
1194     key: "dtype"
1195     value {
1196       type: DT_FLOAT
1197     }
1198   }
1199   attr {
1200     key: "value"
1201     value {
1202       tensor {
1203         dtype: DT_FLOAT
1204         tensor_shape {
1205           dim {
1206             size: 2
1207           }
1208           dim {
1209             size: 2
1210           }
1211         }
1212         float_val: 1.0
1213       }
1214     }
1215   }
1216 }
1217 node {
1218   name: "while/Enter"
1219   op: "Enter"
1220   input: "Const"
1221   attr {
1222     key: "T"
1223     value {
1224       type: DT_INT32
1225     }
1226   }
1227   attr {
1228     key: "frame_name"
1229     value {
1230       s: "while/while/"
1231     }
1232   }
1233   attr {
1234     key: "is_constant"
1235     value {
1236       b: false
1237     }
1238   }
1239   attr {
1240     key: "parallel_iterations"
1241     value {
1242       i: 10
1243     }
1244   }
1245 }
1246 node {
1247   name: "while/Enter_1"
1248   op: "Enter"
1249   input: "ones"
1250   attr {
1251     key: "T"
1252     value {
1253       type: DT_FLOAT
1254     }
1255   }
1256   attr {
1257     key: "frame_name"
1258     value {
1259       s: "while/while/"
1260     }
1261   }
1262   attr {
1263     key: "is_constant"
1264     value {
1265       b: false
1266     }
1267   }
1268   attr {
1269     key: "parallel_iterations"
1270     value {
1271       i: 10
1272     }
1273   }
1274 }
1275 node {
1276   name: "while/Merge"
1277   op: "Merge"
1278   input: "while/Enter"
1279   input: "while/NextIteration"
1280   attr {
1281     key: "N"
1282     value {
1283       i: 2
1284     }
1285   }
1286   attr {
1287     key: "T"
1288     value {
1289       type: DT_INT32
1290     }
1291   }
1292 }
1293 node {
1294   name: "while/Merge_1"
1295   op: "Merge"
1296   input: "while/Enter_1"
1297   input: "while/NextIteration_1"
1298   attr {
1299     key: "N"
1300     value {
1301       i: 2
1302     }
1303   }
1304   attr {
1305     key: "T"
1306     value {
1307       type: DT_FLOAT
1308     }
1309   }
1310 }
1311 node {
1312   name: "while/Less/y"
1313   op: "Const"
1314   input: "^while/Merge"
1315   attr {
1316     key: "dtype"
1317     value {
1318       type: DT_INT32
1319     }
1320   }
1321   attr {
1322     key: "value"
1323     value {
1324       tensor {
1325         dtype: DT_INT32
1326         tensor_shape {
1327         }
1328         int_val: 10
1329       }
1330     }
1331   }
1332 }
1333 node {
1334   name: "while/Less"
1335   op: "Less"
1336   input: "while/Merge"
1337   input: "while/Less/y"
1338   attr {
1339     key: "T"
1340     value {
1341       type: DT_INT32
1342     }
1343   }
1344 }
1345 node {
1346   name: "while/LoopCond"
1347   op: "LoopCond"
1348   input: "while/Less"
1349 }
1350 node {
1351   name: "while/Switch"
1352   op: "Switch"
1353   input: "while/Merge"
1354   input: "while/LoopCond"
1355   attr {
1356     key: "T"
1357     value {
1358       type: DT_INT32
1359     }
1360   }
1361   attr {
1362     key: "_class"
1363     value {
1364       list {
1365         s: "loc:@while/Merge"
1366       }
1367     }
1368   }
1369 }
1370 node {
1371   name: "while/Switch_1"
1372   op: "Switch"
1373   input: "while/Merge_1"
1374   input: "while/LoopCond"
1375   attr {
1376     key: "T"
1377     value {
1378       type: DT_FLOAT
1379     }
1380   }
1381   attr {
1382     key: "_class"
1383     value {
1384       list {
1385         s: "loc:@while/Merge_1"
1386       }
1387     }
1388   }
1389 }
1390 node {
1391   name: "while/Identity"
1392   op: "Identity"
1393   input: "while/Switch:1"
1394   attr {
1395     key: "T"
1396     value {
1397       type: DT_INT32
1398     }
1399   }
1400 }
1401 node {
1402   name: "while/Identity_1"
1403   op: "Identity"
1404   input: "while/Switch_1:1"
1405   attr {
1406     key: "T"
1407     value {
1408       type: DT_FLOAT
1409     }
1410   }
1411 }
1412 node {
1413   name: "while/add/y"
1414   op: "Const"
1415   input: "^while/Identity"
1416   attr {
1417     key: "dtype"
1418     value {
1419       type: DT_INT32
1420     }
1421   }
1422   attr {
1423     key: "value"
1424     value {
1425       tensor {
1426         dtype: DT_INT32
1427         tensor_shape {
1428         }
1429         int_val: 1
1430       }
1431     }
1432   }
1433 }
1434 node {
1435   name: "while/add"
1436   op: "Add"
1437   input: "while/Identity"
1438   input: "while/add/y"
1439   attr {
1440     key: "T"
1441     value {
1442       type: DT_INT32
1443     }
1444   }
1445 }
1446 node {
1447   name: "while/concat/axis"
1448   op: "Const"
1449   input: "^while/Identity"
1450   attr {
1451     key: "dtype"
1452     value {
1453       type: DT_INT32
1454     }
1455   }
1456   attr {
1457     key: "value"
1458     value {
1459       tensor {
1460         dtype: DT_INT32
1461         tensor_shape {
1462         }
1463         int_val: 0
1464       }
1465     }
1466   }
1467 }
1468 node {
1469   name: "while/concat"
1470   op: "ConcatV2"
1471   input: "while/Identity_1"
1472   input: "while/Identity_1"
1473   input: "while/concat/axis"
1474   attr {
1475     key: "N"
1476     value {
1477       i: 2
1478     }
1479   }
1480   attr {
1481     key: "T"
1482     value {
1483       type: DT_FLOAT
1484     }
1485   }
1486   attr {
1487     key: "Tidx"
1488     value {
1489       type: DT_INT32
1490     }
1491   }
1492 }
1493 node {
1494   name: "while/NextIteration"
1495   op: "NextIteration"
1496   input: "while/add"
1497   attr {
1498     key: "T"
1499     value {
1500       type: DT_INT32
1501     }
1502   }
1503 }
1504 node {
1505   name: "while/NextIteration_1"
1506   op: "NextIteration"
1507   input: "while/concat"
1508   attr {
1509     key: "T"
1510     value {
1511       type: DT_FLOAT
1512     }
1513   }
1514 }
1515 node {
1516   name: "while/Exit"
1517   op: "Exit"
1518   input: "while/Switch"
1519   attr {
1520     key: "T"
1521     value {
1522       type: DT_INT32
1523     }
1524   }
1525 }
1526 node {
1527   name: "while/Exit_1"
1528   op: "Exit"
1529   input: "while/Switch_1"
1530   attr {
1531     key: "T"
1532     value {
1533       type: DT_FLOAT
1534     }
1535   }
1536 }
1537 versions {
1538   producer: 21
1539 }
1540   )EOF";
1541 
1542     grappler_item_ = std::make_unique<GrapplerItem>();
1543     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1544                                                 &grappler_item_->graph));
1545     grappler_item_->id = "test_graph";
1546     grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
1547   }
1548 
1549   // A simple while loop strengthened with Switch outputs xxx.
CreateGrapplerItemWithLoopAnnotated()1550   void CreateGrapplerItemWithLoopAnnotated() {
1551     // Test graph produced in python using:
1552     /*
1553       with tf.Graph().as_default():
1554       i0 = tf.constant(0)
1555       m0 = tf.ones([2, 2])
1556       c = lambda i, m: i < 10
1557       b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1558       r = tf.while_loop(
1559       c, b, loop_vars=[i0, m0],
1560       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1561       with open('/tmp/graph.pbtxt', 'w') as f:
1562       f.write(str(tf.get_default_graph().as_graph_def()))
1563     */
1564     const string gdef_ascii = R"EOF(
1565 node {
1566   name: "Const"
1567   op: "Const"
1568   attr {
1569     key: "dtype"
1570     value {
1571       type: DT_INT32
1572     }
1573   }
1574   attr {
1575     key: "value"
1576     value {
1577       tensor {
1578         dtype: DT_INT32
1579         tensor_shape {
1580         }
1581         int_val: 0
1582       }
1583     }
1584   }
1585   attr {
1586     key: "_execution_count"
1587     value {
1588       i: 1
1589     }
1590   }
1591 }
1592 node {
1593   name: "ones"
1594   op: "Const"
1595   attr {
1596     key: "dtype"
1597     value {
1598       type: DT_FLOAT
1599     }
1600   }
1601   attr {
1602     key: "value"
1603     value {
1604       tensor {
1605         dtype: DT_FLOAT
1606         tensor_shape {
1607           dim {
1608             size: 2
1609           }
1610           dim {
1611             size: 2
1612           }
1613         }
1614         float_val: 1.0
1615       }
1616     }
1617   }
1618   attr {
1619     key: "_execution_count"
1620     value {
1621       i: 1
1622     }
1623   }
1624 }
1625 node {
1626   name: "while/Enter"
1627   op: "Enter"
1628   input: "Const"
1629   attr {
1630     key: "T"
1631     value {
1632       type: DT_INT32
1633     }
1634   }
1635   attr {
1636     key: "frame_name"
1637     value {
1638       s: "while/while/"
1639     }
1640   }
1641   attr {
1642     key: "is_constant"
1643     value {
1644       b: false
1645     }
1646   }
1647   attr {
1648     key: "parallel_iterations"
1649     value {
1650       i: 10
1651     }
1652   }
1653   attr {
1654     key: "_execution_count"
1655     value {
1656       i: 1
1657     }
1658   }
1659 }
1660 node {
1661   name: "while/Enter_1"
1662   op: "Enter"
1663   input: "ones"
1664   attr {
1665     key: "T"
1666     value {
1667       type: DT_FLOAT
1668     }
1669   }
1670   attr {
1671     key: "frame_name"
1672     value {
1673       s: "while/while/"
1674     }
1675   }
1676   attr {
1677     key: "is_constant"
1678     value {
1679       b: false
1680     }
1681   }
1682   attr {
1683     key: "parallel_iterations"
1684     value {
1685       i: 10
1686     }
1687   }
1688   attr {
1689     key: "_execution_count"
1690     value {
1691       i: 1
1692     }
1693   }
1694 }
1695 node {
1696   name: "while/Merge"
1697   op: "Merge"
1698   input: "while/Enter"
1699   input: "while/NextIteration"
1700   attr {
1701     key: "N"
1702     value {
1703       i: 2
1704     }
1705   }
1706   attr {
1707     key: "T"
1708     value {
1709       type: DT_INT32
1710     }
1711   }
1712   attr {
1713     key: "_execution_count"
1714     value {
1715       i: 10
1716     }
1717   }
1718 }
1719 node {
1720   name: "while/Merge_1"
1721   op: "Merge"
1722   input: "while/Enter_1"
1723   input: "while/NextIteration_1"
1724   attr {
1725     key: "N"
1726     value {
1727       i: 2
1728     }
1729   }
1730   attr {
1731     key: "T"
1732     value {
1733       type: DT_FLOAT
1734     }
1735   }
1736   attr {
1737     key: "_execution_count"
1738     value {
1739       i: 10
1740     }
1741   }
1742 }
1743 node {
1744   name: "while/Less/y"
1745   op: "Const"
1746   input: "^while/Merge"
1747   attr {
1748     key: "dtype"
1749     value {
1750       type: DT_INT32
1751     }
1752   }
1753   attr {
1754     key: "value"
1755     value {
1756       tensor {
1757         dtype: DT_INT32
1758         tensor_shape {
1759         }
1760         int_val: 10
1761       }
1762     }
1763   }
1764   attr {
1765     key: "_execution_count"
1766     value {
1767       i: 10
1768     }
1769   }
1770 }
1771 node {
1772   name: "while/Less"
1773   op: "Less"
1774   input: "while/Merge"
1775   input: "while/Less/y"
1776   attr {
1777     key: "T"
1778     value {
1779       type: DT_INT32
1780     }
1781   }
1782   attr {
1783     key: "_execution_count"
1784     value {
1785       i: 10
1786     }
1787   }
1788 }
1789 node {
1790   name: "while/LoopCond"
1791   op: "LoopCond"
1792   input: "while/Less"
1793   attr {
1794     key: "_execution_count"
1795     value {
1796       i: 10
1797     }
1798   }
1799 }
1800 node {
1801   name: "while/Switch"
1802   op: "Switch"
1803   input: "while/Merge"
1804   input: "while/LoopCond"
1805   attr {
1806     key: "T"
1807     value {
1808       type: DT_INT32
1809     }
1810   }
1811   attr {
1812     key: "_class"
1813     value {
1814       list {
1815         s: "loc:@while/Merge"
1816       }
1817     }
1818   }
1819   attr {
1820     key: "_execution_count"
1821     value {
1822       i: 11
1823     }
1824   }
1825   attr {
1826     key: "_output_slot_vector"
1827     value {
1828       list {
1829         i: 1
1830         i: 1
1831         i: 1
1832         i: 1
1833         i: 1
1834         i: 1
1835         i: 1
1836         i: 1
1837         i: 1
1838         i: 1
1839         i: 0
1840       }
1841     }
1842   }
1843 }
1844 node {
1845   name: "while/Switch_1"
1846   op: "Switch"
1847   input: "while/Merge_1"
1848   input: "while/LoopCond"
1849   attr {
1850     key: "T"
1851     value {
1852       type: DT_FLOAT
1853     }
1854   }
1855   attr {
1856     key: "_class"
1857     value {
1858       list {
1859         s: "loc:@while/Merge_1"
1860       }
1861     }
1862   }
1863   attr {
1864     key: "_execution_count"
1865     value {
1866       i: 11
1867     }
1868   }
1869   attr {
1870     key: "_output_slot_vector"
1871     value {
1872       list {
1873         i: 1
1874         i: 1
1875         i: 1
1876         i: 1
1877         i: 1
1878         i: 1
1879         i: 1
1880         i: 1
1881         i: 1
1882         i: 1
1883         i: 0
1884       }
1885     }
1886   }
1887 }
1888 node {
1889   name: "while/Identity"
1890   op: "Identity"
1891   input: "while/Switch:1"
1892   attr {
1893     key: "T"
1894     value {
1895       type: DT_INT32
1896     }
1897   }
1898   attr {
1899     key: "_execution_count"
1900     value {
1901       i: 10
1902     }
1903   }
1904 }
1905 node {
1906   name: "while/Identity_1"
1907   op: "Identity"
1908   input: "while/Switch_1:1"
1909   attr {
1910     key: "T"
1911     value {
1912       type: DT_FLOAT
1913     }
1914   }
1915   attr {
1916     key: "_execution_count"
1917     value {
1918       i: 10
1919     }
1920   }
1921 }
1922 node {
1923   name: "while/add/y"
1924   op: "Const"
1925   input: "^while/Identity"
1926   attr {
1927     key: "dtype"
1928     value {
1929       type: DT_INT32
1930     }
1931   }
1932   attr {
1933     key: "value"
1934     value {
1935       tensor {
1936         dtype: DT_INT32
1937         tensor_shape {
1938         }
1939         int_val: 1
1940       }
1941     }
1942   }
1943   attr {
1944     key: "_execution_count"
1945     value {
1946       i: 10
1947     }
1948   }
1949 }
1950 node {
1951   name: "while/add"
1952   op: "Add"
1953   input: "while/Identity"
1954   input: "while/add/y"
1955   attr {
1956     key: "T"
1957     value {
1958       type: DT_INT32
1959     }
1960   }
1961   attr {
1962     key: "_execution_count"
1963     value {
1964       i: 10
1965     }
1966   }
1967 }
1968 node {
1969   name: "while/concat/axis"
1970   op: "Const"
1971   input: "^while/Identity"
1972   attr {
1973     key: "dtype"
1974     value {
1975       type: DT_INT32
1976     }
1977   }
1978   attr {
1979     key: "value"
1980     value {
1981       tensor {
1982         dtype: DT_INT32
1983         tensor_shape {
1984         }
1985         int_val: 0
1986       }
1987     }
1988   }
1989   attr {
1990     key: "_execution_count"
1991     value {
1992       i: 10
1993     }
1994   }
1995 }
1996 node {
1997   name: "while/concat"
1998   op: "ConcatV2"
1999   input: "while/Identity_1"
2000   input: "while/Identity_1"
2001   input: "while/concat/axis"
2002   attr {
2003     key: "N"
2004     value {
2005       i: 2
2006     }
2007   }
2008   attr {
2009     key: "T"
2010     value {
2011       type: DT_FLOAT
2012     }
2013   }
2014   attr {
2015     key: "Tidx"
2016     value {
2017       type: DT_INT32
2018     }
2019   }
2020   attr {
2021     key: "_execution_count"
2022     value {
2023       i: 10
2024     }
2025   }
2026 }
2027 node {
2028   name: "while/NextIteration"
2029   op: "NextIteration"
2030   input: "while/add"
2031   attr {
2032     key: "T"
2033     value {
2034       type: DT_INT32
2035     }
2036   }
2037   attr {
2038     key: "_execution_count"
2039     value {
2040       i: 10
2041     }
2042   }
2043 }
2044 node {
2045   name: "while/NextIteration_1"
2046   op: "NextIteration"
2047   input: "while/concat"
2048   attr {
2049     key: "T"
2050     value {
2051       type: DT_FLOAT
2052     }
2053   }
2054   attr {
2055     key: "_execution_count"
2056     value {
2057       i: 10
2058     }
2059   }
2060 }
2061 node {
2062   name: "while/Exit"
2063   op: "Exit"
2064   input: "while/Switch"
2065   attr {
2066     key: "T"
2067     value {
2068       type: DT_INT32
2069     }
2070   }
2071   attr {
2072     key: "_execution_count"
2073     value {
2074       i: 1
2075     }
2076   }
2077 }
2078 node {
2079   name: "while/Exit_1"
2080   op: "Exit"
2081   input: "while/Switch_1"
2082   attr {
2083     key: "T"
2084     value {
2085       type: DT_FLOAT
2086     }
2087   }
2088   attr {
2089     key: "_execution_count"
2090     value {
2091       i: 1
2092     }
2093   }
2094 }
2095 versions {
2096   producer: 21
2097 }
2098   )EOF";
2099 
2100     grappler_item_.reset(new GrapplerItem);
2101     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2102                                                 &grappler_item_->graph));
2103     grappler_item_->id = "test_graph";
2104     grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
2105   }
2106 
2107   // A simple condition graph.
CreateGrapplerItemWithCondition()2108   void CreateGrapplerItemWithCondition() {
2109     // Handcrafted test graph: a/Less -> Switch -> First/Second -> Merge.
2110     const string gdef_ascii = R"EOF(
2111 node {
2112   name: "a"
2113   op: "Const"
2114   attr {
2115     key: "dtype"
2116     value {
2117       type: DT_FLOAT
2118     }
2119   }
2120   attr {
2121     key: "value"
2122     value {
2123       tensor {
2124         dtype: DT_FLOAT
2125         tensor_shape {
2126         }
2127         float_val: 2.0
2128       }
2129     }
2130   }
2131 }
2132 node {
2133   name: "Less"
2134   op: "Const"
2135   attr {
2136     key: "dtype"
2137     value {
2138       type: DT_BOOL
2139     }
2140   }
2141   attr {
2142     key: "value"
2143     value {
2144       tensor {
2145         dtype: DT_BOOL
2146         tensor_shape {
2147         }
2148         tensor_content: "\001"
2149       }
2150     }
2151   }
2152 }
2153 node {
2154   name: "Switch"
2155   op: "Switch"
2156   input: "a"
2157   input: "Less"
2158   attr {
2159     key: "T"
2160     value {
2161       type: DT_FLOAT
2162     }
2163   }
2164 }
2165 node {
2166   name: "First"
2167   op: "Identity"
2168   input: "Switch"
2169   attr {
2170     key: "T"
2171     value {
2172       type: DT_FLOAT
2173     }
2174   }
2175 }
2176 node {
2177   name: "Second"
2178   op: "Identity"
2179   input: "Switch:1"
2180   attr {
2181     key: "T"
2182     value {
2183       type: DT_FLOAT
2184     }
2185   }
2186 }
2187 node {
2188   name: "Merge"
2189   op: "Merge"
2190   input: "First"
2191   input: "Second"
2192   attr {
2193     key: "N"
2194     value {
2195       i: 2
2196     }
2197   }
2198   attr {
2199     key: "T"
2200     value {
2201       type: DT_FLOAT
2202     }
2203   }
2204 }
2205 versions {
2206   producer: 27
2207 })EOF";
2208 
2209     grappler_item_ = std::make_unique<GrapplerItem>();
2210     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2211                                                 &grappler_item_->graph));
2212     grappler_item_->id = "test_graph";
2213     grappler_item_->fetch = {"Merge"};
2214   }
2215 
2216   // Create a FusedBatchNorm op that has multiple output ports.
CreateGrapplerItemWithInterDeviceTransfers()2217   void CreateGrapplerItemWithInterDeviceTransfers() {
2218     tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
2219 
2220     // Create a FusedBatchNorm op that has multiple output ports.
2221     auto x = ops::RandomUniform(
2222         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
2223     auto scale =
2224         ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
2225     auto offset =
2226         ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
2227     auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
2228     auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
2229 
2230     auto batch_norm = ops::FusedBatchNorm(
2231         s.WithOpName("bn"), x, scale, offset, mean, var,
2232         ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
2233     auto y = batch_norm.y;
2234     auto batch_mean = batch_norm.batch_mean;
2235     auto batch_var = batch_norm.batch_variance;
2236     // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
2237     auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
2238     auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
2239     // batch_mean1 and batch_var1 take different output ports, so each will
2240     // initiate Send/Recv.
2241     auto batch_mean1 = ops::Identity(
2242         s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
2243     auto batch_var1 =
2244         ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
2245     // This is control dependency.
2246     auto control_dep = ops::NoOp(s.WithOpName("control_dep")
2247                                      .WithControlDependencies(y)
2248                                      .WithDevice(kCPU1));
2249 
2250     grappler_item_ = std::make_unique<GrapplerItem>();
2251     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
2252     grappler_item_->id = "test_conv2d_graph";
2253     grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
2254                              "control_dep"};
2255 
2256     dependency_["bn"] = {"x", "mean", "var"};
2257     dependency_["y1"] = {"bn"};
2258     dependency_["y2"] = {"bn"};
2259     dependency_["batch_mean1"] = {"bn"};
2260     dependency_["batch_var1"] = {"bn"};
2261     dependency_["control_dep"] = {"bn"};
2262   }
2263 
2264   // Call this after creating grappler_item_ and setting up dependency_.
InitScheduler()2265   void InitScheduler() { TF_ASSERT_OK(scheduler_->Init(grappler_item_.get())); }
2266 
2267   // Returns cost based on op.
SimplePredictCosts(const OpContext & op_context) const2268   Costs SimplePredictCosts(const OpContext& op_context) const {
2269     Costs c;
2270     int64_t exec_cost = 0;
2271     if (op_context.op_info.op() == "MatMul") {
2272       exec_cost = 2000000000;
2273     } else if (op_context.op_info.op() == "RandomUniform") {
2274       exec_cost = 1000000000;
2275     } else {
2276       exec_cost = 1000;
2277     }
2278     c.execution_time = Costs::NanoSeconds(exec_cost);
2279     return c;
2280   }
2281 
2282   // Call this after init scheduler_. Scheduler stops after executing
2283   // target_node.
RunScheduler(const string & target_node)2284   std::unordered_map<string, OpContext> RunScheduler(
2285       const string& target_node) {
2286     std::unordered_map<string, OpContext> ops_executed;
2287     bool more_nodes = true;
2288     do {
2289       OpContext op_context = scheduler_->GetCurrNode();
2290       ops_executed[op_context.name] = op_context;
2291       std::cout << op_context.name << std::endl;
2292 
2293       Costs node_costs = SimplePredictCosts(op_context);
2294 
2295       // Check scheduling order.
2296       auto it = dependency_.find(op_context.name);
2297       if (it != dependency_.end()) {
2298         for (const auto& preceding_node : it->second) {
2299           EXPECT_GT(ops_executed.count(preceding_node), 0);
2300         }
2301       }
2302       more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
2303 
2304       if (op_context.name == target_node) {
2305         // Scheduler has the state after executing the target node.
2306         break;
2307       }
2308     } while (more_nodes);
2309     return ops_executed;
2310   }
2311 
2312   // Helper method for validating a vector.
2313   template <typename T>
ExpectVectorEq(const std::vector<T> & expected,const std::vector<T> & test_elements)2314   void ExpectVectorEq(const std::vector<T>& expected,
2315                       const std::vector<T>& test_elements) {
2316     // Set of expected elements for an easy comparison.
2317     std::set<T> expected_set(expected.begin(), expected.end());
2318     for (const auto& element : test_elements) {
2319       EXPECT_GT(expected_set.count(element), 0);
2320     }
2321     EXPECT_EQ(expected.size(), test_elements.size());
2322   }
2323 
2324   // Helper method that checks the name of nodes.
ValidateNodeDefs(const std::vector<string> & expected,const std::vector<const NodeDef * > & node_defs)2325   void ValidateNodeDefs(const std::vector<string>& expected,
2326                         const std::vector<const NodeDef*>& node_defs) {
2327     std::vector<string> node_names;
2328     std::transform(node_defs.begin(), node_defs.end(),
2329                    std::back_inserter(node_names),
2330                    [](const NodeDef* node) { return node->name(); });
2331     ExpectVectorEq(expected, node_names);
2332   }
2333 
2334   // Helper method for validating a set.
2335   template <typename T>
ExpectSetEq(const std::set<T> & expected,const std::set<T> & test_elements)2336   void ExpectSetEq(const std::set<T>& expected,
2337                    const std::set<T>& test_elements) {
2338     for (const auto& element : test_elements) {
2339       EXPECT_GT(expected.count(element), 0);
2340     }
2341     EXPECT_EQ(expected.size(), test_elements.size());
2342   }
2343 
2344   // Helper method for validating an unordered map.
2345   template <typename T, typename U>
ExpectUnorderedMapEq(const std::unordered_map<T,U> & expected,const std::unordered_map<T,U> & test_map)2346   void ExpectUnorderedMapEq(const std::unordered_map<T, U>& expected,
2347                             const std::unordered_map<T, U>& test_map) {
2348     EXPECT_EQ(expected.size(), test_map.size());
2349     for (const auto& key_val : expected) {
2350       EXPECT_GT(test_map.count(key_val.first), 0);
2351       EXPECT_EQ(test_map.at(key_val.first), key_val.second);
2352     }
2353   }
2354 
2355   // Helper method that checks name - port pairs.
ValidateMemoryUsageSnapshot(const std::vector<string> & expected_names,const int port_num_expected,const std::unordered_set<std::pair<const NodeDef *,int>,DeviceState::NodePairHash> & mem_usage_snapshot)2356   void ValidateMemoryUsageSnapshot(
2357       const std::vector<string>& expected_names, const int port_num_expected,
2358       const std::unordered_set<std::pair<const NodeDef*, int>,
2359                                DeviceState::NodePairHash>& mem_usage_snapshot) {
2360     std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
2361     std::transform(
2362         mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
2363         std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
2364         [](const std::pair<const NodeDef*, int>& node_port) {
2365           return std::make_pair(node_port.first->name(), node_port.second);
2366         });
2367     std::set<std::pair<string, int>> expected;
2368     std::transform(expected_names.begin(), expected_names.end(),
2369                    std::inserter(expected, expected.begin()),
2370                    [port_num_expected](const string& name) {
2371                      return std::make_pair(name, port_num_expected);
2372                    });
2373     ExpectSetEq(expected, nodes_at_peak_mem_usage);
2374   }
2375 
2376   // Helper method for checking nodes dependency.
ValidateDependencyChain(const std::unordered_map<string,int64_t> & start_times,const std::vector<string> & nodes_in_dependency_order)2377   void ValidateDependencyChain(
2378       const std::unordered_map<string, int64_t>& start_times,
2379       const std::vector<string>& nodes_in_dependency_order) {
2380     int64_t prev_node_time = -1;
2381     for (const auto& node : nodes_in_dependency_order) {
2382       int64_t curr_node_time = start_times.at(node);
2383       EXPECT_GE(curr_node_time, prev_node_time);
2384       prev_node_time = curr_node_time;
2385     }
2386   }
2387 
2388   // cluster_ and scheduler_ are initialized in the c'tor.
2389   std::unique_ptr<VirtualCluster> cluster_;
2390   std::unique_ptr<TestVirtualScheduler> scheduler_;
2391   FirstReadyManager first_ready_manager_;
2392   CompositeNodeManager composite_node_manager_;
2393 
2394   // grappler_item_ will be initialized differently for each test case.
2395   std::unique_ptr<GrapplerItem> grappler_item_;
2396   // Node name -> its preceding nodes map for testing scheduling order.
2397   std::unordered_map<string, std::vector<string>> dependency_;
2398 
2399   // Shared params for Conv2D related graphs:
2400   const int batch_size_ = 4;
2401   const int width_ = 10;
2402   const int height_ = 10;
2403   const int depth_in_ = 8;
2404   const int kernel_ = 3;
2405   const int depth_out_ = 16;
2406 };
2407 
2408 // Create small graph, run predict costs on it, make sure the costs from the
2409 // summary match the hand-calculated costs.
TEST_F(VirtualSchedulerTest,SummaryCostTest)2410 TEST_F(VirtualSchedulerTest, SummaryCostTest) {
2411   // Run matmul test.
2412   CreateGrapplerItemWithMatmulChain();
2413   InitScheduler();
2414   auto ops_executed = RunScheduler("");
2415   Costs c = scheduler_->Summary();
2416 
2417   // RandomUniform - 5 * 1s
2418   // Matmuls - 4 * 2s = 8
2419   // Misc - 5 * 1us
2420   // Total: 13000005
2421   EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2422   EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2423   EXPECT_FALSE(c.inaccurate);
2424   EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2425 }
2426 
2427 // Like the above SummaryCostTest, but makes sure the stepstats timeline is
2428 // correct.
TEST_F(VirtualSchedulerTest,SummaryCostStepStatsTest)2429 TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
2430   // Run matmul test.
2431   CreateGrapplerItemWithMatmulChain();
2432   InitScheduler();
2433   auto ops_executed = RunScheduler("");
2434   RunMetadata metadata;
2435   Costs c = scheduler_->Summary(&metadata);
2436   StepStats stepstats = metadata.step_stats();
2437   EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2438   EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2439   EXPECT_FALSE(c.inaccurate);
2440   EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2441 
2442   // Should only be 1 device!
2443   EXPECT_EQ(1, stepstats.dev_stats().size());
2444 
2445   // Create a map of op name -> start and end times (micros).
2446   std::map<string, std::pair<int64_t, int64_t>> start_end_times;
2447   for (const auto& device_step_stats : stepstats.dev_stats()) {
2448     for (const auto& stats : device_step_stats.node_stats()) {
2449       int64_t start = stats.all_start_micros();
2450       int64_t end = start + stats.all_end_rel_micros();
2451       start_end_times[stats.node_name()] =
2452           std::pair<int64_t, int64_t>(start, end);
2453 
2454       // Make sure that the output properties are correct for
2455       // MatMul and RandomUniform operations.
2456       // We only check for dtype, and shape (excluding alloc)
2457       // since alloc is not set by the virtual scheduler.
2458       if (stats.timeline_label() == "MatMul" ||
2459           stats.timeline_label() == "RandomUniform") {
2460         EXPECT_EQ(1, stats.output().size());
2461         for (const auto& output : stats.output()) {
2462           EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype());
2463           EXPECT_EQ(2, output.tensor_description().shape().dim().size());
2464           for (const auto& dim : output.tensor_description().shape().dim()) {
2465             EXPECT_EQ(3200, dim.size());
2466           }
2467         }
2468       }
2469     }
2470   }
2471 
2472   // The base start_time is the time to compute RandomUniforms
2473   int64_t cur_time = static_cast<int64_t>(5000005);
2474   // The increment is the execution time of one matmul. See
2475   // CreateGrapplerItemWithMatmulChain for details.
2476   int64_t increment = static_cast<int64_t>(2000000);
2477   auto op_names = {"ab", "abc", "abcd", "abcde"};
2478   for (const auto& op_name : op_names) {
2479     int64_t actual_start = start_end_times[op_name].first;
2480     int64_t actual_end = start_end_times[op_name].second;
2481     int64_t expected_start = cur_time;
2482     int64_t expected_end = cur_time + increment;
2483     EXPECT_EQ(expected_start, actual_start);
2484     EXPECT_EQ(expected_end, actual_end);
2485     cur_time += increment;
2486   }
2487 }
2488 
TEST_F(VirtualSchedulerTest,InitAndBasicScheduling)2489 TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
2490   // Init.
2491   CreateGrapplerItemWithConv2Ds();
2492   InitScheduler();
2493 
2494   // Run the scheduler.
2495   auto ops_executed = RunScheduler("");  // Run all the nodes.
2496 
2497   // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
2498   // executed.
2499   EXPECT_EQ(8, ops_executed.size());
2500 
2501   // x, y, f, c0, and c1 should be in the ops executed.
2502   EXPECT_GT(ops_executed.count("x"), 0);
2503   EXPECT_GT(ops_executed.count("y"), 0);
2504   EXPECT_GT(ops_executed.count("f"), 0);
2505   EXPECT_GT(ops_executed.count("c0"), 0);
2506   EXPECT_GT(ops_executed.count("c1"), 0);
2507 
2508   // z and c2 shouldn't be part of it.
2509   EXPECT_EQ(ops_executed.count("z"), 0);
2510   EXPECT_EQ(ops_executed.count("c2"), 0);
2511 
2512   // Check input / output properties.
2513   EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size());
2514   EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size());
2515   EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size());
2516   EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
2517   EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
2518 }
2519 
TEST_F(VirtualSchedulerTest,MemoryUsage)2520 TEST_F(VirtualSchedulerTest, MemoryUsage) {
2521   // Init.
2522   CreateGrapplerItemWithAddN();
2523   InitScheduler();
2524 
2525   // Run the scheduler.
2526   RunScheduler("");
2527 
2528   const auto* device_states = scheduler_->GetDeviceStates();
2529   const auto& cpu_state = device_states->at(kCPU0);
2530 
2531   // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
2532   // is 4 x the input tensor size while executing the out node.
2533   int64_t one_input_node_size = 4 * 10 * 10 * 10 * 10;
2534   const std::vector<string> expected_names = {"x", "y", "z", "w", "add"};
2535   EXPECT_EQ(expected_names.size() * one_input_node_size,
2536             cpu_state.max_memory_usage);
2537   ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
2538                               cpu_state.mem_usage_snapshot_at_peak);
2539 
2540   // Total 10 nodes: Four const, x, y, z, w, add, out.
2541   ASSERT_EQ(cpu_state.temporary_memory_usage_trace.size(), 10);
2542   const std::pair<std::string, int64_t>& x_usage =
2543       cpu_state.temporary_memory_usage_trace.at(4);
2544   EXPECT_EQ(x_usage.first, "x");
2545   EXPECT_EQ(x_usage.second, one_input_node_size);
2546   const std::pair<std::string, int64_t>& add_usage =
2547       cpu_state.temporary_memory_usage_trace.at(8);
2548   EXPECT_EQ(add_usage.first, "add");
2549   EXPECT_EQ(add_usage.second, 5 * one_input_node_size);
2550   const std::pair<std::string, int64_t>& out_usage =
2551       cpu_state.temporary_memory_usage_trace.at(9);
2552   EXPECT_EQ(out_usage.first, "out");
2553   EXPECT_EQ(out_usage.second, one_input_node_size);
2554   ExpectUnorderedMapEq(
2555       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 64)},
2556       scheduler_->GetPersistentMemoryUsage());
2557   ExpectUnorderedMapEq(
2558       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 200000)},
2559       scheduler_->GetPeakMemoryUsage());
2560 }
2561 
TEST_F(VirtualSchedulerTest,MemoryUsageForStreamingOps)2562 TEST_F(VirtualSchedulerTest, MemoryUsageForStreamingOps) {
2563   // Init.
2564   CreateGrapplerItemWithAddN();
2565   auto& graph = grappler_item_->graph;
2566   // Nodes add and out are placed on CPU1.
2567   // Nodes x, y are allocate in memory, while Nodes z and w are streaming nodes.
2568   for (auto& node : *graph.mutable_node()) {
2569     if (node.name() == "out" || node.name() == "add") {
2570       node.set_device(kCPU1);
2571     }
2572     if (node.name() == "z" || node.name() == "w")
2573       (*node.mutable_attr())[kStreaming].mutable_list()->add_b(true);
2574   }
2575 
2576   InitScheduler();
2577 
2578   // Run the scheduler.
2579   auto ops_executed = RunScheduler("");
2580 
2581   const auto* device_states = scheduler_->GetDeviceStates();
2582   const auto& cpu_state_0 = device_states->at(kCPU0);
2583   const auto& cpu_state_1 = device_states->at(kCPU1);
2584   // All tensors are of the same size, 10 x 10 x 10 x 10.
2585   int64_t one_input_node_size = 4 * 10 * 10 * 10 * 10;
2586   const std::vector<string> cpu_0_expected_tensors = {"x", "y"};
2587   const std::vector<string> cpu_1_expected_tensors = {"x", "y", "add"};
2588   EXPECT_EQ(cpu_0_expected_tensors.size() * one_input_node_size,
2589             cpu_state_0.max_memory_usage);
2590   EXPECT_EQ(cpu_1_expected_tensors.size() * one_input_node_size,
2591             cpu_state_1.max_memory_usage);
2592   // After the graph is executed, at the end, memory usage for the device
2593   // should be zero.
2594   EXPECT_EQ(cpu_state_0.memory_usage, 0);
2595   EXPECT_EQ(cpu_state_1.memory_usage, 0);
2596 }
2597 
TEST_F(VirtualSchedulerTest,MemoryUsageWithExecutionCount)2598 TEST_F(VirtualSchedulerTest, MemoryUsageWithExecutionCount) {
2599   // Init.
2600   CreateGrapplerItemWithAddN();
2601   auto& graph = grappler_item_->graph;
2602   // Repeat execution for each node.
2603   for (auto& node : *graph.mutable_node()) {
2604     (*node.mutable_attr())[kExecutionCount].set_i(10000);
2605   }
2606 
2607   InitScheduler();
2608 
2609   // Run the scheduler.
2610   auto ops_executed = RunScheduler("");
2611 
2612   const auto* device_states = scheduler_->GetDeviceStates();
2613   const auto& cpu_state_0 = device_states->at(kCPU0);
2614   // All tensors are of the same size, 10 x 10 x 10 x 10.
2615   int64_t one_input_node_size = 4 * 10 * 10 * 10 * 10;
2616   const std::vector<string> expected_names = {"x", "y", "z", "w", "add"};
2617   // Max memory usage does not rely on the number of executions.
2618   EXPECT_EQ(expected_names.size() * one_input_node_size,
2619             cpu_state_0.max_memory_usage);
2620   // After the graph is executed, at the end, memory usage for the device
2621   // should be zero.
2622   EXPECT_EQ(cpu_state_0.memory_usage, 0);
2623 
2624   Costs c = scheduler_->Summary();
2625   EXPECT_EQ(64, c.persistent_memory);
2626   EXPECT_EQ(200000, c.temporary_memory);
2627   EXPECT_EQ(200064, c.max_memory);
2628 }
2629 
TEST_F(VirtualSchedulerTest,UnnecessaryFeedNodes)2630 TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) {
2631   CreateGrapplerItemWithUnnecessaryPlaceholderNodes();
2632   InitScheduler();
2633 
2634   // Test that scheduler can run graphs with extra unnecessary feed nodes.
2635   auto ops_executed = RunScheduler("");
2636   ASSERT_EQ(1, ops_executed.size());
2637   ASSERT_EQ(ops_executed.count("x"), 1);
2638 }
2639 
TEST_F(VirtualSchedulerTest,ControlDependency)2640 TEST_F(VirtualSchedulerTest, ControlDependency) {
2641   // Init.
2642   CreateGrapplerItemWithControlDependency();
2643   InitScheduler();
2644 
2645   // Run the scheduler.
2646   RunScheduler("");
2647 
2648   const auto* device_states = scheduler_->GetDeviceStates();
2649   const auto& cpu_state = device_states->at(kCPU0);
2650 
2651   // The graph has a NoOp that takes control dependency from 7 NoOps. The peak
2652   // memory usage is when executing the final NoOp.
2653   int64_t one_input_node_size = 4;  // control dependency
2654   const std::vector<string> expected_names = {"x", "y", "z", "w",
2655                                               "u", "v", "t"};
2656   EXPECT_EQ(expected_names.size() * one_input_node_size,
2657             cpu_state.max_memory_usage);
2658   ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
2659                               cpu_state.mem_usage_snapshot_at_peak);
2660   ExpectUnorderedMapEq(
2661       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 0)},
2662       scheduler_->GetPersistentMemoryUsage());
2663   ExpectUnorderedMapEq(
2664       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 28)},
2665       scheduler_->GetPeakMemoryUsage());
2666 }
2667 
TEST_F(VirtualSchedulerTest,ComplexDependency)2668 TEST_F(VirtualSchedulerTest, ComplexDependency) {
2669   // Init.
2670   CreateGrapplerItemWithBatchNorm();
2671   InitScheduler();
2672 
2673   // Run the scheduler.
2674   RunScheduler("bn");
2675 
2676   const auto& device_states = scheduler_->GetDeviceStates();
2677   const auto& cpu_state = device_states->at(kCPU0);
2678 
2679   // The graph is
2680   //  bn = FusedBatchNorm(x, scale, offset, mean, var)
2681   //  z1 = bn.y + x
2682   //  z2 = bn.var + bn.var
2683   //  z3 = bn.var + bn.var
2684   //  z4 = control dependency from bn.
2685   //  Note that bn.mean doesn't have any consumer.
2686   const int x_size = batch_size_ * width_ * height_ * depth_in_;
2687   int64_t expected_size =
2688       4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
2689            1 /* control dependency */);
2690   EXPECT_EQ(expected_size, cpu_state.memory_usage);
2691 
2692   // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0.
2693   std::set<std::pair<string, int>> nodes_in_memory;
2694   std::transform(
2695       cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
2696       std::inserter(nodes_in_memory, nodes_in_memory.begin()),
2697       [](const std::pair<const NodeDef*, int>& node_port) {
2698         return std::make_pair(node_port.first->name(), node_port.second);
2699       });
2700   std::set<std::pair<string, int>> expected = {
2701       std::make_pair("bn", -1),
2702       std::make_pair("bn", 0),
2703       std::make_pair("bn", 2),
2704       std::make_pair("x", 0),
2705   };
2706   ExpectSetEq(expected, nodes_in_memory);
2707 
2708   const auto* node_states = scheduler_->GetNodeStates();
2709   const NodeState* bn_node = nullptr;
2710   const NodeState* x_node = nullptr;
2711   for (const auto& nodedef_node_state : *node_states) {
2712     const NodeDef* node = nodedef_node_state.first;
2713     const NodeState& node_state = nodedef_node_state.second;
2714     if (node->name() == "bn") {
2715       bn_node = &node_state;
2716     }
2717     if (node->name() == "x") {
2718       x_node = &node_state;
2719     }
2720   }
2721   CHECK_NOTNULL(bn_node);
2722   CHECK_NOTNULL(x_node);
2723 
2724   ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
2725   ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
2726   ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
2727   // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
2728   ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
2729 }
2730 
TEST_F(VirtualSchedulerTest,Variable)2731 TEST_F(VirtualSchedulerTest, Variable) {
2732   // Init.
2733   CreateGrapplerItemWithConv2DAndVariable();
2734   InitScheduler();
2735 
2736   // Run the scheduler.
2737   RunScheduler("");
2738 
2739   const auto* device_states = scheduler_->GetDeviceStates();
2740   const auto& cpu_state = device_states->at(kCPU0);
2741 
2742   // There is one Conv2D that takes x and f, but f is variable, so it should be
2743   // in persistent nodes.
2744   ValidateMemoryUsageSnapshot({"f", "Const/Const"}, /*port_num_expected=*/0,
2745                               cpu_state.persistent_nodes);
2746   // Only x in peak memory usage snapshot.
2747   ValidateMemoryUsageSnapshot({"x"}, /*port_num_expected=*/0,
2748                               cpu_state.mem_usage_snapshot_at_peak);
2749   ExpectUnorderedMapEq(
2750       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 4624)},
2751       scheduler_->GetPersistentMemoryUsage());
2752   ExpectUnorderedMapEq(
2753       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 12800)},
2754       scheduler_->GetPeakMemoryUsage());
2755 }
2756 
TEST_F(VirtualSchedulerTest,WhileLoop)2757 TEST_F(VirtualSchedulerTest, WhileLoop) {
2758   // Init.
2759   CreateGrapplerItemWithLoop();
2760   InitScheduler();
2761 
2762   // Run the scheduler.
2763   RunScheduler("");
2764 
2765   // Check the timeline
2766   RunMetadata metadata;
2767   scheduler_->Summary(&metadata);
2768 
2769   // Nodes in topological order:
2770   // * const, ones
2771   // * while/Enter, while/Enter_1
2772   // * while/Merge, while/Merge_1
2773   // * while/Less/y
2774   // * while/Less
2775   // * while/LoopCond
2776   // * while/Switch, while/Switch_1
2777   // * while/Identity, while/Identity_1, while/Exit, while/Exit_1
2778   // * while/add/y, while/concat/axis
2779   // * while/add, while/concat
2780   // * while/NextIteration, while/NextIteration_1
2781 
2782   int num_next_iteration = 0;
2783   int num_next_iteration_1 = 0;
2784   int num_exit = 0;
2785   int num_exit_1 = 0;
2786   int64_t next_iter_start_micro;
2787   int64_t next_iter_1_start_micro;
2788   int64_t exit_start_micro;
2789   int64_t exit_1_start_micro;
2790 
2791   std::unordered_map<string, int64_t> start_times;
2792   for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2793     for (const auto& stats : device_step_stats.node_stats()) {
2794       start_times[stats.node_name()] = stats.all_start_micros();
2795       if (stats.node_name() == "while/NextIteration") {
2796         ++num_next_iteration;
2797         next_iter_start_micro = stats.all_start_micros();
2798       } else if (stats.node_name() == "while/NextIteration_1") {
2799         ++num_next_iteration_1;
2800         next_iter_1_start_micro = stats.all_start_micros();
2801       } else if (stats.node_name() == "while/Exit") {
2802         ++num_exit;
2803         exit_start_micro = stats.all_start_micros();
2804       } else if (stats.node_name() == "while/Exit_1") {
2805         ++num_exit_1;
2806         exit_1_start_micro = stats.all_start_micros();
2807       }
2808     }
2809   }
2810 
2811   // Make sure we went though the body of the loop once, and that the output of
2812   // the loop was scheduled as well.
2813   EXPECT_EQ(1, num_next_iteration);
2814   EXPECT_EQ(1, num_next_iteration_1);
2815   EXPECT_EQ(1, num_exit);
2816   EXPECT_EQ(1, num_exit_1);
2817 
2818   // Start times of while/NextIteration and while/NextIteration_1 should be
2819   // different, so should be those of while/Exit and while/Exit_1.
2820   EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
2821   EXPECT_NE(exit_start_micro, exit_1_start_micro);
2822 
2823   // Check dependency among the nodes; no matter what scheduling mechanism we
2824   // use, the scheduled ops should follow these dependency chains.
2825   // Note that currently, VirtualScheduler executes while/Merge twice; hence,
2826   // we're not testing dependency chains related to while/Merge.
2827   // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the
2828   // order of Enter, Merge, ...loop condition ..., ... loop body ...,
2829   // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency
2830   // chaining test w/ Merge nodes.
2831   ValidateDependencyChain(
2832       start_times,
2833       {"Const", "while/Enter",  // "while/Merge",
2834        "while/Less/y", "while/Less", "while/LoopCond", "while/Switch",
2835        "while/Identity", "while/add/y", "while/add", "while/NextIteration"});
2836   // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"});
2837   ValidateDependencyChain(start_times,
2838                           {"ones", "while/Enter_1",  // "while/Merge_1",
2839                            "while/Switch_1", "while/Identity_1", "while/concat",
2840                            "while/NextIteration_1"});
2841   ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"});
2842   ValidateDependencyChain(
2843       start_times, {"while/Identity", "while/concat/axis", "while/concat"});
2844   ValidateDependencyChain(start_times, {"while/Identity", "while/add"});
2845   ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"});
2846 }
2847 
TEST_F(VirtualSchedulerTest,AnnotatedWhileLoop)2848 TEST_F(VirtualSchedulerTest, AnnotatedWhileLoop) {
2849   {
2850     // Init.
2851     CreateGrapplerItemWithLoop();
2852     InitScheduler();
2853 
2854     // Runs the scheduler.
2855     RunScheduler("");
2856     Costs c = scheduler_->Summary();
2857 
2858     EXPECT_EQ(23, c.execution_time.asMicroSeconds().count());
2859     // Both while/Merge and while/Merge_1 are scheduled twice.
2860     EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2861     EXPECT_FALSE(c.inaccurate);
2862     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2863   }
2864 
2865   {
2866     // Init.
2867     CreateGrapplerItemWithLoopAnnotated();
2868     InitScheduler();
2869 
2870     // Runs the scheduler.
2871     RunScheduler("");
2872     Costs c = scheduler_->Summary();
2873 
2874     // The costs for Merge is accumulated twice for execution_count times, but
2875     // since Merge's cost is minimal, we keep this behavior here.
2876     EXPECT_EQ(178, c.execution_time.asMicroSeconds().count());
2877     // Both while/Merge and while/Merge_1 are scheduled twice.
2878     EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2879     EXPECT_FALSE(c.inaccurate);
2880     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2881   }
2882 }
2883 
TEST_F(VirtualSchedulerTest,Condition)2884 TEST_F(VirtualSchedulerTest, Condition) {
2885   // Without annotation.
2886   {
2887     // Inits.
2888     CreateGrapplerItemWithCondition();
2889     InitScheduler();
2890 
2891     // Runs the scheduler.
2892     RunScheduler("");
2893     RunMetadata metadata;
2894     Costs c = scheduler_->Summary(&metadata);
2895 
2896     // Nodes in topological order: a/Less, Switch, First/Second, Merge.
2897     int num_a = 0;
2898     int num_less = 0;
2899     int num_switch = 0;
2900     int num_first = 0;
2901     int num_second = 0;
2902     int num_merge = 0;
2903 
2904     for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2905       for (const auto& stats : device_step_stats.node_stats()) {
2906         if (stats.node_name() == "a") {
2907           ++num_a;
2908         } else if (stats.node_name() == "Less") {
2909           ++num_less;
2910         } else if (stats.node_name() == "Switch") {
2911           ++num_switch;
2912         } else if (stats.node_name() == "First") {
2913           ++num_first;
2914         } else if (stats.node_name() == "Second") {
2915           ++num_second;
2916         } else if (stats.node_name() == "Merge") {
2917           ++num_merge;
2918         }
2919       }
2920     }
2921 
2922     EXPECT_EQ(1, num_a);
2923     EXPECT_EQ(1, num_less);
2924     EXPECT_EQ(1, num_switch);
2925     EXPECT_EQ(1, num_first);
2926     EXPECT_EQ(1, num_second);
2927     EXPECT_EQ(2, num_merge);
2928 
2929     EXPECT_EQ(7, c.execution_time.asMicroSeconds().count());
2930     // Merge is executed twice.
2931     EXPECT_EQ(grappler_item_->graph.node_size() + 1, c.num_ops_total);
2932     EXPECT_FALSE(c.inaccurate);
2933     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2934   }
2935 
2936   // With annotation.
2937   {
2938     // Inits.
2939     CreateGrapplerItemWithCondition();
2940 
2941     // Annotates the Switch node.
2942     for (auto& node : *grappler_item_->graph.mutable_node()) {
2943       if (node.name() == "Switch") {
2944         AttrValue attr_output_info;
2945         // Adds one output slot 0 so that Second shouldn't be executed.
2946         (*attr_output_info.mutable_list()).add_i(0);
2947         AddNodeAttr(kOutputSlots, attr_output_info, &node);
2948       }
2949     }
2950 
2951     InitScheduler();
2952 
2953     // Runs the scheduler.
2954     RunScheduler("");
2955     RunMetadata metadata;
2956     Costs c = scheduler_->Summary(&metadata);
2957 
2958     // Nodes in topological order: a/Less, Switch, Merge
2959     int num_a = 0;
2960     int num_less = 0;
2961     int num_switch = 0;
2962     int num_first = 0;
2963     int num_second = 0;
2964     int num_merge = 0;
2965 
2966     for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2967       for (const auto& stats : device_step_stats.node_stats()) {
2968         if (stats.node_name() == "a") {
2969           ++num_a;
2970         } else if (stats.node_name() == "Less") {
2971           ++num_less;
2972         } else if (stats.node_name() == "Switch") {
2973           ++num_switch;
2974         } else if (stats.node_name() == "First") {
2975           ++num_first;
2976         } else if (stats.node_name() == "Second") {
2977           ++num_second;
2978         } else if (stats.node_name() == "Merge") {
2979           ++num_merge;
2980         }
2981       }
2982     }
2983 
2984     EXPECT_EQ(1, num_a);
2985     EXPECT_EQ(1, num_less);
2986     EXPECT_EQ(1, num_switch);
2987     EXPECT_EQ(1, num_first);
2988     EXPECT_EQ(0, num_second);
2989     EXPECT_EQ(1, num_merge);
2990 
2991     EXPECT_EQ(5, c.execution_time.asMicroSeconds().count());
2992     // Second is not executed.
2993     EXPECT_EQ(grappler_item_->graph.node_size() - 1, c.num_ops_total);
2994     EXPECT_FALSE(c.inaccurate);
2995     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2996   }
2997 }
2998 
TEST_F(VirtualSchedulerTest,InterDeviceTransfer)2999 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
3000   // Init.
3001   CreateGrapplerItemWithInterDeviceTransfers();
3002   InitScheduler();
3003 
3004   // Run the scheduler.
3005   auto ops_executed = RunScheduler("");
3006 
3007   // Helper lambda to extract port num from _Send and _Recv op name.
3008   auto get_port_num = [](const string& name) -> int {
3009     if (absl::StrContains(name, "bn_0")) {
3010       return 0;
3011     } else if (absl::StrContains(name, "bn_1")) {
3012       return 1;
3013     } else if (absl::StrContains(name, "bn_2")) {
3014       return 2;
3015     } else if (absl::StrContains(name, "bn_minus1")) {
3016       return -1;
3017     }
3018     return -999;
3019   };
3020 
3021   // Reorganize ops_executed for further testing.
3022   std::unordered_map<string, int> op_count;
3023   std::unordered_map<int, string> recv_op_names;
3024   std::unordered_map<int, string> send_op_names;
3025   for (const auto& x : ops_executed) {
3026     const auto& name = x.first;
3027     const auto& node_info = x.second;
3028     const auto& op = node_info.op_info.op();
3029     if (op == kRecv) {
3030       recv_op_names[get_port_num(name)] = name;
3031     } else if (op == kSend) {
3032       send_op_names[get_port_num(name)] = name;
3033     }
3034     op_count[op]++;
3035   }
3036 
3037   // Same number of _Send and _Recv.
3038   EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv));
3039 
3040   // Expect 3 Send and Recvs each: port 0, 1, and, 2.
3041   // Control dependency bypasses the channel.
3042   EXPECT_EQ(op_count.at(kRecv), 3);
3043   EXPECT_EQ(op_count.at(kSend), 3);
3044 
3045   // Helper lambda for extracting output Tensor size.
3046   auto get_output_size = [this, ops_executed](const string& name) -> int64 {
3047     const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
3048     std::vector<OpInfo::TensorProperties> output_properties;
3049     for (const auto& output_property : output_properties_) {
3050       output_properties.push_back(output_property);
3051     }
3052     return CalculateOutputSize(output_properties, 0);
3053   };
3054 
3055   // Validate transfer size.
3056   // Batchnorm output y is 4D vector: batch x width x width x depth.
3057   int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
3058   EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
3059   EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
3060   // Mean and vars are 1-D vector with size depth_in_.
3061   EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
3062   EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
3063   EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
3064   EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
3065 }
3066 
TEST_F(VirtualSchedulerTest,GraphWithSendRecv)3067 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
3068   // Init.
3069   CreateGrapplerItemWithSendRecv();
3070   InitScheduler();
3071 
3072   // Run the scheduler.
3073   auto ops_executed = RunScheduler("");
3074 
3075   EXPECT_GT(ops_executed.count("Const"), 0);
3076   EXPECT_GT(ops_executed.count("Send"), 0);
3077   EXPECT_GT(ops_executed.count("Recv"), 0);
3078 }
3079 
TEST_F(VirtualSchedulerTest,GraphWithSendRecvDifferentDevice)3080 TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
3081   // Init.
3082   CreateGrapplerItemWithSendRecv();
3083   // Change Recv node's device so that Send and Recv are placed on different
3084   // devices.
3085   auto& graph = grappler_item_->graph;
3086   const string recv_device = kCPU1;
3087   for (int i = 0; i < graph.node_size(); i++) {
3088     auto* node = graph.mutable_node(i);
3089     if (node->name() == "Recv") {
3090       node->set_device(recv_device);
3091       auto* attr = node->mutable_attr();
3092       (*attr)["recv_device"].set_s(recv_device);
3093     } else if (node->name() == "Send") {
3094       auto* attr = node->mutable_attr();
3095       (*attr)["recv_device"].set_s(recv_device);
3096     }
3097   }
3098   InitScheduler();
3099 
3100   // Run the scheduler.
3101   auto ops_executed = RunScheduler("");
3102 
3103   // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
3104   EXPECT_GT(ops_executed.count("Const"), 0);
3105   EXPECT_GT(ops_executed.count("Send"), 0);
3106   EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
3107                                "task_0/cpu_0_to_/job_localhost"
3108                                "/replica_0/task_0/cpu_1"),
3109             0);
3110   EXPECT_GT(ops_executed.count(
3111                 "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
3112             0);
3113   EXPECT_GT(ops_executed.count("Recv"), 0);
3114 }
3115 
TEST_F(VirtualSchedulerTest,GraphWihtOnlyRecv)3116 TEST_F(VirtualSchedulerTest, GraphWihtOnlyRecv) {
3117   // Init.
3118   CreateGrapplerItemWithRecvWithoutSend();
3119   InitScheduler();
3120 
3121   // Run the scheduler.
3122   auto ops_executed = RunScheduler("");
3123 
3124   // Recv without Send will be treated as initially ready node.
3125   EXPECT_GT(ops_executed.count("Recv"), 0);
3126 }
3127 
TEST_F(VirtualSchedulerTest,AddMergeSwitch)3128 TEST_F(VirtualSchedulerTest, AddMergeSwitch) {
3129   // Override scheduler_ with CompositeNodeManager.
3130   scheduler_ = std::make_unique<TestVirtualScheduler>(
3131       /*use_static_shapes=*/true,
3132       /*use_aggressive_shape_inference=*/true, &composite_node_manager_,
3133       cluster_.get());
3134   CreateGrapplerItemWithSwitchMergeInput();
3135   InitScheduler();
3136 
3137   // pred --+                      z --+
3138   //        |                          |
3139   //        V                          V
3140   // x -> Switch --------> Merge ---> Add --> y
3141   //        |                ^
3142   //        |                |
3143   //        +-----> Add -----+
3144   //                 ^
3145   //                 |
3146   // b --------------+
3147 
3148   // Run the scheduler. The current VirtualScheduler, w/o annotation, triggers
3149   // both outputs of Switch; then Merge (as long as one input is ready, it's z
3150   // is ready, if we just use num_inputs_ready counter, the final Add becomes
3151   // ready. possible to skip scheduling z. (Need to use CompositeNodeManager
3152   // to test this case).
3153   auto ops_executed = RunScheduler("");
3154 
3155   EXPECT_GT(ops_executed.count("z"), 0);
3156 }
3157 
TEST_F(VirtualSchedulerTest,AddFromOneTensor)3158 TEST_F(VirtualSchedulerTest, AddFromOneTensor) {
3159   CreateGrapplerItemWithAddFromOneTensor();
3160   InitScheduler();
3161 
3162   // x -+----> Add --> y
3163   //    |       ^
3164   //    |       |
3165   //    +-------+
3166 
3167   // Run the scheduler.
3168   auto ops_executed = RunScheduler("");
3169   EXPECT_GT(ops_executed.count("y"), 0);
3170   EXPECT_GT(ops_executed.count("x"), 0);
3171 }
3172 
TEST_F(VirtualSchedulerTest,TestNodeCostOutputTensorSize)3173 TEST_F(VirtualSchedulerTest, TestNodeCostOutputTensorSize) {
3174   // Create a schedule with more than 2 ops to be executed.
3175   CreateGrapplerItemWithMatmulChain();
3176   InitScheduler();
3177   RunScheduler("ab");
3178 
3179   int32_t persistent_memory_before =
3180       scheduler_->GetPersistentMemoryUsage().at(kCPU0);
3181 
3182   auto* device_states = scheduler_->GetDeviceStates();
3183   int32_t memory_usage = device_states->at(kCPU0).memory_usage;
3184 
3185   // Set temporary/persistent memory to some values for the first node cost.
3186   Costs node_costs = Costs::ZeroCosts(false);
3187 
3188   const int32_t node_one_cost = 12345;
3189   const int32_t node_two_cost = 98765;
3190 
3191   const int32_t input_size = 4 * 3200 * 3200;
3192 
3193   node_costs.persistent_memory = node_one_cost;
3194   node_costs.temporary_memory = 0;
3195   node_costs.output_tensor_size_bytes = {{0, node_one_cost}};
3196   node_costs.persistent_output_ports = {0};
3197 
3198   // Mark first node executed and check device state.
3199   scheduler_->MarkCurrNodeExecuted(node_costs);
3200   device_states = scheduler_->GetDeviceStates();
3201   const auto& cpu_state_0 = device_states->at(kCPU0);
3202 
3203   // The expected memory usage is previous memory usage minus the size
3204   // of the two inputs for the multiply operation.
3205   memory_usage -= 2 * input_size;
3206   EXPECT_EQ(cpu_state_0.memory_usage, memory_usage);
3207 
3208   int64_t persistent_memory = node_one_cost + persistent_memory_before;
3209   EXPECT_EQ(scheduler_->GetPersistentMemoryUsage().at(kCPU0),
3210             persistent_memory);
3211 
3212   // Set second node costs to temporary memory.
3213   node_costs = Costs::ZeroCosts(false);
3214   node_costs.persistent_memory = 0;
3215   node_costs.temporary_memory = node_two_cost;
3216   node_costs.output_tensor_size_bytes = {{0, node_two_cost}};
3217 
3218   scheduler_->MarkCurrNodeExecuted(node_costs);
3219   device_states = scheduler_->GetDeviceStates();
3220   const auto& cpu_state_1 = device_states->at(kCPU0);
3221 
3222   // Again we remove the inputs from memory usage.  The output of the previous
3223   // operation is not subtracted because it is set as persistent memory.
3224   memory_usage += node_two_cost - input_size;
3225   EXPECT_EQ(cpu_state_1.memory_usage, memory_usage);
3226   EXPECT_EQ(scheduler_->GetPersistentMemoryUsage().at(kCPU0),
3227             persistent_memory);
3228 
3229   // Finish off the schedule to test if the Summary counts persistent memory
3230   // correctly.
3231   bool more_nodes = true;
3232   do {
3233     OpContext op_context = scheduler_->GetCurrNode();
3234     node_costs = SimplePredictCosts(op_context);
3235     more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
3236   } while (more_nodes);
3237 
3238   RunMetadata metadata;
3239   Costs final_cost = scheduler_->Summary(&metadata);
3240 
3241   EXPECT_EQ(final_cost.persistent_memory, persistent_memory);
3242 
3243   // Since we adjusted the node costs, we expect the requested and allocated
3244   // memory are not equal for nodes "abc" and "abcd" .
3245   StepStats stepstats = metadata.step_stats();
3246   for (const auto& device_step_stats : stepstats.dev_stats()) {
3247     for (const auto& stats : device_step_stats.node_stats()) {
3248       const auto& allocation_description =
3249           stats.output().at(0).tensor_description().allocation_description();
3250       if (stats.node_name() == "abc") {
3251         EXPECT_NE(allocation_description.allocated_bytes(),
3252                   allocation_description.requested_bytes());
3253         const auto& mem_stats = stats.memory_stats();
3254         EXPECT_EQ(mem_stats.persistent_memory_size(), node_one_cost);
3255       } else if (stats.node_name() == "abcd") {
3256         EXPECT_NE(allocation_description.allocated_bytes(),
3257                   allocation_description.requested_bytes());
3258       } else {
3259         EXPECT_EQ(allocation_description.allocated_bytes(),
3260                   allocation_description.requested_bytes());
3261       }
3262     }
3263   }
3264 }
3265 
3266 }  // namespace
3267 }  // end namespace grappler
3268 }  // end namespace tensorflow
3269