1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/tensor_testutil.h"
20 #include "tensorflow/core/grappler/grappler_item.h"
21 #include "tensorflow/core/grappler/utils/grappler_test.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24
25 namespace tensorflow {
26 namespace grappler {
27 namespace {
28
29 class PinToHostOptimizerTest : public GrapplerTest {};
30
TEST_F(PinToHostOptimizerTest,TryFindHostDeviceNoDevices)31 TEST_F(PinToHostOptimizerTest, TryFindHostDeviceNoDevices) {
32 gtl::FlatSet<string> devices = {};
33
34 EXPECT_EQ(internal::TryFindHostDevice(devices, false, "ABC"), "");
35 }
36
TEST_F(PinToHostOptimizerTest,TryFindHostDeviceCpuXlaGpu)37 TEST_F(PinToHostOptimizerTest, TryFindHostDeviceCpuXlaGpu) {
38 gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
39
40 EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
41 EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
42 "/device:CPU:0");
43 EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
44 "/device:CPU:0");
45 }
46
TEST_F(PinToHostOptimizerTest,OptimizeSmallOpsToHost)47 TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
48 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
49 Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
50 Output c = ops::Shape(s.WithOpName("c"), a);
51 Output d = ops::Const(s.WithOpName("d"), 0, {1});
52 Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
53 int num_int32 = 4;
54 Output f = ops::Const(s.WithOpName("f"), {"test"});
55
56 GrapplerItem item;
57 item.fetch = {"a", "c", "d", "e", "f"};
58 TF_CHECK_OK(s.ToGraphDef(&item.graph));
59
60 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
61
62 GraphDef output;
63 PinToHostOptimizer optimizer(RewriterConfig::ON);
64 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
65
66 auto tensors = EvaluateNodes(item.graph, item.fetch);
67 EXPECT_EQ(tensors_expected.size(), tensors.size());
68 for (int i = 0; i < tensors.size(); ++i) {
69 if (i < num_int32) {
70 test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
71 } else {
72 test::ExpectTensorEqual<tstring>(tensors[i], tensors_expected[i]);
73 }
74 }
75
76 int found = 0;
77 for (const NodeDef& node : output.node()) {
78 if (node.name() == "a" || node.name() == "c") {
79 EXPECT_TRUE(node.device().empty());
80 } else if (node.name() == "d" || node.name() == "e" || node.name() == "f") {
81 EXPECT_EQ(node.device(), "/device:CPU:0");
82 }
83 ++found;
84 }
85 EXPECT_EQ(found, 5);
86 }
87
TEST_F(PinToHostOptimizerTest,OptimizeSmallFloatOpsToHost)88 TEST_F(PinToHostOptimizerTest, OptimizeSmallFloatOpsToHost) {
89 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
90 Output a = ops::Const(s.WithOpName("a"), 0.0f, {1024, 1024});
91 Output input_min = ops::Const(s.WithOpName("input_min"), 0.0f);
92 Output input_max = ops::Const(s.WithOpName("input_max"), 6.0f);
93 Output b =
94 ops::QuantizeAndDequantizeV2(s.WithOpName("b"), a, input_min, input_max);
95
96 GrapplerItem item;
97 item.fetch = {"b"};
98 TF_CHECK_OK(s.ToGraphDef(&item.graph));
99
100 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
101
102 GraphDef output;
103 PinToHostOptimizer optimizer(RewriterConfig::ON);
104 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
105
106 auto tensors = EvaluateNodes(item.graph, item.fetch);
107 EXPECT_EQ(tensors_expected.size(), tensors.size());
108 for (int i = 0; i < tensors.size(); ++i) {
109 test::ExpectTensorEqual<float>(tensors[i], tensors_expected[i]);
110 }
111
112 // QuantizeAndDequantizeV2 requires input_min and input_max on CPU, so
113 // pin_to_host_optimizer should pin them to host.
114 for (const NodeDef& node : output.node()) {
115 if (node.name() == "input_min" || node.name() == "input_max") {
116 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
117 EXPECT_EQ(node.device(), "/device:CPU:0");
118 #else
119 EXPECT_TRUE(node.device().empty());
120 #endif
121 }
122 }
123 }
124
TEST_F(PinToHostOptimizerTest,TopologicalSort)125 TEST_F(PinToHostOptimizerTest, TopologicalSort) {
126 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
127 Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
128 Output c = ops::Shape(s.WithOpName("c"), a);
129 Output d = ops::Const(s.WithOpName("d"), 0, {1});
130 Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
131
132 GrapplerItem item;
133 item.fetch = {"a", "c", "d", "e"};
134 TF_CHECK_OK(s.ToGraphDef(&item.graph));
135
136 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
137
138 // Reverse the graph, and hence rely on the optimizer to sort it.
139 std::reverse(item.graph.mutable_node()->begin(),
140 item.graph.mutable_node()->end());
141
142 GraphDef output;
143 PinToHostOptimizer optimizer(RewriterConfig::ON);
144 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
145
146 auto tensors = EvaluateNodes(item.graph, item.fetch);
147 EXPECT_EQ(tensors_expected.size(), tensors.size());
148 for (int i = 0; i < tensors.size(); ++i) {
149 test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
150 }
151
152 int found = 0;
153 for (const NodeDef& node : output.node()) {
154 if (node.name() == "a" || node.name() == "c") {
155 EXPECT_TRUE(node.device().empty());
156 } else if (node.name() == "d" || node.name() == "e") {
157 EXPECT_EQ(node.device(), "/device:CPU:0");
158 }
159 ++found;
160 }
161 EXPECT_EQ(found, 4);
162 }
163
TEST_F(PinToHostOptimizerTest,NoSwap)164 TEST_F(PinToHostOptimizerTest, NoSwap) {
165 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
166 // `b` should be too big to swap, consequently `c` should not be swapped.
167 // PinToHostOptimizer should then detect that `a` should not be swapped.
168 Output a = ops::Const(s.WithOpName("a"), 1, {1, 1});
169 Output b = ops::Const(s.WithOpName("b"), 1, {1, 1024 * 1024});
170 Output c = ops::MatMul(s.WithOpName("c"), a, b);
171
172 GrapplerItem item;
173 item.fetch = {"a", "b", "c"};
174 TF_CHECK_OK(s.ToGraphDef(&item.graph));
175
176 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
177
178 GraphDef output;
179 PinToHostOptimizer optimizer(RewriterConfig::ON);
180 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
181
182 auto tensors = EvaluateNodes(item.graph, item.fetch);
183 EXPECT_EQ(tensors_expected.size(), tensors.size());
184 for (int i = 0; i < tensors.size(); ++i) {
185 test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
186 }
187
188 int found = 0;
189 for (const NodeDef& node : output.node()) {
190 EXPECT_TRUE(node.device().empty());
191 ++found;
192 }
193 EXPECT_EQ(found, 3);
194 }
195
TEST_F(PinToHostOptimizerTest,Identity)196 TEST_F(PinToHostOptimizerTest, Identity) {
197 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
198 // `a,c` is on GPU, `e` is on CPU, consequently `e` should not be swapped.
199 // `b` should be placed onto Host since `c` pins the input to Host memory.
200 Output a =
201 ops::Const(s.WithOpName("a").WithDevice("/device:GPU:0"), 1, {64, 64});
202 Output b = ops::Const(s.WithOpName("b"), {0, 1}, {2});
203 Output c =
204 ops::ReduceProd(s.WithOpName("c").WithDevice("/device:GPU:0"), a, b);
205 Output d = ops::Identity(s.WithDevice("/device:CPU:0").WithOpName("d"), c);
206 Output e = ops::Multiply(s.WithOpName("e"), d, d);
207
208 GrapplerItem item;
209 TF_CHECK_OK(s.ToGraphDef(&item.graph));
210
211 GraphDef output;
212 PinToHostOptimizer optimizer(RewriterConfig::ON);
213 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
214
215 int found = 0;
216 for (const NodeDef& node : output.node()) {
217 if (node.name() == "a" || node.name() == "c") {
218 EXPECT_EQ(node.device(), "/device:GPU:0");
219 } else if (node.name() == "b") {
220 // If CUDA, then there is a GPU kernel registration that is pinned to Host
221 // memory. Consequently, `b` will be mapped to Host correct if there is
222 // a GPU kernel registered.
223 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
224 EXPECT_EQ(node.device(), "/device:CPU:0");
225 #else
226 EXPECT_TRUE(node.device().empty());
227 #endif
228 } else if (node.name() == "d") {
229 EXPECT_EQ(node.device(), "/device:CPU:0");
230 } else if (node.name() == "e") {
231 EXPECT_TRUE(node.device().empty());
232 }
233 ++found;
234 }
235 EXPECT_EQ(found, 5);
236 }
237
TEST_F(PinToHostOptimizerTest,PortIdToArgId)238 TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
239 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
240 Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
241 ops::ShapeN b(s.WithOpName("b"), {a, a, a});
242
243 GrapplerItem item;
244 item.fetch = {"a", "b"};
245 TF_CHECK_OK(s.ToGraphDef(&item.graph));
246
247 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
248
249 GraphDef output;
250 PinToHostOptimizer optimizer(RewriterConfig::ON);
251 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
252
253 auto tensors = EvaluateNodes(item.graph, item.fetch);
254 EXPECT_EQ(tensors_expected.size(), tensors.size());
255 for (int i = 0; i < tensors.size(); ++i) {
256 test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
257 }
258
259 int found = 0;
260 for (const NodeDef& node : output.node()) {
261 EXPECT_EQ(node.device(), "/device:CPU:0");
262 ++found;
263 }
264 EXPECT_EQ(found, 2);
265 }
266
267 } // namespace
268 } // namespace grappler
269 } // namespace tensorflow
270