1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
17
18 #include "absl/memory/memory.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/cc/ops/array_ops.h"
21 #include "tensorflow/cc/ops/const_op.h"
22 #include "tensorflow/cc/ops/nn_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/clusters/single_machine.h"
29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
30 #include "tensorflow/core/grappler/devices.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/utils/graph_view.h"
33 #include "tensorflow/core/grappler/utils/grappler_test.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36
37 namespace tensorflow {
38 namespace grappler {
39
40 using ::tensorflow::Scope;
41 using ::tensorflow::ops::Conv2D;
42 using ::tensorflow::ops::Identity;
43 using ::tensorflow::ops::RandomUniform;
44
45 constexpr int kBatchSize = 32;
46 constexpr int kWidth = 10;
47 constexpr int kHeight = 10;
48 constexpr int kDepthIn = 8;
49 constexpr int kKernel = 3;
50 constexpr int kDepthOut = 16;
51
52 // When there is a GPU, we test generic_layout_optimization for the conversion
53 // from NHWC to NCHW format. When there is only CPU, we test the conversion
54 // from NCHW to NHWC format. The following macros help setting tensor shapes,
55 // source and destination format strings, and transpose permutation vectors
56 // appropriately for NHWC -> NCHW conversion (when GPU) and NCHW -> NHWC
57 // conversion (when only CPU).
58
59 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
60 #define DIMS(n, h, w, c) \
61 { n, h, w, c }
62 #define SRC_DATA_FORMAT "NHWC"
63 #define DST_DATA_FORMAT "NCHW"
64 #define DEVICE "GPU"
65 #define REWRITER_CONFIG \
66 RewriterConfig::DEFAULT, RewriterConfig::NO_CONVERSION_ON_CPU
67 #define PERMUTATION_SRC_TO_DST \
68 { 0, 3, 1, 2 }
69 #define PERMUTATION_DST_TO_SRC \
70 { 0, 2, 3, 1 }
71 #else
72 #define DIMS(n, h, w, c) \
73 { n, c, h, w }
74 #define SRC_DATA_FORMAT "NCHW"
75 #define DST_DATA_FORMAT "NHWC"
76 #define DEVICE "CPU"
77 #define REWRITER_CONFIG RewriterConfig::DEFAULT, RewriterConfig::NCHW_TO_NHWC
78 #define PERMUTATION_SRC_TO_DST \
79 { 0, 2, 3, 1 }
80 #define PERMUTATION_DST_TO_SRC \
81 { 0, 3, 1, 2 }
82 #endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
83
84 template <typename T = float>
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,const string & device)85 Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
86 const string& padding, const string& device) {
87 int batch_size = 8;
88 int input_height = input_size;
89 int input_width = input_size;
90 int input_depth = 3;
91 int filter_count = 2;
92 int stride = 1;
93 TensorShape input_shape(
94 DIMS(batch_size, input_height, input_width, input_depth));
95 Tensor input_data(DataTypeToEnum<T>::value, input_shape);
96 test::FillIota<T>(&input_data, static_cast<T>(1));
97 Output input =
98 ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
99
100 TensorShape filter_shape(
101 {filter_size, filter_size, input_depth, filter_count});
102 Tensor filter_data(DataTypeToEnum<T>::value, filter_shape);
103 test::FillIota<T>(&filter_data, static_cast<T>(1));
104 Output filter =
105 ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
106
107 Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
108 filter, DIMS(1, stride, stride, 1), padding,
109 ops::Conv2D::Attrs().DataFormat(SRC_DATA_FORMAT));
110 return conv;
111 }
112
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,bool dilated,const int input_sizes_length)113 Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
114 int filter_size, const string& padding,
115 bool dilated, const int input_sizes_length) {
116 int batch_size = 128;
117 int input_height = input_size;
118 int input_width = input_size;
119 int input_depth = 3;
120 int filter_count = 2;
121 int stride = 1;
122 TensorShape input_sizes_shape({input_sizes_length});
123 Tensor input_data(DT_INT32, input_sizes_shape);
124 if (input_sizes_length == 4) {
125 test::FillValues<int>(
126 &input_data, DIMS(batch_size, input_height, input_width, input_depth));
127 } else {
128 test::FillValues<int>(&input_data, {input_height, input_width});
129 }
130 Output input_sizes =
131 ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
132
133 TensorShape filter_shape(
134 {filter_size, filter_size, input_depth, filter_count});
135 Output filter =
136 ops::Variable(s->WithOpName("Filter"), filter_shape, DT_FLOAT);
137
138 int output_height = input_height;
139 int output_width = input_width;
140 TensorShape output_shape(
141 DIMS(batch_size, output_height, output_width, filter_count));
142 Tensor output_data(DT_FLOAT, output_shape);
143 test::FillIota<float>(&output_data, 1.0f);
144 Output output =
145 ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
146
147 Output conv_backprop_input;
148 Output input_sizes_i =
149 ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
150 ops::Conv2DBackpropInput::Attrs attrs;
151 attrs = attrs.DataFormat(SRC_DATA_FORMAT);
152 if (dilated) {
153 attrs = attrs.Dilations(DIMS(1, 2, 2, 1));
154 }
155 conv_backprop_input = ops::Conv2DBackpropInput(
156 s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
157 DIMS(1, stride, stride, 1), padding, attrs);
158
159 return conv_backprop_input;
160 }
161
162 class GenericLayoutOptimizerTest : public GrapplerTest {
163 protected:
SetUp()164 void SetUp() override {
165 bool gpu_available = GetNumAvailableGPUs() > 0;
166
167 if (gpu_available) {
168 virtual_cluster_ =
169 std::make_unique<SingleMachine>(/*timeout_s=*/10, 1, 1);
170 } else {
171 DeviceProperties cpu_device;
172 cpu_device.set_type("CPU");
173 cpu_device.set_frequency(1000);
174 cpu_device.set_num_cores(4);
175 cpu_device.set_bandwidth(32);
176 cpu_device.set_l1_cache_size(32 * 1024);
177 cpu_device.set_l2_cache_size(256 * 1024);
178 cpu_device.set_l3_cache_size(4 * 1024 * 1024);
179 cpu_device.set_memory_size(1024 * 1024);
180 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
181 DeviceProperties gpu_device;
182 gpu_device.set_type("GPU");
183 gpu_device.mutable_environment()->insert({"architecture", "6"});
184 virtual_cluster_ =
185 absl::WrapUnique(new VirtualCluster({{"/CPU:0", cpu_device},
186 { "/GPU:1",
187 gpu_device }}));
188 #else
189 virtual_cluster_ =
190 absl::WrapUnique(new VirtualCluster({{"/CPU:0", cpu_device}}));
191 #endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
192 }
193 TF_ASSERT_OK(virtual_cluster_->Provision());
194 }
195
TearDown()196 void TearDown() override { TF_ASSERT_OK(virtual_cluster_->Shutdown()); }
197
198 std::unique_ptr<Cluster> virtual_cluster_;
199 };
200
VerifyRegularFaninMatch(const utils::NodeView * node,int port,absl::string_view fanin_name,int fanin_port)201 void VerifyRegularFaninMatch(const utils::NodeView* node, int port,
202 absl::string_view fanin_name, int fanin_port) {
203 ASSERT_GE(node->NumRegularFanins(), port);
204 const auto& fanin = node->GetRegularFanin(port);
205 EXPECT_EQ(fanin.node_view()->GetName(), fanin_name);
206 EXPECT_EQ(fanin.index(), fanin_port);
207 }
208
VerifyRegularFanoutMatch(const utils::NodeView * node,int port,absl::string_view fanout_name,int fanout_port)209 void VerifyRegularFanoutMatch(const utils::NodeView* node, int port,
210 absl::string_view fanout_name, int fanout_port) {
211 bool found = false;
212 for (const auto& regular_fanout : node->GetRegularFanout(port)) {
213 if (regular_fanout.node_view()->GetName() == fanout_name &&
214 regular_fanout.index() == fanout_port) {
215 found = true;
216 }
217 }
218 EXPECT_TRUE(found);
219 }
220
VerifyDataFormatAttributeMatch(const utils::NodeView * node,absl::string_view attr_value)221 void VerifyDataFormatAttributeMatch(const utils::NodeView* node,
222 absl::string_view attr_value) {
223 const auto* attr = node->GetAttr("data_format");
224 ASSERT_NE(attr, nullptr);
225 EXPECT_EQ(attr->s(), attr_value);
226 }
227
TEST_F(GenericLayoutOptimizerTest,OptimizeSimpleConv2DGraph)228 TEST_F(GenericLayoutOptimizerTest, OptimizeSimpleConv2DGraph) {
229 // A simple graph contains 1 Conv2D node, 2 input and 1 output nodes.
230 // Data format is NHWC on GPU, while NCHW on CPU.
231 Scope scope = Scope::NewRootScope();
232
233 auto conv2d = SimpleConv2D(&scope, 4, 2, "VALID", "");
234 auto identity = Identity(scope.WithOpName("Output"), conv2d);
235 GrapplerItem item;
236 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
237
238 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
239 GraphDef output;
240 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
241
242 Status status;
243 utils::GraphView graph_view(&output, &status);
244 TF_ASSERT_OK(status);
245
246 auto* conv2d_node = graph_view.GetNode("Conv2D");
247 ASSERT_NE(conv2d_node, nullptr);
248 ASSERT_EQ(conv2d_node->NumRegularFanins(), 2);
249 VerifyRegularFaninMatch(conv2d_node, 1, "Filter", 0);
250 VerifyDataFormatAttributeMatch(conv2d_node, SRC_DATA_FORMAT);
251
252 auto* output_node = graph_view.GetNode("Output");
253 ASSERT_NE(output_node, nullptr);
254 ASSERT_EQ(output_node->NumRegularFanins(), 1);
255 }
256
TEST_F(GenericLayoutOptimizerTest,PreserveFetch)257 TEST_F(GenericLayoutOptimizerTest, PreserveFetch) {
258 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
259 auto conv = SimpleConv2D(&s, 4, 2, "VALID", "");
260 auto i = ops::Identity(s.WithOpName("i"), conv);
261 GrapplerItem item;
262 item.fetch.push_back("Conv2D");
263 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
264
265 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
266 GraphDef output;
267 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
268
269 Status status;
270 utils::GraphView graph_view(&output, &status);
271 TF_ASSERT_OK(status);
272 auto* conv_node = graph_view.GetNode("Conv2D");
273 ASSERT_NE(conv_node, nullptr);
274 VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
275 }
276
TEST_F(GenericLayoutOptimizerTest,EmptyDevice)277 TEST_F(GenericLayoutOptimizerTest, EmptyDevice) {
278 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
279 auto conv = SimpleConv2D(&s, 4, 2, "VALID", "");
280 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
281 GrapplerItem item;
282 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
283
284 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
285 GraphDef output;
286 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
287
288 Status status;
289 utils::GraphView graph_view(&output, &status);
290 TF_ASSERT_OK(status);
291 auto* conv_node = graph_view.GetNode("Conv2D");
292 ASSERT_NE(conv_node, nullptr);
293 VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
294 }
295
TEST_F(GenericLayoutOptimizerTest,GPUDevice)296 TEST_F(GenericLayoutOptimizerTest, GPUDevice) {
297 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
298 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
299 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
300 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
301 auto conv =
302 SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:GPU:0");
303 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
304 GrapplerItem item;
305 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
306
307 GenericLayoutOptimizer optimizer;
308 GraphDef output;
309 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
310
311 Status status;
312 utils::GraphView graph_view(&output, &status);
313 TF_ASSERT_OK(status);
314 auto* conv_node = graph_view.GetNode("Conv2D");
315 ASSERT_NE(conv_node, nullptr);
316 VerifyDataFormatAttributeMatch(conv_node, "NCHW");
317 }
318
TEST_F(GenericLayoutOptimizerTest,CPUDevice)319 TEST_F(GenericLayoutOptimizerTest, CPUDevice) {
320 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
321 auto conv = SimpleConv2D(&s, 4, 2, "VALID", "/CPU:0");
322 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
323 GrapplerItem item;
324 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
325
326 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
327 GraphDef output;
328 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
329
330 Status status;
331 utils::GraphView graph_view(&output, &status);
332 TF_ASSERT_OK(status);
333 auto* conv_node = graph_view.GetNode("Conv2D");
334 ASSERT_NE(conv_node, nullptr);
335 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
336 VerifyDataFormatAttributeMatch(conv_node, "NHWC");
337 #else
338 VerifyDataFormatAttributeMatch(conv_node, DST_DATA_FORMAT);
339 #endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
340 }
341
TEST_F(GenericLayoutOptimizerTest,NoOptimizeIntegerConvolution)342 TEST_F(GenericLayoutOptimizerTest, NoOptimizeIntegerConvolution) {
343 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
344 auto conv = SimpleConv2D<int32>(&s, 4, 2, "VALID", "");
345 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
346 GrapplerItem item;
347 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
348
349 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
350 GraphDef output;
351 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
352
353 Status status;
354 utils::GraphView graph_view(&output, &status);
355 TF_ASSERT_OK(status);
356 auto* conv_node = graph_view.GetNode("Conv2D");
357 ASSERT_NE(conv_node, nullptr);
358 VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
359 }
360
TEST_F(GenericLayoutOptimizerTest,Connectivity)361 TEST_F(GenericLayoutOptimizerTest, Connectivity) {
362 Scope scope = Scope::NewRootScope();
363 auto conv = SimpleConv2D(&scope, 4, 2, "VALID",
364 absl::StrCat("/device:", DEVICE, ":0"));
365 auto i1 = ops::Identity(scope.WithOpName("i1"), conv);
366 auto i2 = ops::Identity(scope.WithOpName("i2"), i1);
367 auto i3 = ops::Identity(scope.WithOpName("i3"), i2);
368 GrapplerItem item;
369 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
370 // Make the graph not in topological order to test the handling of multi-hop
371 // connectivity (here we say two nodes are connected if all nodes in the
372 // middle are layout agnostic). If the graph is already in topological order,
373 // the problem is easier, where layout optimizer only needs to check
374 // single-hop connectivity.
375 Status status;
376 utils::GraphView graph_view_original(&item.graph, &status);
377 const int i1_index = graph_view_original.GetNode("i1")->node_index();
378 const int i2_index = graph_view_original.GetNode("i2")->node_index();
379 item.graph.mutable_node()->SwapElements(i1_index, i2_index);
380
381 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
382 GraphDef output;
383 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
384
385 utils::GraphView graph_view(&output, &status);
386 TF_ASSERT_OK(status);
387 auto* node_i2_output = graph_view.GetNode("i2");
388 ASSERT_NE(node_i2_output, nullptr);
389 // Layout optimizer should process i2, as it detects i2 is connected with the
390 // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is
391 // directly connected to the Conv2D node.
392 ASSERT_EQ(node_i2_output->NumRegularFanins(), 1);
393 VerifyRegularFaninMatch(node_i2_output, 0, "i1", 0);
394 }
395
TEST_F(GenericLayoutOptimizerTest,Conv2DBackpropInputNonConstInputSizes)396 TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
397 for (const int input_sizes_length : {2, 4}) {
398 Scope s = Scope::NewRootScope();
399 auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
400 input_sizes_length);
401 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
402 GrapplerItem item;
403 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
404
405 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
406 GraphDef output;
407 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
408
409 Status status;
410 utils::GraphView graph_view(&output, &status);
411 TF_ASSERT_OK(status);
412 auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
413 ASSERT_NE(conv2d_backprop_node, nullptr);
414 ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
415 VerifyRegularFaninMatch(conv2d_backprop_node, 0, "InputSizesIdentity", 0);
416 }
417 }
418
TEST_F(GenericLayoutOptimizerTest,Conv2DDataFormatVecPermuteCollapse)419 TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) {
420 Scope scope =
421 Scope::NewRootScope().WithDevice(absl::StrCat("/device:", DEVICE, ":0"));
422 auto conv = SimpleConv2D(&scope, 4, 2, "VALID",
423 absl::StrCat("/device:", DEVICE, ":0"));
424 auto shape = ops::Shape(scope.WithOpName("shape"), conv);
425 auto value = ops::Const(scope.WithOpName("value"), 0, {});
426 auto fill = ops::Fill(scope.WithOpName("fill"), shape, value);
427 auto i = ops::Identity(scope.WithOpName("i"), fill);
428 GrapplerItem item;
429 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
430
431 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
432 GraphDef output;
433 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
434
435 // Graph before optimization:
436 // input -> conv2d -> shape -> fill -> output
437 //
438 // Graph after expansion:
439 // input -> T -> conv2d -> T' -> T -> shape -> D' -> D -> fill -> T' -> output
440 //
441 // Graph after collapsion:
442 // input -> T -> conv2d -> shape -> fill -> T' -> output
443 Status status;
444 utils::GraphView graph_view(&output, &status);
445 TF_ASSERT_OK(status);
446 auto* conv2d_node = graph_view.GetNode("Conv2D");
447 ASSERT_NE(conv2d_node, nullptr);
448 ASSERT_EQ(conv2d_node->NumRegularFanins(), 2);
449 VerifyRegularFaninMatch(
450 conv2d_node, 0,
451 absl::StrCat("Conv2D-0-Transpose", SRC_DATA_FORMAT, "To", DST_DATA_FORMAT,
452 "-LayoutOptimizer"),
453 0);
454
455 auto* shape_node = graph_view.GetNode("shape");
456 ASSERT_NE(shape_node, nullptr);
457 ASSERT_EQ(shape_node->NumRegularFanins(), 1);
458 VerifyRegularFaninMatch(shape_node, 0, conv2d_node->GetName(), 0);
459
460 auto* fill_node = graph_view.GetNode("fill");
461 ASSERT_NE(fill_node, nullptr);
462 ASSERT_EQ(fill_node->NumRegularFanins(), 2);
463 VerifyRegularFaninMatch(fill_node, 0, shape_node->GetName(), 0);
464 VerifyRegularFanoutMatch(
465 fill_node, 0,
466 absl::StrCat("fill-0-0-Transpose", DST_DATA_FORMAT, "To", SRC_DATA_FORMAT,
467 "-LayoutOptimizer"),
468 0);
469
470 auto* graph_output = graph_view.GetNode("i");
471 ASSERT_NE(graph_output, nullptr);
472 ASSERT_EQ(graph_output->NumRegularFanins(), 1);
473 VerifyRegularFaninMatch(
474 graph_output, 0,
475 absl::StrCat("fill-0-0-Transpose", DST_DATA_FORMAT, "To", SRC_DATA_FORMAT,
476 "-LayoutOptimizer"),
477 0);
478 }
479
TEST_F(GenericLayoutOptimizerTest,DoNotPruneNonAddedCancellableTransposes)480 TEST_F(GenericLayoutOptimizerTest, DoNotPruneNonAddedCancellableTransposes) {
481 GrapplerItem item;
482 {
483 Scope scope = Scope::NewRootScope().WithDevice(
484 absl::StrCat("/device:", DEVICE, ":0"));
485 auto input = ops::RandomUniform(scope.WithOpName("input"),
486 DIMS(kBatchSize, kHeight, kWidth, kDepthIn),
487 DT_FLOAT);
488 // Permutation for source to destination data format.
489 // GPU: NHWC -> NCHW: {0, 3, 1, 2}
490 // CPU: NCHW -> NHWC: {0, 2, 3, 1}
491 auto input_in_transpose =
492 ops::Transpose(scope.WithOpName("input_in_transpose"), input,
493 ops::Const(scope, PERMUTATION_SRC_TO_DST, {4}));
494 // Permutation for destination to source data format.
495 // GPU: NCHW -> NHWC: {0, 2, 3, 1}
496 // CPU: NHWC -> NCHW: {0, 3, 1, 2}
497 auto input_out_transpose = ops::Transpose(
498 scope.WithOpName("input_out_transpose"), input_in_transpose,
499 ops::Const(scope, PERMUTATION_DST_TO_SRC, {4}));
500 Tensor bias_data(DT_FLOAT, TensorShape({kDepthIn}));
501 test::FillIota<float>(&bias_data, 1.0f);
502 auto bias_add = ops::BiasAdd(
503 scope.WithOpName("bias_add"), input_out_transpose, bias_data,
504 ops::BiasAdd::Attrs().DataFormat(SRC_DATA_FORMAT));
505 auto output_in_transpose =
506 ops::Transpose(scope.WithOpName("output_in_transpose"), bias_add,
507 ops::Const(scope, PERMUTATION_SRC_TO_DST, {4}));
508 auto output_out_transpose = ops::Transpose(
509 scope.WithOpName("output_out_transpose"), output_in_transpose,
510 ops::Const(scope, PERMUTATION_DST_TO_SRC, {4}));
511 auto output =
512 ops::Identity(scope.WithOpName("output"), output_out_transpose);
513 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
514 }
515
516 GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
517 GraphDef output;
518 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
519
520 Status status;
521 utils::GraphView graph_view(&output, &status);
522 TF_ASSERT_OK(status);
523
524 auto* input_node = graph_view.GetNode("input");
525 ASSERT_NE(input_node, nullptr);
526
527 auto* input_in_transpose_node = graph_view.GetNode("input_in_transpose");
528 ASSERT_NE(input_in_transpose_node, nullptr);
529 ASSERT_EQ(input_in_transpose_node->NumRegularFanins(), 2);
530 VerifyRegularFaninMatch(input_in_transpose_node, 0, input_node->GetName(), 0);
531
532 auto* input_out_transpose_node = graph_view.GetNode("input_out_transpose");
533 ASSERT_NE(input_out_transpose_node, nullptr);
534 ASSERT_EQ(input_out_transpose_node->NumRegularFanins(), 2);
535 VerifyRegularFaninMatch(input_out_transpose_node, 0,
536 input_in_transpose_node->GetName(), 0);
537
538 auto* bias_add_in_transpose_node = graph_view.GetNode(
539 absl::StrCat("bias_add-0-Transpose", SRC_DATA_FORMAT, "To",
540 DST_DATA_FORMAT, "-LayoutOptimizer"));
541 ASSERT_NE(bias_add_in_transpose_node, nullptr);
542 ASSERT_EQ(bias_add_in_transpose_node->NumRegularFanins(), 2);
543 VerifyRegularFaninMatch(bias_add_in_transpose_node, 0,
544 input_out_transpose_node->GetName(), 0);
545
546 auto* bias_add_node = graph_view.GetNode("bias_add");
547 ASSERT_NE(bias_add_node, nullptr);
548 ASSERT_EQ(bias_add_node->NumRegularFanins(), 2);
549 VerifyRegularFaninMatch(bias_add_node, 0,
550 bias_add_in_transpose_node->GetName(), 0);
551
552 auto* bias_add_out_transpose_node = graph_view.GetNode(
553 absl::StrCat("bias_add-0-0-Transpose", DST_DATA_FORMAT, "To",
554 SRC_DATA_FORMAT, "-LayoutOptimizer"));
555 ASSERT_NE(bias_add_out_transpose_node, nullptr);
556 ASSERT_EQ(bias_add_out_transpose_node->NumRegularFanins(), 2);
557 VerifyRegularFaninMatch(bias_add_out_transpose_node, 0,
558 bias_add_node->GetName(), 0);
559
560 auto* output_in_transpose_node = graph_view.GetNode("output_in_transpose");
561 ASSERT_NE(output_in_transpose_node, nullptr);
562 ASSERT_EQ(output_in_transpose_node->NumRegularFanins(), 2);
563 VerifyRegularFaninMatch(output_in_transpose_node, 0,
564 bias_add_out_transpose_node->GetName(), 0);
565
566 auto* output_out_transpose_node = graph_view.GetNode("output_out_transpose");
567 ASSERT_NE(output_out_transpose_node, nullptr);
568 ASSERT_EQ(output_out_transpose_node->NumRegularFanins(), 2);
569 VerifyRegularFaninMatch(output_out_transpose_node, 0,
570 output_in_transpose_node->GetName(), 0);
571
572 auto* output_node = graph_view.GetNode("output");
573 ASSERT_NE(output_node, nullptr);
574 ASSERT_EQ(output_node->NumRegularFanins(), 1);
575 VerifyRegularFaninMatch(output_node, 0, output_out_transpose_node->GetName(),
576 0);
577 }
578
TEST_F(GenericLayoutOptimizerTest,CancelTransposeAroundPad)579 TEST_F(GenericLayoutOptimizerTest, CancelTransposeAroundPad) {
580 using test::function::NDef;
581
582 GenericLayoutOptimizer optimizer(
583 RewriterConfig::AGGRESSIVE,
584 RewriterConfig::NCHW_TO_NHWC /* CPU settings*/);
585
586 const Tensor kPermuteNhwcToNchw = test::AsTensor<int32>({0, 3, 1, 2});
587 const Tensor kPermuteNchwToNhwc = test::AsTensor<int32>({0, 2, 3, 1});
588 const Tensor kPad = test::AsTensor<int32>({1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
589
590 GrapplerItem item;
591 item.graph = test::function::GDef({
592 NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}),
593
594 NDef("paddings", "Const", {}, {{"dtype", DT_INT32}, {"value", kPad}}),
595 NDef("perm_nhwc_to_nchw", "Const", {},
596 {{"dtype", DT_INT32}, {"value", kPermuteNhwcToNchw}}),
597 NDef("perm_nchw_to_nhwc", "Const", {},
598 {{"dtype", DT_INT32}, {"value", kPermuteNchwToNhwc}}),
599
600 NDef("transpose_0", "Transpose", {"x", "perm_nhwc_to_nchw"},
601 {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
602 NDef("pad", "Pad", {"transpose_0", "paddings"},
603 {{"T", DT_FLOAT}, {"Tpaddings", DT_INT32}}),
604 NDef("transpose_1", "Transpose", {"pad", "perm_nchw_to_nhwc"},
605 {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
606 NDef("transpose_2", "Transpose", {"pad", "perm_nchw_to_nhwc"},
607 {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
608 });
609
610 GraphDef output;
611 TF_EXPECT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
612
613 const Tensor kPermutedPaddings =
614 test::AsTensor<int32>({1, 2, 5, 6, 7, 8, 3, 4}, {4, 2});
615
616 GraphDef expected = test::function::GDef({
617 NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}),
618
619 NDef("paddings", "Const", {},
620 {{"dtype", DT_INT32}, {"value", kPermutedPaddings}}),
621 NDef("perm_nhwc_to_nchw", "Const", {},
622 {{"dtype", DT_INT32}, {"value", kPermuteNhwcToNchw}}),
623 NDef("perm_nchw_to_nhwc", "Const", {},
624 {{"dtype", DT_INT32}, {"value", kPermuteNchwToNhwc}}),
625
626 // Transpose nodes replaced by Identity nodes.
627 NDef("transpose_0", "Identity", {"x"}, {{"T", DT_FLOAT}}),
628 NDef("pad", "Pad", {"transpose_0", "paddings"},
629 {{"T", DT_FLOAT}, {"Tpaddings", DT_INT32}}),
630 NDef("transpose_1", "Identity", {"pad"}, {{"T", DT_FLOAT}}),
631 NDef("transpose_2", "Identity", {"pad"}, {{"T", DT_FLOAT}}),
632 });
633
634 CompareGraphs(expected, output);
635
636 Tensor x = GenerateRandomTensor<DT_FLOAT>({2, 6, 6, 8});
637 item.fetch = {"transpose_1", "transpose_2"};
638 item.feed.emplace_back("x", x);
639 auto tensors_expected = EvaluateFetchNodes(item);
640 GrapplerItem optimized = item.WithGraph(std::move(output));
641 auto tensors = EvaluateFetchNodes(optimized);
642 ASSERT_EQ(tensors.size(), 2);
643 ASSERT_EQ(tensors_expected.size(), 2);
644 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
645 test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
646 }
647
TEST_F(GenericLayoutOptimizerTest,PreserveInputShapes)648 TEST_F(GenericLayoutOptimizerTest, PreserveInputShapes) {
649 using test::function::NDef;
650
651 GenericLayoutOptimizer optimizer(RewriterConfig::AGGRESSIVE);
652
653 AttrValue output_shapes;
654 auto* shape = output_shapes.mutable_list()->add_shape();
655 shape->add_dim()->set_size(-1);
656
657 GrapplerItem item;
658 item.graph = test::function::GDef({NDef(
659 "x", "_Arg", {},
660 {{"T", DT_FLOAT}, {"index", 0}, {"_output_shapes", output_shapes}})});
661 item.feed.emplace_back("x", Tensor(DT_FLOAT));
662
663 GraphDef output;
664 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
665
666 Status status;
667 utils::GraphView graph_view(&output, &status);
668 TF_ASSERT_OK(status);
669
670 auto* arg = graph_view.GetNode("x");
671 ASSERT_NE(arg, nullptr);
672 EXPECT_TRUE(arg->HasAttr("_output_shapes"));
673 EXPECT_EQ(arg->GetAttr("_output_shapes")->DebugString(),
674 output_shapes.DebugString());
675 }
676
677 // TODO(yanzha): Add more complex Graph for test.
678
679 } // namespace grappler
680 } // namespace tensorflow
681