1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
17 
18 #include "absl/memory/memory.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/cc/ops/array_ops.h"
21 #include "tensorflow/cc/ops/const_op.h"
22 #include "tensorflow/cc/ops/nn_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/clusters/single_machine.h"
29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
30 #include "tensorflow/core/grappler/devices.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/utils/graph_view.h"
33 #include "tensorflow/core/grappler/utils/grappler_test.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36 
37 namespace tensorflow {
38 namespace grappler {
39 
40 using ::tensorflow::Scope;
41 using ::tensorflow::ops::Conv2D;
42 using ::tensorflow::ops::Identity;
43 using ::tensorflow::ops::RandomUniform;
44 
45 constexpr int kBatchSize = 32;
46 constexpr int kWidth = 10;
47 constexpr int kHeight = 10;
48 constexpr int kDepthIn = 8;
49 constexpr int kKernel = 3;
50 constexpr int kDepthOut = 16;
51 
52 // When there is a GPU, we test generic_layout_optimization for the conversion
53 // from NHWC to NCHW format. When there is only CPU, we test the conversion
54 // from NCHW to NHWC format. The following macros help setting tensor shapes,
55 // source and destination format strings, and transpose permutation vectors
56 // appropriately for NHWC -> NCHW conversion (when GPU) and NCHW -> NHWC
57 // conversion (when only CPU).
58 
59 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
60 #define DIMS(n, h, w, c) \
61   { n, h, w, c }
62 #define SRC_DATA_FORMAT "NHWC"
63 #define DST_DATA_FORMAT "NCHW"
64 #define DEVICE "GPU"
65 #define REWRITER_CONFIG \
66   RewriterConfig::DEFAULT, RewriterConfig::NO_CONVERSION_ON_CPU
67 #define PERMUTATION_SRC_TO_DST \
68   { 0, 3, 1, 2 }
69 #define PERMUTATION_DST_TO_SRC \
70   { 0, 2, 3, 1 }
71 #else
72 #define DIMS(n, h, w, c) \
73   { n, c, h, w }
74 #define SRC_DATA_FORMAT "NCHW"
75 #define DST_DATA_FORMAT "NHWC"
76 #define DEVICE "CPU"
77 #define REWRITER_CONFIG RewriterConfig::DEFAULT, RewriterConfig::NCHW_TO_NHWC
78 #define PERMUTATION_SRC_TO_DST \
79   { 0, 2, 3, 1 }
80 #define PERMUTATION_DST_TO_SRC \
81   { 0, 3, 1, 2 }
82 #endif  // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
83 
84 template <typename T = float>
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,const string & device)85 Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
86                     const string& padding, const string& device) {
87   int batch_size = 8;
88   int input_height = input_size;
89   int input_width = input_size;
90   int input_depth = 3;
91   int filter_count = 2;
92   int stride = 1;
93   TensorShape input_shape(
94       DIMS(batch_size, input_height, input_width, input_depth));
95   Tensor input_data(DataTypeToEnum<T>::value, input_shape);
96   test::FillIota<T>(&input_data, static_cast<T>(1));
97   Output input =
98       ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
99 
100   TensorShape filter_shape(
101       {filter_size, filter_size, input_depth, filter_count});
102   Tensor filter_data(DataTypeToEnum<T>::value, filter_shape);
103   test::FillIota<T>(&filter_data, static_cast<T>(1));
104   Output filter =
105       ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
106 
107   Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
108                             filter, DIMS(1, stride, stride, 1), padding,
109                             ops::Conv2D::Attrs().DataFormat(SRC_DATA_FORMAT));
110   return conv;
111 }
112 
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,bool dilated,const int input_sizes_length)113 Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
114                                  int filter_size, const string& padding,
115                                  bool dilated, const int input_sizes_length) {
116   int batch_size = 128;
117   int input_height = input_size;
118   int input_width = input_size;
119   int input_depth = 3;
120   int filter_count = 2;
121   int stride = 1;
122   TensorShape input_sizes_shape({input_sizes_length});
123   Tensor input_data(DT_INT32, input_sizes_shape);
124   if (input_sizes_length == 4) {
125     test::FillValues<int>(
126         &input_data, DIMS(batch_size, input_height, input_width, input_depth));
127   } else {
128     test::FillValues<int>(&input_data, {input_height, input_width});
129   }
130   Output input_sizes =
131       ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
132 
133   TensorShape filter_shape(
134       {filter_size, filter_size, input_depth, filter_count});
135   Output filter =
136       ops::Variable(s->WithOpName("Filter"), filter_shape, DT_FLOAT);
137 
138   int output_height = input_height;
139   int output_width = input_width;
140   TensorShape output_shape(
141       DIMS(batch_size, output_height, output_width, filter_count));
142   Tensor output_data(DT_FLOAT, output_shape);
143   test::FillIota<float>(&output_data, 1.0f);
144   Output output =
145       ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
146 
147   Output conv_backprop_input;
148   Output input_sizes_i =
149       ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
150   ops::Conv2DBackpropInput::Attrs attrs;
151   attrs = attrs.DataFormat(SRC_DATA_FORMAT);
152   if (dilated) {
153     attrs = attrs.Dilations(DIMS(1, 2, 2, 1));
154   }
155   conv_backprop_input = ops::Conv2DBackpropInput(
156       s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
157       DIMS(1, stride, stride, 1), padding, attrs);
158 
159   return conv_backprop_input;
160 }
161 
162 class GenericLayoutOptimizerTest : public GrapplerTest {
163  protected:
SetUp()164   void SetUp() override {
165     bool gpu_available = GetNumAvailableGPUs() > 0;
166 
167     if (gpu_available) {
168       virtual_cluster_ =
169           std::make_unique<SingleMachine>(/*timeout_s=*/10, 1, 1);
170     } else {
171       DeviceProperties cpu_device;
172       cpu_device.set_type("CPU");
173       cpu_device.set_frequency(1000);
174       cpu_device.set_num_cores(4);
175       cpu_device.set_bandwidth(32);
176       cpu_device.set_l1_cache_size(32 * 1024);
177       cpu_device.set_l2_cache_size(256 * 1024);
178       cpu_device.set_l3_cache_size(4 * 1024 * 1024);
179       cpu_device.set_memory_size(1024 * 1024);
180 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
181       DeviceProperties gpu_device;
182       gpu_device.set_type("GPU");
183       gpu_device.mutable_environment()->insert({"architecture", "6"});
184       virtual_cluster_ =
185           absl::WrapUnique(new VirtualCluster({{"/CPU:0", cpu_device},
186                                                { "/GPU:1",
187                                                  gpu_device }}));
188 #else
189       virtual_cluster_ =
190           absl::WrapUnique(new VirtualCluster({{"/CPU:0", cpu_device}}));
191 #endif  // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
192     }
193     TF_ASSERT_OK(virtual_cluster_->Provision());
194   }
195 
TearDown()196   void TearDown() override { TF_ASSERT_OK(virtual_cluster_->Shutdown()); }
197 
198   std::unique_ptr<Cluster> virtual_cluster_;
199 };
200 
VerifyRegularFaninMatch(const utils::NodeView * node,int port,absl::string_view fanin_name,int fanin_port)201 void VerifyRegularFaninMatch(const utils::NodeView* node, int port,
202                              absl::string_view fanin_name, int fanin_port) {
203   ASSERT_GE(node->NumRegularFanins(), port);
204   const auto& fanin = node->GetRegularFanin(port);
205   EXPECT_EQ(fanin.node_view()->GetName(), fanin_name);
206   EXPECT_EQ(fanin.index(), fanin_port);
207 }
208 
VerifyRegularFanoutMatch(const utils::NodeView * node,int port,absl::string_view fanout_name,int fanout_port)209 void VerifyRegularFanoutMatch(const utils::NodeView* node, int port,
210                               absl::string_view fanout_name, int fanout_port) {
211   bool found = false;
212   for (const auto& regular_fanout : node->GetRegularFanout(port)) {
213     if (regular_fanout.node_view()->GetName() == fanout_name &&
214         regular_fanout.index() == fanout_port) {
215       found = true;
216     }
217   }
218   EXPECT_TRUE(found);
219 }
220 
VerifyDataFormatAttributeMatch(const utils::NodeView * node,absl::string_view attr_value)221 void VerifyDataFormatAttributeMatch(const utils::NodeView* node,
222                                     absl::string_view attr_value) {
223   const auto* attr = node->GetAttr("data_format");
224   ASSERT_NE(attr, nullptr);
225   EXPECT_EQ(attr->s(), attr_value);
226 }
227 
TEST_F(GenericLayoutOptimizerTest,OptimizeSimpleConv2DGraph)228 TEST_F(GenericLayoutOptimizerTest, OptimizeSimpleConv2DGraph) {
229   // A simple graph contains 1 Conv2D node, 2 input and 1 output nodes.
230   // Data format is NHWC on GPU, while NCHW on CPU.
231   Scope scope = Scope::NewRootScope();
232 
233   auto conv2d = SimpleConv2D(&scope, 4, 2, "VALID", "");
234   auto identity = Identity(scope.WithOpName("Output"), conv2d);
235   GrapplerItem item;
236   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
237 
238   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
239   GraphDef output;
240   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
241 
242   Status status;
243   utils::GraphView graph_view(&output, &status);
244   TF_ASSERT_OK(status);
245 
246   auto* conv2d_node = graph_view.GetNode("Conv2D");
247   ASSERT_NE(conv2d_node, nullptr);
248   ASSERT_EQ(conv2d_node->NumRegularFanins(), 2);
249   VerifyRegularFaninMatch(conv2d_node, 1, "Filter", 0);
250   VerifyDataFormatAttributeMatch(conv2d_node, SRC_DATA_FORMAT);
251 
252   auto* output_node = graph_view.GetNode("Output");
253   ASSERT_NE(output_node, nullptr);
254   ASSERT_EQ(output_node->NumRegularFanins(), 1);
255 }
256 
TEST_F(GenericLayoutOptimizerTest,PreserveFetch)257 TEST_F(GenericLayoutOptimizerTest, PreserveFetch) {
258   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
259   auto conv = SimpleConv2D(&s, 4, 2, "VALID", "");
260   auto i = ops::Identity(s.WithOpName("i"), conv);
261   GrapplerItem item;
262   item.fetch.push_back("Conv2D");
263   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
264 
265   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
266   GraphDef output;
267   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
268 
269   Status status;
270   utils::GraphView graph_view(&output, &status);
271   TF_ASSERT_OK(status);
272   auto* conv_node = graph_view.GetNode("Conv2D");
273   ASSERT_NE(conv_node, nullptr);
274   VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
275 }
276 
TEST_F(GenericLayoutOptimizerTest,EmptyDevice)277 TEST_F(GenericLayoutOptimizerTest, EmptyDevice) {
278   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
279   auto conv = SimpleConv2D(&s, 4, 2, "VALID", "");
280   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
281   GrapplerItem item;
282   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
283 
284   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
285   GraphDef output;
286   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
287 
288   Status status;
289   utils::GraphView graph_view(&output, &status);
290   TF_ASSERT_OK(status);
291   auto* conv_node = graph_view.GetNode("Conv2D");
292   ASSERT_NE(conv_node, nullptr);
293   VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
294 }
295 
TEST_F(GenericLayoutOptimizerTest,GPUDevice)296 TEST_F(GenericLayoutOptimizerTest, GPUDevice) {
297 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
298   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
299 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
300   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
301   auto conv =
302       SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:GPU:0");
303   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
304   GrapplerItem item;
305   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
306 
307   GenericLayoutOptimizer optimizer;
308   GraphDef output;
309   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
310 
311   Status status;
312   utils::GraphView graph_view(&output, &status);
313   TF_ASSERT_OK(status);
314   auto* conv_node = graph_view.GetNode("Conv2D");
315   ASSERT_NE(conv_node, nullptr);
316   VerifyDataFormatAttributeMatch(conv_node, "NCHW");
317 }
318 
TEST_F(GenericLayoutOptimizerTest,CPUDevice)319 TEST_F(GenericLayoutOptimizerTest, CPUDevice) {
320   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
321   auto conv = SimpleConv2D(&s, 4, 2, "VALID", "/CPU:0");
322   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
323   GrapplerItem item;
324   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
325 
326   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
327   GraphDef output;
328   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
329 
330   Status status;
331   utils::GraphView graph_view(&output, &status);
332   TF_ASSERT_OK(status);
333   auto* conv_node = graph_view.GetNode("Conv2D");
334   ASSERT_NE(conv_node, nullptr);
335 #if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
336   VerifyDataFormatAttributeMatch(conv_node, "NHWC");
337 #else
338   VerifyDataFormatAttributeMatch(conv_node, DST_DATA_FORMAT);
339 #endif  // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
340 }
341 
TEST_F(GenericLayoutOptimizerTest,NoOptimizeIntegerConvolution)342 TEST_F(GenericLayoutOptimizerTest, NoOptimizeIntegerConvolution) {
343   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
344   auto conv = SimpleConv2D<int32>(&s, 4, 2, "VALID", "");
345   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
346   GrapplerItem item;
347   TF_ASSERT_OK(s.ToGraphDef(&item.graph));
348 
349   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
350   GraphDef output;
351   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
352 
353   Status status;
354   utils::GraphView graph_view(&output, &status);
355   TF_ASSERT_OK(status);
356   auto* conv_node = graph_view.GetNode("Conv2D");
357   ASSERT_NE(conv_node, nullptr);
358   VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
359 }
360 
TEST_F(GenericLayoutOptimizerTest,Connectivity)361 TEST_F(GenericLayoutOptimizerTest, Connectivity) {
362   Scope scope = Scope::NewRootScope();
363   auto conv = SimpleConv2D(&scope, 4, 2, "VALID",
364                            absl::StrCat("/device:", DEVICE, ":0"));
365   auto i1 = ops::Identity(scope.WithOpName("i1"), conv);
366   auto i2 = ops::Identity(scope.WithOpName("i2"), i1);
367   auto i3 = ops::Identity(scope.WithOpName("i3"), i2);
368   GrapplerItem item;
369   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
370   // Make the graph not in topological order to test the handling of multi-hop
371   // connectivity (here we say two nodes are connected if all nodes in the
372   // middle are layout agnostic). If the graph is already in topological order,
373   // the problem is easier, where layout optimizer only needs to check
374   // single-hop connectivity.
375   Status status;
376   utils::GraphView graph_view_original(&item.graph, &status);
377   const int i1_index = graph_view_original.GetNode("i1")->node_index();
378   const int i2_index = graph_view_original.GetNode("i2")->node_index();
379   item.graph.mutable_node()->SwapElements(i1_index, i2_index);
380 
381   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
382   GraphDef output;
383   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
384 
385   utils::GraphView graph_view(&output, &status);
386   TF_ASSERT_OK(status);
387   auto* node_i2_output = graph_view.GetNode("i2");
388   ASSERT_NE(node_i2_output, nullptr);
389   // Layout optimizer should process i2, as it detects i2 is connected with the
390   // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is
391   // directly connected to the Conv2D node.
392   ASSERT_EQ(node_i2_output->NumRegularFanins(), 1);
393   VerifyRegularFaninMatch(node_i2_output, 0, "i1", 0);
394 }
395 
TEST_F(GenericLayoutOptimizerTest,Conv2DBackpropInputNonConstInputSizes)396 TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
397   for (const int input_sizes_length : {2, 4}) {
398     Scope s = Scope::NewRootScope();
399     auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
400                                           input_sizes_length);
401     Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
402     GrapplerItem item;
403     TF_ASSERT_OK(s.ToGraphDef(&item.graph));
404 
405     GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
406     GraphDef output;
407     TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
408 
409     Status status;
410     utils::GraphView graph_view(&output, &status);
411     TF_ASSERT_OK(status);
412     auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
413     ASSERT_NE(conv2d_backprop_node, nullptr);
414     ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
415     VerifyRegularFaninMatch(conv2d_backprop_node, 0, "InputSizesIdentity", 0);
416   }
417 }
418 
TEST_F(GenericLayoutOptimizerTest,Conv2DDataFormatVecPermuteCollapse)419 TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) {
420   Scope scope =
421       Scope::NewRootScope().WithDevice(absl::StrCat("/device:", DEVICE, ":0"));
422   auto conv = SimpleConv2D(&scope, 4, 2, "VALID",
423                            absl::StrCat("/device:", DEVICE, ":0"));
424   auto shape = ops::Shape(scope.WithOpName("shape"), conv);
425   auto value = ops::Const(scope.WithOpName("value"), 0, {});
426   auto fill = ops::Fill(scope.WithOpName("fill"), shape, value);
427   auto i = ops::Identity(scope.WithOpName("i"), fill);
428   GrapplerItem item;
429   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
430 
431   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
432   GraphDef output;
433   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
434 
435   // Graph before optimization:
436   // input -> conv2d -> shape -> fill -> output
437   //
438   // Graph after expansion:
439   // input -> T -> conv2d -> T' -> T -> shape -> D' -> D -> fill -> T' -> output
440   //
441   // Graph after collapsion:
442   // input -> T -> conv2d -> shape -> fill -> T' -> output
443   Status status;
444   utils::GraphView graph_view(&output, &status);
445   TF_ASSERT_OK(status);
446   auto* conv2d_node = graph_view.GetNode("Conv2D");
447   ASSERT_NE(conv2d_node, nullptr);
448   ASSERT_EQ(conv2d_node->NumRegularFanins(), 2);
449   VerifyRegularFaninMatch(
450       conv2d_node, 0,
451       absl::StrCat("Conv2D-0-Transpose", SRC_DATA_FORMAT, "To", DST_DATA_FORMAT,
452                    "-LayoutOptimizer"),
453       0);
454 
455   auto* shape_node = graph_view.GetNode("shape");
456   ASSERT_NE(shape_node, nullptr);
457   ASSERT_EQ(shape_node->NumRegularFanins(), 1);
458   VerifyRegularFaninMatch(shape_node, 0, conv2d_node->GetName(), 0);
459 
460   auto* fill_node = graph_view.GetNode("fill");
461   ASSERT_NE(fill_node, nullptr);
462   ASSERT_EQ(fill_node->NumRegularFanins(), 2);
463   VerifyRegularFaninMatch(fill_node, 0, shape_node->GetName(), 0);
464   VerifyRegularFanoutMatch(
465       fill_node, 0,
466       absl::StrCat("fill-0-0-Transpose", DST_DATA_FORMAT, "To", SRC_DATA_FORMAT,
467                    "-LayoutOptimizer"),
468       0);
469 
470   auto* graph_output = graph_view.GetNode("i");
471   ASSERT_NE(graph_output, nullptr);
472   ASSERT_EQ(graph_output->NumRegularFanins(), 1);
473   VerifyRegularFaninMatch(
474       graph_output, 0,
475       absl::StrCat("fill-0-0-Transpose", DST_DATA_FORMAT, "To", SRC_DATA_FORMAT,
476                    "-LayoutOptimizer"),
477       0);
478 }
479 
TEST_F(GenericLayoutOptimizerTest,DoNotPruneNonAddedCancellableTransposes)480 TEST_F(GenericLayoutOptimizerTest, DoNotPruneNonAddedCancellableTransposes) {
481   GrapplerItem item;
482   {
483     Scope scope = Scope::NewRootScope().WithDevice(
484         absl::StrCat("/device:", DEVICE, ":0"));
485     auto input = ops::RandomUniform(scope.WithOpName("input"),
486                                     DIMS(kBatchSize, kHeight, kWidth, kDepthIn),
487                                     DT_FLOAT);
488     // Permutation for source to destination data format.
489     // GPU: NHWC -> NCHW: {0, 3, 1, 2}
490     // CPU: NCHW -> NHWC: {0, 2, 3, 1}
491     auto input_in_transpose =
492         ops::Transpose(scope.WithOpName("input_in_transpose"), input,
493                        ops::Const(scope, PERMUTATION_SRC_TO_DST, {4}));
494     // Permutation for destination to source data format.
495     // GPU: NCHW -> NHWC: {0, 2, 3, 1}
496     // CPU: NHWC -> NCHW: {0, 3, 1, 2}
497     auto input_out_transpose = ops::Transpose(
498         scope.WithOpName("input_out_transpose"), input_in_transpose,
499         ops::Const(scope, PERMUTATION_DST_TO_SRC, {4}));
500     Tensor bias_data(DT_FLOAT, TensorShape({kDepthIn}));
501     test::FillIota<float>(&bias_data, 1.0f);
502     auto bias_add = ops::BiasAdd(
503         scope.WithOpName("bias_add"), input_out_transpose, bias_data,
504         ops::BiasAdd::Attrs().DataFormat(SRC_DATA_FORMAT));
505     auto output_in_transpose =
506         ops::Transpose(scope.WithOpName("output_in_transpose"), bias_add,
507                        ops::Const(scope, PERMUTATION_SRC_TO_DST, {4}));
508     auto output_out_transpose = ops::Transpose(
509         scope.WithOpName("output_out_transpose"), output_in_transpose,
510         ops::Const(scope, PERMUTATION_DST_TO_SRC, {4}));
511     auto output =
512         ops::Identity(scope.WithOpName("output"), output_out_transpose);
513     TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
514   }
515 
516   GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
517   GraphDef output;
518   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
519 
520   Status status;
521   utils::GraphView graph_view(&output, &status);
522   TF_ASSERT_OK(status);
523 
524   auto* input_node = graph_view.GetNode("input");
525   ASSERT_NE(input_node, nullptr);
526 
527   auto* input_in_transpose_node = graph_view.GetNode("input_in_transpose");
528   ASSERT_NE(input_in_transpose_node, nullptr);
529   ASSERT_EQ(input_in_transpose_node->NumRegularFanins(), 2);
530   VerifyRegularFaninMatch(input_in_transpose_node, 0, input_node->GetName(), 0);
531 
532   auto* input_out_transpose_node = graph_view.GetNode("input_out_transpose");
533   ASSERT_NE(input_out_transpose_node, nullptr);
534   ASSERT_EQ(input_out_transpose_node->NumRegularFanins(), 2);
535   VerifyRegularFaninMatch(input_out_transpose_node, 0,
536                           input_in_transpose_node->GetName(), 0);
537 
538   auto* bias_add_in_transpose_node = graph_view.GetNode(
539       absl::StrCat("bias_add-0-Transpose", SRC_DATA_FORMAT, "To",
540                    DST_DATA_FORMAT, "-LayoutOptimizer"));
541   ASSERT_NE(bias_add_in_transpose_node, nullptr);
542   ASSERT_EQ(bias_add_in_transpose_node->NumRegularFanins(), 2);
543   VerifyRegularFaninMatch(bias_add_in_transpose_node, 0,
544                           input_out_transpose_node->GetName(), 0);
545 
546   auto* bias_add_node = graph_view.GetNode("bias_add");
547   ASSERT_NE(bias_add_node, nullptr);
548   ASSERT_EQ(bias_add_node->NumRegularFanins(), 2);
549   VerifyRegularFaninMatch(bias_add_node, 0,
550                           bias_add_in_transpose_node->GetName(), 0);
551 
552   auto* bias_add_out_transpose_node = graph_view.GetNode(
553       absl::StrCat("bias_add-0-0-Transpose", DST_DATA_FORMAT, "To",
554                    SRC_DATA_FORMAT, "-LayoutOptimizer"));
555   ASSERT_NE(bias_add_out_transpose_node, nullptr);
556   ASSERT_EQ(bias_add_out_transpose_node->NumRegularFanins(), 2);
557   VerifyRegularFaninMatch(bias_add_out_transpose_node, 0,
558                           bias_add_node->GetName(), 0);
559 
560   auto* output_in_transpose_node = graph_view.GetNode("output_in_transpose");
561   ASSERT_NE(output_in_transpose_node, nullptr);
562   ASSERT_EQ(output_in_transpose_node->NumRegularFanins(), 2);
563   VerifyRegularFaninMatch(output_in_transpose_node, 0,
564                           bias_add_out_transpose_node->GetName(), 0);
565 
566   auto* output_out_transpose_node = graph_view.GetNode("output_out_transpose");
567   ASSERT_NE(output_out_transpose_node, nullptr);
568   ASSERT_EQ(output_out_transpose_node->NumRegularFanins(), 2);
569   VerifyRegularFaninMatch(output_out_transpose_node, 0,
570                           output_in_transpose_node->GetName(), 0);
571 
572   auto* output_node = graph_view.GetNode("output");
573   ASSERT_NE(output_node, nullptr);
574   ASSERT_EQ(output_node->NumRegularFanins(), 1);
575   VerifyRegularFaninMatch(output_node, 0, output_out_transpose_node->GetName(),
576                           0);
577 }
578 
TEST_F(GenericLayoutOptimizerTest,CancelTransposeAroundPad)579 TEST_F(GenericLayoutOptimizerTest, CancelTransposeAroundPad) {
580   using test::function::NDef;
581 
582   GenericLayoutOptimizer optimizer(
583       RewriterConfig::AGGRESSIVE,
584       RewriterConfig::NCHW_TO_NHWC /* CPU settings*/);
585 
586   const Tensor kPermuteNhwcToNchw = test::AsTensor<int32>({0, 3, 1, 2});
587   const Tensor kPermuteNchwToNhwc = test::AsTensor<int32>({0, 2, 3, 1});
588   const Tensor kPad = test::AsTensor<int32>({1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
589 
590   GrapplerItem item;
591   item.graph = test::function::GDef({
592       NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}),
593 
594       NDef("paddings", "Const", {}, {{"dtype", DT_INT32}, {"value", kPad}}),
595       NDef("perm_nhwc_to_nchw", "Const", {},
596            {{"dtype", DT_INT32}, {"value", kPermuteNhwcToNchw}}),
597       NDef("perm_nchw_to_nhwc", "Const", {},
598            {{"dtype", DT_INT32}, {"value", kPermuteNchwToNhwc}}),
599 
600       NDef("transpose_0", "Transpose", {"x", "perm_nhwc_to_nchw"},
601            {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
602       NDef("pad", "Pad", {"transpose_0", "paddings"},
603            {{"T", DT_FLOAT}, {"Tpaddings", DT_INT32}}),
604       NDef("transpose_1", "Transpose", {"pad", "perm_nchw_to_nhwc"},
605            {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
606       NDef("transpose_2", "Transpose", {"pad", "perm_nchw_to_nhwc"},
607            {{"T", DT_FLOAT}, {"Tperm", DT_INT32}}),
608   });
609 
610   GraphDef output;
611   TF_EXPECT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
612 
613   const Tensor kPermutedPaddings =
614       test::AsTensor<int32>({1, 2, 5, 6, 7, 8, 3, 4}, {4, 2});
615 
616   GraphDef expected = test::function::GDef({
617       NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}),
618 
619       NDef("paddings", "Const", {},
620            {{"dtype", DT_INT32}, {"value", kPermutedPaddings}}),
621       NDef("perm_nhwc_to_nchw", "Const", {},
622            {{"dtype", DT_INT32}, {"value", kPermuteNhwcToNchw}}),
623       NDef("perm_nchw_to_nhwc", "Const", {},
624            {{"dtype", DT_INT32}, {"value", kPermuteNchwToNhwc}}),
625 
626       // Transpose nodes replaced by Identity nodes.
627       NDef("transpose_0", "Identity", {"x"}, {{"T", DT_FLOAT}}),
628       NDef("pad", "Pad", {"transpose_0", "paddings"},
629            {{"T", DT_FLOAT}, {"Tpaddings", DT_INT32}}),
630       NDef("transpose_1", "Identity", {"pad"}, {{"T", DT_FLOAT}}),
631       NDef("transpose_2", "Identity", {"pad"}, {{"T", DT_FLOAT}}),
632   });
633 
634   CompareGraphs(expected, output);
635 
636   Tensor x = GenerateRandomTensor<DT_FLOAT>({2, 6, 6, 8});
637   item.fetch = {"transpose_1", "transpose_2"};
638   item.feed.emplace_back("x", x);
639   auto tensors_expected = EvaluateFetchNodes(item);
640   GrapplerItem optimized = item.WithGraph(std::move(output));
641   auto tensors = EvaluateFetchNodes(optimized);
642   ASSERT_EQ(tensors.size(), 2);
643   ASSERT_EQ(tensors_expected.size(), 2);
644   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
645   test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
646 }
647 
TEST_F(GenericLayoutOptimizerTest,PreserveInputShapes)648 TEST_F(GenericLayoutOptimizerTest, PreserveInputShapes) {
649   using test::function::NDef;
650 
651   GenericLayoutOptimizer optimizer(RewriterConfig::AGGRESSIVE);
652 
653   AttrValue output_shapes;
654   auto* shape = output_shapes.mutable_list()->add_shape();
655   shape->add_dim()->set_size(-1);
656 
657   GrapplerItem item;
658   item.graph = test::function::GDef({NDef(
659       "x", "_Arg", {},
660       {{"T", DT_FLOAT}, {"index", 0}, {"_output_shapes", output_shapes}})});
661   item.feed.emplace_back("x", Tensor(DT_FLOAT));
662 
663   GraphDef output;
664   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
665 
666   Status status;
667   utils::GraphView graph_view(&output, &status);
668   TF_ASSERT_OK(status);
669 
670   auto* arg = graph_view.GetNode("x");
671   ASSERT_NE(arg, nullptr);
672   EXPECT_TRUE(arg->HasAttr("_output_shapes"));
673   EXPECT_EQ(arg->GetAttr("_output_shapes")->DebugString(),
674             output_shapes.DebugString());
675 }
676 
677 // TODO(yanzha): Add more complex Graph for test.
678 
679 }  // namespace grappler
680 }  // namespace tensorflow
681