1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h"
17 
18 #include "absl/strings/string_view.h"
19 #include "absl/types/span.h"
20 #include "tensorflow/cc/ops/const_op.h"
21 #include "tensorflow/cc/ops/math_ops_internal.h"
22 #include "tensorflow/cc/ops/nn_ops.h"
23 #include "tensorflow/cc/ops/nn_ops_internal.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/grappler/clusters/cluster.h"
29 #include "tensorflow/core/grappler/clusters/single_machine.h"
30 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
31 #include "tensorflow/core/grappler/devices.h"
32 #include "tensorflow/core/grappler/grappler_item.h"
33 #include "tensorflow/core/grappler/utils/graph_view.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36 
37 namespace tensorflow {
38 namespace grappler {
39 
40 namespace {
41 
42 using ::tensorflow::test::ExpectTensorEqual;
43 
44 constexpr int kBatchSize = 32;
45 constexpr int kWidth = 10;
46 constexpr int kHeight = 10;
47 constexpr int kDepthIn = 8;
48 constexpr int kKernel = 2;
49 constexpr int kStride1 = 2;
50 constexpr int kStride2 = 4;
51 constexpr int kOutWidth = 5;
52 constexpr int kOutHeight = 5;
53 constexpr int kDepthOut = 16;
54 constexpr int kDilation = 2;
55 constexpr int kPaddingTop = 1;
56 constexpr int kPaddingBottom = 2;
57 constexpr int kPaddingLeft = 3;
58 constexpr int kPaddingRight = 4;
59 constexpr char kSrcFormat[] = "NHWC";
60 constexpr char kDstFormat[] = "NCHW";
61 constexpr char kGPU[] = "GPU";
62 constexpr char kAttrOutputShapes[] = "_output_shapes";
63 constexpr char kAttrDataFormat[] = "data_format";
64 constexpr char kOpTranspose[] = "Transpose";
65 
66 class TransposerImpl : public Transposer {
67  public:
TransposerImpl()68   explicit TransposerImpl() : Transposer() {}
TransposeNode(TransposeContext *,utils::MutableNodeView *)69   Status TransposeNode(TransposeContext*, utils::MutableNodeView*) override {
70     return OkStatus();
71   }
72 };
73 
VerifyRegularFaninMatch(const utils::MutableNodeView * node,int port,absl::string_view fanin_name,int fanin_port)74 void VerifyRegularFaninMatch(const utils::MutableNodeView* node, int port,
75                              absl::string_view fanin_name, int fanin_port) {
76   ASSERT_GT(node->NumRegularFanins(), port);
77   const auto& fanin = node->GetRegularFanin(port);
78   EXPECT_EQ(fanin.node_view()->GetName(), fanin_name);
79   EXPECT_EQ(fanin.index(), fanin_port);
80 }
81 
VerifyShapeAttributeMatch(const utils::MutableNodeView * node,absl::string_view attr_value)82 void VerifyShapeAttributeMatch(const utils::MutableNodeView* node,
83                                absl::string_view attr_value) {
84   const auto* attr = node->GetAttr(kAttrOutputShapes);
85   ASSERT_NE(attr, nullptr);
86   EXPECT_EQ(attr->shape().DebugString(), attr_value);
87 }
88 
VerifyShapeAttributeMatch(const utils::MutableNodeView * node,int shape_index,absl::string_view attr_value)89 void VerifyShapeAttributeMatch(const utils::MutableNodeView* node,
90                                int shape_index, absl::string_view attr_value) {
91   const auto* attr = node->GetAttr(kAttrOutputShapes);
92   ASSERT_NE(attr, nullptr);
93   ASSERT_GT(attr->list().shape_size(), shape_index);
94   EXPECT_EQ(attr->list().shape(shape_index).DebugString(), attr_value);
95 }
96 
VerifyDataFormatAttributeMatch(const utils::MutableNodeView * node,absl::string_view attr_value)97 void VerifyDataFormatAttributeMatch(const utils::MutableNodeView* node,
98                                     absl::string_view attr_value) {
99   const auto* attr = node->GetAttr(kAttrDataFormat);
100   ASSERT_NE(attr, nullptr);
101   EXPECT_EQ(attr->s(), attr_value);
102 }
103 
SimpleConv2D(const Scope * scope,const DataType & data_type=DT_FLOAT)104 Output SimpleConv2D(const Scope* scope, const DataType& data_type = DT_FLOAT) {
105   auto input =
106       ops::RandomUniform(scope->WithOpName("input"),
107                          {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
108   auto filter =
109       ops::RandomUniform(scope->WithOpName("filter"),
110                          {kHeight, kWidth, kDepthIn, kDepthOut}, data_type);
111   auto conv2d = ops::Conv2D(
112       scope->WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
113       {1, kStride1, kStride2, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
114 
115   return conv2d;
116 }
117 
CreateSimpleConv2DGraph(GraphDef * graph,const DataType & data_type=DT_FLOAT)118 Status CreateSimpleConv2DGraph(GraphDef* graph,
119                                const DataType& data_type = DT_FLOAT) {
120   Scope scope = Scope::NewRootScope();
121   auto conv2d = SimpleConv2D(&scope, data_type);
122   auto output = ops::Identity(scope.WithOpName("output"), conv2d);
123 
124   return scope.ToGraphDef(graph);
125 }
126 
CreateSimpleFusedBatchNorm(GraphDef * graph,const DataType & data_type=DT_FLOAT)127 Status CreateSimpleFusedBatchNorm(GraphDef* graph,
128                                   const DataType& data_type = DT_FLOAT) {
129   Scope scope = Scope::NewRootScope();
130   auto x =
131       ops::RandomUniform(scope.WithOpName("x"),
132                          {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
133   auto scale =
134       ops::RandomUniform(scope.WithOpName("scale"), {kDepthIn}, DT_FLOAT);
135   auto offset =
136       ops::RandomUniform(scope.WithOpName("offset"), {kDepthIn}, DT_FLOAT);
137   auto mean =
138       ops::RandomUniform(scope.WithOpName("mean"), {kDepthIn}, DT_FLOAT);
139   auto var = ops::RandomUniform(scope.WithOpName("var"), {kDepthIn}, DT_FLOAT);
140   auto batch_norm = ops::FusedBatchNormV2(
141       scope.WithOpName("bn").WithDevice("/device:GPU:0"), x, scale, offset,
142       mean, var, ops::FusedBatchNormV2::IsTraining(false).Epsilon(0.1f));
143 
144   auto output_y = ops::Identity(scope.WithOpName("output_y"), batch_norm.y);
145   auto output_mean =
146       ops::Identity(scope.WithOpName("output_mean"), batch_norm.batch_mean);
147   auto output_variance = ops::Identity(scope.WithOpName("output_variance"),
148                                        batch_norm.batch_variance);
149 
150   return scope.ToGraphDef(graph);
151 }
152 
CreateSimpleMaxPoolGrad(GraphDef * graph,bool use_grad_grad)153 Status CreateSimpleMaxPoolGrad(GraphDef* graph, bool use_grad_grad) {
154   Scope scope = Scope::NewRootScope();
155   auto input =
156       ops::RandomUniform(scope.WithOpName("orig_input"),
157                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
158   auto output_data = ops::RandomUniform(
159       scope.WithOpName("orig_output"),
160       {kBatchSize, kOutHeight, kOutWidth, kDepthIn}, DT_FLOAT);
161   auto output_grad =
162       ops::RandomUniform(scope.WithOpName("grad"),
163                          {kBatchSize, use_grad_grad ? kHeight : kOutHeight,
164                           use_grad_grad ? kWidth : kOutWidth, kDepthIn},
165                          DT_FLOAT);
166   Output maxpool_grad;
167   if (use_grad_grad) {
168     maxpool_grad = ops::MaxPoolGradGrad(
169         scope.WithOpName("maxpool_grad").WithDevice("/device:GPU:0"), input,
170         output_data, output_grad, {1, kKernel, kKernel, 1},
171         {1, kStride1, kStride1, 1}, "VALID");
172   } else {
173     maxpool_grad = ops::internal::MaxPoolGrad(
174         scope.WithOpName("maxpool_grad").WithDevice("/device:GPU:0"), input,
175         output_data, output_grad, {1, kKernel, kKernel, 1},
176         {1, kStride1, kStride1, 1}, "VALID");
177   }
178 
179   auto output = ops::Identity(scope.WithOpName("output"), maxpool_grad);
180 
181   return scope.ToGraphDef(graph);
182 }
183 
CreateSimpleBiasAddGrad(GraphDef * graph,const Input & shape)184 Status CreateSimpleBiasAddGrad(GraphDef* graph, const Input& shape) {
185   Scope scope = Scope::NewRootScope();
186   auto input = ops::RandomUniform(scope.WithOpName("input"), shape, DT_FLOAT);
187   auto bag =
188       ops::BiasAddGrad(scope.WithOpName("bag").WithDevice("/device:GPU:0"),
189                        input, ops::BiasAddGrad::DataFormat(kSrcFormat));
190   auto output = ops::Identity(scope.WithOpName("output"), bag);
191 
192   return scope.ToGraphDef(graph);
193 }
194 
CreateSimpleConv2DBackpropFilter(GraphDef * graph,const DataType & data_type=DT_FLOAT,absl::string_view padding="SAME")195 Status CreateSimpleConv2DBackpropFilter(GraphDef* graph,
196                                         const DataType& data_type = DT_FLOAT,
197                                         absl::string_view padding = "SAME") {
198   Scope scope = Scope::NewRootScope();
199   auto input =
200       ops::RandomUniform(scope.WithOpName("input"),
201                          {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
202   auto out_backprop =
203       ops::RandomUniform(scope.WithOpName("out_backprop"),
204                          {kBatchSize, kHeight, kWidth, kDepthOut}, data_type);
205   if (padding == "EXPLICIT") {
206     auto conv2d_backprop_filter = ops::Conv2DBackpropFilter(
207         scope.WithOpName("conv2d_backprop_filter").WithDevice("/device:GPU:0"),
208         input, {kHeight, kWidth, kDepthIn, kDepthOut}, out_backprop,
209         {1, 2, 4, 1}, padding,
210         ops::Conv2DBackpropFilter::Attrs()
211             .Dilations({1, kDilation, kDilation, 1})
212             .ExplicitPaddings({0, 0, kPaddingTop, kPaddingBottom, kPaddingLeft,
213                                kPaddingRight, 0, 0})
214             .DataFormat(kSrcFormat));
215     auto output =
216         ops::Identity(scope.WithOpName("output"), conv2d_backprop_filter);
217   } else {
218     auto conv2d_backprop_filter = ops::Conv2DBackpropFilter(
219         scope.WithOpName("conv2d_backprop_filter").WithDevice("/device:GPU:0"),
220         input, {kHeight, kWidth, kDepthIn, kDepthOut}, out_backprop,
221         {1, 2, 4, 1}, padding,
222         ops::Conv2DBackpropFilter::DataFormat(kSrcFormat));
223     auto output =
224         ops::Identity(scope.WithOpName("output"), conv2d_backprop_filter);
225   }
226 
227   return scope.ToGraphDef(graph);
228 }
229 
CreateSimpleConv2DBackpropInput(GraphDef * graph,const DataType & data_type=DT_FLOAT)230 Status CreateSimpleConv2DBackpropInput(GraphDef* graph,
231                                        const DataType& data_type = DT_FLOAT) {
232   Scope scope = Scope::NewRootScope();
233   auto input_sizes = ops::Const(scope.WithOpName("input_sizes"),
234                                 {kBatchSize, kHeight, kWidth, kDepthIn});
235   auto input =
236       ops::RandomUniform(scope.WithOpName("input"),
237                          {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
238   auto filter =
239       ops::RandomUniform(scope.WithOpName("filter"),
240                          {kHeight, kWidth, kDepthIn, kDepthOut}, data_type);
241   auto out_backprop =
242       ops::RandomUniform(scope.WithOpName("out_backprop"),
243                          {kBatchSize, kHeight, kWidth, kDepthOut}, data_type);
244   auto conv2d_backprop_input = ops::Conv2DBackpropInput(
245       scope.WithOpName("conv2d_backprop_input").WithDevice("/device:GPU:0"),
246       input_sizes, filter, out_backprop, {1, kStride1, kStride1, 1}, "VALID");
247   auto output =
248       ops::Identity(scope.WithOpName("output"), conv2d_backprop_input);
249 
250   return scope.ToGraphDef(graph);
251 }
252 
CreateSimpleFusedBatchNormGrad(GraphDef * graph,bool is_training,const DataType & data_type=DT_FLOAT)253 Status CreateSimpleFusedBatchNormGrad(GraphDef* graph, bool is_training,
254                                       const DataType& data_type = DT_FLOAT) {
255   Scope scope = Scope::NewRootScope();
256   auto y_backprop =
257       ops::RandomUniform(scope.WithOpName("y_backprop"),
258                          {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
259   auto x =
260       ops::RandomUniform(scope.WithOpName("x"),
261                          {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
262   auto scale =
263       ops::RandomUniform(scope.WithOpName("scale"), {kDepthIn}, DT_FLOAT);
264   auto reserve_space_1 = ops::RandomUniform(scope.WithOpName("reserve_space_1"),
265                                             {kDepthIn}, DT_FLOAT);
266   auto reserve_space_2 = ops::RandomUniform(scope.WithOpName("reserve_space_2"),
267                                             {kDepthIn}, DT_FLOAT);
268   auto fused_batch_norm_grad = ops::FusedBatchNormGradV2(
269       scope.WithOpName("fused_batch_norm_grad").WithDevice("/device:GPU:0"),
270       y_backprop, x, scale, reserve_space_1, reserve_space_2,
271       ops::FusedBatchNormGradV2::DataFormat(kSrcFormat)
272           .IsTraining(is_training)
273           .Epsilon(0.1f));
274   auto x_backprop = ops::Identity(scope.WithOpName("x_backprop"),
275                                   fused_batch_norm_grad.x_backprop);
276   auto scale_backprop = ops::Identity(scope.WithOpName("scale_backprop"),
277                                       fused_batch_norm_grad.scale_backprop);
278   auto offset_backprop = ops::Identity(scope.WithOpName("offset_backprop"),
279                                        fused_batch_norm_grad.offset_backprop);
280   auto reserve_space_3 = ops::Identity(scope.WithOpName("reserve_space_3"),
281                                        fused_batch_norm_grad.reserve_space_3);
282   auto reserve_space_4 = ops::Identity(scope.WithOpName("reserve_space_4"),
283                                        fused_batch_norm_grad.reserve_space_4);
284 
285   return scope.ToGraphDef(graph);
286 }
287 
CreateSimpleAddN(GraphDef * graph)288 Status CreateSimpleAddN(GraphDef* graph) {
289   Scope scope = Scope::NewRootScope();
290   auto input =
291       ops::RandomUniform(scope.WithOpName("input"),
292                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
293   auto filter =
294       ops::RandomUniform(scope.WithOpName("filter"),
295                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
296   Output conv2d = ops::Conv2D(
297       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
298       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
299   Output a = ops::RandomUniform(scope.WithOpName("a"),
300                                 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
301   Output b = ops::RandomUniform(scope.WithOpName("b"),
302                                 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
303   Output c = ops::RandomUniform(scope.WithOpName("c"),
304                                 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
305   auto add_n = ops::AddN(scope.WithOpName("add_n").WithDevice("/device:GPU:0"),
306                          {a, b, c, conv2d});
307   auto output = ops::Identity(scope.WithOpName("output"), add_n);
308 
309   return scope.ToGraphDef(graph);
310 }
311 
CreateSimpleIdentityN(GraphDef * graph)312 Status CreateSimpleIdentityN(GraphDef* graph) {
313   Scope scope = Scope::NewRootScope();
314   auto conv2d_1_input =
315       ops::RandomUniform(scope.WithOpName("conv2d_1_input"),
316                          {kBatchSize, kDepthIn, kHeight, kWidth}, DT_FLOAT);
317   auto conv2d_1_filter =
318       ops::RandomUniform(scope.WithOpName("conv2d_1_filter"),
319                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
320   Output conv2d_1 =
321       ops::Conv2D(scope.WithOpName("conv2d_1").WithDevice("/device:GPU:0"),
322                   conv2d_1_input, conv2d_1_filter, {1, 1, 2, 4}, "SAME",
323                   ops::Conv2D::DataFormat(kDstFormat));
324   auto conv2d_2_input =
325       ops::RandomUniform(scope.WithOpName("conv2d_2_input"),
326                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
327   auto conv2d_2_filter =
328       ops::RandomUniform(scope.WithOpName("conv2d_2_filter"),
329                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
330   Output conv2d_2 =
331       ops::Conv2D(scope.WithOpName("conv2d_2").WithDevice("/device:GPU:0"),
332                   conv2d_2_input, conv2d_2_filter, {1, 2, 4, 1}, "SAME",
333                   ops::Conv2D::DataFormat(kSrcFormat));
334   Output a = ops::RandomUniform(
335       scope.WithOpName("a"), {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
336   Output b = ops::RandomUniform(scope.WithOpName("b"), {kBatchSize, kDepthIn},
337                                 DT_FLOAT);
338   auto identity_n =
339       ops::IdentityN(scope.WithOpName("identity_n").WithDevice("/device:GPU:0"),
340                      {conv2d_1, conv2d_2, a, b});
341   auto conv2d_1_output =
342       ops::Identity(scope.WithOpName("conv2d_1_output"), identity_n.output[0]);
343   auto conv2d_2_output =
344       ops::Identity(scope.WithOpName("conv2d_2_output"), identity_n.output[1]);
345   auto a_output =
346       ops::Identity(scope.WithOpName("a_output"), identity_n.output[2]);
347   auto b_output =
348       ops::Identity(scope.WithOpName("b_output"), identity_n.output[3]);
349 
350   return scope.ToGraphDef(graph);
351 }
352 
353 class TransposerTest : public ::testing::Test {
354  protected:
SetUp()355   void SetUp() override {
356     bool gpu_available = GetNumAvailableGPUs() > 0;
357 
358     if (gpu_available) {
359       virtual_cluster_ =
360           std::make_unique<SingleMachine>(/*timeout_s=*/10, 1, 1);
361     } else {
362       DeviceProperties gpu_device;
363       gpu_device.set_type(kGPU);
364       gpu_device.mutable_environment()->insert({"architecture", "6"});
365       virtual_cluster_ =
366           absl::WrapUnique(new VirtualCluster({{"/GPU:1", gpu_device}}));
367     }
368     TF_ASSERT_OK(virtual_cluster_->Provision());
369   }
370 
TearDown()371   void TearDown() override { TF_ASSERT_OK(virtual_cluster_->Shutdown()); }
372 
373   template <typename T>
ReduceTransposerKeepDims()374   void ReduceTransposerKeepDims() {
375 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
376     GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
377 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
378     GrapplerItem item;
379     Scope scope = Scope::NewRootScope();
380 
381     auto input =
382         ops::RandomUniform(scope.WithOpName("input"),
383                            {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
384     auto filter =
385         ops::RandomUniform(scope.WithOpName("filter"),
386                            {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
387     Output conv2d = ops::Conv2D(
388         scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
389         {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
390 
391     auto axis = ops::Const<T>(scope.WithOpName("axis"), {0, 1, 2}, {3});
392     auto attrs = ops::Sum::Attrs().KeepDims(true);
393     auto sum_op = ops::Sum(scope.WithOpName("sum").WithDevice("/device:GPU:0"),
394                            conv2d, axis, attrs);
395 
396     auto z = ops::Identity(scope.WithOpName("z"), sum_op);
397     TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
398 
399     TransposeContext context;
400     TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
401         item, virtual_cluster_.get(), &context));
402     context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
403 
404     DefaultLayoutSensitiveOpTransposer conv2d_transposer;
405     auto* c2d = context.graph_view->GetNode("conv2d");
406     ASSERT_NE(c2d, nullptr);
407     TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
408 
409     ReduceTransposer reducer_transposer;
410     auto* sum = context.graph_view->GetNode("sum");
411     ASSERT_NE(sum, nullptr);
412     TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, sum));
413 
414     auto* input_transpose_node = context.graph_view->GetNode(
415         "sum-0-TransposeNHWCToNCHW-LayoutOptimizer");
416     ASSERT_NE(input_transpose_node, nullptr);
417 
418     auto* updated_sum_node = context.graph_view->GetNode("sum");
419     ASSERT_NE(updated_sum_node, nullptr);
420     ASSERT_EQ(updated_sum_node->NumRegularFanins(), 2);
421     VerifyRegularFaninMatch(updated_sum_node, 0,
422                             input_transpose_node->GetName(), 0);
423 
424     auto* axis_node = context.graph_view->GetNode(
425         "sum-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
426     ASSERT_NE(axis_node, nullptr);
427     ASSERT_EQ(axis_node->NumRegularFanins(), 1);
428     VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
429 
430     auto* output_transpose_node = context.graph_view->GetNode(
431         "sum-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
432     ASSERT_NE(output_transpose_node, nullptr);
433 
434     auto* z_output_node = context.graph_view->GetNode("z");
435     ASSERT_NE(z_output_node, nullptr);
436     ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
437     VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
438                             0);
439   }
440 
441   template <typename T>
ReduceTransposerValidAxisNode()442   void ReduceTransposerValidAxisNode() {
443 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
444     GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
445 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
446     GrapplerItem item;
447     Scope scope = Scope::NewRootScope();
448 
449     auto input =
450         ops::RandomUniform(scope.WithOpName("input"),
451                            {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
452     auto filter =
453         ops::RandomUniform(scope.WithOpName("filter"),
454                            {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
455     Output conv2d = ops::Conv2D(
456         scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
457         {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
458 
459     auto axis = ops::Const<T>(scope.WithOpName("axis"), {0, 1, 2}, {3});
460     auto sum_op = ops::Max(scope.WithOpName("max").WithDevice("/device:GPU:0"),
461                            conv2d, axis);
462 
463     auto z = ops::Identity(scope.WithOpName("z"), sum_op);
464     TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
465 
466     TransposeContext context;
467     TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
468         item, virtual_cluster_.get(), &context));
469     context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
470 
471     DefaultLayoutSensitiveOpTransposer conv2d_transposer;
472     auto* c2d = context.graph_view->GetNode("conv2d");
473     ASSERT_NE(c2d, nullptr);
474     TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
475 
476     ReduceTransposer reducer_transposer;
477     auto* max = context.graph_view->GetNode("max");
478     ASSERT_NE(max, nullptr);
479     TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, max));
480 
481     auto* input_transpose_node = context.graph_view->GetNode(
482         "max-0-TransposeNHWCToNCHW-LayoutOptimizer");
483     ASSERT_NE(input_transpose_node, nullptr);
484 
485     auto* updated_max_node = context.graph_view->GetNode("max");
486     ASSERT_NE(updated_max_node, nullptr);
487     ASSERT_EQ(updated_max_node->NumRegularFanins(), 2);
488     VerifyRegularFaninMatch(updated_max_node, 0,
489                             input_transpose_node->GetName(), 0);
490 
491     auto* axis_node = context.graph_view->GetNode(
492         "max-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
493     ASSERT_NE(axis_node, nullptr);
494     ASSERT_EQ(axis_node->NumRegularFanins(), 1);
495     VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
496 
497     auto* z_output_node = context.graph_view->GetNode("z");
498     ASSERT_NE(z_output_node, nullptr);
499     ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
500     VerifyRegularFaninMatch(z_output_node, 0, updated_max_node->GetName(), 0);
501   }
502 
503   std::unique_ptr<Cluster> virtual_cluster_;
504 };
505 
TEST_F(TransposerTest,CreateConstPermNode)506 TEST_F(TransposerTest, CreateConstPermNode) {
507 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
508   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
509 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
510   GrapplerItem item;
511   TransposeContext context;
512   TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
513   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
514       item, virtual_cluster_.get(), &context));
515   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
516 
517   TransposerImpl transposer;
518   constexpr char kNodeName[] = "const_perm_node";
519   constexpr char kDevice[] = "/device:GPU:0";
520   utils::MutationNewNode added_node;
521   EXPECT_FALSE(context.graph_view->HasNode(kNodeName));
522   TF_ASSERT_OK(transposer.CreateConstPermNode(&context, kNodeName, kDevice,
523                                               {0, 3, 1, 2}, "", &added_node));
524   TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
525 
526   utils::MutableNodeView* const_perm_node =
527       context.graph_view->GetNode(kNodeName);
528   EXPECT_EQ(const_perm_node->GetName(), kNodeName);
529   EXPECT_EQ(const_perm_node->GetDevice(), kDevice);
530   const auto* value_attr = const_perm_node->GetAttr("value");
531   ASSERT_NE(value_attr, nullptr);
532 
533   Tensor tensor;
534   ASSERT_TRUE(tensor.FromProto(value_attr->tensor()));
535   Tensor expected(DT_INT32, {4});
536   ::tensorflow::test::FillValues<int32>(&expected, {0, 3, 1, 2});
537   ExpectTensorEqual<int32>(tensor, expected);
538 }
539 
MakeTensorShapeFromDimensions(absl::Span<const int> dims)540 TensorShapeProto MakeTensorShapeFromDimensions(absl::Span<const int> dims) {
541   TensorShapeProto shape_proto = TensorShapeProto();
542   for (const int dim : dims) {
543     TensorShapeProto_Dim dim_proto = TensorShapeProto_Dim();
544     dim_proto.set_size(dim);
545     *shape_proto.add_dim() = std::move(dim_proto);
546   }
547   return shape_proto;
548 }
549 
TEST_F(TransposerTest,CreateTransposeNode)550 TEST_F(TransposerTest, CreateTransposeNode) {
551 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
552   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
553 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
554   GrapplerItem item;
555   TransposeContext context;
556   TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
557   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
558       item, virtual_cluster_.get(), &context));
559   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
560 
561   TransposerImpl transposer;
562   constexpr char kNodeNameFormat[] =
563       "transpose_node-0-$0-NWCHToNCWH-LayoutOptimizer";
564   constexpr char kDevice[] = "/device:GPU:0";
565   TensorShapeProto input_shape = MakeTensorShapeFromDimensions({1, 2, 3, 4});
566   TensorShapeProto expected_shape = MakeTensorShapeFromDimensions({1, 4, 2, 3});
567   utils::MutationNewNode added_node;
568   string transpose_node_name;
569   TF_ASSERT_OK(transposer.CreateTransposeNode(
570       &context, kNodeNameFormat, DT_DOUBLE, kDevice, input_shape, {0, 3, 1, 2},
571       "", &added_node, &transpose_node_name));
572 
573   EXPECT_EQ(transpose_node_name,
574             "transpose_node-0-Transpose-NWCHToNCWH-LayoutOptimizer");
575   utils::Mutation* mutation = context.graph_view->GetMutationBuilder();
576   Status status;
577   // Placeholder node with empty name as transpose node is created with it's
578   // first input not set.
579   mutation->AddNode({}, &status);
580   TF_ASSERT_OK(status);
581   TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
582   auto* transpose_node = context.graph_view->GetNode(transpose_node_name);
583   ASSERT_NE(transpose_node, nullptr);
584   EXPECT_EQ(transpose_node->GetDevice(), kDevice);
585   const auto* output_shapes_attr = transpose_node->GetAttr("_output_shapes");
586   EXPECT_EQ(output_shapes_attr->list().shape(0).DebugString(),
587             expected_shape.DebugString());
588 }
589 
TEST_F(TransposerTest,UpdateNode)590 TEST_F(TransposerTest, UpdateNode) {
591 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
592   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
593 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
594   GrapplerItem item;
595   TransposeContext context;
596   TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
597   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
598       item, virtual_cluster_.get(), &context));
599   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
600 
601   DefaultLayoutSensitiveOpTransposer transposer;
602   auto* conv2d = context.graph_view->GetNode("conv2d");
603   ASSERT_NE(conv2d, nullptr);
604   TF_ASSERT_OK(transposer.UpdateNode(&context, conv2d));
605   TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
606 
607   auto* updated_conv2d = context.graph_view->GetNode("conv2d");
608   ASSERT_NE(updated_conv2d, nullptr);
609   VerifyDataFormatAttributeMatch(updated_conv2d, kDstFormat);
610 }
611 
MakeAttrValueListValueFromVector(absl::Span<const int> vec)612 AttrValue_ListValue MakeAttrValueListValueFromVector(
613     absl::Span<const int> vec) {
614   AttrValue_ListValue list_proto = AttrValue_ListValue();
615   for (const int i : vec) {
616     list_proto.add_i(i);
617   }
618   return list_proto;
619 }
620 
TEST_F(TransposerTest,UpdateStrides)621 TEST_F(TransposerTest, UpdateStrides) {
622 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
623   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
624 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
625   GrapplerItem item;
626   TransposeContext context;
627   TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
628   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
629       item, virtual_cluster_.get(), &context));
630   context.AssignDeviceAndDataFormats(kGPU, "ABCD", "ACBD");
631 
632   AttrValue_ListValue expected_original_strides =
633       MakeAttrValueListValueFromVector({1, 2, 4, 1});
634   AttrValue_ListValue expected_updated_strides =
635       MakeAttrValueListValueFromVector({1, 4, 2, 1});
636   auto* conv2d = context.graph_view->GetNode("conv2d");
637   ASSERT_NE(conv2d, nullptr);
638   const auto& strides_attr = conv2d->GetAttr("strides");
639   ASSERT_NE(strides_attr, nullptr);
640   EXPECT_EQ(strides_attr->list().DebugString(),
641             expected_original_strides.DebugString());
642   AttrValue data_format_attr;
643   data_format_attr.set_s("ABCD");
644   context.graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
645       conv2d, "data_format", data_format_attr);
646   TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
647 
648   DefaultLayoutSensitiveOpTransposer transposer;
649   TF_ASSERT_OK(transposer.UpdateNode(&context, conv2d));
650   TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
651 
652   auto* updated_conv2d = context.graph_view->GetNode("conv2d");
653   const auto& updated_strides_attr = updated_conv2d->GetAttr("strides");
654   ASSERT_NE(updated_strides_attr, nullptr);
655   EXPECT_EQ(updated_strides_attr->list().DebugString(),
656             expected_updated_strides.DebugString());
657 }
658 
TEST_F(TransposerTest,UpdateFaninEdgesTranspose)659 TEST_F(TransposerTest, UpdateFaninEdgesTranspose) {
660 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
661   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
662 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
663   GrapplerItem item;
664   TransposeContext context;
665   TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true));
666   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
667       item, virtual_cluster_.get(), &context));
668   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
669 
670   FusedBatchNormGradTransposer transposer;
671   auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
672   ASSERT_NE(fbng, nullptr);
673   const auto& fbng_output_shapes_attr = fbng->GetAttr("_output_shapes");
674   ASSERT_NE(fbng_output_shapes_attr, nullptr);
675   const TensorShapeProto& expected_shape = fbng_output_shapes_attr->shape();
676   TF_ASSERT_OK(
677       transposer.UpdateFaninEdgesWithOp(&context, {0, 1}, fbng, kOpTranspose));
678   TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
679 
680   // Verify output shape matches input shape.
681   auto* transpose_node1 = context.graph_view->GetNode(
682       "fused_batch_norm_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
683   ASSERT_NE(transpose_node1, nullptr);
684   VerifyShapeAttributeMatch(transpose_node1, expected_shape.DebugString());
685   auto* transpose_node2 = context.graph_view->GetNode(
686       "fused_batch_norm_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
687   ASSERT_NE(transpose_node2, nullptr);
688   VerifyShapeAttributeMatch(transpose_node2, expected_shape.DebugString());
689 
690   // Validate a const perm node is created.
691   auto* const_node1 = context.graph_view->GetNode(
692       "fused_batch_norm_grad-0-PermConstNHWCToNCHW-LayoutOptimizer");
693   ASSERT_NE(const_node1, nullptr);
694   auto* const_node2 = context.graph_view->GetNode(
695       "fused_batch_norm_grad-1-PermConstNHWCToNCHW-LayoutOptimizer");
696   ASSERT_NE(const_node2, nullptr);
697 
698   // Validate nodes connected correctly.
699   auto* y_backprop = context.graph_view->GetNode("y_backprop");
700   ASSERT_NE(y_backprop, nullptr);
701   ASSERT_EQ(transpose_node1->NumRegularFanins(), 2);
702   VerifyRegularFaninMatch(transpose_node1, 0, y_backprop->GetName(), 0);
703   VerifyRegularFaninMatch(transpose_node1, 1, const_node1->GetName(), 0);
704 
705   auto* x = context.graph_view->GetNode("x");
706   ASSERT_NE(x, nullptr);
707   ASSERT_EQ(transpose_node2->NumRegularFanins(), 2);
708   VerifyRegularFaninMatch(transpose_node2, 0, x->GetName(), 0);
709   VerifyRegularFaninMatch(transpose_node2, 1, const_node2->GetName(), 0);
710 
711   auto* updated_fbng = context.graph_view->GetNode("fused_batch_norm_grad");
712   ASSERT_NE(updated_fbng, nullptr);
713   ASSERT_EQ(updated_fbng->NumRegularFanins(), 5);
714   VerifyRegularFaninMatch(updated_fbng, 0, transpose_node1->GetName(), 0);
715   VerifyRegularFaninMatch(updated_fbng, 1, transpose_node2->GetName(), 0);
716 }
717 
TEST_F(TransposerTest,UpdateFanoutEdgesTranspose)718 TEST_F(TransposerTest, UpdateFanoutEdgesTranspose) {
719 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
720   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
721 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
722   GrapplerItem item;
723   TransposeContext context;
724   TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
725   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
726       item, virtual_cluster_.get(), &context));
727   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
728 
729   TransposerImpl transposer;
730   TensorShapeProto expected_original_shape =
731       MakeTensorShapeFromDimensions({32, 5, 3, 16});
732   TensorShapeProto expected_updated_shape =
733       MakeTensorShapeFromDimensions({32, 16, 5, 3});
734 
735   auto* conv2d = context.graph_view->GetNode("conv2d");
736   ASSERT_NE(conv2d, nullptr);
737   VerifyShapeAttributeMatch(conv2d, 0, expected_original_shape.DebugString());
738 
739   TF_ASSERT_OK(
740       transposer.UpdateFanoutEdgesWithOp(&context, {0}, conv2d, kOpTranspose));
741   TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
742 
743   auto* updated_conv2d = context.graph_view->GetNode("conv2d");
744   ASSERT_NE(updated_conv2d, nullptr);
745   VerifyShapeAttributeMatch(updated_conv2d, 0,
746                             expected_updated_shape.DebugString());
747 
748   // Verify output shape matches original shape.
749   auto* transpose_node = context.graph_view->GetNode(
750       "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
751   ASSERT_NE(transpose_node, nullptr);
752   VerifyShapeAttributeMatch(transpose_node, 0,
753                             expected_original_shape.DebugString());
754 
755   // Verify a const perm node is created for transpose node.
756   auto* const_node = context.graph_view->GetNode(
757       "conv2d-0-0-PermConstNCHWToNHWC-LayoutOptimizer");
758   ASSERT_NE(const_node, nullptr);
759 
760   // Verify nodes connected correctly.
761   ASSERT_EQ(transpose_node->NumRegularFanins(), 2);
762   VerifyRegularFaninMatch(transpose_node, 0, updated_conv2d->GetName(), 0);
763   VerifyRegularFaninMatch(transpose_node, 1, const_node->GetName(), 0);
764 
765   auto* output = context.graph_view->GetNode("output");
766   ASSERT_NE(output, nullptr);
767   ASSERT_EQ(output->NumRegularFanins(), 1);
768   VerifyRegularFaninMatch(output, 0, transpose_node->GetName(), 0);
769 }
770 
TEST_F(TransposerTest,DefaultLayoutSensitiveOpTransposerTestFusedBatchNorm)771 TEST_F(TransposerTest, DefaultLayoutSensitiveOpTransposerTestFusedBatchNorm) {
772 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
773   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
774 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
775   // Use FusedBatchNorm for default transposer test
776   GrapplerItem item;
777   TransposeContext context;
778   TF_ASSERT_OK(CreateSimpleFusedBatchNorm(&item.graph));
779   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
780       item, virtual_cluster_.get(), &context));
781   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
782 
783   DefaultLayoutSensitiveOpTransposer transposer;
784   auto* bn = context.graph_view->GetNode("bn");
785   TF_ASSERT_OK(transposer.TransposeNode(&context, bn));
786 
787   // The expected optimized graph contains 2 extra sets of Transpose nodes and
788   // has the FusedBatchNorm's data_format set to "NCHW".
789   auto* input_transpose_node =
790       context.graph_view->GetNode("bn-0-TransposeNHWCToNCHW-LayoutOptimizer");
791   ASSERT_NE(input_transpose_node, nullptr);
792   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
793   VerifyRegularFaninMatch(input_transpose_node, 0, "x", 0);
794 
795   auto* bn_node = context.graph_view->GetNode("bn");
796   ASSERT_NE(bn_node, nullptr);
797   ASSERT_EQ(bn_node->NumRegularFanins(), 5);
798   VerifyRegularFaninMatch(bn_node, 0, input_transpose_node->GetName(), 0);
799   VerifyRegularFaninMatch(bn_node, 1, "scale", 0);
800   VerifyRegularFaninMatch(bn_node, 2, "offset", 0);
801   VerifyRegularFaninMatch(bn_node, 3, "mean", 0);
802   VerifyRegularFaninMatch(bn_node, 4, "var", 0);
803   VerifyDataFormatAttributeMatch(bn_node, kDstFormat);
804 
805   auto* output_transpose_node =
806       context.graph_view->GetNode("bn-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
807   ASSERT_NE(output_transpose_node, nullptr);
808   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
809   VerifyRegularFaninMatch(output_transpose_node, 0, bn_node->GetName(), 0);
810 
811   auto* output_y = context.graph_view->GetNode("output_y");
812   ASSERT_NE(output_y, nullptr);
813   ASSERT_EQ(output_y->NumRegularFanins(), 1);
814   VerifyRegularFaninMatch(output_y, 0, output_transpose_node->GetName(), 0);
815 
816   auto* output_mean = context.graph_view->GetNode("output_mean");
817   ASSERT_NE(output_mean, nullptr);
818   ASSERT_EQ(output_mean->NumRegularFanins(), 1);
819   VerifyRegularFaninMatch(output_mean, 0, bn_node->GetName(), 1);
820 
821   auto* output_variance = context.graph_view->GetNode("output_variance");
822   ASSERT_NE(output_variance, nullptr);
823   ASSERT_EQ(output_variance->NumRegularFanins(), 1);
824   VerifyRegularFaninMatch(output_variance, 0, bn_node->GetName(), 2);
825 }
826 
TEST_F(TransposerTest,DefaultLayoutSensitiveOpTransposerTestConv2D)827 TEST_F(TransposerTest, DefaultLayoutSensitiveOpTransposerTestConv2D) {
828 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
829   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
830 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
831   // Use Conv2D for default transposer test
832   GrapplerItem item;
833   TransposeContext context;
834   TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
835   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
836       item, virtual_cluster_.get(), &context));
837   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
838 
839   DefaultLayoutSensitiveOpTransposer transposer;
840   auto* conv2d = context.graph_view->GetNode("conv2d");
841   ASSERT_NE(conv2d, nullptr);
842   TF_ASSERT_OK(transposer.TransposeNode(&context, conv2d));
843 
844   // The expected optimized graph contains 2 extra sets of Transpose nodes and
845   // has the Conv2D's data_format set to "NCHW".
846   auto* input_transpose_node = context.graph_view->GetNode(
847       "conv2d-0-TransposeNHWCToNCHW-LayoutOptimizer");
848   ASSERT_NE(input_transpose_node, nullptr);
849   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
850   VerifyRegularFaninMatch(input_transpose_node, 0, "input", 0);
851 
852   auto* conv2d_node = context.graph_view->GetNode("conv2d");
853   ASSERT_NE(conv2d_node, nullptr);
854   ASSERT_EQ(conv2d_node->NumRegularFanins(), 2);
855   VerifyRegularFaninMatch(conv2d_node, 0, input_transpose_node->GetName(), 0);
856   VerifyRegularFaninMatch(conv2d_node, 1, "filter", 0);
857   VerifyDataFormatAttributeMatch(conv2d_node, kDstFormat);
858   const auto* strides_attr = conv2d_node->GetAttr("strides");
859   ASSERT_NE(strides_attr, nullptr);
860   ASSERT_EQ(strides_attr->list().i_size(), 4);
861   EXPECT_EQ(strides_attr->list().i(0), 1);
862   EXPECT_EQ(strides_attr->list().i(1), 1);
863   EXPECT_EQ(strides_attr->list().i(2), kStride1);
864   EXPECT_EQ(strides_attr->list().i(3), kStride2);
865 
866   auto* output_transpose_node = context.graph_view->GetNode(
867       "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
868   ASSERT_NE(output_transpose_node, nullptr);
869   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
870   VerifyRegularFaninMatch(output_transpose_node, 0, conv2d_node->GetName(), 0);
871 
872   auto* output_node = context.graph_view->GetNode("output");
873   ASSERT_NE(output_node, nullptr);
874   ASSERT_EQ(output_node->NumRegularFanins(), 1);
875   VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
876 }
877 
TEST_F(TransposerTest,MaxPoolGradTransposerTest)878 TEST_F(TransposerTest, MaxPoolGradTransposerTest) {
879 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
880   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
881 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
882   for (bool use_grad_grad : {false, true}) {
883     GrapplerItem item;
884     TransposeContext context;
885     TF_ASSERT_OK(CreateSimpleMaxPoolGrad(&item.graph, use_grad_grad));
886     TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
887         item, virtual_cluster_.get(), &context));
888     context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
889 
890     MaxPoolGradTransposer transposer;
891     auto* maxpool_grad = context.graph_view->GetNode("maxpool_grad");
892     ASSERT_NE(maxpool_grad, nullptr);
893     TF_ASSERT_OK(transposer.TransposeNode(&context, maxpool_grad));
894 
895     auto* input_transpose_node1 = context.graph_view->GetNode(
896         "maxpool_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
897     ASSERT_NE(input_transpose_node1, nullptr);
898     ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
899     VerifyRegularFaninMatch(input_transpose_node1, 0, "orig_input", 0);
900 
901     auto* input_transpose_node2 = context.graph_view->GetNode(
902         "maxpool_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
903     ASSERT_NE(input_transpose_node2, nullptr);
904     ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
905     VerifyRegularFaninMatch(input_transpose_node2, 0, "orig_output", 0);
906 
907     auto* input_transpose_node3 = context.graph_view->GetNode(
908         "maxpool_grad-2-TransposeNHWCToNCHW-LayoutOptimizer");
909     ASSERT_NE(input_transpose_node3, nullptr);
910     ASSERT_EQ(input_transpose_node3->NumRegularFanins(), 2);
911     VerifyRegularFaninMatch(input_transpose_node3, 0, "grad", 0);
912 
913     auto* updated_maxpool_grad = context.graph_view->GetNode("maxpool_grad");
914     VerifyDataFormatAttributeMatch(updated_maxpool_grad, kDstFormat);
915     ASSERT_EQ(updated_maxpool_grad->NumRegularFanins(), 3);
916     VerifyRegularFaninMatch(updated_maxpool_grad, 0,
917                             input_transpose_node1->GetName(), 0);
918     VerifyRegularFaninMatch(updated_maxpool_grad, 1,
919                             input_transpose_node2->GetName(), 0);
920     VerifyRegularFaninMatch(updated_maxpool_grad, 2,
921                             input_transpose_node3->GetName(), 0);
922 
923     auto* output_transpose_node = context.graph_view->GetNode(
924         "maxpool_grad-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
925     ASSERT_NE(output_transpose_node, nullptr);
926     ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
927     VerifyRegularFaninMatch(output_transpose_node, 0,
928                             updated_maxpool_grad->GetName(), 0);
929   }
930 }
931 
TEST_F(TransposerTest,BiasAddGradTransposerTest)932 TEST_F(TransposerTest, BiasAddGradTransposerTest) {
933 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
934   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
935 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
936   GrapplerItem item;
937   TransposeContext context;
938   TF_ASSERT_OK(CreateSimpleBiasAddGrad(
939       &item.graph, {kBatchSize, kHeight, kWidth, kDepthIn}));
940   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
941       item, virtual_cluster_.get(), &context));
942   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
943 
944   BiasAddGradTransposer transposer;
945   auto* bag = context.graph_view->GetNode("bag");
946   ASSERT_NE(bag, nullptr);
947   TF_ASSERT_OK(transposer.TransposeNode(&context, bag));
948 
949   // The expected optimized graph contains 1 extra Transpose node and has the
950   // BiasAddGrad's data_format set to "NCHW".
951   auto* input_transpose_node =
952       context.graph_view->GetNode("bag-0-TransposeNHWCToNCHW-LayoutOptimizer");
953   ASSERT_NE(input_transpose_node, nullptr);
954   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
955   VerifyRegularFaninMatch(input_transpose_node, 0, "input", 0);
956 
957   auto* bag_node = context.graph_view->GetNode("bag");
958   ASSERT_NE(bag_node, nullptr);
959   VerifyDataFormatAttributeMatch(bag_node, kDstFormat);
960 
961   auto* output_transpose_node = context.graph_view->GetNode(
962       "bag-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
963   EXPECT_EQ(output_transpose_node, nullptr);
964 
965   auto* output_node = context.graph_view->GetNode("output");
966   ASSERT_NE(output_node, nullptr);
967   ASSERT_EQ(output_node->NumRegularFanins(), 1);
968   VerifyRegularFaninMatch(output_node, 0, bag_node->GetName(), 0);
969 }
970 
TEST_F(TransposerTest,BiasAddGradTransposerIncorrectInputTest)971 TEST_F(TransposerTest, BiasAddGradTransposerIncorrectInputTest) {
972   GrapplerItem item;
973   TransposeContext context;
974   TF_ASSERT_OK(
975       CreateSimpleBiasAddGrad(&item.graph, {kHeight, kWidth, kDepthIn}));
976   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
977       item, virtual_cluster_.get(), &context));
978   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
979 
980   BiasAddGradTransposer transposer;
981   auto* bag = context.graph_view->GetNode("bag");
982   ASSERT_NE(bag, nullptr);
983   TF_ASSERT_OK(transposer.TransposeNode(&context, bag));
984 
985   // Optimization should not occur because of incorrect input dimensions.
986   auto* input_transpose_node =
987       context.graph_view->GetNode("bag-0-TransposeNHWCToNCHW-LayoutOptimizer");
988   EXPECT_EQ(input_transpose_node, nullptr);
989 
990   auto* bag_node = context.graph_view->GetNode("bag");
991   ASSERT_NE(bag_node, nullptr);
992   VerifyDataFormatAttributeMatch(bag_node, kSrcFormat);
993 
994   auto* output_transpose_node = context.graph_view->GetNode(
995       "bag-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
996   EXPECT_EQ(output_transpose_node, nullptr);
997 
998   auto* output_node = context.graph_view->GetNode("output");
999   ASSERT_NE(output_node, nullptr);
1000   ASSERT_EQ(output_node->NumRegularFanins(), 1);
1001   VerifyRegularFaninMatch(output_node, 0, bag_node->GetName(), 0);
1002 }
1003 
TEST_F(TransposerTest,Conv2DBackpropFilterTransposerTest)1004 TEST_F(TransposerTest, Conv2DBackpropFilterTransposerTest) {
1005 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1006   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1007 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1008   GrapplerItem item;
1009   TransposeContext context;
1010   TF_ASSERT_OK(CreateSimpleConv2DBackpropFilter(&item.graph));
1011   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1012       item, virtual_cluster_.get(), &context));
1013   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1014 
1015   Conv2DBackpropFilterTransposer transposer;
1016   auto* conv2d_bf = context.graph_view->GetNode("conv2d_backprop_filter");
1017   ASSERT_NE(conv2d_bf, nullptr);
1018   TF_ASSERT_OK(transposer.TransposeNode(&context, conv2d_bf));
1019 
1020   // The expected optimized graph contains 2 extra sets of Transpose nodes and
1021   // has the Conv2DBackpropFilter's data_format set to "NCHW".
1022   auto* input_transpose_node1 = context.graph_view->GetNode(
1023       "conv2d_backprop_filter-0-TransposeNHWCToNCHW-LayoutOptimizer");
1024   ASSERT_NE(input_transpose_node1, nullptr);
1025   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1026   VerifyRegularFaninMatch(input_transpose_node1, 0, "input", 0);
1027 
1028   auto* input_transpose_node_filter_sizes = context.graph_view->GetNode(
1029       "conv2d_backprop_filter-1-TransposeNHWCToNCHW-LayoutOptimizer");
1030   EXPECT_EQ(input_transpose_node_filter_sizes, nullptr);
1031 
1032   auto* input_transpose_node2 = context.graph_view->GetNode(
1033       "conv2d_backprop_filter-2-TransposeNHWCToNCHW-LayoutOptimizer");
1034   ASSERT_NE(input_transpose_node2, nullptr);
1035   ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1036   VerifyRegularFaninMatch(input_transpose_node2, 0, "out_backprop", 0);
1037 
1038   auto* conv2d_bf_node = context.graph_view->GetNode("conv2d_backprop_filter");
1039   ASSERT_NE(conv2d_bf_node, nullptr);
1040   ASSERT_EQ(conv2d_bf_node->NumRegularFanins(), 3);
1041   VerifyRegularFaninMatch(conv2d_bf_node, 0, input_transpose_node1->GetName(),
1042                           0);
1043   VerifyRegularFaninMatch(conv2d_bf_node, 2, input_transpose_node2->GetName(),
1044                           0);
1045   VerifyDataFormatAttributeMatch(conv2d_bf_node, kDstFormat);
1046 
1047   auto* output_transpose_node = context.graph_view->GetNode(
1048       "conv2d_backprop_filter-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1049   EXPECT_EQ(output_transpose_node, nullptr);
1050 
1051   auto* output_node = context.graph_view->GetNode("output");
1052   ASSERT_NE(output_node, nullptr);
1053   ASSERT_EQ(output_node->NumRegularFanins(), 1);
1054   VerifyRegularFaninMatch(output_node, 0, conv2d_bf_node->GetName(), 0);
1055 }
1056 
TEST_F(TransposerTest,NodeAttributes)1057 TEST_F(TransposerTest, NodeAttributes) {
1058 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1059   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1060 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1061   GrapplerItem item;
1062   TransposeContext context;
1063   TF_ASSERT_OK(
1064       CreateSimpleConv2DBackpropFilter(&item.graph, DT_FLOAT, "EXPLICIT"));
1065   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1066       item, virtual_cluster_.get(), &context));
1067   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1068 
1069   Conv2DBackpropFilterTransposer transposer;
1070   auto* conv2d_bf = context.graph_view->GetNode("conv2d_backprop_filter");
1071   ASSERT_NE(conv2d_bf, nullptr);
1072   TF_ASSERT_OK(transposer.TransposeNode(&context, conv2d_bf));
1073 
1074   auto* conv2d_bf_node = context.graph_view->GetNode("conv2d_backprop_filter");
1075   ASSERT_NE(conv2d_bf_node, nullptr);
1076   ASSERT_EQ(conv2d_bf_node->NumRegularFanins(), 3);
1077   VerifyDataFormatAttributeMatch(conv2d_bf_node, kDstFormat);
1078   auto* dilations_attr = conv2d_bf_node->GetAttr("dilations");
1079   ASSERT_NE(dilations_attr, nullptr);
1080   ASSERT_EQ(dilations_attr->list().i_size(), 4);
1081   EXPECT_EQ(dilations_attr->list().i(0), 1);
1082   EXPECT_EQ(dilations_attr->list().i(1), 1);
1083   EXPECT_EQ(dilations_attr->list().i(2), kDilation);
1084   EXPECT_EQ(dilations_attr->list().i(3), kDilation);
1085   auto* explicit_paddings_attr = conv2d_bf_node->GetAttr("explicit_paddings");
1086   ASSERT_NE(explicit_paddings_attr, nullptr);
1087   ASSERT_EQ(explicit_paddings_attr->list().i_size(), 8);
1088   EXPECT_EQ(explicit_paddings_attr->list().i(0), 0);
1089   EXPECT_EQ(explicit_paddings_attr->list().i(1), 0);
1090   EXPECT_EQ(explicit_paddings_attr->list().i(2), 0);
1091   EXPECT_EQ(explicit_paddings_attr->list().i(3), 0);
1092   EXPECT_EQ(explicit_paddings_attr->list().i(4), kPaddingTop);
1093   EXPECT_EQ(explicit_paddings_attr->list().i(5), kPaddingBottom);
1094   EXPECT_EQ(explicit_paddings_attr->list().i(6), kPaddingLeft);
1095   EXPECT_EQ(explicit_paddings_attr->list().i(7), kPaddingRight);
1096 }
1097 
TEST_F(TransposerTest,Conv2DBackpropInputTransposerTest)1098 TEST_F(TransposerTest, Conv2DBackpropInputTransposerTest) {
1099 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1100   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1101 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1102   GrapplerItem item;
1103   TransposeContext context;
1104   TF_ASSERT_OK(CreateSimpleConv2DBackpropInput(&item.graph));
1105   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1106       item, virtual_cluster_.get(), &context));
1107   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1108 
1109   Conv2DBackpropInputTransposer transposer;
1110   auto* conv2d_i = context.graph_view->GetNode("conv2d_backprop_input");
1111   ASSERT_NE(conv2d_i, nullptr);
1112   TF_ASSERT_OK(transposer.TransposeNode(&context, conv2d_i));
1113 
1114   // The expected optimized graph contains 1 extra set of Transpose nodes,
1115   // 1 DataFormatVecPermute node and has the Conv2DBackpropInput's data_format
1116   // set to "NCHW".
1117   auto* input_vec_permute_node = context.graph_view->GetNode(
1118       "conv2d_backprop_input-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
1119   ASSERT_NE(input_vec_permute_node, nullptr);
1120   ASSERT_EQ(input_vec_permute_node->NumRegularFanins(), 1);
1121   const auto* src_format_attr = input_vec_permute_node->GetAttr(kAttrSrcFormat);
1122   ASSERT_NE(src_format_attr, nullptr);
1123   EXPECT_EQ(src_format_attr->s(), kSrcFormat);
1124   const auto* dst_format_attr = input_vec_permute_node->GetAttr(kAttrDstFormat);
1125   ASSERT_NE(dst_format_attr, nullptr);
1126   EXPECT_EQ(dst_format_attr->s(), kDstFormat);
1127 
1128   auto* input_transpose_node = context.graph_view->GetNode(
1129       "conv2d_backprop_input-2-TransposeNHWCToNCHW-LayoutOptimizer");
1130   ASSERT_NE(input_transpose_node, nullptr);
1131   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
1132   VerifyRegularFaninMatch(input_transpose_node, 0, "out_backprop", 0);
1133 
1134   auto* conv2d_i_node = context.graph_view->GetNode("conv2d_backprop_input");
1135   ASSERT_NE(conv2d_i_node, nullptr);
1136   ASSERT_EQ(conv2d_i_node->NumRegularFanins(), 3);
1137   VerifyRegularFaninMatch(conv2d_i_node, 0, input_vec_permute_node->GetName(),
1138                           0);
1139   VerifyRegularFaninMatch(conv2d_i_node, 1, "filter", 0);
1140   VerifyRegularFaninMatch(conv2d_i_node, 2, input_transpose_node->GetName(), 0);
1141   VerifyDataFormatAttributeMatch(conv2d_i_node, kDstFormat);
1142 
1143   auto* output_transpose_node = context.graph_view->GetNode(
1144       "conv2d_backprop_input-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1145   ASSERT_NE(output_transpose_node, nullptr);
1146   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1147   VerifyRegularFaninMatch(output_transpose_node, 0, conv2d_i_node->GetName(),
1148                           0);
1149 
1150   auto* output_node = context.graph_view->GetNode("output");
1151   ASSERT_NE(output_node, nullptr);
1152   ASSERT_EQ(output_node->NumRegularFanins(), 1);
1153   VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
1154 }
1155 
TEST_F(TransposerTest,FusedBatchNormGradTransposerIsTrainingTest)1156 TEST_F(TransposerTest, FusedBatchNormGradTransposerIsTrainingTest) {
1157 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1158   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1159 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1160   GrapplerItem item;
1161   TransposeContext context;
1162   TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true));
1163   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1164       item, virtual_cluster_.get(), &context));
1165   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1166 
1167   FusedBatchNormGradTransposer transposer;
1168   auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
1169   ASSERT_NE(fbng, nullptr);
1170   TF_ASSERT_OK(transposer.TransposeNode(&context, fbng));
1171 
1172   // The expected optimized graph contains 3 extra sets of Transpose nodes and
1173   // has the FusedBatchNormGrad's data_format set to "NCHW".
1174   auto* input_transpose_node1 = context.graph_view->GetNode(
1175       "fused_batch_norm_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
1176   ASSERT_NE(input_transpose_node1, nullptr);
1177   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1178   VerifyRegularFaninMatch(input_transpose_node1, 0, "y_backprop", 0);
1179 
1180   auto* input_transpose_node2 = context.graph_view->GetNode(
1181       "fused_batch_norm_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
1182   ASSERT_NE(input_transpose_node2, nullptr);
1183   ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1184   VerifyRegularFaninMatch(input_transpose_node2, 0, "x", 0);
1185 
1186   auto* fbng_node = context.graph_view->GetNode("fused_batch_norm_grad");
1187   ASSERT_NE(fbng_node, nullptr);
1188   ASSERT_EQ(fbng_node->NumRegularFanins(), 5);
1189   VerifyRegularFaninMatch(fbng_node, 0, input_transpose_node1->GetName(), 0);
1190   VerifyRegularFaninMatch(fbng_node, 1, input_transpose_node2->GetName(), 0);
1191   VerifyRegularFaninMatch(fbng_node, 2, "scale", 0);
1192   VerifyRegularFaninMatch(fbng_node, 3, "reserve_space_1", 0);
1193   VerifyRegularFaninMatch(fbng_node, 4, "reserve_space_2", 0);
1194   VerifyDataFormatAttributeMatch(fbng_node, kDstFormat);
1195 
1196   auto* output_transpose_node = context.graph_view->GetNode(
1197       "fused_batch_norm_grad-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1198   ASSERT_NE(output_transpose_node, nullptr);
1199   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1200   VerifyRegularFaninMatch(output_transpose_node, 0, fbng_node->GetName(), 0);
1201 
1202   auto* x_backprop = context.graph_view->GetNode("x_backprop");
1203   ASSERT_NE(x_backprop, nullptr);
1204   ASSERT_EQ(x_backprop->NumRegularFanins(), 1);
1205   VerifyRegularFaninMatch(x_backprop, 0, output_transpose_node->GetName(), 0);
1206 
1207   auto* scale_backprop = context.graph_view->GetNode("scale_backprop");
1208   ASSERT_NE(scale_backprop, nullptr);
1209   ASSERT_EQ(scale_backprop->NumRegularFanins(), 1);
1210   VerifyRegularFaninMatch(scale_backprop, 0, fbng_node->GetName(), 1);
1211 
1212   auto* offset_backprop = context.graph_view->GetNode("offset_backprop");
1213   ASSERT_NE(offset_backprop, nullptr);
1214   ASSERT_EQ(offset_backprop->NumRegularFanins(), 1);
1215   VerifyRegularFaninMatch(offset_backprop, 0, fbng_node->GetName(), 2);
1216 
1217   auto* reserve_space_3 = context.graph_view->GetNode("reserve_space_3");
1218   ASSERT_NE(reserve_space_3, nullptr);
1219   ASSERT_EQ(reserve_space_3->NumRegularFanins(), 1);
1220   VerifyRegularFaninMatch(reserve_space_3, 0, fbng_node->GetName(), 3);
1221 
1222   auto* reserve_space_4 = context.graph_view->GetNode("reserve_space_4");
1223   ASSERT_NE(reserve_space_4, nullptr);
1224   ASSERT_EQ(reserve_space_4->NumRegularFanins(), 1);
1225   VerifyRegularFaninMatch(reserve_space_4, 0, fbng_node->GetName(), 4);
1226 }
1227 
TEST_F(TransposerTest,FusedBatchNormGradTransposerNotTrainingTest)1228 TEST_F(TransposerTest, FusedBatchNormGradTransposerNotTrainingTest) {
1229   GrapplerItem item;
1230   TransposeContext context;
1231   TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, false));
1232   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1233       item, virtual_cluster_.get(), &context));
1234   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1235 
1236   FusedBatchNormGradTransposer transposer;
1237   auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
1238   ASSERT_NE(fbng, nullptr);
1239   TF_ASSERT_OK(transposer.TransposeNode(&context, fbng));
1240 
1241   // Optimization should not occur because FusedBatchNormGrad is not set to
1242   // training.
1243   auto* input_transpose_node1 = context.graph_view->GetNode(
1244       "fused_batch_norm_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
1245   EXPECT_EQ(input_transpose_node1, nullptr);
1246 
1247   auto* input_transpose_node2 = context.graph_view->GetNode(
1248       "fused_batch_norm_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
1249   EXPECT_EQ(input_transpose_node2, nullptr);
1250 
1251   auto* fbng_node = context.graph_view->GetNode("fused_batch_norm_grad");
1252   ASSERT_NE(fbng_node, nullptr);
1253   ASSERT_EQ(fbng_node->NumRegularFanins(), 5);
1254   VerifyRegularFaninMatch(fbng_node, 0, "y_backprop", 0);
1255   VerifyRegularFaninMatch(fbng_node, 1, "x", 0);
1256   VerifyRegularFaninMatch(fbng_node, 2, "scale", 0);
1257   VerifyRegularFaninMatch(fbng_node, 3, "reserve_space_1", 0);
1258   VerifyRegularFaninMatch(fbng_node, 4, "reserve_space_2", 0);
1259   VerifyDataFormatAttributeMatch(fbng_node, kSrcFormat);
1260 
1261   auto* output_transpose_node = context.graph_view->GetNode(
1262       "fused_batch_norm_grad-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1263   EXPECT_EQ(output_transpose_node, nullptr);
1264 
1265   auto* x_backprop = context.graph_view->GetNode("x_backprop");
1266   ASSERT_NE(x_backprop, nullptr);
1267   ASSERT_EQ(x_backprop->NumRegularFanins(), 1);
1268   VerifyRegularFaninMatch(x_backprop, 0, fbng_node->GetName(), 0);
1269 
1270   auto* scale_backprop = context.graph_view->GetNode("scale_backprop");
1271   ASSERT_NE(scale_backprop, nullptr);
1272   ASSERT_EQ(scale_backprop->NumRegularFanins(), 1);
1273   VerifyRegularFaninMatch(scale_backprop, 0, fbng_node->GetName(), 1);
1274 
1275   auto* offset_backprop = context.graph_view->GetNode("offset_backprop");
1276   ASSERT_NE(offset_backprop, nullptr);
1277   ASSERT_EQ(offset_backprop->NumRegularFanins(), 1);
1278   VerifyRegularFaninMatch(offset_backprop, 0, fbng_node->GetName(), 2);
1279 
1280   auto* reserve_space_3 = context.graph_view->GetNode("reserve_space_3");
1281   ASSERT_NE(reserve_space_3, nullptr);
1282   ASSERT_EQ(reserve_space_3->NumRegularFanins(), 1);
1283   VerifyRegularFaninMatch(reserve_space_3, 0, fbng_node->GetName(), 3);
1284 
1285   auto* reserve_space_4 = context.graph_view->GetNode("reserve_space_4");
1286   ASSERT_NE(reserve_space_4, nullptr);
1287   ASSERT_EQ(reserve_space_4->NumRegularFanins(), 1);
1288   VerifyRegularFaninMatch(reserve_space_4, 0, fbng_node->GetName(), 4);
1289 }
1290 
TEST_F(TransposerTest,DefaultLayoutAgnosticOpTransposerIdentityTest)1291 TEST_F(TransposerTest, DefaultLayoutAgnosticOpTransposerIdentityTest) {
1292 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1293   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1294 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1295   GrapplerItem item;
1296   Scope scope = Scope::NewRootScope();
1297   auto conv2d = SimpleConv2D(&scope);
1298   auto identity = ops::Identity(
1299       scope.WithOpName("identity").WithDevice("/device:GPU:0"), conv2d);
1300   auto output = ops::Identity(scope.WithOpName("output"), identity);
1301   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1302   TransposeContext context;
1303   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1304       item, virtual_cluster_.get(), &context));
1305   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1306 
1307   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1308   auto* c2d = context.graph_view->GetNode("conv2d");
1309   ASSERT_NE(c2d, nullptr);
1310   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1311 
1312   DefaultLayoutAgnosticOpTransposer transposer;
1313   auto* i = context.graph_view->GetNode("identity");
1314   ASSERT_NE(i, nullptr);
1315   TF_ASSERT_OK(transposer.TransposeNode(&context, i));
1316 
1317   // The expected optimized graph contains 2 extra sets of Transpose nodes.
1318   auto* input_transpose_node = context.graph_view->GetNode(
1319       "identity-0-TransposeNHWCToNCHW-LayoutOptimizer");
1320   ASSERT_NE(input_transpose_node, nullptr);
1321   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
1322   VerifyRegularFaninMatch(input_transpose_node, 0,
1323                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1324 
1325   auto* i_node = context.graph_view->GetNode("identity");
1326   ASSERT_NE(i_node, nullptr);
1327   ASSERT_EQ(i_node->NumRegularFanins(), 1);
1328   VerifyRegularFaninMatch(i_node, 0, input_transpose_node->GetName(), 0);
1329 
1330   auto* output_transpose_node = context.graph_view->GetNode(
1331       "identity-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1332   ASSERT_NE(output_transpose_node, nullptr);
1333   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1334   VerifyRegularFaninMatch(output_transpose_node, 0, i_node->GetName(), 0);
1335 
1336   auto* output_node = context.graph_view->GetNode("output");
1337   ASSERT_NE(output_node, nullptr);
1338   ASSERT_EQ(output_node->NumRegularFanins(), 1);
1339   VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
1340 }
1341 
TEST_F(TransposerTest,DefaultLayoutAgnosticOpTransposerIdentityBadInputTest)1342 TEST_F(TransposerTest, DefaultLayoutAgnosticOpTransposerIdentityBadInputTest) {
1343 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1344   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1345 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1346   GrapplerItem item;
1347   Scope scope = Scope::NewRootScope();
1348   auto conv2d = SimpleConv2D(&scope);
1349   auto sum = ops::Sum(scope.WithOpName("sum"), conv2d, {0, 1});
1350   auto identity = ops::Identity(
1351       scope.WithOpName("identity").WithDevice("/device:GPU:0"), sum);
1352   auto output = ops::Identity(scope.WithOpName("output"), identity);
1353   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1354   TransposeContext context;
1355   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1356       item, virtual_cluster_.get(), &context));
1357   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1358 
1359   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1360   auto* c2d = context.graph_view->GetNode("conv2d");
1361   ASSERT_NE(c2d, nullptr);
1362   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1363 
1364   DefaultLayoutAgnosticOpTransposer transposer;
1365   auto* i = context.graph_view->GetNode("identity");
1366   ASSERT_NE(i, nullptr);
1367   TF_ASSERT_OK(transposer.TransposeNode(&context, i));
1368 
1369   // Optimization should not occur because input is not the right shape (needs
1370   // to be 4D).
1371   auto* input_transpose_node = context.graph_view->GetNode(
1372       "identity-0-TransposeNHWCToNCHW-LayoutOptimizer");
1373   EXPECT_EQ(input_transpose_node, nullptr);
1374 
1375   auto* i_node = context.graph_view->GetNode("identity");
1376   ASSERT_NE(i_node, nullptr);
1377   ASSERT_EQ(i_node->NumRegularFanins(), 1);
1378   VerifyRegularFaninMatch(i_node, 0, "sum", 0);
1379 
1380   auto* output_transpose_node = context.graph_view->GetNode(
1381       "identity-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1382   EXPECT_EQ(output_transpose_node, nullptr);
1383 
1384   auto* output_node = context.graph_view->GetNode("output");
1385   ASSERT_NE(output_node, nullptr);
1386   ASSERT_EQ(output_node->NumRegularFanins(), 1);
1387   VerifyRegularFaninMatch(output_node, 0, i_node->GetName(), 0);
1388 }
1389 
TEST_F(TransposerTest,AddNTransposerTest)1390 TEST_F(TransposerTest, AddNTransposerTest) {
1391 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1392   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1393 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1394   GrapplerItem item;
1395   TF_ASSERT_OK(CreateSimpleAddN(&item.graph));
1396   TransposeContext context;
1397   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1398       item, virtual_cluster_.get(), &context));
1399   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1400 
1401   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1402   auto* conv2d = context.graph_view->GetNode("conv2d");
1403   ASSERT_NE(conv2d, nullptr);
1404   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, conv2d));
1405 
1406   AddNTransposer addn_transposer;
1407   auto* an = context.graph_view->GetNode("add_n");
1408   ASSERT_NE(an, nullptr);
1409   TF_ASSERT_OK(addn_transposer.TransposeNode(&context, an));
1410 
1411   // The expected optimized graph contains 5 extra sets of Transpose nodes.
1412   auto* input_transpose_node1 = context.graph_view->GetNode(
1413       "add_n-0-TransposeNHWCToNCHW-LayoutOptimizer");
1414   ASSERT_NE(input_transpose_node1, nullptr);
1415   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1416   VerifyRegularFaninMatch(input_transpose_node1, 0, "a", 0);
1417 
1418   auto* input_transpose_node2 = context.graph_view->GetNode(
1419       "add_n-1-TransposeNHWCToNCHW-LayoutOptimizer");
1420   ASSERT_NE(input_transpose_node2, nullptr);
1421   ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1422   VerifyRegularFaninMatch(input_transpose_node2, 0, "b", 0);
1423 
1424   auto* input_transpose_node3 = context.graph_view->GetNode(
1425       "add_n-2-TransposeNHWCToNCHW-LayoutOptimizer");
1426   ASSERT_NE(input_transpose_node3, nullptr);
1427   ASSERT_EQ(input_transpose_node3->NumRegularFanins(), 2);
1428   VerifyRegularFaninMatch(input_transpose_node3, 0, "c", 0);
1429 
1430   auto* input_transpose_node4 = context.graph_view->GetNode(
1431       "add_n-3-TransposeNHWCToNCHW-LayoutOptimizer");
1432   ASSERT_NE(input_transpose_node4, nullptr);
1433   ASSERT_EQ(input_transpose_node4->NumRegularFanins(), 2);
1434   VerifyRegularFaninMatch(input_transpose_node4, 0,
1435                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1436 
1437   auto* an_node = context.graph_view->GetNode("add_n");
1438   ASSERT_NE(an_node, nullptr);
1439   ASSERT_EQ(an_node->NumRegularFanins(), 4);
1440   VerifyRegularFaninMatch(an_node, 0, input_transpose_node1->GetName(), 0);
1441   VerifyRegularFaninMatch(an_node, 1, input_transpose_node2->GetName(), 0);
1442   VerifyRegularFaninMatch(an_node, 2, input_transpose_node3->GetName(), 0);
1443   VerifyRegularFaninMatch(an_node, 3, input_transpose_node4->GetName(), 0);
1444 
1445   auto* output_transpose_node = context.graph_view->GetNode(
1446       "add_n-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1447   ASSERT_NE(output_transpose_node, nullptr);
1448   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1449   VerifyRegularFaninMatch(output_transpose_node, 0, an_node->GetName(), 0);
1450 
1451   auto* output_node = context.graph_view->GetNode("output");
1452   ASSERT_NE(output_node, nullptr);
1453   ASSERT_EQ(output_node->NumRegularFanins(), 1);
1454   VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
1455 }
1456 
TEST_F(TransposerTest,AddNTransposerNotAfterTransformTest)1457 TEST_F(TransposerTest, AddNTransposerNotAfterTransformTest) {
1458   GrapplerItem item;
1459   TF_ASSERT_OK(CreateSimpleAddN(&item.graph));
1460   TransposeContext context;
1461   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1462       item, virtual_cluster_.get(), &context));
1463   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1464 
1465   AddNTransposer addn_transposer;
1466   auto* an = context.graph_view->GetNode("add_n");
1467   ASSERT_NE(an, nullptr);
1468   TF_ASSERT_OK(addn_transposer.TransposeNode(&context, an));
1469 
1470   // Optimization should not occur because AddN does not follow a transform.
1471   auto* input_transpose_node1 = context.graph_view->GetNode(
1472       "add_n-0-TransposeNHWCToNCHW-LayoutOptimizer");
1473   EXPECT_EQ(input_transpose_node1, nullptr);
1474 
1475   auto* input_transpose_node2 = context.graph_view->GetNode(
1476       "add_n-1-TransposeNHWCToNCHW-LayoutOptimizer");
1477   EXPECT_EQ(input_transpose_node2, nullptr);
1478 
1479   auto* input_transpose_node3 = context.graph_view->GetNode(
1480       "add_n-2-TransposeNHWCToNCHW-LayoutOptimizer");
1481   EXPECT_EQ(input_transpose_node3, nullptr);
1482 
1483   auto* input_transpose_node4 = context.graph_view->GetNode(
1484       "add_n-3-TransposeNHWCToNCHW-LayoutOptimizer");
1485   EXPECT_EQ(input_transpose_node4, nullptr);
1486 
1487   auto* an_node = context.graph_view->GetNode("add_n");
1488   ASSERT_NE(an_node, nullptr);
1489   ASSERT_EQ(an_node->NumRegularFanins(), 4);
1490   VerifyRegularFaninMatch(an_node, 0, "a", 0);
1491   VerifyRegularFaninMatch(an_node, 1, "b", 0);
1492   VerifyRegularFaninMatch(an_node, 2, "c", 0);
1493   VerifyRegularFaninMatch(an_node, 3, "conv2d", 0);
1494 
1495   auto* output_transpose_node = context.graph_view->GetNode(
1496       "add_n-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1497   EXPECT_EQ(output_transpose_node, nullptr);
1498 
1499   auto* output_node = context.graph_view->GetNode("output");
1500   ASSERT_NE(output_node, nullptr);
1501   ASSERT_EQ(output_node->NumRegularFanins(), 1);
1502   VerifyRegularFaninMatch(output_node, 0, an_node->GetName(), 0);
1503 }
1504 
TEST_F(TransposerTest,IdentityNTransposerTest)1505 TEST_F(TransposerTest, IdentityNTransposerTest) {
1506 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1507   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1508 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1509   GrapplerItem item;
1510   TF_ASSERT_OK(CreateSimpleIdentityN(&item.graph));
1511   TransposeContext context;
1512   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1513       item, virtual_cluster_.get(), &context));
1514   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1515 
1516   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1517   auto* conv2d_1 = context.graph_view->GetNode("conv2d_1");
1518   ASSERT_NE(conv2d_1, nullptr);
1519   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, conv2d_1));
1520   auto* conv2d_2 = context.graph_view->GetNode("conv2d_2");
1521   ASSERT_NE(conv2d_2, nullptr);
1522   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, conv2d_2));
1523 
1524   IdentityNTransposer identityn_transposer;
1525   auto* in = context.graph_view->GetNode("identity_n");
1526   ASSERT_NE(in, nullptr);
1527   TF_ASSERT_OK(identityn_transposer.TransposeNode(&context, in));
1528 
1529   // The expected optimized graph contains 4 extra sets of Transpose nodes.
1530   auto* input_transpose_node1 = context.graph_view->GetNode(
1531       "identity_n-0-TransposeNHWCToNCHW-LayoutOptimizer");
1532   EXPECT_EQ(input_transpose_node1, nullptr);
1533 
1534   auto* input_transpose_node2 = context.graph_view->GetNode(
1535       "identity_n-1-TransposeNHWCToNCHW-LayoutOptimizer");
1536   ASSERT_NE(input_transpose_node2, nullptr);
1537   ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1538   VerifyRegularFaninMatch(input_transpose_node2, 0,
1539                           "conv2d_2-0-0-TransposeNCHWToNHWC-LayoutOptimizer",
1540                           0);
1541 
1542   auto* input_transpose_node3 = context.graph_view->GetNode(
1543       "identity_n-2-TransposeNHWCToNCHW-LayoutOptimizer");
1544   EXPECT_EQ(input_transpose_node3, nullptr);
1545 
1546   auto* input_transpose_node4 = context.graph_view->GetNode(
1547       "identity_n-3-TransposeNHWCToNCHW-LayoutOptimizer");
1548   EXPECT_EQ(input_transpose_node4, nullptr);
1549 
1550   auto* in_node = context.graph_view->GetNode("identity_n");
1551   ASSERT_NE(in_node, nullptr);
1552   ASSERT_EQ(in_node->NumRegularFanins(), 4);
1553   VerifyRegularFaninMatch(in_node, 0, "conv2d_1", 0);
1554   VerifyRegularFaninMatch(in_node, 1, input_transpose_node2->GetName(), 0);
1555   VerifyRegularFaninMatch(in_node, 2, "a", 0);
1556   VerifyRegularFaninMatch(in_node, 3, "b", 0);
1557 
1558   auto* output_transpose_node1 = context.graph_view->GetNode(
1559       "identity_n-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1560   EXPECT_EQ(output_transpose_node1, nullptr);
1561 
1562   auto* output_transpose_node2 = context.graph_view->GetNode(
1563       "identity_n-1-0-TransposeNCHWToNHWC-LayoutOptimizer");
1564   ASSERT_NE(output_transpose_node2, nullptr);
1565   ASSERT_EQ(output_transpose_node2->NumRegularFanins(), 2);
1566   VerifyRegularFaninMatch(output_transpose_node2, 0, in_node->GetName(), 1);
1567 
1568   auto* output_transpose_node3 = context.graph_view->GetNode(
1569       "identity_n-2-0-TransposeNCHWToNHWC-LayoutOptimizer");
1570   EXPECT_EQ(output_transpose_node3, nullptr);
1571 
1572   auto* output_transpose_node4 = context.graph_view->GetNode(
1573       "identity_n-3-0-TransposeNCHWToNHWC-LayoutOptimizer");
1574   EXPECT_EQ(output_transpose_node4, nullptr);
1575 
1576   auto* conv2d_1_output_node = context.graph_view->GetNode("conv2d_1_output");
1577   ASSERT_NE(conv2d_1_output_node, nullptr);
1578   ASSERT_EQ(conv2d_1_output_node->NumRegularFanins(), 1);
1579   VerifyRegularFaninMatch(conv2d_1_output_node, 0, in_node->GetName(), 0);
1580 
1581   auto* conv2d_2_output_node = context.graph_view->GetNode("conv2d_2_output");
1582   ASSERT_NE(conv2d_2_output_node, nullptr);
1583   ASSERT_EQ(conv2d_2_output_node->NumRegularFanins(), 1);
1584   VerifyRegularFaninMatch(conv2d_2_output_node, 0,
1585                           output_transpose_node2->GetName(), 0);
1586 
1587   auto* a_output_node = context.graph_view->GetNode("a_output");
1588   ASSERT_NE(a_output_node, nullptr);
1589   ASSERT_EQ(a_output_node->NumRegularFanins(), 1);
1590   VerifyRegularFaninMatch(a_output_node, 0, in_node->GetName(), 2);
1591 
1592   auto* b_output_node = context.graph_view->GetNode("b_output");
1593   ASSERT_NE(b_output_node, nullptr);
1594   ASSERT_EQ(b_output_node->NumRegularFanins(), 1);
1595   VerifyRegularFaninMatch(b_output_node, 0, in_node->GetName(), 3);
1596 }
1597 
TEST_F(TransposerTest,MergeTransposerTestMergeBothInputsConvertible)1598 TEST_F(TransposerTest, MergeTransposerTestMergeBothInputsConvertible) {
1599 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1600   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1601 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1602   GrapplerItem item;
1603   Scope scope = Scope::NewRootScope();
1604   auto conv2d = SimpleConv2D(&scope);
1605   Output i1 = ops::Identity(scope.WithOpName("i1"), conv2d);
1606   auto merge = ops::Merge(scope.WithOpName("merge").WithDevice("/device:GPU:0"),
1607                           {conv2d, i1});
1608   auto i2 = ops::Identity(scope.WithOpName("i2"), merge.output);
1609   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1610   TransposeContext context;
1611   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1612       item, virtual_cluster_.get(), &context));
1613   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1614 
1615   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1616   auto* c2d = context.graph_view->GetNode("conv2d");
1617   ASSERT_NE(c2d, nullptr);
1618   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1619 
1620   MergeTransposer merge_transposer;
1621   auto* m = context.graph_view->GetNode("merge");
1622   ASSERT_NE(m, nullptr);
1623   TF_ASSERT_OK(merge_transposer.TransposeNode(&context, m));
1624 
1625   // The expected optimized graph contains 3 extra sets of Transpose nodes.
1626   auto* input_transpose_node1 = context.graph_view->GetNode(
1627       "merge-0-TransposeNHWCToNCHW-LayoutOptimizer");
1628   ASSERT_NE(input_transpose_node1, nullptr);
1629   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1630   VerifyRegularFaninMatch(input_transpose_node1, 0,
1631                           "conv2d-0-1-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1632 
1633   auto* input_transpose_node2 = context.graph_view->GetNode(
1634       "merge-1-TransposeNHWCToNCHW-LayoutOptimizer");
1635   ASSERT_NE(input_transpose_node2, nullptr);
1636   ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1637   VerifyRegularFaninMatch(input_transpose_node2, 0, "i1", 0);
1638 
1639   auto* m_node = context.graph_view->GetNode("merge");
1640   ASSERT_NE(m_node, nullptr);
1641   ASSERT_EQ(m_node->NumRegularFanins(), 2);
1642   VerifyRegularFaninMatch(m_node, 0, input_transpose_node1->GetName(), 0);
1643   VerifyRegularFaninMatch(m_node, 1, input_transpose_node2->GetName(), 0);
1644 
1645   auto* output_transpose_node = context.graph_view->GetNode(
1646       "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1647   ASSERT_NE(output_transpose_node, nullptr);
1648   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1649   VerifyRegularFaninMatch(output_transpose_node, 0, m_node->GetName(), 0);
1650 
1651   auto* output_node = context.graph_view->GetNode("i2");
1652   ASSERT_NE(output_node, nullptr);
1653   ASSERT_EQ(output_node->NumRegularFanins(), 1);
1654   VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
1655 }
1656 
TEST_F(TransposerTest,MergeTransposerTestMergeOneInputNotConvertible)1657 TEST_F(TransposerTest, MergeTransposerTestMergeOneInputNotConvertible) {
1658 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1659   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1660 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1661   GrapplerItem item;
1662   Scope scope = Scope::NewRootScope();
1663   auto conv2d = SimpleConv2D(&scope);
1664   auto tensor_4d =
1665       ops::Const(scope.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
1666   auto merge = ops::Merge(scope.WithOpName("merge").WithDevice("/device:GPU:0"),
1667                           {conv2d, tensor_4d});
1668   auto i2 = ops::Identity(scope.WithOpName("i2"), merge.output);
1669   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1670   TransposeContext context;
1671   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1672       item, virtual_cluster_.get(), &context));
1673   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1674 
1675   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1676   auto* c2d = context.graph_view->GetNode("conv2d");
1677   ASSERT_NE(c2d, nullptr);
1678   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1679 
1680   MergeTransposer merge_transposer;
1681   auto* m = context.graph_view->GetNode("merge");
1682   ASSERT_NE(m, nullptr);
1683   TF_ASSERT_OK(merge_transposer.TransposeNode(&context, m));
1684 
1685   // Optimization should not occur because not every input is a transform or
1686   // after transform.
1687   auto* input_transpose_node1 = context.graph_view->GetNode(
1688       "merge-0-TransposeNHWCToNCHW-LayoutOptimizer");
1689   EXPECT_EQ(input_transpose_node1, nullptr);
1690 
1691   auto* input_transpose_node2 = context.graph_view->GetNode(
1692       "merge-1-TransposeNHWCToNCHW-LayoutOptimizer");
1693   EXPECT_EQ(input_transpose_node2, nullptr);
1694 
1695   auto* m_node = context.graph_view->GetNode("merge");
1696   ASSERT_NE(m_node, nullptr);
1697   ASSERT_EQ(m_node->NumRegularFanins(), 2);
1698   VerifyRegularFaninMatch(m_node, 0,
1699                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1700   VerifyRegularFaninMatch(m_node, 1, "tensor_4d", 0);
1701 
1702   auto* output_transpose_node = context.graph_view->GetNode(
1703       "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1704   EXPECT_EQ(output_transpose_node, nullptr);
1705 
1706   auto* output_node = context.graph_view->GetNode("i2");
1707   ASSERT_NE(output_node, nullptr);
1708   ASSERT_EQ(output_node->NumRegularFanins(), 1);
1709   VerifyRegularFaninMatch(output_node, 0, m_node->GetName(), 0);
1710 }
1711 
TEST_F(TransposerTest,PadTransposerTest)1712 TEST_F(TransposerTest, PadTransposerTest) {
1713 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1714   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1715 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1716   GrapplerItem item;
1717   Scope scope = Scope::NewRootScope();
1718   auto conv2d = SimpleConv2D(&scope);
1719   auto c = ops::Const(scope.WithOpName("c"), {1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
1720   auto p =
1721       ops::Pad(scope.WithOpName("p").WithDevice("/device:GPU:0"), conv2d, c);
1722   auto o = ops::Identity(scope.WithOpName("o"), p);
1723   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1724   TransposeContext context;
1725   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1726       item, virtual_cluster_.get(), &context));
1727   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1728 
1729   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1730   auto* c2d = context.graph_view->GetNode("conv2d");
1731   ASSERT_NE(c2d, nullptr);
1732   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1733 
1734   PadTransposer pad_transposer;
1735   auto* pad = context.graph_view->GetNode("p");
1736   ASSERT_NE(pad, nullptr);
1737   TF_ASSERT_OK(pad_transposer.TransposeNode(&context, pad));
1738 
1739   // The expected optimized graph contains 2 extra sets of Transpose nodes and 1
1740   // DataFormatVecPermute node.
1741   auto* input_transpose_node =
1742       context.graph_view->GetNode("p-0-TransposeNHWCToNCHW-LayoutOptimizer");
1743   ASSERT_NE(input_transpose_node, nullptr);
1744   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
1745   VerifyRegularFaninMatch(input_transpose_node, 0,
1746                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1747 
1748   auto* padding_transpose_node = context.graph_view->GetNode(
1749       "p-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
1750   ASSERT_NE(padding_transpose_node, nullptr);
1751   ASSERT_EQ(padding_transpose_node->NumRegularFanins(), 1);
1752   VerifyRegularFaninMatch(padding_transpose_node, 0, "c", 0);
1753 
1754   auto* pad_node = context.graph_view->GetNode("p");
1755   ASSERT_NE(pad_node, nullptr);
1756   ASSERT_EQ(pad_node->NumRegularFanins(), 2);
1757   VerifyRegularFaninMatch(pad_node, 0, input_transpose_node->GetName(), 0);
1758   VerifyRegularFaninMatch(pad_node, 1, padding_transpose_node->GetName(), 0);
1759 
1760   auto* output_transpose_node =
1761       context.graph_view->GetNode("p-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1762   ASSERT_NE(output_transpose_node, nullptr);
1763   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1764   VerifyRegularFaninMatch(output_transpose_node, 0, pad_node->GetName(), 0);
1765 
1766   auto* output_node = context.graph_view->GetNode("o");
1767   ASSERT_NE(output_node, nullptr);
1768   ASSERT_EQ(output_node->NumRegularFanins(), 1);
1769   VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
1770 }
1771 
TEST_F(TransposerTest,SwitchTransposerTest)1772 TEST_F(TransposerTest, SwitchTransposerTest) {
1773 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1774   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1775 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1776   GrapplerItem item;
1777   Scope scope = Scope::NewRootScope();
1778   auto conv2d = SimpleConv2D(&scope);
1779   ops::Variable ctrl(scope.WithOpName("ctrl"), {}, DT_BOOL);
1780   auto sw = ops::Switch(scope.WithOpName("switch").WithDevice("/device:GPU:0"),
1781                         conv2d, ctrl);
1782   auto i1 = ops::Identity(scope.WithOpName("i1"), sw.output_false);
1783   auto i2 = ops::Identity(scope.WithOpName("i2"), sw.output_true);
1784   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1785   TransposeContext context;
1786   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1787       item, virtual_cluster_.get(), &context));
1788   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1789 
1790   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1791   auto* c2d = context.graph_view->GetNode("conv2d");
1792   ASSERT_NE(c2d, nullptr);
1793   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1794 
1795   SwitchTransposer switch_transposer;
1796   auto* sw_node = context.graph_view->GetNode("switch");
1797   ASSERT_NE(sw_node, nullptr);
1798   TF_ASSERT_OK(switch_transposer.TransposeNode(&context, sw_node));
1799 
1800   // The expected optimized graph contains 3 extra sets of Transpose nodes.
1801   auto* input_transpose_node = context.graph_view->GetNode(
1802       "switch-0-TransposeNHWCToNCHW-LayoutOptimizer");
1803   ASSERT_NE(input_transpose_node, nullptr);
1804   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
1805   VerifyRegularFaninMatch(input_transpose_node, 0,
1806                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1807 
1808   auto* switch_node = context.graph_view->GetNode("switch");
1809   ASSERT_NE(switch_node, nullptr);
1810   ASSERT_EQ(switch_node->NumRegularFanins(), 2);
1811   VerifyRegularFaninMatch(switch_node, 0, input_transpose_node->GetName(), 0);
1812   VerifyRegularFaninMatch(switch_node, 1, "ctrl", 0);
1813 
1814   auto* output_transpose_node1 = context.graph_view->GetNode(
1815       "switch-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1816   ASSERT_NE(output_transpose_node1, nullptr);
1817   ASSERT_EQ(output_transpose_node1->NumRegularFanins(), 2);
1818   VerifyRegularFaninMatch(output_transpose_node1, 0, switch_node->GetName(), 0);
1819 
1820   auto* output_transpose_node2 = context.graph_view->GetNode(
1821       "switch-1-0-TransposeNCHWToNHWC-LayoutOptimizer");
1822   ASSERT_NE(output_transpose_node2, nullptr);
1823   ASSERT_EQ(output_transpose_node2->NumRegularFanins(), 2);
1824   VerifyRegularFaninMatch(output_transpose_node2, 0, switch_node->GetName(), 1);
1825 
1826   auto* i1_node = context.graph_view->GetNode("i1");
1827   ASSERT_NE(i1_node, nullptr);
1828   ASSERT_EQ(i1_node->NumRegularFanins(), 1);
1829   VerifyRegularFaninMatch(i1_node, 0, output_transpose_node1->GetName(), 0);
1830 
1831   auto* i2_node = context.graph_view->GetNode("i2");
1832   ASSERT_NE(i2_node, nullptr);
1833   ASSERT_EQ(i2_node->NumRegularFanins(), 1);
1834   VerifyRegularFaninMatch(i2_node, 0, output_transpose_node2->GetName(), 0);
1835 }
1836 
TEST_F(TransposerTest,TernaryOpTransposerTest)1837 TEST_F(TransposerTest, TernaryOpTransposerTest) {
1838 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1839   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1840 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1841   GrapplerItem item;
1842   Scope scope = Scope::NewRootScope();
1843   auto conv2d = SimpleConv2D(&scope);
1844   auto a = ops::RandomUniform(scope.WithOpName("a"),
1845                               {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
1846   auto b = ops::RandomUniform(scope.WithOpName("b"),
1847                               {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
1848   auto beta_inc = ops::Betainc(
1849       scope.WithOpName("beta_inc").WithDevice("/device:GPU:0"), a, b, conv2d);
1850   auto z = ops::Identity(scope.WithOpName("z"), beta_inc);
1851   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1852   TransposeContext context;
1853   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1854       item, virtual_cluster_.get(), &context));
1855   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1856 
1857   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1858   auto* c2d = context.graph_view->GetNode("conv2d");
1859   ASSERT_NE(c2d, nullptr);
1860   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1861 
1862   TernaryOpTransposer ternary_op_transposer;
1863   auto* bi = context.graph_view->GetNode("beta_inc");
1864   ASSERT_NE(bi, nullptr);
1865   TF_ASSERT_OK(ternary_op_transposer.TransposeNode(&context, bi));
1866 
1867   // The expected optimized graph contains 4 extra sets of Transpose nodes.
1868   auto* input_transpose_node1 = context.graph_view->GetNode(
1869       "beta_inc-0-TransposeNHWCToNCHW-LayoutOptimizer");
1870   ASSERT_NE(input_transpose_node1, nullptr);
1871   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1872   VerifyRegularFaninMatch(input_transpose_node1, 0, "a", 0);
1873 
1874   auto* input_transpose_node2 = context.graph_view->GetNode(
1875       "beta_inc-1-TransposeNHWCToNCHW-LayoutOptimizer");
1876   ASSERT_NE(input_transpose_node2, nullptr);
1877   ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1878   VerifyRegularFaninMatch(input_transpose_node2, 0, "b", 0);
1879 
1880   auto* input_transpose_node3 = context.graph_view->GetNode(
1881       "beta_inc-2-TransposeNHWCToNCHW-LayoutOptimizer");
1882   ASSERT_NE(input_transpose_node3, nullptr);
1883   ASSERT_EQ(input_transpose_node3->NumRegularFanins(), 2);
1884   VerifyRegularFaninMatch(input_transpose_node3, 0,
1885                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1886 
1887   auto* bi_node = context.graph_view->GetNode("beta_inc");
1888   ASSERT_NE(bi_node, nullptr);
1889   ASSERT_EQ(bi_node->NumRegularFanins(), 3);
1890   VerifyRegularFaninMatch(bi_node, 0, input_transpose_node1->GetName(), 0);
1891   VerifyRegularFaninMatch(bi_node, 1, input_transpose_node2->GetName(), 0);
1892   VerifyRegularFaninMatch(bi_node, 2, input_transpose_node3->GetName(), 0);
1893 
1894   auto* output_transpose_node = context.graph_view->GetNode(
1895       "beta_inc-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1896   ASSERT_NE(output_transpose_node, nullptr);
1897   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1898   VerifyRegularFaninMatch(output_transpose_node, 0, bi_node->GetName(), 0);
1899 
1900   auto* z_output_node = context.graph_view->GetNode("z");
1901   ASSERT_NE(z_output_node, nullptr);
1902   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
1903   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
1904                           0);
1905 }
1906 
TEST_F(TransposerTest,UnaryGradTransposerTestTanhGrad)1907 TEST_F(TransposerTest, UnaryGradTransposerTestTanhGrad) {
1908 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1909   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1910 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1911   GrapplerItem item;
1912   Scope scope = Scope::NewRootScope();
1913   auto conv2d = SimpleConv2D(&scope);
1914   auto a = ops::RandomUniform(scope.WithOpName("a"),
1915                               {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
1916   auto tanh_grad_op = ops::internal::TanhGrad(
1917       scope.WithOpName("tanh_grad").WithDevice("/device:GPU:0"), conv2d, a);
1918   auto z = ops::Identity(scope.WithOpName("z"), tanh_grad_op);
1919   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1920   TransposeContext context;
1921   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1922       item, virtual_cluster_.get(), &context));
1923   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1924 
1925   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1926   auto* c2d = context.graph_view->GetNode("conv2d");
1927   ASSERT_NE(c2d, nullptr);
1928   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1929 
1930   UnaryGradTransposer unary_grad_transposer;
1931   auto* tanh_grad = context.graph_view->GetNode("tanh_grad");
1932   ASSERT_NE(tanh_grad, nullptr);
1933   TF_ASSERT_OK(unary_grad_transposer.TransposeNode(&context, tanh_grad));
1934 
1935   // The expected optimized graph contains 4 extra sets of Transpose nodes.
1936   auto* input_transpose_node1 = context.graph_view->GetNode(
1937       "tanh_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
1938   ASSERT_NE(input_transpose_node1, nullptr);
1939   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1940   VerifyRegularFaninMatch(input_transpose_node1, 0,
1941                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1942 
1943   auto* input_transpose_node2 = context.graph_view->GetNode(
1944       "tanh_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
1945   ASSERT_NE(input_transpose_node2, nullptr);
1946   ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1947   VerifyRegularFaninMatch(input_transpose_node2, 0, "a", 0);
1948 
1949   auto* tanh_grad_node = context.graph_view->GetNode("tanh_grad");
1950   ASSERT_NE(tanh_grad_node, nullptr);
1951   ASSERT_EQ(tanh_grad_node->NumRegularFanins(), 2);
1952   VerifyRegularFaninMatch(tanh_grad_node, 0, input_transpose_node1->GetName(),
1953                           0);
1954   VerifyRegularFaninMatch(tanh_grad_node, 1, input_transpose_node2->GetName(),
1955                           0);
1956 
1957   auto* output_transpose_node = context.graph_view->GetNode(
1958       "tanh_grad-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1959   ASSERT_NE(output_transpose_node, nullptr);
1960   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1961   VerifyRegularFaninMatch(output_transpose_node, 0, tanh_grad_node->GetName(),
1962                           0);
1963 
1964   auto* z_output_node = context.graph_view->GetNode("z");
1965   ASSERT_NE(z_output_node, nullptr);
1966   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
1967   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
1968                           0);
1969 }
1970 
TEST_F(TransposerTest,UnaryGradTransposerTestRelu6Grad)1971 TEST_F(TransposerTest, UnaryGradTransposerTestRelu6Grad) {
1972 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1973   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1974 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1975   GrapplerItem item;
1976   Scope scope = Scope::NewRootScope();
1977   auto conv2d = SimpleConv2D(&scope);
1978   auto a = ops::RandomUniform(scope.WithOpName("a"),
1979                               {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
1980   auto relu6_grad_op = ops::internal::SigmoidGrad(
1981       scope.WithOpName("relu6_grad").WithDevice("/device:GPU:0"), conv2d, a);
1982   auto z = ops::Identity(scope.WithOpName("z"), relu6_grad_op);
1983   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1984   TransposeContext context;
1985   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1986       item, virtual_cluster_.get(), &context));
1987   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1988 
1989   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1990   auto* c2d = context.graph_view->GetNode("conv2d");
1991   ASSERT_NE(c2d, nullptr);
1992   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1993 
1994   UnaryGradTransposer unary_grad_transposer;
1995   auto* relu6_grad = context.graph_view->GetNode("relu6_grad");
1996   ASSERT_NE(relu6_grad, nullptr);
1997   TF_ASSERT_OK(unary_grad_transposer.TransposeNode(&context, relu6_grad));
1998 
1999   // The expected optimized graph contains 4 extra sets of Transpose nodes.
2000   auto* input_transpose_node1 = context.graph_view->GetNode(
2001       "relu6_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
2002   ASSERT_NE(input_transpose_node1, nullptr);
2003   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2004   VerifyRegularFaninMatch(input_transpose_node1, 0,
2005                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2006 
2007   auto* input_transpose_node2 = context.graph_view->GetNode(
2008       "relu6_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
2009   ASSERT_NE(input_transpose_node2, nullptr);
2010   ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
2011   VerifyRegularFaninMatch(input_transpose_node2, 0, "a", 0);
2012 
2013   auto* relu6_grad_node = context.graph_view->GetNode("relu6_grad");
2014   ASSERT_NE(relu6_grad_node, nullptr);
2015   ASSERT_EQ(relu6_grad_node->NumRegularFanins(), 2);
2016   VerifyRegularFaninMatch(relu6_grad_node, 0, input_transpose_node1->GetName(),
2017                           0);
2018   VerifyRegularFaninMatch(relu6_grad_node, 1, input_transpose_node2->GetName(),
2019                           0);
2020 
2021   auto* output_transpose_node = context.graph_view->GetNode(
2022       "relu6_grad-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2023   ASSERT_NE(output_transpose_node, nullptr);
2024   ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
2025   VerifyRegularFaninMatch(output_transpose_node, 0, relu6_grad_node->GetName(),
2026                           0);
2027 
2028   auto* z_output_node = context.graph_view->GetNode("z");
2029   ASSERT_NE(z_output_node, nullptr);
2030   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2031   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2032                           0);
2033 }
2034 
TEST_F(TransposerTest,SqueezeTransposerTest)2035 TEST_F(TransposerTest, SqueezeTransposerTest) {
2036 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2037   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2038 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2039   GrapplerItem item;
2040   Scope scope = Scope::NewRootScope();
2041   auto input =
2042       ops::RandomUniform(scope.WithOpName("input"), {32, 1, 1, 8}, DT_FLOAT);
2043   auto filter =
2044       ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 16}, DT_FLOAT);
2045   auto conv2d = ops::Conv2D(
2046       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2047       {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2048 
2049   auto squeeze_op = ops::Squeeze(
2050       scope.WithOpName("squeeze").WithDevice("/device:GPU:0"), conv2d);
2051   auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2052   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2053   TransposeContext context;
2054   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2055       item, virtual_cluster_.get(), &context));
2056   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2057 
2058   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2059   auto* c2d = context.graph_view->GetNode("conv2d");
2060   ASSERT_NE(c2d, nullptr);
2061   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2062 
2063   SqueezeTransposer squeeze_transposer;
2064   auto* squeeze = context.graph_view->GetNode("squeeze");
2065   ASSERT_NE(squeeze, nullptr);
2066   TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2067 
2068   auto* input_transpose_node1 = context.graph_view->GetNode(
2069       "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2070   ASSERT_NE(input_transpose_node1, nullptr);
2071   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2072   VerifyRegularFaninMatch(input_transpose_node1, 0,
2073                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2074 
2075   auto* squeeze_node = context.graph_view->GetNode("squeeze");
2076   ASSERT_NE(squeeze_node, nullptr);
2077   ASSERT_EQ(squeeze_node->NumRegularFanins(), 1);
2078   VerifyRegularFaninMatch(squeeze_node, 0, input_transpose_node1->GetName(), 0);
2079 
2080   auto* output_transpose_node = context.graph_view->GetNode(
2081       "squeeze-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2082   EXPECT_EQ(output_transpose_node, nullptr);
2083 
2084   auto* z_output_node = context.graph_view->GetNode("z");
2085   ASSERT_NE(z_output_node, nullptr);
2086   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2087   VerifyRegularFaninMatch(z_output_node, 0, squeeze_node->GetName(), 0);
2088 }
2089 
TEST_F(TransposerTest,SqueezeTransposerTestUnsupportedInputShape)2090 TEST_F(TransposerTest, SqueezeTransposerTestUnsupportedInputShape) {
2091 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2092   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2093 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2094   GrapplerItem item;
2095   Scope scope = Scope::NewRootScope();
2096   auto input =
2097       ops::RandomUniform(scope.WithOpName("input"), {32, 5, 5, 8}, DT_FLOAT);
2098   auto filter =
2099       ops::RandomUniform(scope.WithOpName("filter"), {5, 5, 8, 16}, DT_FLOAT);
2100   auto conv2d = ops::Conv2D(
2101       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2102       {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2103 
2104   auto squeeze_op = ops::Squeeze(
2105       scope.WithOpName("squeeze").WithDevice("/device:GPU:0"), conv2d);
2106   auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2107   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2108   TransposeContext context;
2109   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2110       item, virtual_cluster_.get(), &context));
2111   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2112 
2113   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2114   auto* c2d = context.graph_view->GetNode("conv2d");
2115   ASSERT_NE(c2d, nullptr);
2116   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2117 
2118   SqueezeTransposer squeeze_transposer;
2119   auto* squeeze = context.graph_view->GetNode("squeeze");
2120   ASSERT_NE(squeeze, nullptr);
2121   TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2122 
2123   // Expect no changes to the input edge.
2124   auto* input_transpose_node1 = context.graph_view->GetNode(
2125       "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2126   EXPECT_EQ(input_transpose_node1, nullptr);
2127 }
2128 
TEST_F(TransposerTest,SqueezeTransposerTestInvalidHWAxis)2129 TEST_F(TransposerTest, SqueezeTransposerTestInvalidHWAxis) {
2130 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2131   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2132 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2133   GrapplerItem item;
2134   Scope scope = Scope::NewRootScope();
2135   auto input =
2136       ops::RandomUniform(scope.WithOpName("input"), {32, 1, 1, 8}, DT_FLOAT);
2137   auto filter =
2138       ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 16}, DT_FLOAT);
2139   auto conv2d = ops::Conv2D(
2140       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2141       {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2142 
2143   auto squeeze_op =
2144       ops::Squeeze(scope.WithOpName("squeeze").WithDevice("/device:GPU:0"),
2145                    conv2d, ops::Squeeze::Attrs().Axis({1}));
2146   auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2147   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2148   TransposeContext context;
2149   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2150       item, virtual_cluster_.get(), &context));
2151   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2152 
2153   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2154   auto* c2d = context.graph_view->GetNode("conv2d");
2155   ASSERT_NE(c2d, nullptr);
2156   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2157 
2158   SqueezeTransposer squeeze_transposer;
2159   auto* squeeze = context.graph_view->GetNode("squeeze");
2160   ASSERT_NE(squeeze, nullptr);
2161   TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2162 
2163   // Expect no changes to the input edge.
2164   auto* input_transpose_node1 = context.graph_view->GetNode(
2165       "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2166   EXPECT_EQ(input_transpose_node1, nullptr);
2167 }
2168 
TEST_F(TransposerTest,SqueezeTransposerTestInvalidNHWAxis)2169 TEST_F(TransposerTest, SqueezeTransposerTestInvalidNHWAxis) {
2170 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2171   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2172 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2173   GrapplerItem item;
2174   Scope scope = Scope::NewRootScope();
2175   auto input =
2176       ops::RandomUniform(scope.WithOpName("input"), {32, 1, 1, 8}, DT_FLOAT);
2177   auto filter =
2178       ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 1}, DT_FLOAT);
2179   auto conv2d = ops::Conv2D(
2180       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2181       {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2182 
2183   auto squeeze_op =
2184       ops::Squeeze(scope.WithOpName("squeeze").WithDevice("/device:GPU:0"),
2185                    conv2d, ops::Squeeze::Attrs().Axis({1, 2, 3}));
2186   auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2187   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2188   TransposeContext context;
2189   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2190       item, virtual_cluster_.get(), &context));
2191   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2192 
2193   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2194   auto* c2d = context.graph_view->GetNode("conv2d");
2195   ASSERT_NE(c2d, nullptr);
2196   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2197 
2198   SqueezeTransposer squeeze_transposer;
2199   auto* squeeze = context.graph_view->GetNode("squeeze");
2200   ASSERT_NE(squeeze, nullptr);
2201   TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2202 
2203   // Expect no changes to the input edge.
2204   auto* input_transpose_node1 = context.graph_view->GetNode(
2205       "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2206   EXPECT_EQ(input_transpose_node1, nullptr);
2207 }
2208 
TEST_F(TransposerTest,SqueezeTransposerTestSqueezeDimsUpdated)2209 TEST_F(TransposerTest, SqueezeTransposerTestSqueezeDimsUpdated) {
2210 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2211   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2212 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2213   GrapplerItem item;
2214   Scope scope = Scope::NewRootScope();
2215   auto input =
2216       ops::RandomUniform(scope.WithOpName("input"), {1, 1, 1, 8}, DT_FLOAT);
2217   auto filter =
2218       ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 1}, DT_FLOAT);
2219   auto conv2d = ops::Conv2D(
2220       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2221       {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2222 
2223   auto squeeze_op =
2224       ops::Squeeze(scope.WithOpName("squeeze").WithDevice("/device:GPU:0"),
2225                    conv2d, ops::Squeeze::Attrs().Axis({1, 2}));
2226   auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2227   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2228   TransposeContext context;
2229   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2230       item, virtual_cluster_.get(), &context));
2231   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2232 
2233   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2234   auto* c2d = context.graph_view->GetNode("conv2d");
2235   ASSERT_NE(c2d, nullptr);
2236   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2237 
2238   SqueezeTransposer squeeze_transposer;
2239   auto* squeeze = context.graph_view->GetNode("squeeze");
2240   ASSERT_NE(squeeze, nullptr);
2241   TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2242 
2243   auto* input_transpose_node1 = context.graph_view->GetNode(
2244       "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2245   ASSERT_NE(input_transpose_node1, nullptr);
2246   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2247   VerifyRegularFaninMatch(input_transpose_node1, 0,
2248                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2249 
2250   auto* squeeze_node = context.graph_view->GetNode("squeeze");
2251   ASSERT_NE(squeeze_node, nullptr);
2252   ASSERT_EQ(squeeze_node->NumRegularFanins(), 1);
2253   VerifyRegularFaninMatch(squeeze_node, 0, input_transpose_node1->GetName(), 0);
2254   const auto* squeeze_dims_attr = squeeze_node->GetAttr("squeeze_dims");
2255   const auto& list = squeeze_dims_attr->list();
2256   ASSERT_EQ(list.i_size(), 2);
2257   EXPECT_EQ(list.i(0), 2);
2258   EXPECT_EQ(list.i(1), 3);
2259 
2260   auto* output_transpose_node = context.graph_view->GetNode(
2261       "squeeze-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2262   EXPECT_EQ(output_transpose_node, nullptr);
2263 
2264   auto* z_output_node = context.graph_view->GetNode("z");
2265   ASSERT_NE(z_output_node, nullptr);
2266   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2267   VerifyRegularFaninMatch(z_output_node, 0, squeeze_node->GetName(), 0);
2268 }
2269 
2270 // Same as SqueezeTransposerTestSqueezeDimsUpdated but with squeeze dims
2271 // specified with negative values.
TEST_F(TransposerTest,SqueezeTransposerTestNegativeSqueezeDimsUpdated)2272 TEST_F(TransposerTest, SqueezeTransposerTestNegativeSqueezeDimsUpdated) {
2273 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2274   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2275 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2276   GrapplerItem item;
2277   Scope scope = Scope::NewRootScope();
2278   auto input =
2279       ops::RandomUniform(scope.WithOpName("input"), {1, 1, 1, 8}, DT_FLOAT);
2280   auto filter =
2281       ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 1}, DT_FLOAT);
2282   auto conv2d = ops::Conv2D(
2283       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2284       {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2285 
2286   auto squeeze_op =
2287       ops::Squeeze(scope.WithOpName("squeeze").WithDevice("/device:GPU:0"),
2288                    conv2d, ops::Squeeze::Attrs().Axis({-3, -2}));
2289   auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2290   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2291   TransposeContext context;
2292   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2293       item, virtual_cluster_.get(), &context));
2294   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2295 
2296   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2297   auto* c2d = context.graph_view->GetNode("conv2d");
2298   ASSERT_NE(c2d, nullptr);
2299   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2300 
2301   SqueezeTransposer squeeze_transposer;
2302   auto* squeeze = context.graph_view->GetNode("squeeze");
2303   ASSERT_NE(squeeze, nullptr);
2304   TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2305 
2306   auto* input_transpose_node1 = context.graph_view->GetNode(
2307       "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2308   ASSERT_NE(input_transpose_node1, nullptr);
2309   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2310   VerifyRegularFaninMatch(input_transpose_node1, 0,
2311                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2312 
2313   auto* squeeze_node = context.graph_view->GetNode("squeeze");
2314   ASSERT_NE(squeeze_node, nullptr);
2315   ASSERT_EQ(squeeze_node->NumRegularFanins(), 1);
2316   VerifyRegularFaninMatch(squeeze_node, 0, input_transpose_node1->GetName(), 0);
2317   const auto* squeeze_dims_attr = squeeze_node->GetAttr("squeeze_dims");
2318   const auto& list = squeeze_dims_attr->list();
2319   ASSERT_EQ(list.i_size(), 2);
2320   EXPECT_EQ(list.i(0), 2);
2321   EXPECT_EQ(list.i(1), 3);
2322 
2323   auto* output_transpose_node = context.graph_view->GetNode(
2324       "squeeze-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2325   EXPECT_EQ(output_transpose_node, nullptr);
2326 
2327   auto* z_output_node = context.graph_view->GetNode("z");
2328   ASSERT_NE(z_output_node, nullptr);
2329   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2330   VerifyRegularFaninMatch(z_output_node, 0, squeeze_node->GetName(), 0);
2331 }
2332 
2333 // Same as SqueezeTransposerTestSqueezeDimsUpdated but with the source and
2334 // destination formats swapped (as is used in some cases when the data type is
2335 // DT_HALF).
TEST_F(TransposerTest,SqueezeTransposerTestNCHWToNHWCSqueezeDimsUpdated)2336 TEST_F(TransposerTest, SqueezeTransposerTestNCHWToNHWCSqueezeDimsUpdated) {
2337 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2338   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2339 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2340   GrapplerItem item;
2341   Scope scope = Scope::NewRootScope();
2342   auto input =
2343       ops::RandomUniform(scope.WithOpName("input"), {1, 8, 1, 1}, DT_FLOAT);
2344   auto filter =
2345       ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 1}, DT_FLOAT);
2346   auto conv2d = ops::Conv2D(
2347       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2348       {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kDstFormat));
2349 
2350   auto squeeze_op =
2351       ops::Squeeze(scope.WithOpName("squeeze").WithDevice("/device:GPU:0"),
2352                    conv2d, ops::Squeeze::Attrs().Axis({2, 3}));
2353   auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2354   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2355   TransposeContext context;
2356   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2357       item, virtual_cluster_.get(), &context));
2358   context.AssignDeviceAndDataFormats(kGPU, kDstFormat, kSrcFormat);
2359 
2360   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2361   auto* c2d = context.graph_view->GetNode("conv2d");
2362   ASSERT_NE(c2d, nullptr);
2363   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2364 
2365   SqueezeTransposer squeeze_transposer;
2366   auto* squeeze = context.graph_view->GetNode("squeeze");
2367   ASSERT_NE(squeeze, nullptr);
2368   TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2369 
2370   auto* input_transpose_node1 = context.graph_view->GetNode(
2371       "squeeze-0-TransposeNCHWToNHWC-LayoutOptimizer");
2372   ASSERT_NE(input_transpose_node1, nullptr);
2373   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2374   VerifyRegularFaninMatch(input_transpose_node1, 0,
2375                           "conv2d-0-0-TransposeNHWCToNCHW-LayoutOptimizer", 0);
2376 
2377   auto* squeeze_node = context.graph_view->GetNode("squeeze");
2378   ASSERT_NE(squeeze_node, nullptr);
2379   ASSERT_EQ(squeeze_node->NumRegularFanins(), 1);
2380   VerifyRegularFaninMatch(squeeze_node, 0, input_transpose_node1->GetName(), 0);
2381   const auto* squeeze_dims_attr = squeeze_node->GetAttr("squeeze_dims");
2382   const auto& list = squeeze_dims_attr->list();
2383   ASSERT_EQ(list.i_size(), 2);
2384   EXPECT_EQ(list.i(0), 1);
2385   EXPECT_EQ(list.i(1), 2);
2386 
2387   auto* output_transpose_node = context.graph_view->GetNode(
2388       "squeeze-0-0-TransposeNHWCToNCHW-LayoutOptimizer");
2389   EXPECT_EQ(output_transpose_node, nullptr);
2390 
2391   auto* z_output_node = context.graph_view->GetNode("z");
2392   ASSERT_NE(z_output_node, nullptr);
2393   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2394   VerifyRegularFaninMatch(z_output_node, 0, squeeze_node->GetName(), 0);
2395 }
2396 
TEST_F(TransposerTest,MaxPoolV2Transposer)2397 TEST_F(TransposerTest, MaxPoolV2Transposer) {
2398 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2399   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2400 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2401   GrapplerItem item;
2402   Scope scope = Scope::NewRootScope();
2403   auto input =
2404       ops::RandomUniform(scope.WithOpName("input"),
2405                          {kBatchSize, kWidth, kHeight, kDepthIn}, DT_FLOAT);
2406   auto ksize = ops::Const(scope.WithOpName("ksize"), {1, kKernel, kKernel, 1});
2407   auto strides =
2408       ops::Const(scope.WithOpName("strides"), {1, kKernel, kKernel, 1});
2409   auto maxpool_op =
2410       ops::MaxPoolV2(scope.WithOpName("maxpoolv2").WithDevice("/device:GPU:0"),
2411                      input, ksize, strides, "VALID");
2412   auto z = ops::Identity(scope.WithOpName("z"), maxpool_op);
2413   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2414   TransposeContext context;
2415   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2416       item, virtual_cluster_.get(), &context));
2417   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2418 
2419   MaxPoolV2Transposer maxpool_transposer;
2420   auto* maxpool = context.graph_view->GetNode("maxpoolv2");
2421   ASSERT_NE(maxpool, nullptr);
2422   TF_ASSERT_OK(maxpool_transposer.TransposeNode(&context, maxpool));
2423 
2424   auto* input_transpose_node1 = context.graph_view->GetNode(
2425       "maxpoolv2-0-TransposeNHWCToNCHW-LayoutOptimizer");
2426   ASSERT_NE(input_transpose_node1, nullptr);
2427   auto* input_transpose_node2 = context.graph_view->GetNode(
2428       "maxpoolv2-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
2429   ASSERT_NE(input_transpose_node2, nullptr);
2430   auto* input_transpose_node3 = context.graph_view->GetNode(
2431       "maxpoolv2-2-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
2432   ASSERT_NE(input_transpose_node3, nullptr);
2433 
2434   auto* updated_maxpool = context.graph_view->GetNode("maxpoolv2");
2435   ASSERT_NE(updated_maxpool, nullptr);
2436   ASSERT_EQ(updated_maxpool->NumRegularFanins(), 3);
2437   VerifyRegularFaninMatch(updated_maxpool, 0, input_transpose_node1->GetName(),
2438                           0);
2439   VerifyRegularFaninMatch(updated_maxpool, 1, input_transpose_node2->GetName(),
2440                           0);
2441   VerifyRegularFaninMatch(updated_maxpool, 2, input_transpose_node3->GetName(),
2442                           0);
2443 
2444   auto* output_transpose_node = context.graph_view->GetNode(
2445       "maxpoolv2-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2446   ASSERT_NE(output_transpose_node, nullptr);
2447 
2448   auto* z_output_node = context.graph_view->GetNode("z");
2449   ASSERT_NE(z_output_node, nullptr);
2450   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2451   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2452                           0);
2453 }
2454 
TEST_F(TransposerTest,MaxPoolGradV2Transposer)2455 TEST_F(TransposerTest, MaxPoolGradV2Transposer) {
2456 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2457   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2458 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2459   for (bool use_grad_grad : {false, true}) {
2460     GrapplerItem item;
2461     Scope scope = Scope::NewRootScope();
2462     auto orig_input =
2463         ops::RandomUniform(scope.WithOpName("orig_input"),
2464                            {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2465     auto orig_output =
2466         ops::RandomUniform(scope.WithOpName("orig_output"),
2467                            {kBatchSize, use_grad_grad ? kOutHeight : kHeight,
2468                             use_grad_grad ? kOutWidth : kWidth, kDepthIn},
2469                            DT_FLOAT);
2470     auto grad =
2471         ops::RandomUniform(scope.WithOpName("grad_input"),
2472                            {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2473     auto ksize =
2474         ops::Const(scope.WithOpName("ksize"), {1, kKernel, kKernel, 1});
2475     auto strides =
2476         ops::Const(scope.WithOpName("strides"), {1, kKernel, kKernel, 1});
2477     Output maxpoolgrad_op;
2478     if (use_grad_grad) {
2479       maxpoolgrad_op = ops::MaxPoolGradGradV2(
2480           scope.WithOpName("maxpoolgradv2").WithDevice("/device:GPU:0"),
2481           orig_input, orig_output, grad, ksize, strides, "VALID");
2482     } else {
2483       maxpoolgrad_op = ops::MaxPoolGradV2(
2484           scope.WithOpName("maxpoolgradv2").WithDevice("/device:GPU:0"),
2485           orig_input, orig_output, grad, ksize, strides, "VALID");
2486     }
2487     auto z = ops::Identity(scope.WithOpName("z"), maxpoolgrad_op);
2488     TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2489     TransposeContext context;
2490     TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2491         item, virtual_cluster_.get(), &context));
2492     context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2493 
2494     MaxPoolGradV2Transposer maxpoolgrad_transposer;
2495     auto* maxpoolgrad = context.graph_view->GetNode("maxpoolgradv2");
2496     ASSERT_NE(maxpoolgrad, nullptr);
2497     TF_ASSERT_OK(maxpoolgrad_transposer.TransposeNode(&context, maxpoolgrad));
2498 
2499     auto* orig_input_transpose_node = context.graph_view->GetNode(
2500         "maxpoolgradv2-0-TransposeNHWCToNCHW-LayoutOptimizer");
2501     ASSERT_NE(orig_input_transpose_node, nullptr);
2502     auto* orig_output_transpose_node = context.graph_view->GetNode(
2503         "maxpoolgradv2-1-TransposeNHWCToNCHW-LayoutOptimizer");
2504     ASSERT_NE(orig_output_transpose_node, nullptr);
2505     auto* grad_input_transpose_node = context.graph_view->GetNode(
2506         "maxpoolgradv2-2-TransposeNHWCToNCHW-LayoutOptimizer");
2507     ASSERT_NE(grad_input_transpose_node, nullptr);
2508     auto* size_node = context.graph_view->GetNode(
2509         "maxpoolgradv2-3-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
2510     ASSERT_NE(size_node, nullptr);
2511     auto* stride_node = context.graph_view->GetNode(
2512         "maxpoolgradv2-4-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
2513     ASSERT_NE(stride_node, nullptr);
2514 
2515     auto* updated_maxpoolgrad = context.graph_view->GetNode("maxpoolgradv2");
2516     ASSERT_NE(updated_maxpoolgrad, nullptr);
2517     ASSERT_EQ(updated_maxpoolgrad->NumRegularFanins(), 5);
2518     VerifyRegularFaninMatch(updated_maxpoolgrad, 0,
2519                             orig_input_transpose_node->GetName(), 0);
2520     VerifyRegularFaninMatch(updated_maxpoolgrad, 1,
2521                             orig_output_transpose_node->GetName(), 0);
2522     VerifyRegularFaninMatch(updated_maxpoolgrad, 2,
2523                             grad_input_transpose_node->GetName(), 0);
2524     VerifyRegularFaninMatch(updated_maxpoolgrad, 3, size_node->GetName(), 0);
2525     VerifyRegularFaninMatch(updated_maxpoolgrad, 4, stride_node->GetName(), 0);
2526 
2527     auto* output_transpose_node = context.graph_view->GetNode(
2528         "maxpoolgradv2-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2529     ASSERT_NE(output_transpose_node, nullptr);
2530 
2531     auto* z_output_node = context.graph_view->GetNode("z");
2532     ASSERT_NE(z_output_node, nullptr);
2533     ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2534     VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2535                             0);
2536   }
2537 }
2538 
TEST_F(TransposerTest,BinaryOpTransposerAdd)2539 TEST_F(TransposerTest, BinaryOpTransposerAdd) {
2540 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2541   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2542 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2543   GrapplerItem item;
2544   Scope scope = Scope::NewRootScope();
2545   auto input =
2546       ops::RandomUniform(scope.WithOpName("input"),
2547                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2548   auto filter =
2549       ops::RandomUniform(scope.WithOpName("filter"),
2550                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2551   auto conv2d = ops::Conv2D(
2552       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2553       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2554   auto a = ops::RandomUniform(scope.WithOpName("a"), {1}, DT_FLOAT);
2555   auto add =
2556       ops::Add(scope.WithOpName("Add").WithDevice("/device:GPU:0"), a, conv2d);
2557   auto z = ops::Identity(scope.WithOpName("z"), add);
2558   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2559   TransposeContext context;
2560   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2561       item, virtual_cluster_.get(), &context));
2562   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2563 
2564   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2565   auto* c2d = context.graph_view->GetNode("conv2d");
2566   ASSERT_NE(c2d, nullptr);
2567   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2568 
2569   auto* addop = context.graph_view->GetNode("Add");
2570   ASSERT_NE(addop, nullptr);
2571   BinaryOpTransposer binaryop_transposer;
2572   TF_ASSERT_OK(binaryop_transposer.TransposeNode(&context, addop));
2573 
2574   auto* input_const_node =
2575       context.graph_view->GetNode("Add-0-ReshapeConst-LayoutOptimizer");
2576   ASSERT_NE(input_const_node, nullptr);
2577   EXPECT_EQ(input_const_node->NumRegularFanins(), 0);
2578 
2579   auto* input_reshape_node =
2580       context.graph_view->GetNode("Add-0-ReshapeNHWCToNCHW-LayoutOptimizer");
2581   ASSERT_NE(input_reshape_node, nullptr);
2582   ASSERT_EQ(input_reshape_node->NumRegularFanins(), 2);
2583   VerifyRegularFaninMatch(input_reshape_node, 0, "a", 0);
2584   VerifyRegularFaninMatch(input_reshape_node, 1, input_const_node->GetName(),
2585                           0);
2586 
2587   auto* input_transpose_node =
2588       context.graph_view->GetNode("Add-1-TransposeNHWCToNCHW-LayoutOptimizer");
2589   ASSERT_NE(input_transpose_node, nullptr);
2590   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
2591   VerifyRegularFaninMatch(input_transpose_node, 0,
2592                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2593 
2594   auto* updated_add = context.graph_view->GetNode("Add");
2595   ASSERT_NE(updated_add, nullptr);
2596   ASSERT_EQ(updated_add->NumRegularFanins(), 2);
2597   VerifyRegularFaninMatch(updated_add, 0, input_reshape_node->GetName(), 0);
2598   VerifyRegularFaninMatch(updated_add, 1, input_transpose_node->GetName(), 0);
2599 
2600   auto* output_transpose_node = context.graph_view->GetNode(
2601       "Add-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2602   ASSERT_NE(output_transpose_node, nullptr);
2603 
2604   auto* z_output_node = context.graph_view->GetNode("z");
2605   ASSERT_NE(z_output_node, nullptr);
2606   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2607   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2608                           0);
2609 }
2610 
TEST_F(TransposerTest,BinaryOpTransposerMul)2611 TEST_F(TransposerTest, BinaryOpTransposerMul) {
2612 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2613   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2614 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2615   GrapplerItem item;
2616   Scope scope = Scope::NewRootScope();
2617   auto input =
2618       ops::RandomUniform(scope.WithOpName("input"),
2619                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2620   auto filter =
2621       ops::RandomUniform(scope.WithOpName("filter"),
2622                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2623   auto conv2d = ops::Conv2D(
2624       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2625       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2626   auto a = ops::RandomUniform(scope.WithOpName("a"), {1}, DT_FLOAT);
2627   auto mul =
2628       ops::Mul(scope.WithOpName("Mul").WithDevice("/device:GPU:0"), conv2d, a);
2629   auto z = ops::Identity(scope.WithOpName("z"), mul);
2630   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2631   TransposeContext context;
2632   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2633       item, virtual_cluster_.get(), &context));
2634   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2635 
2636   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2637   auto* c2d = context.graph_view->GetNode("conv2d");
2638   ASSERT_NE(c2d, nullptr);
2639   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2640 
2641   auto* mulop = context.graph_view->GetNode("Mul");
2642   ASSERT_NE(mulop, nullptr);
2643   BinaryOpTransposer binaryop_transposer;
2644   TF_ASSERT_OK(binaryop_transposer.TransposeNode(&context, mulop));
2645 
2646   auto* input_const_node =
2647       context.graph_view->GetNode("Mul-1-ReshapeConst-LayoutOptimizer");
2648   ASSERT_NE(input_const_node, nullptr);
2649   EXPECT_EQ(input_const_node->NumRegularFanins(), 0);
2650 
2651   auto* input_reshape_node =
2652       context.graph_view->GetNode("Mul-1-ReshapeNHWCToNCHW-LayoutOptimizer");
2653   ASSERT_NE(input_reshape_node, nullptr);
2654   ASSERT_EQ(input_reshape_node->NumRegularFanins(), 2);
2655   VerifyRegularFaninMatch(input_reshape_node, 0, "a", 0);
2656   VerifyRegularFaninMatch(input_reshape_node, 1, input_const_node->GetName(),
2657                           0);
2658 
2659   auto* input_transpose_node =
2660       context.graph_view->GetNode("Mul-0-TransposeNHWCToNCHW-LayoutOptimizer");
2661   ASSERT_NE(input_transpose_node, nullptr);
2662   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
2663   VerifyRegularFaninMatch(input_transpose_node, 0,
2664                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2665 
2666   auto* updated_mul = context.graph_view->GetNode("Mul");
2667   ASSERT_NE(updated_mul, nullptr);
2668   ASSERT_EQ(updated_mul->NumRegularFanins(), 2);
2669   VerifyRegularFaninMatch(updated_mul, 1, input_reshape_node->GetName(), 0);
2670   VerifyRegularFaninMatch(updated_mul, 0, input_transpose_node->GetName(), 0);
2671 
2672   auto* output_transpose_node = context.graph_view->GetNode(
2673       "Mul-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2674   ASSERT_NE(output_transpose_node, nullptr);
2675 
2676   auto* z_output_node = context.graph_view->GetNode("z");
2677   ASSERT_NE(z_output_node, nullptr);
2678   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2679   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2680                           0);
2681 }
2682 
TEST_F(TransposerTest,BinaryOpTransposerPolygamma)2683 TEST_F(TransposerTest, BinaryOpTransposerPolygamma) {
2684 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2685   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2686 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2687   GrapplerItem item;
2688   Scope scope = Scope::NewRootScope();
2689   auto input =
2690       ops::RandomUniform(scope.WithOpName("input"),
2691                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2692   auto filter =
2693       ops::RandomUniform(scope.WithOpName("filter"),
2694                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2695   auto conv2d = ops::Conv2D(
2696       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2697       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2698   auto a = ops::RandomUniform(scope.WithOpName("a"),
2699                               {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
2700 
2701   auto polygamma = ops::Polygamma(
2702       scope.WithOpName("polygamma").WithDevice("/device:GPU:0"), conv2d, a);
2703   auto z = ops::Identity(scope.WithOpName("z"), polygamma);
2704   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2705   TransposeContext context;
2706   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2707       item, virtual_cluster_.get(), &context));
2708   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2709 
2710   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2711   auto* c2d = context.graph_view->GetNode("conv2d");
2712   ASSERT_NE(c2d, nullptr);
2713   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2714 
2715   BinaryOpTransposer binaryop_transposer;
2716   auto* polygamma_op = context.graph_view->GetNode("polygamma");
2717   ASSERT_NE(polygamma_op, nullptr);
2718   TF_ASSERT_OK(binaryop_transposer.TransposeNode(&context, polygamma_op));
2719 
2720   auto* input_transpose_node1 = context.graph_view->GetNode(
2721       "polygamma-0-TransposeNHWCToNCHW-LayoutOptimizer");
2722   ASSERT_NE(input_transpose_node1, nullptr);
2723   ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2724   VerifyRegularFaninMatch(input_transpose_node1, 0,
2725                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2726 
2727   auto* updated_polygamma = context.graph_view->GetNode("polygamma");
2728   ASSERT_NE(updated_polygamma, nullptr);
2729   ASSERT_EQ(updated_polygamma->NumRegularFanins(), 2);
2730   VerifyRegularFaninMatch(updated_polygamma, 0,
2731                           input_transpose_node1->GetName(), 0);
2732 
2733   auto* output_transpose_node = context.graph_view->GetNode(
2734       "polygamma-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2735   ASSERT_NE(output_transpose_node, nullptr);
2736 
2737   auto* z_output_node = context.graph_view->GetNode("z");
2738   ASSERT_NE(z_output_node, nullptr);
2739   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2740   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2741                           0);
2742 }
2743 
CreateConcatV1Op(const Scope & scope,const InputList & tensors,const Input & concat_axis,Output * output)2744 bool CreateConcatV1Op(const Scope& scope, const InputList& tensors,
2745                       const Input& concat_axis, Output* output) {
2746   if (!scope.ok()) {
2747     return false;
2748   }
2749   auto values = ops::AsNodeOutList(scope, tensors);
2750   if (!scope.ok()) {
2751     return false;
2752   }
2753   auto axis = ops::AsNodeOut(scope, concat_axis);
2754   if (!scope.ok()) {
2755     return false;
2756   }
2757   Node* ret;
2758   const auto unique_name = scope.GetUniqueNameForOp("Concat");
2759   auto builder = NodeBuilder(unique_name, "Concat").Input(axis).Input(values);
2760   scope.UpdateBuilder(&builder);
2761   scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
2762   if (!scope.ok()) {
2763     return false;
2764   }
2765   scope.UpdateStatus(scope.DoShapeInference(ret));
2766   *output = Output(ret, 0);
2767   return true;
2768 }
2769 
TEST_F(TransposerTest,ConcatOpTransposerConcat)2770 TEST_F(TransposerTest, ConcatOpTransposerConcat) {
2771 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2772   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2773 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2774   GrapplerItem item;
2775   Scope scope = Scope::NewRootScope();
2776   Output input_1 = ops::RandomUniform(scope.WithOpName("input_1"),
2777                                       {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
2778   Output input_2 = ops::RandomUniform(scope.WithOpName("input_2"),
2779                                       {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
2780   auto input =
2781       ops::RandomUniform(scope.WithOpName("input"),
2782                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2783   auto filter =
2784       ops::RandomUniform(scope.WithOpName("filter"),
2785                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2786   Output conv2d = ops::Conv2D(
2787       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2788       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2789   auto axis = ops::Const(scope.WithOpName("axis"), 2, {});
2790   Output concat_op;
2791   ASSERT_TRUE(
2792       CreateConcatV1Op(scope.WithOpName("concat").WithDevice("/device:GPU:0"),
2793                        {input_1, input_2, conv2d}, axis, &concat_op));
2794   auto z = ops::Identity(scope.WithOpName("z"), concat_op);
2795   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2796 
2797   TransposeContext context;
2798   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2799       item, virtual_cluster_.get(), &context));
2800   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2801 
2802   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2803   auto* c2d = context.graph_view->GetNode("conv2d");
2804   ASSERT_NE(c2d, nullptr);
2805   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2806 
2807   ConcatOpTransposer concat_transposer;
2808   auto* concat = context.graph_view->GetNode("concat");
2809   ASSERT_NE(concat, nullptr);
2810   TF_ASSERT_OK(concat_transposer.TransposeNode(&context, concat));
2811 
2812   auto* conv2d_transpose_node = context.graph_view->GetNode(
2813       "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2814   ASSERT_NE(conv2d_transpose_node, nullptr);
2815   auto* conv2d_concat_input_node = context.graph_view->GetNode(
2816       "concat-3-TransposeNHWCToNCHW-LayoutOptimizer");
2817   ASSERT_NE(conv2d_concat_input_node, nullptr);
2818   ASSERT_EQ(conv2d_concat_input_node->NumRegularFanins(), 2);
2819   VerifyRegularFaninMatch(conv2d_concat_input_node, 0,
2820                           conv2d_transpose_node->GetName(), 0);
2821 
2822   auto* axis_dim_node = context.graph_view->GetNode(
2823       "concat-0-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
2824   ASSERT_NE(axis_dim_node, nullptr);
2825 
2826   auto* updated_concat = context.graph_view->GetNode("concat");
2827   ASSERT_NE(updated_concat, nullptr);
2828   ASSERT_EQ(updated_concat->NumRegularFanins(), 4);
2829   VerifyRegularFaninMatch(updated_concat, 0, axis_dim_node->GetName(), 0);
2830   VerifyRegularFaninMatch(updated_concat, 1,
2831                           "concat-1-TransposeNHWCToNCHW-LayoutOptimizer", 0);
2832   VerifyRegularFaninMatch(updated_concat, 2,
2833                           "concat-2-TransposeNHWCToNCHW-LayoutOptimizer", 0);
2834   VerifyRegularFaninMatch(updated_concat, 3,
2835                           conv2d_concat_input_node->GetName(), 0);
2836 
2837   auto* output_transpose_node = context.graph_view->GetNode(
2838       "concat-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2839   ASSERT_NE(output_transpose_node, nullptr);
2840 
2841   auto* z_output_node = context.graph_view->GetNode("z");
2842   ASSERT_NE(z_output_node, nullptr);
2843   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2844   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2845                           0);
2846 }
2847 
TEST_F(TransposerTest,ConcatOpTransposerConcatV2)2848 TEST_F(TransposerTest, ConcatOpTransposerConcatV2) {
2849 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2850   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2851 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2852   GrapplerItem item;
2853   Scope scope = Scope::NewRootScope();
2854   Output input_1 = ops::RandomUniform(scope.WithOpName("input_1"),
2855                                       {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
2856   Output input_2 = ops::RandomUniform(scope.WithOpName("input_2"),
2857                                       {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
2858   auto input =
2859       ops::RandomUniform(scope.WithOpName("input"),
2860                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2861   auto filter =
2862       ops::RandomUniform(scope.WithOpName("filter"),
2863                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2864   Output conv2d = ops::Conv2D(
2865       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2866       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2867   auto axis = ops::Const(scope.WithOpName("axis"), 2, {});
2868   auto concat_op =
2869       ops::Concat(scope.WithOpName("concat").WithDevice("/device:GPU:0"),
2870                   {input_1, input_2, conv2d}, axis);
2871   auto z = ops::Identity(scope.WithOpName("z"), concat_op);
2872   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2873 
2874   TransposeContext context;
2875   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2876       item, virtual_cluster_.get(), &context));
2877   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2878 
2879   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2880   auto* c2d = context.graph_view->GetNode("conv2d");
2881   ASSERT_NE(c2d, nullptr);
2882   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2883 
2884   ConcatOpTransposer concat_transposer;
2885   auto* concat = context.graph_view->GetNode("concat");
2886   ASSERT_NE(concat, nullptr);
2887   TF_ASSERT_OK(concat_transposer.TransposeNode(&context, concat));
2888 
2889   auto* conv2d_transpose_node = context.graph_view->GetNode(
2890       "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2891   ASSERT_NE(conv2d_transpose_node, nullptr);
2892   auto* conv2d_concat_input_node = context.graph_view->GetNode(
2893       "concat-2-TransposeNHWCToNCHW-LayoutOptimizer");
2894   ASSERT_NE(conv2d_concat_input_node, nullptr);
2895   ASSERT_EQ(conv2d_concat_input_node->NumRegularFanins(), 2);
2896   VerifyRegularFaninMatch(conv2d_concat_input_node, 0,
2897                           conv2d_transpose_node->GetName(), 0);
2898 
2899   auto* axis_dim_node = context.graph_view->GetNode(
2900       "concat-3-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
2901   ASSERT_NE(axis_dim_node, nullptr);
2902 
2903   auto* updated_concat = context.graph_view->GetNode("concat");
2904   ASSERT_NE(updated_concat, nullptr);
2905   ASSERT_EQ(updated_concat->NumRegularFanins(), 4);
2906   VerifyRegularFaninMatch(updated_concat, 0,
2907                           "concat-0-TransposeNHWCToNCHW-LayoutOptimizer", 0);
2908   VerifyRegularFaninMatch(updated_concat, 1,
2909                           "concat-1-TransposeNHWCToNCHW-LayoutOptimizer", 0);
2910   VerifyRegularFaninMatch(updated_concat, 2,
2911                           conv2d_concat_input_node->GetName(), 0);
2912   VerifyRegularFaninMatch(updated_concat, 3, axis_dim_node->GetName(), 0);
2913 
2914   auto* output_transpose_node = context.graph_view->GetNode(
2915       "concat-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2916   ASSERT_NE(output_transpose_node, nullptr);
2917 
2918   auto* z_output_node = context.graph_view->GetNode("z");
2919   ASSERT_NE(z_output_node, nullptr);
2920   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2921   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2922                           0);
2923 }
2924 
TEST_F(TransposerTest,ReverseV2Transposer)2925 TEST_F(TransposerTest, ReverseV2Transposer) {
2926 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2927   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2928 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2929   GrapplerItem item;
2930   Scope scope = Scope::NewRootScope();
2931 
2932   auto input =
2933       ops::RandomUniform(scope.WithOpName("input"),
2934                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2935   auto filter =
2936       ops::RandomUniform(scope.WithOpName("filter"),
2937                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2938   Output conv2d = ops::Conv2D(
2939       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2940       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2941   auto axis = ops::Const(scope.WithOpName("axis"), {0, 3}, {2});
2942   auto reverse_op = ops::Reverse(
2943       scope.WithOpName("reverse_v2").WithDevice("/device:GPU:0"), conv2d, axis);
2944   auto z = ops::Identity(scope.WithOpName("z"), reverse_op);
2945   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2946 
2947   TransposeContext context;
2948   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2949       item, virtual_cluster_.get(), &context));
2950   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2951 
2952   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2953   auto* c2d = context.graph_view->GetNode("conv2d");
2954   ASSERT_NE(c2d, nullptr);
2955   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2956 
2957   ReverseV2Transposer reverse_v2_transposer;
2958   auto* reverse_v2 = context.graph_view->GetNode("reverse_v2");
2959   ASSERT_NE(reverse_v2, nullptr);
2960   TF_ASSERT_OK(reverse_v2_transposer.TransposeNode(&context, reverse_v2));
2961 
2962   auto* input_transpose_node = context.graph_view->GetNode(
2963       "reverse_v2-0-TransposeNHWCToNCHW-LayoutOptimizer");
2964   ASSERT_NE(input_transpose_node, nullptr);
2965   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
2966   VerifyRegularFaninMatch(input_transpose_node, 0,
2967                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2968 
2969   auto* axis_node = context.graph_view->GetNode(
2970       "reverse_v2-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
2971   ASSERT_NE(axis_node, nullptr);
2972   ASSERT_EQ(axis_node->NumRegularFanins(), 1);
2973   VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
2974 
2975   auto* updated_reverse_v2 = context.graph_view->GetNode("reverse_v2");
2976   ASSERT_NE(updated_reverse_v2, nullptr);
2977   ASSERT_EQ(updated_reverse_v2->NumRegularFanins(), 2);
2978   VerifyRegularFaninMatch(updated_reverse_v2, 0,
2979                           input_transpose_node->GetName(), 0);
2980   VerifyRegularFaninMatch(updated_reverse_v2, 1, axis_node->GetName(), 0);
2981 
2982   auto* output_transpose_node = context.graph_view->GetNode(
2983       "reverse_v2-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2984   ASSERT_NE(output_transpose_node, nullptr);
2985 
2986   auto* z_output_node = context.graph_view->GetNode("z");
2987   ASSERT_NE(z_output_node, nullptr);
2988   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2989   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2990                           0);
2991 }
2992 
TEST_F(TransposerTest,TileTransposer)2993 TEST_F(TransposerTest, TileTransposer) {
2994 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2995   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2996 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2997   GrapplerItem item;
2998   Scope scope = Scope::NewRootScope();
2999 
3000   auto input =
3001       ops::RandomUniform(scope.WithOpName("input"),
3002                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3003   auto filter =
3004       ops::RandomUniform(scope.WithOpName("filter"),
3005                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3006   Output conv2d = ops::Conv2D(
3007       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3008       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3009   auto multiple = ops::Const(scope.WithOpName("multiple"), {1, 1, 2, 3}, {4});
3010   auto tile_op = ops::Tile(scope.WithOpName("tile").WithDevice("/device:GPU:0"),
3011                            conv2d, multiple);
3012   auto z = ops::Identity(scope.WithOpName("z"), tile_op);
3013   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3014 
3015   TransposeContext context;
3016   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3017       item, virtual_cluster_.get(), &context));
3018   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3019 
3020   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3021   auto* c2d = context.graph_view->GetNode("conv2d");
3022   ASSERT_NE(c2d, nullptr);
3023   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3024 
3025   TileTransposer tile_transposer;
3026   auto* tile = context.graph_view->GetNode("tile");
3027   ASSERT_NE(tile, nullptr);
3028   TF_ASSERT_OK(tile_transposer.TransposeNode(&context, tile));
3029 
3030   auto* input_transpose_node =
3031       context.graph_view->GetNode("tile-0-TransposeNHWCToNCHW-LayoutOptimizer");
3032   ASSERT_NE(input_transpose_node, nullptr);
3033   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
3034   VerifyRegularFaninMatch(input_transpose_node, 0,
3035                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3036 
3037   auto* multiple_node = context.graph_view->GetNode(
3038       "tile-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3039   ASSERT_NE(multiple_node, nullptr);
3040   ASSERT_EQ(multiple_node->NumRegularFanins(), 1);
3041   VerifyRegularFaninMatch(multiple_node, 0, "multiple", 0);
3042 
3043   auto* updated_tile = context.graph_view->GetNode("tile");
3044   ASSERT_NE(updated_tile, nullptr);
3045   ASSERT_EQ(updated_tile->NumRegularFanins(), 2);
3046   VerifyRegularFaninMatch(updated_tile, 0, input_transpose_node->GetName(), 0);
3047   VerifyRegularFaninMatch(updated_tile, 1, multiple_node->GetName(), 0);
3048 
3049   auto* output_transpose_node = context.graph_view->GetNode(
3050       "tile-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3051   ASSERT_NE(output_transpose_node, nullptr);
3052 
3053   auto* z_output_node = context.graph_view->GetNode("z");
3054   ASSERT_NE(z_output_node, nullptr);
3055   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3056   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
3057                           0);
3058 }
3059 
TEST_F(TransposerTest,ShapeTransposer)3060 TEST_F(TransposerTest, ShapeTransposer) {
3061 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3062   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3063 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3064   GrapplerItem item;
3065   Scope scope = Scope::NewRootScope();
3066   auto input =
3067       ops::RandomUniform(scope.WithOpName("input"),
3068                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3069   auto filter =
3070       ops::RandomUniform(scope.WithOpName("filter"),
3071                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3072   Output conv2d = ops::Conv2D(
3073       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3074       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3075   auto shape =
3076       ops::Shape(scope.WithOpName("shape").WithDevice("/device:GPU:0"), conv2d);
3077   auto z = ops::Identity(scope.WithOpName("z"), shape);
3078 
3079   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3080 
3081   TransposeContext context;
3082   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3083       item, virtual_cluster_.get(), &context));
3084   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3085 
3086   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3087   auto* c2d = context.graph_view->GetNode("conv2d");
3088   ASSERT_NE(c2d, nullptr);
3089   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3090 
3091   ShapeTransposer shape_transposer;
3092   auto* shape_node = context.graph_view->GetNode("shape");
3093   ASSERT_NE(shape_node, nullptr);
3094   TF_ASSERT_OK(shape_transposer.TransposeNode(&context, shape_node));
3095 
3096   auto* conv2d_transpose_node = context.graph_view->GetNode(
3097       "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3098   ASSERT_NE(conv2d_transpose_node, nullptr);
3099 
3100   auto* shape_input_node = context.graph_view->GetNode(
3101       "shape-0-TransposeNHWCToNCHW-LayoutOptimizer");
3102   ASSERT_NE(shape_input_node, nullptr);
3103   ASSERT_EQ(shape_input_node->NumRegularFanins(), 2);
3104   VerifyRegularFaninMatch(shape_input_node, 0, conv2d_transpose_node->GetName(),
3105                           0);
3106 
3107   auto* output_vec_perm_node = context.graph_view->GetNode(
3108       "shape-0-0-DataFormatVecPermuteNCHWToNHWC-LayoutOptimizer");
3109   ASSERT_NE(output_vec_perm_node, nullptr);
3110 
3111   auto* z_output_node = context.graph_view->GetNode("z");
3112   ASSERT_NE(z_output_node, nullptr);
3113   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3114   VerifyRegularFaninMatch(z_output_node, 0, output_vec_perm_node->GetName(), 0);
3115 }
3116 
TEST_F(TransposerTest,ShapeNTransposer)3117 TEST_F(TransposerTest, ShapeNTransposer) {
3118 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3119   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3120 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3121   GrapplerItem item;
3122   Scope scope = Scope::NewRootScope();
3123   auto input =
3124       ops::RandomUniform(scope.WithOpName("input"),
3125                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3126   auto filter =
3127       ops::RandomUniform(scope.WithOpName("filter"),
3128                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3129   Output conv2d_1 = ops::Conv2D(
3130       scope.WithOpName("conv2d_1").WithDevice("/device:GPU:0"), input, filter,
3131       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3132   Output conv2d_2 = ops::Conv2D(
3133       scope.WithOpName("conv2d_2").WithDevice("/device:GPU:0"), input, filter,
3134       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3135   Output conv2d_3 = ops::Conv2D(
3136       scope.WithOpName("conv2d_3").WithDevice("/device:GPU:0"), input, filter,
3137       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3138   auto shape =
3139       ops::ShapeN(scope.WithOpName("shape").WithDevice("/device:GPU:0"),
3140                   {conv2d_1, conv2d_2, conv2d_3});
3141   auto z_1 = ops::Identity(scope.WithOpName("z_1"), shape.output[0]);
3142   auto z_2 = ops::Identity(scope.WithOpName("z_2"), shape.output[1]);
3143   auto z_3 = ops::Identity(scope.WithOpName("z_3"), shape.output[2]);
3144 
3145   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3146 
3147   TransposeContext context;
3148   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3149       item, virtual_cluster_.get(), &context));
3150   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3151 
3152   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3153   auto* c2d_1 = context.graph_view->GetNode("conv2d_1");
3154   ASSERT_NE(c2d_1, nullptr);
3155   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d_1));
3156   auto* c2d_2 = context.graph_view->GetNode("conv2d_2");
3157   ASSERT_NE(c2d_2, nullptr);
3158   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d_2));
3159 
3160   ShapeNTransposer shape_transposer;
3161   auto* shape_node = context.graph_view->GetNode("shape");
3162   ASSERT_NE(shape_node, nullptr);
3163   TF_ASSERT_OK(shape_transposer.TransposeNode(&context, shape_node));
3164 
3165   auto* conv2d_1_transpose_node = context.graph_view->GetNode(
3166       "conv2d_1-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3167   ASSERT_NE(conv2d_1_transpose_node, nullptr);
3168   auto* conv2d_2_transpose_node = context.graph_view->GetNode(
3169       "conv2d_2-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3170   ASSERT_NE(conv2d_2_transpose_node, nullptr);
3171 
3172   auto* shape_input_1_node = context.graph_view->GetNode(
3173       "shape-0-TransposeNHWCToNCHW-LayoutOptimizer");
3174   ASSERT_NE(shape_input_1_node, nullptr);
3175   ASSERT_EQ(shape_input_1_node->NumRegularFanins(), 2);
3176   VerifyRegularFaninMatch(shape_input_1_node, 0,
3177                           conv2d_1_transpose_node->GetName(), 0);
3178 
3179   auto* shape_input_2_node = context.graph_view->GetNode(
3180       "shape-1-TransposeNHWCToNCHW-LayoutOptimizer");
3181   ASSERT_NE(shape_input_2_node, nullptr);
3182   ASSERT_EQ(shape_input_2_node->NumRegularFanins(), 2);
3183   VerifyRegularFaninMatch(shape_input_2_node, 0,
3184                           conv2d_2_transpose_node->GetName(), 0);
3185 
3186   auto* updated_shape_node = context.graph_view->GetNode("shape");
3187   ASSERT_NE(updated_shape_node, nullptr);
3188   ASSERT_EQ(updated_shape_node->NumRegularFanins(), 3);
3189   VerifyRegularFaninMatch(updated_shape_node, 0, shape_input_1_node->GetName(),
3190                           0);
3191   VerifyRegularFaninMatch(updated_shape_node, 1, shape_input_2_node->GetName(),
3192                           0);
3193   VerifyRegularFaninMatch(updated_shape_node, 2, "conv2d_3", 0);
3194 
3195   auto* output_vec_perm_node_1 = context.graph_view->GetNode(
3196       "shape-0-0-DataFormatVecPermuteNCHWToNHWC-LayoutOptimizer");
3197   ASSERT_NE(output_vec_perm_node_1, nullptr);
3198   auto* output_vec_perm_node_2 = context.graph_view->GetNode(
3199       "shape-1-0-DataFormatVecPermuteNCHWToNHWC-LayoutOptimizer");
3200   ASSERT_NE(output_vec_perm_node_2, nullptr);
3201 
3202   auto* z_output_node_1 = context.graph_view->GetNode("z_1");
3203   ASSERT_NE(z_output_node_1, nullptr);
3204   ASSERT_EQ(z_output_node_1->NumRegularFanins(), 1);
3205   VerifyRegularFaninMatch(z_output_node_1, 0, output_vec_perm_node_1->GetName(),
3206                           0);
3207 
3208   auto* z_output_node_2 = context.graph_view->GetNode("z_2");
3209   ASSERT_NE(z_output_node_2, nullptr);
3210   ASSERT_EQ(z_output_node_2->NumRegularFanins(), 1);
3211   VerifyRegularFaninMatch(z_output_node_2, 0, output_vec_perm_node_2->GetName(),
3212                           0);
3213 
3214   auto* z_output_node_3 = context.graph_view->GetNode("z_3");
3215   ASSERT_NE(z_output_node_3, nullptr);
3216   ASSERT_EQ(z_output_node_3->NumRegularFanins(), 1);
3217   VerifyRegularFaninMatch(z_output_node_3, 0, updated_shape_node->GetName(), 2);
3218 }
3219 
TEST_F(TransposerTest,FillOpTransposer)3220 TEST_F(TransposerTest, FillOpTransposer) {
3221 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3222   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3223 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3224   GrapplerItem item;
3225   Scope scope = Scope::NewRootScope();
3226   auto input =
3227       ops::RandomUniform(scope.WithOpName("input"),
3228                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3229   auto filter =
3230       ops::RandomUniform(scope.WithOpName("filter"),
3231                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3232   Output conv2d = ops::Conv2D(
3233       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3234       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3235   auto shape = ops::Shape(scope.WithOpName("conv2d"), conv2d);
3236   auto value = ops::Const(scope.WithOpName("value"), 0, {});
3237   auto fill = ops::Fill(scope.WithOpName("fill").WithDevice("/device:GPU:0"),
3238                         shape, value);
3239   auto z = ops::Identity(scope.WithOpName("z"), fill);
3240   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3241 
3242   TransposeContext context;
3243   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3244       item, virtual_cluster_.get(), &context));
3245   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3246 
3247   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3248   auto* c2d = context.graph_view->GetNode("conv2d");
3249   ASSERT_NE(c2d, nullptr);
3250   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3251 
3252   FillOpTransposer fill_op_transposer;
3253   auto* fill_node = context.graph_view->GetNode("fill");
3254   ASSERT_NE(fill_node, nullptr);
3255   TF_ASSERT_OK(fill_op_transposer.TransposeNode(&context, fill_node));
3256 
3257   auto* input_node = context.graph_view->GetNode(
3258       "fill-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3259   ASSERT_NE(input_node, nullptr);
3260 
3261   auto* updated_fill_node = context.graph_view->GetNode("fill");
3262   ASSERT_NE(updated_fill_node, nullptr);
3263   ASSERT_EQ(updated_fill_node->NumRegularFanins(), 2);
3264   VerifyRegularFaninMatch(updated_fill_node, 0, input_node->GetName(), 0);
3265   VerifyRegularFaninMatch(updated_fill_node, 1, "value", 0);
3266 
3267   auto* output_node = context.graph_view->GetNode(
3268       "fill-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3269   ASSERT_NE(output_node, nullptr);
3270   ASSERT_EQ(output_node->NumRegularFanins(), 2);
3271   VerifyRegularFaninMatch(output_node, 0, updated_fill_node->GetName(), 0);
3272 
3273   auto* z_node = context.graph_view->GetNode("z");
3274   ASSERT_NE(z_node, nullptr);
3275   ASSERT_EQ(z_node->NumRegularFanins(), 1);
3276   VerifyRegularFaninMatch(z_node, 0, output_node->GetName(), 0);
3277 }
3278 
TEST_F(TransposerTest,SliceTransposer)3279 TEST_F(TransposerTest, SliceTransposer) {
3280 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3281   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3282 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3283   GrapplerItem item;
3284   Scope scope = Scope::NewRootScope();
3285 
3286   auto input =
3287       ops::RandomUniform(scope.WithOpName("input"),
3288                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3289   auto filter =
3290       ops::RandomUniform(scope.WithOpName("filter"),
3291                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3292   Output conv2d = ops::Conv2D(
3293       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3294       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3295   auto begin = ops::Const(scope.WithOpName("begin"), {0, 0, 2, 1}, {4});
3296   auto size = ops::Const(scope.WithOpName("size"), {1, 1, 2, 3}, {4});
3297   auto slice_op =
3298       ops::Slice(scope.WithOpName("slice").WithDevice("/device:GPU:0"), conv2d,
3299                  begin, size);
3300   auto z = ops::Identity(scope.WithOpName("z"), slice_op);
3301   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3302 
3303   TransposeContext context;
3304   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3305       item, virtual_cluster_.get(), &context));
3306   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3307 
3308   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3309   auto* c2d = context.graph_view->GetNode("conv2d");
3310   ASSERT_NE(c2d, nullptr);
3311   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3312 
3313   SliceTransposer slice_transposer;
3314   auto* slice = context.graph_view->GetNode("slice");
3315   ASSERT_NE(slice, nullptr);
3316   TF_ASSERT_OK(slice_transposer.TransposeNode(&context, slice));
3317 
3318   auto* input_transpose_node = context.graph_view->GetNode(
3319       "slice-0-TransposeNHWCToNCHW-LayoutOptimizer");
3320   ASSERT_NE(input_transpose_node, nullptr);
3321   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
3322   VerifyRegularFaninMatch(input_transpose_node, 0,
3323                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3324 
3325   auto* begin_node = context.graph_view->GetNode(
3326       "slice-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3327   ASSERT_NE(begin_node, nullptr);
3328   ASSERT_EQ(begin_node->NumRegularFanins(), 1);
3329   VerifyRegularFaninMatch(begin_node, 0, "begin", 0);
3330 
3331   auto* size_node = context.graph_view->GetNode(
3332       "slice-2-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3333   ASSERT_NE(size_node, nullptr);
3334   ASSERT_EQ(size_node->NumRegularFanins(), 1);
3335   VerifyRegularFaninMatch(size_node, 0, "size", 0);
3336 
3337   auto* updated_slice_node = context.graph_view->GetNode("slice");
3338   ASSERT_NE(updated_slice_node, nullptr);
3339   ASSERT_EQ(updated_slice_node->NumRegularFanins(), 3);
3340   VerifyRegularFaninMatch(updated_slice_node, 0,
3341                           input_transpose_node->GetName(), 0);
3342   VerifyRegularFaninMatch(updated_slice_node, 1, begin_node->GetName(), 0);
3343   VerifyRegularFaninMatch(updated_slice_node, 2, size_node->GetName(), 0);
3344 
3345   auto* output_transpose_node = context.graph_view->GetNode(
3346       "slice-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3347   ASSERT_NE(output_transpose_node, nullptr);
3348 
3349   auto* z_output_node = context.graph_view->GetNode("z");
3350   ASSERT_NE(z_output_node, nullptr);
3351   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3352   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
3353                           0);
3354 }
3355 
TEST_F(TransposerTest,SplitTransposer)3356 TEST_F(TransposerTest, SplitTransposer) {
3357 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3358   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3359 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3360   GrapplerItem item;
3361   Scope scope = Scope::NewRootScope();
3362 
3363   auto input =
3364       ops::RandomUniform(scope.WithOpName("input"),
3365                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3366   auto filter =
3367       ops::RandomUniform(scope.WithOpName("filter"),
3368                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3369   Output conv2d = ops::Conv2D(
3370       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3371       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3372   auto axis = ops::Const(scope.WithOpName("axis"), 2, {});
3373   auto split_op = ops::Split(
3374       scope.WithOpName("split").WithDevice("/device:GPU:0"), axis, conv2d, 3);
3375   auto z_1 = ops::Identity(scope.WithOpName("z_1"), split_op.output[0]);
3376   auto z_2 = ops::Identity(scope.WithOpName("z_2"), split_op.output[1]);
3377   auto z_3 = ops::Identity(scope.WithOpName("z_3"), split_op.output[2]);
3378   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3379 
3380   TransposeContext context;
3381   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3382       item, virtual_cluster_.get(), &context));
3383   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3384 
3385   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3386   auto* c2d = context.graph_view->GetNode("conv2d");
3387   ASSERT_NE(c2d, nullptr);
3388   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3389 
3390   SplitTransposer split_transposer;
3391   auto* split = context.graph_view->GetNode("split");
3392   ASSERT_NE(split, nullptr);
3393   TF_ASSERT_OK(split_transposer.TransposeNode(&context, split));
3394 
3395   auto* input_transpose_node = context.graph_view->GetNode(
3396       "split-1-TransposeNHWCToNCHW-LayoutOptimizer");
3397   ASSERT_NE(input_transpose_node, nullptr);
3398   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
3399   VerifyRegularFaninMatch(input_transpose_node, 0,
3400                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3401 
3402   auto* axis_node = context.graph_view->GetNode(
3403       "split-0-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
3404   ASSERT_NE(axis_node, nullptr);
3405   ASSERT_EQ(axis_node->NumRegularFanins(), 1);
3406   VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
3407 
3408   auto* updated_split_node = context.graph_view->GetNode("split");
3409   ASSERT_NE(updated_split_node, nullptr);
3410   ASSERT_EQ(updated_split_node->NumRegularFanins(), 2);
3411   VerifyRegularFaninMatch(updated_split_node, 0, axis_node->GetName(), 0);
3412   VerifyRegularFaninMatch(updated_split_node, 1,
3413                           input_transpose_node->GetName(), 0);
3414 
3415   auto* output_transpose_node_1 = context.graph_view->GetNode(
3416       "split-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3417   ASSERT_NE(output_transpose_node_1, nullptr);
3418   auto* output_transpose_node_2 = context.graph_view->GetNode(
3419       "split-1-0-TransposeNCHWToNHWC-LayoutOptimizer");
3420   ASSERT_NE(output_transpose_node_2, nullptr);
3421   auto* output_transpose_node_3 = context.graph_view->GetNode(
3422       "split-2-0-TransposeNCHWToNHWC-LayoutOptimizer");
3423   ASSERT_NE(output_transpose_node_3, nullptr);
3424 
3425   auto* z_output_node_1 = context.graph_view->GetNode("z_1");
3426   ASSERT_NE(z_output_node_1, nullptr);
3427   ASSERT_EQ(z_output_node_1->NumRegularFanins(), 1);
3428   VerifyRegularFaninMatch(z_output_node_1, 0,
3429                           output_transpose_node_1->GetName(), 0);
3430   auto* z_output_node_2 = context.graph_view->GetNode("z_2");
3431   ASSERT_NE(z_output_node_2, nullptr);
3432   ASSERT_EQ(z_output_node_2->NumRegularFanins(), 1);
3433   VerifyRegularFaninMatch(z_output_node_2, 0,
3434                           output_transpose_node_2->GetName(), 0);
3435   auto* z_output_node_3 = context.graph_view->GetNode("z_3");
3436   ASSERT_NE(z_output_node_3, nullptr);
3437   ASSERT_EQ(z_output_node_3->NumRegularFanins(), 1);
3438   VerifyRegularFaninMatch(z_output_node_3, 0,
3439                           output_transpose_node_3->GetName(), 0);
3440 }
3441 
TEST_F(TransposerTest,SplitVTransposer)3442 TEST_F(TransposerTest, SplitVTransposer) {
3443 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3444   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3445 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3446   GrapplerItem item;
3447   Scope scope = Scope::NewRootScope();
3448 
3449   auto input =
3450       ops::RandomUniform(scope.WithOpName("input"),
3451                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3452   auto filter =
3453       ops::RandomUniform(scope.WithOpName("filter"),
3454                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3455   Output conv2d = ops::Conv2D(
3456       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3457       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3458   auto axis = ops::Const(scope.WithOpName("axis"), 1, {});
3459   auto size_splits =
3460       ops::Const(scope.WithOpName("size_splits"), {2, 2, 1}, {3});
3461   auto splitv_op =
3462       ops::SplitV(scope.WithOpName("splitv").WithDevice("/device:GPU:0"),
3463                   conv2d, size_splits, axis, 3);
3464   auto z_1 = ops::Identity(scope.WithOpName("z_1"), splitv_op.output[0]);
3465   auto z_2 = ops::Identity(scope.WithOpName("z_2"), splitv_op.output[1]);
3466   auto z_3 = ops::Identity(scope.WithOpName("z_3"), splitv_op.output[2]);
3467   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3468 
3469   TransposeContext context;
3470   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3471       item, virtual_cluster_.get(), &context));
3472   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3473 
3474   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3475   auto* c2d = context.graph_view->GetNode("conv2d");
3476   ASSERT_NE(c2d, nullptr);
3477   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3478 
3479   SplitVTransposer splitv_transposer;
3480   auto* splitv = context.graph_view->GetNode("splitv");
3481   ASSERT_NE(splitv, nullptr);
3482   TF_ASSERT_OK(splitv_transposer.TransposeNode(&context, splitv));
3483 
3484   auto* input_transpose_node = context.graph_view->GetNode(
3485       "splitv-0-TransposeNHWCToNCHW-LayoutOptimizer");
3486   ASSERT_NE(input_transpose_node, nullptr);
3487   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
3488   VerifyRegularFaninMatch(input_transpose_node, 0,
3489                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3490 
3491   auto* axis_node = context.graph_view->GetNode(
3492       "splitv-2-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
3493   ASSERT_NE(axis_node, nullptr);
3494   ASSERT_EQ(axis_node->NumRegularFanins(), 1);
3495   VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
3496 
3497   auto* updated_splitv_node = context.graph_view->GetNode("splitv");
3498   ASSERT_NE(updated_splitv_node, nullptr);
3499   ASSERT_EQ(updated_splitv_node->NumRegularFanins(), 3);
3500   VerifyRegularFaninMatch(updated_splitv_node, 0,
3501                           input_transpose_node->GetName(), 0);
3502   VerifyRegularFaninMatch(updated_splitv_node, 1, "size_splits", 0);
3503   VerifyRegularFaninMatch(updated_splitv_node, 2, axis_node->GetName(), 0);
3504 
3505   auto* output_transpose_node_1 = context.graph_view->GetNode(
3506       "splitv-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3507   ASSERT_NE(output_transpose_node_1, nullptr);
3508   auto* output_transpose_node_2 = context.graph_view->GetNode(
3509       "splitv-1-0-TransposeNCHWToNHWC-LayoutOptimizer");
3510   ASSERT_NE(output_transpose_node_2, nullptr);
3511   auto* output_transpose_node_3 = context.graph_view->GetNode(
3512       "splitv-2-0-TransposeNCHWToNHWC-LayoutOptimizer");
3513   ASSERT_NE(output_transpose_node_3, nullptr);
3514 
3515   auto* z_output_node_1 = context.graph_view->GetNode("z_1");
3516   ASSERT_NE(z_output_node_1, nullptr);
3517   ASSERT_EQ(z_output_node_1->NumRegularFanins(), 1);
3518   VerifyRegularFaninMatch(z_output_node_1, 0,
3519                           output_transpose_node_1->GetName(), 0);
3520   auto* z_output_node_2 = context.graph_view->GetNode("z_2");
3521   ASSERT_NE(z_output_node_2, nullptr);
3522   ASSERT_EQ(z_output_node_2->NumRegularFanins(), 1);
3523   VerifyRegularFaninMatch(z_output_node_2, 0,
3524                           output_transpose_node_2->GetName(), 0);
3525   auto* z_output_node_3 = context.graph_view->GetNode("z_3");
3526   ASSERT_NE(z_output_node_3, nullptr);
3527   ASSERT_EQ(z_output_node_3->NumRegularFanins(), 1);
3528   VerifyRegularFaninMatch(z_output_node_3, 0,
3529                           output_transpose_node_3->GetName(), 0);
3530 }
3531 
TEST_F(TransposerTest,StridedSliceTransposer)3532 TEST_F(TransposerTest, StridedSliceTransposer) {
3533 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3534   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3535 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3536   GrapplerItem item;
3537   Scope scope = Scope::NewRootScope();
3538 
3539   auto input =
3540       ops::RandomUniform(scope.WithOpName("input"),
3541                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3542   auto filter =
3543       ops::RandomUniform(scope.WithOpName("filter"),
3544                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3545   Output conv2d = ops::Conv2D(
3546       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3547       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3548 
3549   auto attrs = ops::StridedSlice::Attrs().BeginMask(0xB).EndMask(0x7);
3550 
3551   auto begin = ops::Const(scope.WithOpName("begin"), {2, 0, 2, 1}, {4});
3552   auto end = ops::Const(scope.WithOpName("end"), {34, 4, 3, 1}, {4});
3553   auto strides = ops::Const(scope.WithOpName("strides"), {7, 2, 1, 1}, {4});
3554 
3555   auto strided_slice_op = ops::StridedSlice(
3556       scope.WithOpName("stridedslice").WithDevice("/device:GPU:0"), conv2d,
3557       begin, end, strides, attrs);
3558   auto z = ops::Identity(scope.WithOpName("z"), strided_slice_op);
3559   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3560 
3561   TransposeContext context;
3562   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3563       item, virtual_cluster_.get(), &context));
3564   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3565 
3566   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3567   auto* c2d = context.graph_view->GetNode("conv2d");
3568   ASSERT_NE(c2d, nullptr);
3569   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3570 
3571   StridedSliceTransposer stridedslice_transposer;
3572   auto* stridedslice = context.graph_view->GetNode("stridedslice");
3573   ASSERT_NE(stridedslice, nullptr);
3574   TF_ASSERT_OK(stridedslice_transposer.TransposeNode(&context, stridedslice));
3575 
3576   auto* input_transpose_node = context.graph_view->GetNode(
3577       "stridedslice-0-TransposeNHWCToNCHW-LayoutOptimizer");
3578   ASSERT_NE(input_transpose_node, nullptr);
3579   ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
3580   VerifyRegularFaninMatch(input_transpose_node, 0,
3581                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3582 
3583   auto* begin_node = context.graph_view->GetNode(
3584       "stridedslice-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3585   ASSERT_NE(begin_node, nullptr);
3586   auto* end_node = context.graph_view->GetNode(
3587       "stridedslice-2-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3588   ASSERT_NE(end_node, nullptr);
3589   auto* strides_node = context.graph_view->GetNode(
3590       "stridedslice-3-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3591   ASSERT_NE(strides_node, nullptr);
3592 
3593   auto* updated_stridedslice_node = context.graph_view->GetNode("stridedslice");
3594   ASSERT_NE(updated_stridedslice_node, nullptr);
3595   ASSERT_EQ(updated_stridedslice_node->NumRegularFanins(), 4);
3596   VerifyRegularFaninMatch(updated_stridedslice_node, 0,
3597                           input_transpose_node->GetName(), 0);
3598   VerifyRegularFaninMatch(updated_stridedslice_node, 1, begin_node->GetName(),
3599                           0);
3600   VerifyRegularFaninMatch(updated_stridedslice_node, 2, end_node->GetName(), 0);
3601   VerifyRegularFaninMatch(updated_stridedslice_node, 3, strides_node->GetName(),
3602                           0);
3603   const auto* begin_mask_attr =
3604       updated_stridedslice_node->GetAttr("begin_mask");
3605   ASSERT_NE(begin_mask_attr, nullptr);
3606   EXPECT_EQ(begin_mask_attr->i(), 0x7);
3607   const auto* end_mask_attr = updated_stridedslice_node->GetAttr("end_mask");
3608   ASSERT_NE(end_mask_attr, nullptr);
3609   EXPECT_EQ(end_mask_attr->i(), 0xD);
3610 
3611   auto* output_transpose_node = context.graph_view->GetNode(
3612       "stridedslice-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3613   ASSERT_NE(output_transpose_node, nullptr);
3614 
3615   auto* z_output_node = context.graph_view->GetNode("z");
3616   ASSERT_NE(z_output_node, nullptr);
3617   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3618   VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
3619                           0);
3620 }
3621 
TEST_F(TransposerTest,StridedSliceTransposerEllipsisMaskPresent)3622 TEST_F(TransposerTest, StridedSliceTransposerEllipsisMaskPresent) {
3623 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3624   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3625 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3626   GrapplerItem item;
3627   Scope scope = Scope::NewRootScope();
3628 
3629   auto input =
3630       ops::RandomUniform(scope.WithOpName("input"),
3631                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3632   auto filter =
3633       ops::RandomUniform(scope.WithOpName("filter"),
3634                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3635   Output conv2d = ops::Conv2D(
3636       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3637       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3638 
3639   auto attrs =
3640       ops::StridedSlice::Attrs().BeginMask(0xB).EndMask(0x7).EllipsisMask(0x2);
3641 
3642   auto begin = ops::Const(scope.WithOpName("begin"), {2, 0, 2, 1}, {4});
3643   auto end = ops::Const(scope.WithOpName("end"), {34, 4, 3, 1}, {4});
3644   auto strides = ops::Const(scope.WithOpName("strides"), {7, 2, 1, 1}, {4});
3645 
3646   auto strided_slice_op = ops::StridedSlice(
3647       scope.WithOpName("stridedslice").WithDevice("/device:GPU:0"), conv2d,
3648       begin, end, strides, attrs);
3649   auto z = ops::Identity(scope.WithOpName("z"), strided_slice_op);
3650   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3651 
3652   TransposeContext context;
3653   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3654       item, virtual_cluster_.get(), &context));
3655   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3656 
3657   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3658   auto* c2d = context.graph_view->GetNode("conv2d");
3659   ASSERT_NE(c2d, nullptr);
3660   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3661 
3662   StridedSliceTransposer stridedslice_transposer;
3663   auto* stridedslice = context.graph_view->GetNode("stridedslice");
3664   ASSERT_NE(stridedslice, nullptr);
3665   TF_ASSERT_OK(stridedslice_transposer.TransposeNode(&context, stridedslice));
3666 
3667   // Expect StridedSlice Node to remain unchanged because of the ellipsis mask.
3668   auto* updated_stridedslice_node = context.graph_view->GetNode("stridedslice");
3669   ASSERT_NE(updated_stridedslice_node, nullptr);
3670   ASSERT_EQ(updated_stridedslice_node->NumRegularFanins(), 4);
3671   VerifyRegularFaninMatch(updated_stridedslice_node, 0,
3672                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3673   VerifyRegularFaninMatch(updated_stridedslice_node, 1, "begin", 0);
3674   VerifyRegularFaninMatch(updated_stridedslice_node, 2, "end", 0);
3675   VerifyRegularFaninMatch(updated_stridedslice_node, 3, "strides", 0);
3676 
3677   auto* z_output_node = context.graph_view->GetNode("z");
3678   ASSERT_NE(z_output_node, nullptr);
3679   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3680   VerifyRegularFaninMatch(z_output_node, 0,
3681                           updated_stridedslice_node->GetName(), 0);
3682 }
3683 
TEST_F(TransposerTest,StridedSliceTransposerConstFaninBadRank)3684 TEST_F(TransposerTest, StridedSliceTransposerConstFaninBadRank) {
3685 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3686   GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3687 #endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3688   GrapplerItem item;
3689   Scope scope = Scope::NewRootScope();
3690 
3691   auto input =
3692       ops::RandomUniform(scope.WithOpName("input"),
3693                          {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3694   auto filter =
3695       ops::RandomUniform(scope.WithOpName("filter"),
3696                          {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3697   Output conv2d = ops::Conv2D(
3698       scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3699       {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3700 
3701   auto attrs = ops::StridedSlice::Attrs().BeginMask(0xB).EndMask(0x7);
3702 
3703   auto begin = ops::Const(scope.WithOpName("begin"), {2, 0, 2}, {3});
3704   auto end = ops::Const(scope.WithOpName("end"), {34, 4, 3}, {3});
3705   auto strides = ops::Const(scope.WithOpName("strides"), {7, 2, 1}, {3});
3706 
3707   auto strided_slice_op = ops::StridedSlice(
3708       scope.WithOpName("stridedslice").WithDevice("/device:GPU:0"), conv2d,
3709       begin, end, strides, attrs);
3710   auto z = ops::Identity(scope.WithOpName("z"), strided_slice_op);
3711   TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3712 
3713   TransposeContext context;
3714   TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3715       item, virtual_cluster_.get(), &context));
3716   context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3717 
3718   DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3719   auto* c2d = context.graph_view->GetNode("conv2d");
3720   ASSERT_NE(c2d, nullptr);
3721   TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3722 
3723   StridedSliceTransposer stridedslice_transposer;
3724   auto* stridedslice = context.graph_view->GetNode("stridedslice");
3725   ASSERT_NE(stridedslice, nullptr);
3726   TF_ASSERT_OK(stridedslice_transposer.TransposeNode(&context, stridedslice));
3727 
3728   auto* input_transpose_node = context.graph_view->GetNode(
3729       "stridedslice-0-TransposeNHWCToNCHW-LayoutOptimizer");
3730   ASSERT_EQ(input_transpose_node, nullptr);
3731 
3732   auto* begin_node = context.graph_view->GetNode(
3733       "stridedslice-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3734   ASSERT_EQ(begin_node, nullptr);
3735   auto* end_node = context.graph_view->GetNode(
3736       "stridedslice-2-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3737   ASSERT_EQ(end_node, nullptr);
3738   auto* strides_node = context.graph_view->GetNode(
3739       "stridedslice-3-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3740   ASSERT_EQ(strides_node, nullptr);
3741 
3742   auto* updated_stridedslice_node = context.graph_view->GetNode("stridedslice");
3743   ASSERT_NE(updated_stridedslice_node, nullptr);
3744   ASSERT_EQ(updated_stridedslice_node->NumRegularFanins(), 4);
3745   VerifyRegularFaninMatch(updated_stridedslice_node, 0,
3746                           "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3747   VerifyRegularFaninMatch(updated_stridedslice_node, 1, "begin", 0);
3748   VerifyRegularFaninMatch(updated_stridedslice_node, 2, "end", 0);
3749   VerifyRegularFaninMatch(updated_stridedslice_node, 3, "strides", 0);
3750   const auto* begin_mask_attr =
3751       updated_stridedslice_node->GetAttr("begin_mask");
3752   ASSERT_NE(begin_mask_attr, nullptr);
3753   EXPECT_EQ(begin_mask_attr->i(), 0xB);
3754   const auto* end_mask_attr = updated_stridedslice_node->GetAttr("end_mask");
3755   ASSERT_NE(end_mask_attr, nullptr);
3756   EXPECT_EQ(end_mask_attr->i(), 0x7);
3757 
3758   auto* output_transpose_node = context.graph_view->GetNode(
3759       "stridedslice-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3760   ASSERT_EQ(output_transpose_node, nullptr);
3761 
3762   auto* z_output_node = context.graph_view->GetNode("z");
3763   ASSERT_NE(z_output_node, nullptr);
3764   ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3765   VerifyRegularFaninMatch(z_output_node, 0,
3766                           updated_stridedslice_node->GetName(), 0);
3767 }
3768 
TEST_F(TransposerTest,ReduceTransposerKeepDims)3769 TEST_F(TransposerTest, ReduceTransposerKeepDims) {
3770   ReduceTransposerKeepDims<int32>();
3771   ReduceTransposerKeepDims<int64_t>();
3772 }
3773 
TEST_F(TransposerTest,ReduceTransposerValidAxisNode)3774 TEST_F(TransposerTest, ReduceTransposerValidAxisNode) {
3775   ReduceTransposerValidAxisNode<int32>();
3776   ReduceTransposerValidAxisNode<int64_t>();
3777 }
3778 
TEST(PermutationTest,PermutesVector)3779 TEST(PermutationTest, PermutesVector) {
3780   std::vector<int64_t> input{32, 16, 8, 4};
3781   std::vector<int64_t> expected{4, 8, 16, 32};
3782   TF_ASSERT_OK(PermuteSingle("test", {3, 2, 1, 0}, &input));
3783   ASSERT_EQ(input.size(), 4);
3784   for (int i = 0; i < input.size(); ++i) {
3785     EXPECT_EQ(input[i], expected[i]);
3786   }
3787 }
3788 
TEST(PermutationTest,PermutesRepeatedField)3789 TEST(PermutationTest, PermutesRepeatedField) {
3790   TensorShapeProto input_shape = MakeTensorShapeFromDimensions({1, 2, 3, 4});
3791   TensorShapeProto expected_shape = MakeTensorShapeFromDimensions({1, 4, 2, 3});
3792 
3793   TF_ASSERT_OK(PermuteSingle("test", {0, 3, 1, 2}, input_shape.mutable_dim()));
3794   EXPECT_EQ(input_shape.DebugString(), expected_shape.DebugString());
3795 }
3796 
TEST(PermutationTest,PermutesDoubleRepeatedField)3797 TEST(PermutationTest, PermutesDoubleRepeatedField) {
3798   {
3799     // NHWC -> NCHW
3800     TensorShapeProto input =
3801         MakeTensorShapeFromDimensions({1, 2, 3, 4, 5, 6, 7, 8});
3802     TensorShapeProto expected =
3803         MakeTensorShapeFromDimensions({1, 2, 7, 8, 3, 4, 5, 6});
3804 
3805     TF_ASSERT_OK(PermuteDouble("test", {0, 3, 1, 2}, input.mutable_dim()));
3806     EXPECT_EQ(input.DebugString(), expected.DebugString());
3807   }
3808   {
3809     // NCHW -> NHWC
3810     TensorShapeProto input =
3811         MakeTensorShapeFromDimensions({1, 2, 3, 4, 5, 6, 7, 8});
3812     TensorShapeProto expected =
3813         MakeTensorShapeFromDimensions({1, 2, 5, 6, 7, 8, 3, 4});
3814     TF_ASSERT_OK(PermuteDouble("test", {0, 2, 3, 1}, input.mutable_dim()));
3815     EXPECT_EQ(input.DebugString(), expected.DebugString());
3816   }
3817 }
3818 
TEST(PermutationTest,PermutesDataFormat)3819 TEST(PermutationTest, PermutesDataFormat) {
3820   string input = "NHWC";
3821   string expected = "NCHW";
3822   TF_ASSERT_OK(PermuteSingle("test", {0, 3, 1, 2}, &input));
3823   EXPECT_EQ(input, expected);
3824 }
3825 
TEST(PermutationTest,PermutesString)3826 TEST(PermutationTest, PermutesString) {
3827   string input = "ABCD";
3828   string expected = "ACBD";
3829   TF_ASSERT_OK(PermuteSingle("test", {0, 2, 1, 3}, &input));
3830   EXPECT_EQ(input, expected);
3831 }
3832 
TEST(PermutationTest,GetNHWCToNCHWPermutation)3833 TEST(PermutationTest, GetNHWCToNCHWPermutation) {
3834   string src_format = "NHWC";
3835   absl::flat_hash_map<char, int> src_dim_indices =
3836       GetDimensionIndices(src_format);
3837   EXPECT_EQ(src_dim_indices.size(), 4);
3838   EXPECT_EQ(src_dim_indices['N'], 0);
3839   EXPECT_EQ(src_dim_indices['H'], 1);
3840   EXPECT_EQ(src_dim_indices['W'], 2);
3841   EXPECT_EQ(src_dim_indices['C'], 3);
3842   string dst_format = "NCHW";
3843   std::vector<int> permutation = GetPermutation(src_dim_indices, dst_format);
3844   ASSERT_EQ(permutation.size(), 4);
3845   EXPECT_EQ(permutation[0], 0);
3846   EXPECT_EQ(permutation[1], 3);
3847   EXPECT_EQ(permutation[2], 1);
3848   EXPECT_EQ(permutation[3], 2);
3849 }
3850 
TEST(PermutationTest,GetNCHWToNHWCPermutation)3851 TEST(PermutationTest, GetNCHWToNHWCPermutation) {
3852   string src_format = "NCHW";
3853   absl::flat_hash_map<char, int> src_dim_indices =
3854       GetDimensionIndices(src_format);
3855   EXPECT_EQ(src_dim_indices.size(), 4);
3856   EXPECT_EQ(src_dim_indices['N'], 0);
3857   EXPECT_EQ(src_dim_indices['C'], 1);
3858   EXPECT_EQ(src_dim_indices['H'], 2);
3859   EXPECT_EQ(src_dim_indices['W'], 3);
3860   string dst_format = "NHWC";
3861   std::vector<int> permutation = GetPermutation(src_dim_indices, dst_format);
3862   ASSERT_EQ(permutation.size(), 4);
3863   EXPECT_EQ(permutation[0], 0);
3864   EXPECT_EQ(permutation[1], 2);
3865   EXPECT_EQ(permutation[2], 3);
3866   EXPECT_EQ(permutation[3], 1);
3867 }
3868 
3869 // TODO(yanzha): Add frame related tests.
3870 }  // namespace
3871 }  // namespace grappler
3872 }  // namespace tensorflow
3873