1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28 namespace tensorflow {
29 namespace graph_transforms {
30
31 class SortByExecutionOrderTest : public ::testing::Test {
32 protected:
GetOrder(const GraphDef & graph_def,std::map<string,int> * order)33 void GetOrder(const GraphDef& graph_def, std::map<string, int>* order) {
34 for (int i = 0; i < graph_def.node_size(); ++i) {
35 const NodeDef& node = graph_def.node(i);
36 (*order)[node.name()] = i;
37 }
38 }
39
TestSimpleAdd()40 void TestSimpleAdd() {
41 GraphDef graph_def;
42 NodeDef* add_node = graph_def.add_node();
43 add_node->set_name("add_node");
44 add_node->set_op("Add");
45 add_node->add_input("a_node");
46 add_node->add_input("b_node");
47
48 NodeDef* b_node = graph_def.add_node();
49 b_node->set_name("b_node");
50 b_node->set_op("Const");
51
52 NodeDef* a_node = graph_def.add_node();
53 a_node->set_name("a_node");
54 a_node->set_op("Const");
55
56 GraphDef result;
57 TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
58
59 std::map<string, int> order;
60 GetOrder(result, &order);
61 EXPECT_EQ(2, order["add_node"]);
62 EXPECT_GT(2, order["a_node"]);
63 EXPECT_GT(2, order["b_node"]);
64 }
65
TestSimpleLinear()66 void TestSimpleLinear() {
67 GraphDef graph_def;
68
69 NodeDef* negative_node = graph_def.add_node();
70 negative_node->set_name("negative_node");
71 negative_node->set_op("Negative");
72 negative_node->add_input("sqrt_node");
73
74 NodeDef* relu_node = graph_def.add_node();
75 relu_node->set_name("relu_node");
76 relu_node->set_op("Relu");
77 relu_node->add_input("const_node");
78
79 NodeDef* sqrt_node = graph_def.add_node();
80 sqrt_node->set_name("sqrt_node");
81 sqrt_node->set_op("Sqrt");
82 sqrt_node->add_input("relu_node");
83
84 NodeDef* const_node = graph_def.add_node();
85 const_node->set_name("const_node");
86 const_node->set_op("Const");
87
88 GraphDef result;
89 TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
90
91 std::map<string, int> order;
92 GetOrder(result, &order);
93 EXPECT_EQ(3, order["negative_node"]);
94 EXPECT_EQ(2, order["sqrt_node"]);
95 EXPECT_EQ(1, order["relu_node"]);
96 EXPECT_EQ(0, order["const_node"]);
97 }
98
TestSimpleTree()99 void TestSimpleTree() {
100 GraphDef graph_def;
101
102 NodeDef* add_node1 = graph_def.add_node();
103 add_node1->set_name("add_node1");
104 add_node1->set_op("Add");
105 add_node1->add_input("add_node2");
106 add_node1->add_input("add_node3");
107
108 NodeDef* add_node2 = graph_def.add_node();
109 add_node2->set_name("add_node2");
110 add_node2->set_op("Add");
111 add_node2->add_input("const_node1");
112 add_node2->add_input("const_node2");
113
114 NodeDef* add_node3 = graph_def.add_node();
115 add_node3->set_name("add_node3");
116 add_node3->set_op("Add");
117 add_node3->add_input("const_node3");
118 add_node3->add_input("const_node4");
119
120 NodeDef* const_node1 = graph_def.add_node();
121 const_node1->set_name("const_node1");
122 const_node1->set_op("Const");
123
124 NodeDef* const_node2 = graph_def.add_node();
125 const_node2->set_name("const_node2");
126 const_node2->set_op("Const");
127
128 NodeDef* const_node3 = graph_def.add_node();
129 const_node3->set_name("const_node3");
130 const_node3->set_op("Const");
131
132 NodeDef* const_node4 = graph_def.add_node();
133 const_node4->set_name("const_node4");
134 const_node4->set_op("Const");
135
136 GraphDef result;
137 TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
138
139 std::map<string, int> order;
140 GetOrder(result, &order);
141 EXPECT_EQ(6, order["add_node1"]);
142 EXPECT_GT(6, order["add_node2"]);
143 EXPECT_GT(6, order["add_node3"]);
144 EXPECT_GT(5, order["const_node1"]);
145 EXPECT_GT(5, order["const_node2"]);
146 EXPECT_GT(5, order["const_node3"]);
147 EXPECT_GT(5, order["const_node4"]);
148 }
149
TestCommonAncestor()150 void TestCommonAncestor() {
151 GraphDef graph_def;
152
153 NodeDef* add_node1 = graph_def.add_node();
154 add_node1->set_name("add_node1");
155 add_node1->set_op("Add");
156 add_node1->add_input("add_node2");
157 add_node1->add_input("add_node3");
158
159 NodeDef* add_node2 = graph_def.add_node();
160 add_node2->set_name("add_node2");
161 add_node2->set_op("Add");
162 add_node2->add_input("const_node1");
163 add_node2->add_input("const_node2");
164
165 NodeDef* add_node3 = graph_def.add_node();
166 add_node3->set_name("add_node3");
167 add_node3->set_op("Add");
168 add_node3->add_input("const_node1");
169 add_node3->add_input("const_node3");
170
171 NodeDef* const_node1 = graph_def.add_node();
172 const_node1->set_name("const_node1");
173 const_node1->set_op("Const");
174
175 NodeDef* const_node2 = graph_def.add_node();
176 const_node2->set_name("const_node2");
177 const_node2->set_op("Const");
178
179 NodeDef* const_node3 = graph_def.add_node();
180 const_node3->set_name("const_node3");
181 const_node3->set_op("Const");
182
183 GraphDef result;
184 TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
185
186 std::map<string, int> order;
187 GetOrder(result, &order);
188 EXPECT_EQ(5, order["add_node1"]);
189 EXPECT_GT(5, order["add_node2"]);
190 EXPECT_GT(5, order["add_node3"]);
191 EXPECT_GT(4, order["const_node2"]);
192 EXPECT_GT(4, order["const_node3"]);
193 EXPECT_GT(3, order["const_node1"]);
194 }
195 };
196
TEST_F(SortByExecutionOrderTest,TestSimpleAdd)197 TEST_F(SortByExecutionOrderTest, TestSimpleAdd) { TestSimpleAdd(); }
198
TEST_F(SortByExecutionOrderTest,TestSimpleLinear)199 TEST_F(SortByExecutionOrderTest, TestSimpleLinear) { TestSimpleLinear(); }
200
TEST_F(SortByExecutionOrderTest,TestSimpleTree)201 TEST_F(SortByExecutionOrderTest, TestSimpleTree) { TestSimpleTree(); }
202
TEST_F(SortByExecutionOrderTest,TestCommonAncestor)203 TEST_F(SortByExecutionOrderTest, TestCommonAncestor) { TestCommonAncestor(); }
204
205 } // namespace graph_transforms
206 } // namespace tensorflow
207