1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h"
17
18 #include "absl/strings/string_view.h"
19 #include "absl/types/span.h"
20 #include "tensorflow/cc/ops/const_op.h"
21 #include "tensorflow/cc/ops/math_ops_internal.h"
22 #include "tensorflow/cc/ops/nn_ops.h"
23 #include "tensorflow/cc/ops/nn_ops_internal.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/grappler/clusters/cluster.h"
29 #include "tensorflow/core/grappler/clusters/single_machine.h"
30 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
31 #include "tensorflow/core/grappler/devices.h"
32 #include "tensorflow/core/grappler/grappler_item.h"
33 #include "tensorflow/core/grappler/utils/graph_view.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36
37 namespace tensorflow {
38 namespace grappler {
39
40 namespace {
41
42 using ::tensorflow::test::ExpectTensorEqual;
43
44 constexpr int kBatchSize = 32;
45 constexpr int kWidth = 10;
46 constexpr int kHeight = 10;
47 constexpr int kDepthIn = 8;
48 constexpr int kKernel = 2;
49 constexpr int kStride1 = 2;
50 constexpr int kStride2 = 4;
51 constexpr int kOutWidth = 5;
52 constexpr int kOutHeight = 5;
53 constexpr int kDepthOut = 16;
54 constexpr int kDilation = 2;
55 constexpr int kPaddingTop = 1;
56 constexpr int kPaddingBottom = 2;
57 constexpr int kPaddingLeft = 3;
58 constexpr int kPaddingRight = 4;
59 constexpr char kSrcFormat[] = "NHWC";
60 constexpr char kDstFormat[] = "NCHW";
61 constexpr char kGPU[] = "GPU";
62 constexpr char kAttrOutputShapes[] = "_output_shapes";
63 constexpr char kAttrDataFormat[] = "data_format";
64 constexpr char kOpTranspose[] = "Transpose";
65
66 class TransposerImpl : public Transposer {
67 public:
TransposerImpl()68 explicit TransposerImpl() : Transposer() {}
TransposeNode(TransposeContext *,utils::MutableNodeView *)69 Status TransposeNode(TransposeContext*, utils::MutableNodeView*) override {
70 return OkStatus();
71 }
72 };
73
VerifyRegularFaninMatch(const utils::MutableNodeView * node,int port,absl::string_view fanin_name,int fanin_port)74 void VerifyRegularFaninMatch(const utils::MutableNodeView* node, int port,
75 absl::string_view fanin_name, int fanin_port) {
76 ASSERT_GT(node->NumRegularFanins(), port);
77 const auto& fanin = node->GetRegularFanin(port);
78 EXPECT_EQ(fanin.node_view()->GetName(), fanin_name);
79 EXPECT_EQ(fanin.index(), fanin_port);
80 }
81
VerifyShapeAttributeMatch(const utils::MutableNodeView * node,absl::string_view attr_value)82 void VerifyShapeAttributeMatch(const utils::MutableNodeView* node,
83 absl::string_view attr_value) {
84 const auto* attr = node->GetAttr(kAttrOutputShapes);
85 ASSERT_NE(attr, nullptr);
86 EXPECT_EQ(attr->shape().DebugString(), attr_value);
87 }
88
VerifyShapeAttributeMatch(const utils::MutableNodeView * node,int shape_index,absl::string_view attr_value)89 void VerifyShapeAttributeMatch(const utils::MutableNodeView* node,
90 int shape_index, absl::string_view attr_value) {
91 const auto* attr = node->GetAttr(kAttrOutputShapes);
92 ASSERT_NE(attr, nullptr);
93 ASSERT_GT(attr->list().shape_size(), shape_index);
94 EXPECT_EQ(attr->list().shape(shape_index).DebugString(), attr_value);
95 }
96
VerifyDataFormatAttributeMatch(const utils::MutableNodeView * node,absl::string_view attr_value)97 void VerifyDataFormatAttributeMatch(const utils::MutableNodeView* node,
98 absl::string_view attr_value) {
99 const auto* attr = node->GetAttr(kAttrDataFormat);
100 ASSERT_NE(attr, nullptr);
101 EXPECT_EQ(attr->s(), attr_value);
102 }
103
SimpleConv2D(const Scope * scope,const DataType & data_type=DT_FLOAT)104 Output SimpleConv2D(const Scope* scope, const DataType& data_type = DT_FLOAT) {
105 auto input =
106 ops::RandomUniform(scope->WithOpName("input"),
107 {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
108 auto filter =
109 ops::RandomUniform(scope->WithOpName("filter"),
110 {kHeight, kWidth, kDepthIn, kDepthOut}, data_type);
111 auto conv2d = ops::Conv2D(
112 scope->WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
113 {1, kStride1, kStride2, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
114
115 return conv2d;
116 }
117
CreateSimpleConv2DGraph(GraphDef * graph,const DataType & data_type=DT_FLOAT)118 Status CreateSimpleConv2DGraph(GraphDef* graph,
119 const DataType& data_type = DT_FLOAT) {
120 Scope scope = Scope::NewRootScope();
121 auto conv2d = SimpleConv2D(&scope, data_type);
122 auto output = ops::Identity(scope.WithOpName("output"), conv2d);
123
124 return scope.ToGraphDef(graph);
125 }
126
CreateSimpleFusedBatchNorm(GraphDef * graph,const DataType & data_type=DT_FLOAT)127 Status CreateSimpleFusedBatchNorm(GraphDef* graph,
128 const DataType& data_type = DT_FLOAT) {
129 Scope scope = Scope::NewRootScope();
130 auto x =
131 ops::RandomUniform(scope.WithOpName("x"),
132 {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
133 auto scale =
134 ops::RandomUniform(scope.WithOpName("scale"), {kDepthIn}, DT_FLOAT);
135 auto offset =
136 ops::RandomUniform(scope.WithOpName("offset"), {kDepthIn}, DT_FLOAT);
137 auto mean =
138 ops::RandomUniform(scope.WithOpName("mean"), {kDepthIn}, DT_FLOAT);
139 auto var = ops::RandomUniform(scope.WithOpName("var"), {kDepthIn}, DT_FLOAT);
140 auto batch_norm = ops::FusedBatchNormV2(
141 scope.WithOpName("bn").WithDevice("/device:GPU:0"), x, scale, offset,
142 mean, var, ops::FusedBatchNormV2::IsTraining(false).Epsilon(0.1f));
143
144 auto output_y = ops::Identity(scope.WithOpName("output_y"), batch_norm.y);
145 auto output_mean =
146 ops::Identity(scope.WithOpName("output_mean"), batch_norm.batch_mean);
147 auto output_variance = ops::Identity(scope.WithOpName("output_variance"),
148 batch_norm.batch_variance);
149
150 return scope.ToGraphDef(graph);
151 }
152
CreateSimpleMaxPoolGrad(GraphDef * graph,bool use_grad_grad)153 Status CreateSimpleMaxPoolGrad(GraphDef* graph, bool use_grad_grad) {
154 Scope scope = Scope::NewRootScope();
155 auto input =
156 ops::RandomUniform(scope.WithOpName("orig_input"),
157 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
158 auto output_data = ops::RandomUniform(
159 scope.WithOpName("orig_output"),
160 {kBatchSize, kOutHeight, kOutWidth, kDepthIn}, DT_FLOAT);
161 auto output_grad =
162 ops::RandomUniform(scope.WithOpName("grad"),
163 {kBatchSize, use_grad_grad ? kHeight : kOutHeight,
164 use_grad_grad ? kWidth : kOutWidth, kDepthIn},
165 DT_FLOAT);
166 Output maxpool_grad;
167 if (use_grad_grad) {
168 maxpool_grad = ops::MaxPoolGradGrad(
169 scope.WithOpName("maxpool_grad").WithDevice("/device:GPU:0"), input,
170 output_data, output_grad, {1, kKernel, kKernel, 1},
171 {1, kStride1, kStride1, 1}, "VALID");
172 } else {
173 maxpool_grad = ops::internal::MaxPoolGrad(
174 scope.WithOpName("maxpool_grad").WithDevice("/device:GPU:0"), input,
175 output_data, output_grad, {1, kKernel, kKernel, 1},
176 {1, kStride1, kStride1, 1}, "VALID");
177 }
178
179 auto output = ops::Identity(scope.WithOpName("output"), maxpool_grad);
180
181 return scope.ToGraphDef(graph);
182 }
183
CreateSimpleBiasAddGrad(GraphDef * graph,const Input & shape)184 Status CreateSimpleBiasAddGrad(GraphDef* graph, const Input& shape) {
185 Scope scope = Scope::NewRootScope();
186 auto input = ops::RandomUniform(scope.WithOpName("input"), shape, DT_FLOAT);
187 auto bag =
188 ops::BiasAddGrad(scope.WithOpName("bag").WithDevice("/device:GPU:0"),
189 input, ops::BiasAddGrad::DataFormat(kSrcFormat));
190 auto output = ops::Identity(scope.WithOpName("output"), bag);
191
192 return scope.ToGraphDef(graph);
193 }
194
CreateSimpleConv2DBackpropFilter(GraphDef * graph,const DataType & data_type=DT_FLOAT,absl::string_view padding="SAME")195 Status CreateSimpleConv2DBackpropFilter(GraphDef* graph,
196 const DataType& data_type = DT_FLOAT,
197 absl::string_view padding = "SAME") {
198 Scope scope = Scope::NewRootScope();
199 auto input =
200 ops::RandomUniform(scope.WithOpName("input"),
201 {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
202 auto out_backprop =
203 ops::RandomUniform(scope.WithOpName("out_backprop"),
204 {kBatchSize, kHeight, kWidth, kDepthOut}, data_type);
205 if (padding == "EXPLICIT") {
206 auto conv2d_backprop_filter = ops::Conv2DBackpropFilter(
207 scope.WithOpName("conv2d_backprop_filter").WithDevice("/device:GPU:0"),
208 input, {kHeight, kWidth, kDepthIn, kDepthOut}, out_backprop,
209 {1, 2, 4, 1}, padding,
210 ops::Conv2DBackpropFilter::Attrs()
211 .Dilations({1, kDilation, kDilation, 1})
212 .ExplicitPaddings({0, 0, kPaddingTop, kPaddingBottom, kPaddingLeft,
213 kPaddingRight, 0, 0})
214 .DataFormat(kSrcFormat));
215 auto output =
216 ops::Identity(scope.WithOpName("output"), conv2d_backprop_filter);
217 } else {
218 auto conv2d_backprop_filter = ops::Conv2DBackpropFilter(
219 scope.WithOpName("conv2d_backprop_filter").WithDevice("/device:GPU:0"),
220 input, {kHeight, kWidth, kDepthIn, kDepthOut}, out_backprop,
221 {1, 2, 4, 1}, padding,
222 ops::Conv2DBackpropFilter::DataFormat(kSrcFormat));
223 auto output =
224 ops::Identity(scope.WithOpName("output"), conv2d_backprop_filter);
225 }
226
227 return scope.ToGraphDef(graph);
228 }
229
CreateSimpleConv2DBackpropInput(GraphDef * graph,const DataType & data_type=DT_FLOAT)230 Status CreateSimpleConv2DBackpropInput(GraphDef* graph,
231 const DataType& data_type = DT_FLOAT) {
232 Scope scope = Scope::NewRootScope();
233 auto input_sizes = ops::Const(scope.WithOpName("input_sizes"),
234 {kBatchSize, kHeight, kWidth, kDepthIn});
235 auto input =
236 ops::RandomUniform(scope.WithOpName("input"),
237 {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
238 auto filter =
239 ops::RandomUniform(scope.WithOpName("filter"),
240 {kHeight, kWidth, kDepthIn, kDepthOut}, data_type);
241 auto out_backprop =
242 ops::RandomUniform(scope.WithOpName("out_backprop"),
243 {kBatchSize, kHeight, kWidth, kDepthOut}, data_type);
244 auto conv2d_backprop_input = ops::Conv2DBackpropInput(
245 scope.WithOpName("conv2d_backprop_input").WithDevice("/device:GPU:0"),
246 input_sizes, filter, out_backprop, {1, kStride1, kStride1, 1}, "VALID");
247 auto output =
248 ops::Identity(scope.WithOpName("output"), conv2d_backprop_input);
249
250 return scope.ToGraphDef(graph);
251 }
252
CreateSimpleFusedBatchNormGrad(GraphDef * graph,bool is_training,const DataType & data_type=DT_FLOAT)253 Status CreateSimpleFusedBatchNormGrad(GraphDef* graph, bool is_training,
254 const DataType& data_type = DT_FLOAT) {
255 Scope scope = Scope::NewRootScope();
256 auto y_backprop =
257 ops::RandomUniform(scope.WithOpName("y_backprop"),
258 {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
259 auto x =
260 ops::RandomUniform(scope.WithOpName("x"),
261 {kBatchSize, kHeight, kWidth, kDepthIn}, data_type);
262 auto scale =
263 ops::RandomUniform(scope.WithOpName("scale"), {kDepthIn}, DT_FLOAT);
264 auto reserve_space_1 = ops::RandomUniform(scope.WithOpName("reserve_space_1"),
265 {kDepthIn}, DT_FLOAT);
266 auto reserve_space_2 = ops::RandomUniform(scope.WithOpName("reserve_space_2"),
267 {kDepthIn}, DT_FLOAT);
268 auto fused_batch_norm_grad = ops::FusedBatchNormGradV2(
269 scope.WithOpName("fused_batch_norm_grad").WithDevice("/device:GPU:0"),
270 y_backprop, x, scale, reserve_space_1, reserve_space_2,
271 ops::FusedBatchNormGradV2::DataFormat(kSrcFormat)
272 .IsTraining(is_training)
273 .Epsilon(0.1f));
274 auto x_backprop = ops::Identity(scope.WithOpName("x_backprop"),
275 fused_batch_norm_grad.x_backprop);
276 auto scale_backprop = ops::Identity(scope.WithOpName("scale_backprop"),
277 fused_batch_norm_grad.scale_backprop);
278 auto offset_backprop = ops::Identity(scope.WithOpName("offset_backprop"),
279 fused_batch_norm_grad.offset_backprop);
280 auto reserve_space_3 = ops::Identity(scope.WithOpName("reserve_space_3"),
281 fused_batch_norm_grad.reserve_space_3);
282 auto reserve_space_4 = ops::Identity(scope.WithOpName("reserve_space_4"),
283 fused_batch_norm_grad.reserve_space_4);
284
285 return scope.ToGraphDef(graph);
286 }
287
CreateSimpleAddN(GraphDef * graph)288 Status CreateSimpleAddN(GraphDef* graph) {
289 Scope scope = Scope::NewRootScope();
290 auto input =
291 ops::RandomUniform(scope.WithOpName("input"),
292 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
293 auto filter =
294 ops::RandomUniform(scope.WithOpName("filter"),
295 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
296 Output conv2d = ops::Conv2D(
297 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
298 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
299 Output a = ops::RandomUniform(scope.WithOpName("a"),
300 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
301 Output b = ops::RandomUniform(scope.WithOpName("b"),
302 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
303 Output c = ops::RandomUniform(scope.WithOpName("c"),
304 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
305 auto add_n = ops::AddN(scope.WithOpName("add_n").WithDevice("/device:GPU:0"),
306 {a, b, c, conv2d});
307 auto output = ops::Identity(scope.WithOpName("output"), add_n);
308
309 return scope.ToGraphDef(graph);
310 }
311
CreateSimpleIdentityN(GraphDef * graph)312 Status CreateSimpleIdentityN(GraphDef* graph) {
313 Scope scope = Scope::NewRootScope();
314 auto conv2d_1_input =
315 ops::RandomUniform(scope.WithOpName("conv2d_1_input"),
316 {kBatchSize, kDepthIn, kHeight, kWidth}, DT_FLOAT);
317 auto conv2d_1_filter =
318 ops::RandomUniform(scope.WithOpName("conv2d_1_filter"),
319 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
320 Output conv2d_1 =
321 ops::Conv2D(scope.WithOpName("conv2d_1").WithDevice("/device:GPU:0"),
322 conv2d_1_input, conv2d_1_filter, {1, 1, 2, 4}, "SAME",
323 ops::Conv2D::DataFormat(kDstFormat));
324 auto conv2d_2_input =
325 ops::RandomUniform(scope.WithOpName("conv2d_2_input"),
326 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
327 auto conv2d_2_filter =
328 ops::RandomUniform(scope.WithOpName("conv2d_2_filter"),
329 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
330 Output conv2d_2 =
331 ops::Conv2D(scope.WithOpName("conv2d_2").WithDevice("/device:GPU:0"),
332 conv2d_2_input, conv2d_2_filter, {1, 2, 4, 1}, "SAME",
333 ops::Conv2D::DataFormat(kSrcFormat));
334 Output a = ops::RandomUniform(
335 scope.WithOpName("a"), {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
336 Output b = ops::RandomUniform(scope.WithOpName("b"), {kBatchSize, kDepthIn},
337 DT_FLOAT);
338 auto identity_n =
339 ops::IdentityN(scope.WithOpName("identity_n").WithDevice("/device:GPU:0"),
340 {conv2d_1, conv2d_2, a, b});
341 auto conv2d_1_output =
342 ops::Identity(scope.WithOpName("conv2d_1_output"), identity_n.output[0]);
343 auto conv2d_2_output =
344 ops::Identity(scope.WithOpName("conv2d_2_output"), identity_n.output[1]);
345 auto a_output =
346 ops::Identity(scope.WithOpName("a_output"), identity_n.output[2]);
347 auto b_output =
348 ops::Identity(scope.WithOpName("b_output"), identity_n.output[3]);
349
350 return scope.ToGraphDef(graph);
351 }
352
353 class TransposerTest : public ::testing::Test {
354 protected:
SetUp()355 void SetUp() override {
356 bool gpu_available = GetNumAvailableGPUs() > 0;
357
358 if (gpu_available) {
359 virtual_cluster_ =
360 std::make_unique<SingleMachine>(/*timeout_s=*/10, 1, 1);
361 } else {
362 DeviceProperties gpu_device;
363 gpu_device.set_type(kGPU);
364 gpu_device.mutable_environment()->insert({"architecture", "6"});
365 virtual_cluster_ =
366 absl::WrapUnique(new VirtualCluster({{"/GPU:1", gpu_device}}));
367 }
368 TF_ASSERT_OK(virtual_cluster_->Provision());
369 }
370
TearDown()371 void TearDown() override { TF_ASSERT_OK(virtual_cluster_->Shutdown()); }
372
373 template <typename T>
ReduceTransposerKeepDims()374 void ReduceTransposerKeepDims() {
375 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
376 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
377 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
378 GrapplerItem item;
379 Scope scope = Scope::NewRootScope();
380
381 auto input =
382 ops::RandomUniform(scope.WithOpName("input"),
383 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
384 auto filter =
385 ops::RandomUniform(scope.WithOpName("filter"),
386 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
387 Output conv2d = ops::Conv2D(
388 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
389 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
390
391 auto axis = ops::Const<T>(scope.WithOpName("axis"), {0, 1, 2}, {3});
392 auto attrs = ops::Sum::Attrs().KeepDims(true);
393 auto sum_op = ops::Sum(scope.WithOpName("sum").WithDevice("/device:GPU:0"),
394 conv2d, axis, attrs);
395
396 auto z = ops::Identity(scope.WithOpName("z"), sum_op);
397 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
398
399 TransposeContext context;
400 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
401 item, virtual_cluster_.get(), &context));
402 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
403
404 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
405 auto* c2d = context.graph_view->GetNode("conv2d");
406 ASSERT_NE(c2d, nullptr);
407 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
408
409 ReduceTransposer reducer_transposer;
410 auto* sum = context.graph_view->GetNode("sum");
411 ASSERT_NE(sum, nullptr);
412 TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, sum));
413
414 auto* input_transpose_node = context.graph_view->GetNode(
415 "sum-0-TransposeNHWCToNCHW-LayoutOptimizer");
416 ASSERT_NE(input_transpose_node, nullptr);
417
418 auto* updated_sum_node = context.graph_view->GetNode("sum");
419 ASSERT_NE(updated_sum_node, nullptr);
420 ASSERT_EQ(updated_sum_node->NumRegularFanins(), 2);
421 VerifyRegularFaninMatch(updated_sum_node, 0,
422 input_transpose_node->GetName(), 0);
423
424 auto* axis_node = context.graph_view->GetNode(
425 "sum-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
426 ASSERT_NE(axis_node, nullptr);
427 ASSERT_EQ(axis_node->NumRegularFanins(), 1);
428 VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
429
430 auto* output_transpose_node = context.graph_view->GetNode(
431 "sum-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
432 ASSERT_NE(output_transpose_node, nullptr);
433
434 auto* z_output_node = context.graph_view->GetNode("z");
435 ASSERT_NE(z_output_node, nullptr);
436 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
437 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
438 0);
439 }
440
441 template <typename T>
ReduceTransposerValidAxisNode()442 void ReduceTransposerValidAxisNode() {
443 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
444 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
445 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
446 GrapplerItem item;
447 Scope scope = Scope::NewRootScope();
448
449 auto input =
450 ops::RandomUniform(scope.WithOpName("input"),
451 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
452 auto filter =
453 ops::RandomUniform(scope.WithOpName("filter"),
454 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
455 Output conv2d = ops::Conv2D(
456 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
457 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
458
459 auto axis = ops::Const<T>(scope.WithOpName("axis"), {0, 1, 2}, {3});
460 auto sum_op = ops::Max(scope.WithOpName("max").WithDevice("/device:GPU:0"),
461 conv2d, axis);
462
463 auto z = ops::Identity(scope.WithOpName("z"), sum_op);
464 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
465
466 TransposeContext context;
467 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
468 item, virtual_cluster_.get(), &context));
469 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
470
471 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
472 auto* c2d = context.graph_view->GetNode("conv2d");
473 ASSERT_NE(c2d, nullptr);
474 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
475
476 ReduceTransposer reducer_transposer;
477 auto* max = context.graph_view->GetNode("max");
478 ASSERT_NE(max, nullptr);
479 TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, max));
480
481 auto* input_transpose_node = context.graph_view->GetNode(
482 "max-0-TransposeNHWCToNCHW-LayoutOptimizer");
483 ASSERT_NE(input_transpose_node, nullptr);
484
485 auto* updated_max_node = context.graph_view->GetNode("max");
486 ASSERT_NE(updated_max_node, nullptr);
487 ASSERT_EQ(updated_max_node->NumRegularFanins(), 2);
488 VerifyRegularFaninMatch(updated_max_node, 0,
489 input_transpose_node->GetName(), 0);
490
491 auto* axis_node = context.graph_view->GetNode(
492 "max-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
493 ASSERT_NE(axis_node, nullptr);
494 ASSERT_EQ(axis_node->NumRegularFanins(), 1);
495 VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
496
497 auto* z_output_node = context.graph_view->GetNode("z");
498 ASSERT_NE(z_output_node, nullptr);
499 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
500 VerifyRegularFaninMatch(z_output_node, 0, updated_max_node->GetName(), 0);
501 }
502
503 std::unique_ptr<Cluster> virtual_cluster_;
504 };
505
TEST_F(TransposerTest,CreateConstPermNode)506 TEST_F(TransposerTest, CreateConstPermNode) {
507 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
508 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
509 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
510 GrapplerItem item;
511 TransposeContext context;
512 TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
513 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
514 item, virtual_cluster_.get(), &context));
515 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
516
517 TransposerImpl transposer;
518 constexpr char kNodeName[] = "const_perm_node";
519 constexpr char kDevice[] = "/device:GPU:0";
520 utils::MutationNewNode added_node;
521 EXPECT_FALSE(context.graph_view->HasNode(kNodeName));
522 TF_ASSERT_OK(transposer.CreateConstPermNode(&context, kNodeName, kDevice,
523 {0, 3, 1, 2}, "", &added_node));
524 TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
525
526 utils::MutableNodeView* const_perm_node =
527 context.graph_view->GetNode(kNodeName);
528 EXPECT_EQ(const_perm_node->GetName(), kNodeName);
529 EXPECT_EQ(const_perm_node->GetDevice(), kDevice);
530 const auto* value_attr = const_perm_node->GetAttr("value");
531 ASSERT_NE(value_attr, nullptr);
532
533 Tensor tensor;
534 ASSERT_TRUE(tensor.FromProto(value_attr->tensor()));
535 Tensor expected(DT_INT32, {4});
536 ::tensorflow::test::FillValues<int32>(&expected, {0, 3, 1, 2});
537 ExpectTensorEqual<int32>(tensor, expected);
538 }
539
MakeTensorShapeFromDimensions(absl::Span<const int> dims)540 TensorShapeProto MakeTensorShapeFromDimensions(absl::Span<const int> dims) {
541 TensorShapeProto shape_proto = TensorShapeProto();
542 for (const int dim : dims) {
543 TensorShapeProto_Dim dim_proto = TensorShapeProto_Dim();
544 dim_proto.set_size(dim);
545 *shape_proto.add_dim() = std::move(dim_proto);
546 }
547 return shape_proto;
548 }
549
TEST_F(TransposerTest,CreateTransposeNode)550 TEST_F(TransposerTest, CreateTransposeNode) {
551 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
552 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
553 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
554 GrapplerItem item;
555 TransposeContext context;
556 TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
557 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
558 item, virtual_cluster_.get(), &context));
559 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
560
561 TransposerImpl transposer;
562 constexpr char kNodeNameFormat[] =
563 "transpose_node-0-$0-NWCHToNCWH-LayoutOptimizer";
564 constexpr char kDevice[] = "/device:GPU:0";
565 TensorShapeProto input_shape = MakeTensorShapeFromDimensions({1, 2, 3, 4});
566 TensorShapeProto expected_shape = MakeTensorShapeFromDimensions({1, 4, 2, 3});
567 utils::MutationNewNode added_node;
568 string transpose_node_name;
569 TF_ASSERT_OK(transposer.CreateTransposeNode(
570 &context, kNodeNameFormat, DT_DOUBLE, kDevice, input_shape, {0, 3, 1, 2},
571 "", &added_node, &transpose_node_name));
572
573 EXPECT_EQ(transpose_node_name,
574 "transpose_node-0-Transpose-NWCHToNCWH-LayoutOptimizer");
575 utils::Mutation* mutation = context.graph_view->GetMutationBuilder();
576 Status status;
577 // Placeholder node with empty name as transpose node is created with it's
578 // first input not set.
579 mutation->AddNode({}, &status);
580 TF_ASSERT_OK(status);
581 TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
582 auto* transpose_node = context.graph_view->GetNode(transpose_node_name);
583 ASSERT_NE(transpose_node, nullptr);
584 EXPECT_EQ(transpose_node->GetDevice(), kDevice);
585 const auto* output_shapes_attr = transpose_node->GetAttr("_output_shapes");
586 EXPECT_EQ(output_shapes_attr->list().shape(0).DebugString(),
587 expected_shape.DebugString());
588 }
589
TEST_F(TransposerTest,UpdateNode)590 TEST_F(TransposerTest, UpdateNode) {
591 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
592 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
593 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
594 GrapplerItem item;
595 TransposeContext context;
596 TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
597 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
598 item, virtual_cluster_.get(), &context));
599 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
600
601 DefaultLayoutSensitiveOpTransposer transposer;
602 auto* conv2d = context.graph_view->GetNode("conv2d");
603 ASSERT_NE(conv2d, nullptr);
604 TF_ASSERT_OK(transposer.UpdateNode(&context, conv2d));
605 TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
606
607 auto* updated_conv2d = context.graph_view->GetNode("conv2d");
608 ASSERT_NE(updated_conv2d, nullptr);
609 VerifyDataFormatAttributeMatch(updated_conv2d, kDstFormat);
610 }
611
MakeAttrValueListValueFromVector(absl::Span<const int> vec)612 AttrValue_ListValue MakeAttrValueListValueFromVector(
613 absl::Span<const int> vec) {
614 AttrValue_ListValue list_proto = AttrValue_ListValue();
615 for (const int i : vec) {
616 list_proto.add_i(i);
617 }
618 return list_proto;
619 }
620
TEST_F(TransposerTest,UpdateStrides)621 TEST_F(TransposerTest, UpdateStrides) {
622 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
623 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
624 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
625 GrapplerItem item;
626 TransposeContext context;
627 TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
628 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
629 item, virtual_cluster_.get(), &context));
630 context.AssignDeviceAndDataFormats(kGPU, "ABCD", "ACBD");
631
632 AttrValue_ListValue expected_original_strides =
633 MakeAttrValueListValueFromVector({1, 2, 4, 1});
634 AttrValue_ListValue expected_updated_strides =
635 MakeAttrValueListValueFromVector({1, 4, 2, 1});
636 auto* conv2d = context.graph_view->GetNode("conv2d");
637 ASSERT_NE(conv2d, nullptr);
638 const auto& strides_attr = conv2d->GetAttr("strides");
639 ASSERT_NE(strides_attr, nullptr);
640 EXPECT_EQ(strides_attr->list().DebugString(),
641 expected_original_strides.DebugString());
642 AttrValue data_format_attr;
643 data_format_attr.set_s("ABCD");
644 context.graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
645 conv2d, "data_format", data_format_attr);
646 TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
647
648 DefaultLayoutSensitiveOpTransposer transposer;
649 TF_ASSERT_OK(transposer.UpdateNode(&context, conv2d));
650 TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
651
652 auto* updated_conv2d = context.graph_view->GetNode("conv2d");
653 const auto& updated_strides_attr = updated_conv2d->GetAttr("strides");
654 ASSERT_NE(updated_strides_attr, nullptr);
655 EXPECT_EQ(updated_strides_attr->list().DebugString(),
656 expected_updated_strides.DebugString());
657 }
658
TEST_F(TransposerTest,UpdateFaninEdgesTranspose)659 TEST_F(TransposerTest, UpdateFaninEdgesTranspose) {
660 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
661 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
662 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
663 GrapplerItem item;
664 TransposeContext context;
665 TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true));
666 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
667 item, virtual_cluster_.get(), &context));
668 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
669
670 FusedBatchNormGradTransposer transposer;
671 auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
672 ASSERT_NE(fbng, nullptr);
673 const auto& fbng_output_shapes_attr = fbng->GetAttr("_output_shapes");
674 ASSERT_NE(fbng_output_shapes_attr, nullptr);
675 const TensorShapeProto& expected_shape = fbng_output_shapes_attr->shape();
676 TF_ASSERT_OK(
677 transposer.UpdateFaninEdgesWithOp(&context, {0, 1}, fbng, kOpTranspose));
678 TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
679
680 // Verify output shape matches input shape.
681 auto* transpose_node1 = context.graph_view->GetNode(
682 "fused_batch_norm_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
683 ASSERT_NE(transpose_node1, nullptr);
684 VerifyShapeAttributeMatch(transpose_node1, expected_shape.DebugString());
685 auto* transpose_node2 = context.graph_view->GetNode(
686 "fused_batch_norm_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
687 ASSERT_NE(transpose_node2, nullptr);
688 VerifyShapeAttributeMatch(transpose_node2, expected_shape.DebugString());
689
690 // Validate a const perm node is created.
691 auto* const_node1 = context.graph_view->GetNode(
692 "fused_batch_norm_grad-0-PermConstNHWCToNCHW-LayoutOptimizer");
693 ASSERT_NE(const_node1, nullptr);
694 auto* const_node2 = context.graph_view->GetNode(
695 "fused_batch_norm_grad-1-PermConstNHWCToNCHW-LayoutOptimizer");
696 ASSERT_NE(const_node2, nullptr);
697
698 // Validate nodes connected correctly.
699 auto* y_backprop = context.graph_view->GetNode("y_backprop");
700 ASSERT_NE(y_backprop, nullptr);
701 ASSERT_EQ(transpose_node1->NumRegularFanins(), 2);
702 VerifyRegularFaninMatch(transpose_node1, 0, y_backprop->GetName(), 0);
703 VerifyRegularFaninMatch(transpose_node1, 1, const_node1->GetName(), 0);
704
705 auto* x = context.graph_view->GetNode("x");
706 ASSERT_NE(x, nullptr);
707 ASSERT_EQ(transpose_node2->NumRegularFanins(), 2);
708 VerifyRegularFaninMatch(transpose_node2, 0, x->GetName(), 0);
709 VerifyRegularFaninMatch(transpose_node2, 1, const_node2->GetName(), 0);
710
711 auto* updated_fbng = context.graph_view->GetNode("fused_batch_norm_grad");
712 ASSERT_NE(updated_fbng, nullptr);
713 ASSERT_EQ(updated_fbng->NumRegularFanins(), 5);
714 VerifyRegularFaninMatch(updated_fbng, 0, transpose_node1->GetName(), 0);
715 VerifyRegularFaninMatch(updated_fbng, 1, transpose_node2->GetName(), 0);
716 }
717
TEST_F(TransposerTest,UpdateFanoutEdgesTranspose)718 TEST_F(TransposerTest, UpdateFanoutEdgesTranspose) {
719 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
720 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
721 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
722 GrapplerItem item;
723 TransposeContext context;
724 TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
725 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
726 item, virtual_cluster_.get(), &context));
727 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
728
729 TransposerImpl transposer;
730 TensorShapeProto expected_original_shape =
731 MakeTensorShapeFromDimensions({32, 5, 3, 16});
732 TensorShapeProto expected_updated_shape =
733 MakeTensorShapeFromDimensions({32, 16, 5, 3});
734
735 auto* conv2d = context.graph_view->GetNode("conv2d");
736 ASSERT_NE(conv2d, nullptr);
737 VerifyShapeAttributeMatch(conv2d, 0, expected_original_shape.DebugString());
738
739 TF_ASSERT_OK(
740 transposer.UpdateFanoutEdgesWithOp(&context, {0}, conv2d, kOpTranspose));
741 TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
742
743 auto* updated_conv2d = context.graph_view->GetNode("conv2d");
744 ASSERT_NE(updated_conv2d, nullptr);
745 VerifyShapeAttributeMatch(updated_conv2d, 0,
746 expected_updated_shape.DebugString());
747
748 // Verify output shape matches original shape.
749 auto* transpose_node = context.graph_view->GetNode(
750 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
751 ASSERT_NE(transpose_node, nullptr);
752 VerifyShapeAttributeMatch(transpose_node, 0,
753 expected_original_shape.DebugString());
754
755 // Verify a const perm node is created for transpose node.
756 auto* const_node = context.graph_view->GetNode(
757 "conv2d-0-0-PermConstNCHWToNHWC-LayoutOptimizer");
758 ASSERT_NE(const_node, nullptr);
759
760 // Verify nodes connected correctly.
761 ASSERT_EQ(transpose_node->NumRegularFanins(), 2);
762 VerifyRegularFaninMatch(transpose_node, 0, updated_conv2d->GetName(), 0);
763 VerifyRegularFaninMatch(transpose_node, 1, const_node->GetName(), 0);
764
765 auto* output = context.graph_view->GetNode("output");
766 ASSERT_NE(output, nullptr);
767 ASSERT_EQ(output->NumRegularFanins(), 1);
768 VerifyRegularFaninMatch(output, 0, transpose_node->GetName(), 0);
769 }
770
TEST_F(TransposerTest,DefaultLayoutSensitiveOpTransposerTestFusedBatchNorm)771 TEST_F(TransposerTest, DefaultLayoutSensitiveOpTransposerTestFusedBatchNorm) {
772 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
773 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
774 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
775 // Use FusedBatchNorm for default transposer test
776 GrapplerItem item;
777 TransposeContext context;
778 TF_ASSERT_OK(CreateSimpleFusedBatchNorm(&item.graph));
779 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
780 item, virtual_cluster_.get(), &context));
781 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
782
783 DefaultLayoutSensitiveOpTransposer transposer;
784 auto* bn = context.graph_view->GetNode("bn");
785 TF_ASSERT_OK(transposer.TransposeNode(&context, bn));
786
787 // The expected optimized graph contains 2 extra sets of Transpose nodes and
788 // has the FusedBatchNorm's data_format set to "NCHW".
789 auto* input_transpose_node =
790 context.graph_view->GetNode("bn-0-TransposeNHWCToNCHW-LayoutOptimizer");
791 ASSERT_NE(input_transpose_node, nullptr);
792 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
793 VerifyRegularFaninMatch(input_transpose_node, 0, "x", 0);
794
795 auto* bn_node = context.graph_view->GetNode("bn");
796 ASSERT_NE(bn_node, nullptr);
797 ASSERT_EQ(bn_node->NumRegularFanins(), 5);
798 VerifyRegularFaninMatch(bn_node, 0, input_transpose_node->GetName(), 0);
799 VerifyRegularFaninMatch(bn_node, 1, "scale", 0);
800 VerifyRegularFaninMatch(bn_node, 2, "offset", 0);
801 VerifyRegularFaninMatch(bn_node, 3, "mean", 0);
802 VerifyRegularFaninMatch(bn_node, 4, "var", 0);
803 VerifyDataFormatAttributeMatch(bn_node, kDstFormat);
804
805 auto* output_transpose_node =
806 context.graph_view->GetNode("bn-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
807 ASSERT_NE(output_transpose_node, nullptr);
808 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
809 VerifyRegularFaninMatch(output_transpose_node, 0, bn_node->GetName(), 0);
810
811 auto* output_y = context.graph_view->GetNode("output_y");
812 ASSERT_NE(output_y, nullptr);
813 ASSERT_EQ(output_y->NumRegularFanins(), 1);
814 VerifyRegularFaninMatch(output_y, 0, output_transpose_node->GetName(), 0);
815
816 auto* output_mean = context.graph_view->GetNode("output_mean");
817 ASSERT_NE(output_mean, nullptr);
818 ASSERT_EQ(output_mean->NumRegularFanins(), 1);
819 VerifyRegularFaninMatch(output_mean, 0, bn_node->GetName(), 1);
820
821 auto* output_variance = context.graph_view->GetNode("output_variance");
822 ASSERT_NE(output_variance, nullptr);
823 ASSERT_EQ(output_variance->NumRegularFanins(), 1);
824 VerifyRegularFaninMatch(output_variance, 0, bn_node->GetName(), 2);
825 }
826
TEST_F(TransposerTest,DefaultLayoutSensitiveOpTransposerTestConv2D)827 TEST_F(TransposerTest, DefaultLayoutSensitiveOpTransposerTestConv2D) {
828 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
829 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
830 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
831 // Use Conv2D for default transposer test
832 GrapplerItem item;
833 TransposeContext context;
834 TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
835 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
836 item, virtual_cluster_.get(), &context));
837 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
838
839 DefaultLayoutSensitiveOpTransposer transposer;
840 auto* conv2d = context.graph_view->GetNode("conv2d");
841 ASSERT_NE(conv2d, nullptr);
842 TF_ASSERT_OK(transposer.TransposeNode(&context, conv2d));
843
844 // The expected optimized graph contains 2 extra sets of Transpose nodes and
845 // has the Conv2D's data_format set to "NCHW".
846 auto* input_transpose_node = context.graph_view->GetNode(
847 "conv2d-0-TransposeNHWCToNCHW-LayoutOptimizer");
848 ASSERT_NE(input_transpose_node, nullptr);
849 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
850 VerifyRegularFaninMatch(input_transpose_node, 0, "input", 0);
851
852 auto* conv2d_node = context.graph_view->GetNode("conv2d");
853 ASSERT_NE(conv2d_node, nullptr);
854 ASSERT_EQ(conv2d_node->NumRegularFanins(), 2);
855 VerifyRegularFaninMatch(conv2d_node, 0, input_transpose_node->GetName(), 0);
856 VerifyRegularFaninMatch(conv2d_node, 1, "filter", 0);
857 VerifyDataFormatAttributeMatch(conv2d_node, kDstFormat);
858 const auto* strides_attr = conv2d_node->GetAttr("strides");
859 ASSERT_NE(strides_attr, nullptr);
860 ASSERT_EQ(strides_attr->list().i_size(), 4);
861 EXPECT_EQ(strides_attr->list().i(0), 1);
862 EXPECT_EQ(strides_attr->list().i(1), 1);
863 EXPECT_EQ(strides_attr->list().i(2), kStride1);
864 EXPECT_EQ(strides_attr->list().i(3), kStride2);
865
866 auto* output_transpose_node = context.graph_view->GetNode(
867 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
868 ASSERT_NE(output_transpose_node, nullptr);
869 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
870 VerifyRegularFaninMatch(output_transpose_node, 0, conv2d_node->GetName(), 0);
871
872 auto* output_node = context.graph_view->GetNode("output");
873 ASSERT_NE(output_node, nullptr);
874 ASSERT_EQ(output_node->NumRegularFanins(), 1);
875 VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
876 }
877
TEST_F(TransposerTest,MaxPoolGradTransposerTest)878 TEST_F(TransposerTest, MaxPoolGradTransposerTest) {
879 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
880 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
881 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
882 for (bool use_grad_grad : {false, true}) {
883 GrapplerItem item;
884 TransposeContext context;
885 TF_ASSERT_OK(CreateSimpleMaxPoolGrad(&item.graph, use_grad_grad));
886 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
887 item, virtual_cluster_.get(), &context));
888 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
889
890 MaxPoolGradTransposer transposer;
891 auto* maxpool_grad = context.graph_view->GetNode("maxpool_grad");
892 ASSERT_NE(maxpool_grad, nullptr);
893 TF_ASSERT_OK(transposer.TransposeNode(&context, maxpool_grad));
894
895 auto* input_transpose_node1 = context.graph_view->GetNode(
896 "maxpool_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
897 ASSERT_NE(input_transpose_node1, nullptr);
898 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
899 VerifyRegularFaninMatch(input_transpose_node1, 0, "orig_input", 0);
900
901 auto* input_transpose_node2 = context.graph_view->GetNode(
902 "maxpool_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
903 ASSERT_NE(input_transpose_node2, nullptr);
904 ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
905 VerifyRegularFaninMatch(input_transpose_node2, 0, "orig_output", 0);
906
907 auto* input_transpose_node3 = context.graph_view->GetNode(
908 "maxpool_grad-2-TransposeNHWCToNCHW-LayoutOptimizer");
909 ASSERT_NE(input_transpose_node3, nullptr);
910 ASSERT_EQ(input_transpose_node3->NumRegularFanins(), 2);
911 VerifyRegularFaninMatch(input_transpose_node3, 0, "grad", 0);
912
913 auto* updated_maxpool_grad = context.graph_view->GetNode("maxpool_grad");
914 VerifyDataFormatAttributeMatch(updated_maxpool_grad, kDstFormat);
915 ASSERT_EQ(updated_maxpool_grad->NumRegularFanins(), 3);
916 VerifyRegularFaninMatch(updated_maxpool_grad, 0,
917 input_transpose_node1->GetName(), 0);
918 VerifyRegularFaninMatch(updated_maxpool_grad, 1,
919 input_transpose_node2->GetName(), 0);
920 VerifyRegularFaninMatch(updated_maxpool_grad, 2,
921 input_transpose_node3->GetName(), 0);
922
923 auto* output_transpose_node = context.graph_view->GetNode(
924 "maxpool_grad-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
925 ASSERT_NE(output_transpose_node, nullptr);
926 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
927 VerifyRegularFaninMatch(output_transpose_node, 0,
928 updated_maxpool_grad->GetName(), 0);
929 }
930 }
931
TEST_F(TransposerTest,BiasAddGradTransposerTest)932 TEST_F(TransposerTest, BiasAddGradTransposerTest) {
933 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
934 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
935 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
936 GrapplerItem item;
937 TransposeContext context;
938 TF_ASSERT_OK(CreateSimpleBiasAddGrad(
939 &item.graph, {kBatchSize, kHeight, kWidth, kDepthIn}));
940 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
941 item, virtual_cluster_.get(), &context));
942 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
943
944 BiasAddGradTransposer transposer;
945 auto* bag = context.graph_view->GetNode("bag");
946 ASSERT_NE(bag, nullptr);
947 TF_ASSERT_OK(transposer.TransposeNode(&context, bag));
948
949 // The expected optimized graph contains 1 extra Transpose node and has the
950 // BiasAddGrad's data_format set to "NCHW".
951 auto* input_transpose_node =
952 context.graph_view->GetNode("bag-0-TransposeNHWCToNCHW-LayoutOptimizer");
953 ASSERT_NE(input_transpose_node, nullptr);
954 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
955 VerifyRegularFaninMatch(input_transpose_node, 0, "input", 0);
956
957 auto* bag_node = context.graph_view->GetNode("bag");
958 ASSERT_NE(bag_node, nullptr);
959 VerifyDataFormatAttributeMatch(bag_node, kDstFormat);
960
961 auto* output_transpose_node = context.graph_view->GetNode(
962 "bag-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
963 EXPECT_EQ(output_transpose_node, nullptr);
964
965 auto* output_node = context.graph_view->GetNode("output");
966 ASSERT_NE(output_node, nullptr);
967 ASSERT_EQ(output_node->NumRegularFanins(), 1);
968 VerifyRegularFaninMatch(output_node, 0, bag_node->GetName(), 0);
969 }
970
TEST_F(TransposerTest,BiasAddGradTransposerIncorrectInputTest)971 TEST_F(TransposerTest, BiasAddGradTransposerIncorrectInputTest) {
972 GrapplerItem item;
973 TransposeContext context;
974 TF_ASSERT_OK(
975 CreateSimpleBiasAddGrad(&item.graph, {kHeight, kWidth, kDepthIn}));
976 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
977 item, virtual_cluster_.get(), &context));
978 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
979
980 BiasAddGradTransposer transposer;
981 auto* bag = context.graph_view->GetNode("bag");
982 ASSERT_NE(bag, nullptr);
983 TF_ASSERT_OK(transposer.TransposeNode(&context, bag));
984
985 // Optimization should not occur because of incorrect input dimensions.
986 auto* input_transpose_node =
987 context.graph_view->GetNode("bag-0-TransposeNHWCToNCHW-LayoutOptimizer");
988 EXPECT_EQ(input_transpose_node, nullptr);
989
990 auto* bag_node = context.graph_view->GetNode("bag");
991 ASSERT_NE(bag_node, nullptr);
992 VerifyDataFormatAttributeMatch(bag_node, kSrcFormat);
993
994 auto* output_transpose_node = context.graph_view->GetNode(
995 "bag-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
996 EXPECT_EQ(output_transpose_node, nullptr);
997
998 auto* output_node = context.graph_view->GetNode("output");
999 ASSERT_NE(output_node, nullptr);
1000 ASSERT_EQ(output_node->NumRegularFanins(), 1);
1001 VerifyRegularFaninMatch(output_node, 0, bag_node->GetName(), 0);
1002 }
1003
TEST_F(TransposerTest,Conv2DBackpropFilterTransposerTest)1004 TEST_F(TransposerTest, Conv2DBackpropFilterTransposerTest) {
1005 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1006 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1007 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1008 GrapplerItem item;
1009 TransposeContext context;
1010 TF_ASSERT_OK(CreateSimpleConv2DBackpropFilter(&item.graph));
1011 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1012 item, virtual_cluster_.get(), &context));
1013 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1014
1015 Conv2DBackpropFilterTransposer transposer;
1016 auto* conv2d_bf = context.graph_view->GetNode("conv2d_backprop_filter");
1017 ASSERT_NE(conv2d_bf, nullptr);
1018 TF_ASSERT_OK(transposer.TransposeNode(&context, conv2d_bf));
1019
1020 // The expected optimized graph contains 2 extra sets of Transpose nodes and
1021 // has the Conv2DBackpropFilter's data_format set to "NCHW".
1022 auto* input_transpose_node1 = context.graph_view->GetNode(
1023 "conv2d_backprop_filter-0-TransposeNHWCToNCHW-LayoutOptimizer");
1024 ASSERT_NE(input_transpose_node1, nullptr);
1025 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1026 VerifyRegularFaninMatch(input_transpose_node1, 0, "input", 0);
1027
1028 auto* input_transpose_node_filter_sizes = context.graph_view->GetNode(
1029 "conv2d_backprop_filter-1-TransposeNHWCToNCHW-LayoutOptimizer");
1030 EXPECT_EQ(input_transpose_node_filter_sizes, nullptr);
1031
1032 auto* input_transpose_node2 = context.graph_view->GetNode(
1033 "conv2d_backprop_filter-2-TransposeNHWCToNCHW-LayoutOptimizer");
1034 ASSERT_NE(input_transpose_node2, nullptr);
1035 ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1036 VerifyRegularFaninMatch(input_transpose_node2, 0, "out_backprop", 0);
1037
1038 auto* conv2d_bf_node = context.graph_view->GetNode("conv2d_backprop_filter");
1039 ASSERT_NE(conv2d_bf_node, nullptr);
1040 ASSERT_EQ(conv2d_bf_node->NumRegularFanins(), 3);
1041 VerifyRegularFaninMatch(conv2d_bf_node, 0, input_transpose_node1->GetName(),
1042 0);
1043 VerifyRegularFaninMatch(conv2d_bf_node, 2, input_transpose_node2->GetName(),
1044 0);
1045 VerifyDataFormatAttributeMatch(conv2d_bf_node, kDstFormat);
1046
1047 auto* output_transpose_node = context.graph_view->GetNode(
1048 "conv2d_backprop_filter-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1049 EXPECT_EQ(output_transpose_node, nullptr);
1050
1051 auto* output_node = context.graph_view->GetNode("output");
1052 ASSERT_NE(output_node, nullptr);
1053 ASSERT_EQ(output_node->NumRegularFanins(), 1);
1054 VerifyRegularFaninMatch(output_node, 0, conv2d_bf_node->GetName(), 0);
1055 }
1056
TEST_F(TransposerTest,NodeAttributes)1057 TEST_F(TransposerTest, NodeAttributes) {
1058 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1059 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1060 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1061 GrapplerItem item;
1062 TransposeContext context;
1063 TF_ASSERT_OK(
1064 CreateSimpleConv2DBackpropFilter(&item.graph, DT_FLOAT, "EXPLICIT"));
1065 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1066 item, virtual_cluster_.get(), &context));
1067 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1068
1069 Conv2DBackpropFilterTransposer transposer;
1070 auto* conv2d_bf = context.graph_view->GetNode("conv2d_backprop_filter");
1071 ASSERT_NE(conv2d_bf, nullptr);
1072 TF_ASSERT_OK(transposer.TransposeNode(&context, conv2d_bf));
1073
1074 auto* conv2d_bf_node = context.graph_view->GetNode("conv2d_backprop_filter");
1075 ASSERT_NE(conv2d_bf_node, nullptr);
1076 ASSERT_EQ(conv2d_bf_node->NumRegularFanins(), 3);
1077 VerifyDataFormatAttributeMatch(conv2d_bf_node, kDstFormat);
1078 auto* dilations_attr = conv2d_bf_node->GetAttr("dilations");
1079 ASSERT_NE(dilations_attr, nullptr);
1080 ASSERT_EQ(dilations_attr->list().i_size(), 4);
1081 EXPECT_EQ(dilations_attr->list().i(0), 1);
1082 EXPECT_EQ(dilations_attr->list().i(1), 1);
1083 EXPECT_EQ(dilations_attr->list().i(2), kDilation);
1084 EXPECT_EQ(dilations_attr->list().i(3), kDilation);
1085 auto* explicit_paddings_attr = conv2d_bf_node->GetAttr("explicit_paddings");
1086 ASSERT_NE(explicit_paddings_attr, nullptr);
1087 ASSERT_EQ(explicit_paddings_attr->list().i_size(), 8);
1088 EXPECT_EQ(explicit_paddings_attr->list().i(0), 0);
1089 EXPECT_EQ(explicit_paddings_attr->list().i(1), 0);
1090 EXPECT_EQ(explicit_paddings_attr->list().i(2), 0);
1091 EXPECT_EQ(explicit_paddings_attr->list().i(3), 0);
1092 EXPECT_EQ(explicit_paddings_attr->list().i(4), kPaddingTop);
1093 EXPECT_EQ(explicit_paddings_attr->list().i(5), kPaddingBottom);
1094 EXPECT_EQ(explicit_paddings_attr->list().i(6), kPaddingLeft);
1095 EXPECT_EQ(explicit_paddings_attr->list().i(7), kPaddingRight);
1096 }
1097
TEST_F(TransposerTest,Conv2DBackpropInputTransposerTest)1098 TEST_F(TransposerTest, Conv2DBackpropInputTransposerTest) {
1099 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1100 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1101 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1102 GrapplerItem item;
1103 TransposeContext context;
1104 TF_ASSERT_OK(CreateSimpleConv2DBackpropInput(&item.graph));
1105 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1106 item, virtual_cluster_.get(), &context));
1107 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1108
1109 Conv2DBackpropInputTransposer transposer;
1110 auto* conv2d_i = context.graph_view->GetNode("conv2d_backprop_input");
1111 ASSERT_NE(conv2d_i, nullptr);
1112 TF_ASSERT_OK(transposer.TransposeNode(&context, conv2d_i));
1113
1114 // The expected optimized graph contains 1 extra set of Transpose nodes,
1115 // 1 DataFormatVecPermute node and has the Conv2DBackpropInput's data_format
1116 // set to "NCHW".
1117 auto* input_vec_permute_node = context.graph_view->GetNode(
1118 "conv2d_backprop_input-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
1119 ASSERT_NE(input_vec_permute_node, nullptr);
1120 ASSERT_EQ(input_vec_permute_node->NumRegularFanins(), 1);
1121 const auto* src_format_attr = input_vec_permute_node->GetAttr(kAttrSrcFormat);
1122 ASSERT_NE(src_format_attr, nullptr);
1123 EXPECT_EQ(src_format_attr->s(), kSrcFormat);
1124 const auto* dst_format_attr = input_vec_permute_node->GetAttr(kAttrDstFormat);
1125 ASSERT_NE(dst_format_attr, nullptr);
1126 EXPECT_EQ(dst_format_attr->s(), kDstFormat);
1127
1128 auto* input_transpose_node = context.graph_view->GetNode(
1129 "conv2d_backprop_input-2-TransposeNHWCToNCHW-LayoutOptimizer");
1130 ASSERT_NE(input_transpose_node, nullptr);
1131 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
1132 VerifyRegularFaninMatch(input_transpose_node, 0, "out_backprop", 0);
1133
1134 auto* conv2d_i_node = context.graph_view->GetNode("conv2d_backprop_input");
1135 ASSERT_NE(conv2d_i_node, nullptr);
1136 ASSERT_EQ(conv2d_i_node->NumRegularFanins(), 3);
1137 VerifyRegularFaninMatch(conv2d_i_node, 0, input_vec_permute_node->GetName(),
1138 0);
1139 VerifyRegularFaninMatch(conv2d_i_node, 1, "filter", 0);
1140 VerifyRegularFaninMatch(conv2d_i_node, 2, input_transpose_node->GetName(), 0);
1141 VerifyDataFormatAttributeMatch(conv2d_i_node, kDstFormat);
1142
1143 auto* output_transpose_node = context.graph_view->GetNode(
1144 "conv2d_backprop_input-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1145 ASSERT_NE(output_transpose_node, nullptr);
1146 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1147 VerifyRegularFaninMatch(output_transpose_node, 0, conv2d_i_node->GetName(),
1148 0);
1149
1150 auto* output_node = context.graph_view->GetNode("output");
1151 ASSERT_NE(output_node, nullptr);
1152 ASSERT_EQ(output_node->NumRegularFanins(), 1);
1153 VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
1154 }
1155
TEST_F(TransposerTest,FusedBatchNormGradTransposerIsTrainingTest)1156 TEST_F(TransposerTest, FusedBatchNormGradTransposerIsTrainingTest) {
1157 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1158 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1159 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1160 GrapplerItem item;
1161 TransposeContext context;
1162 TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true));
1163 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1164 item, virtual_cluster_.get(), &context));
1165 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1166
1167 FusedBatchNormGradTransposer transposer;
1168 auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
1169 ASSERT_NE(fbng, nullptr);
1170 TF_ASSERT_OK(transposer.TransposeNode(&context, fbng));
1171
1172 // The expected optimized graph contains 3 extra sets of Transpose nodes and
1173 // has the FusedBatchNormGrad's data_format set to "NCHW".
1174 auto* input_transpose_node1 = context.graph_view->GetNode(
1175 "fused_batch_norm_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
1176 ASSERT_NE(input_transpose_node1, nullptr);
1177 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1178 VerifyRegularFaninMatch(input_transpose_node1, 0, "y_backprop", 0);
1179
1180 auto* input_transpose_node2 = context.graph_view->GetNode(
1181 "fused_batch_norm_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
1182 ASSERT_NE(input_transpose_node2, nullptr);
1183 ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1184 VerifyRegularFaninMatch(input_transpose_node2, 0, "x", 0);
1185
1186 auto* fbng_node = context.graph_view->GetNode("fused_batch_norm_grad");
1187 ASSERT_NE(fbng_node, nullptr);
1188 ASSERT_EQ(fbng_node->NumRegularFanins(), 5);
1189 VerifyRegularFaninMatch(fbng_node, 0, input_transpose_node1->GetName(), 0);
1190 VerifyRegularFaninMatch(fbng_node, 1, input_transpose_node2->GetName(), 0);
1191 VerifyRegularFaninMatch(fbng_node, 2, "scale", 0);
1192 VerifyRegularFaninMatch(fbng_node, 3, "reserve_space_1", 0);
1193 VerifyRegularFaninMatch(fbng_node, 4, "reserve_space_2", 0);
1194 VerifyDataFormatAttributeMatch(fbng_node, kDstFormat);
1195
1196 auto* output_transpose_node = context.graph_view->GetNode(
1197 "fused_batch_norm_grad-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1198 ASSERT_NE(output_transpose_node, nullptr);
1199 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1200 VerifyRegularFaninMatch(output_transpose_node, 0, fbng_node->GetName(), 0);
1201
1202 auto* x_backprop = context.graph_view->GetNode("x_backprop");
1203 ASSERT_NE(x_backprop, nullptr);
1204 ASSERT_EQ(x_backprop->NumRegularFanins(), 1);
1205 VerifyRegularFaninMatch(x_backprop, 0, output_transpose_node->GetName(), 0);
1206
1207 auto* scale_backprop = context.graph_view->GetNode("scale_backprop");
1208 ASSERT_NE(scale_backprop, nullptr);
1209 ASSERT_EQ(scale_backprop->NumRegularFanins(), 1);
1210 VerifyRegularFaninMatch(scale_backprop, 0, fbng_node->GetName(), 1);
1211
1212 auto* offset_backprop = context.graph_view->GetNode("offset_backprop");
1213 ASSERT_NE(offset_backprop, nullptr);
1214 ASSERT_EQ(offset_backprop->NumRegularFanins(), 1);
1215 VerifyRegularFaninMatch(offset_backprop, 0, fbng_node->GetName(), 2);
1216
1217 auto* reserve_space_3 = context.graph_view->GetNode("reserve_space_3");
1218 ASSERT_NE(reserve_space_3, nullptr);
1219 ASSERT_EQ(reserve_space_3->NumRegularFanins(), 1);
1220 VerifyRegularFaninMatch(reserve_space_3, 0, fbng_node->GetName(), 3);
1221
1222 auto* reserve_space_4 = context.graph_view->GetNode("reserve_space_4");
1223 ASSERT_NE(reserve_space_4, nullptr);
1224 ASSERT_EQ(reserve_space_4->NumRegularFanins(), 1);
1225 VerifyRegularFaninMatch(reserve_space_4, 0, fbng_node->GetName(), 4);
1226 }
1227
TEST_F(TransposerTest,FusedBatchNormGradTransposerNotTrainingTest)1228 TEST_F(TransposerTest, FusedBatchNormGradTransposerNotTrainingTest) {
1229 GrapplerItem item;
1230 TransposeContext context;
1231 TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, false));
1232 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1233 item, virtual_cluster_.get(), &context));
1234 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1235
1236 FusedBatchNormGradTransposer transposer;
1237 auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
1238 ASSERT_NE(fbng, nullptr);
1239 TF_ASSERT_OK(transposer.TransposeNode(&context, fbng));
1240
1241 // Optimization should not occur because FusedBatchNormGrad is not set to
1242 // training.
1243 auto* input_transpose_node1 = context.graph_view->GetNode(
1244 "fused_batch_norm_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
1245 EXPECT_EQ(input_transpose_node1, nullptr);
1246
1247 auto* input_transpose_node2 = context.graph_view->GetNode(
1248 "fused_batch_norm_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
1249 EXPECT_EQ(input_transpose_node2, nullptr);
1250
1251 auto* fbng_node = context.graph_view->GetNode("fused_batch_norm_grad");
1252 ASSERT_NE(fbng_node, nullptr);
1253 ASSERT_EQ(fbng_node->NumRegularFanins(), 5);
1254 VerifyRegularFaninMatch(fbng_node, 0, "y_backprop", 0);
1255 VerifyRegularFaninMatch(fbng_node, 1, "x", 0);
1256 VerifyRegularFaninMatch(fbng_node, 2, "scale", 0);
1257 VerifyRegularFaninMatch(fbng_node, 3, "reserve_space_1", 0);
1258 VerifyRegularFaninMatch(fbng_node, 4, "reserve_space_2", 0);
1259 VerifyDataFormatAttributeMatch(fbng_node, kSrcFormat);
1260
1261 auto* output_transpose_node = context.graph_view->GetNode(
1262 "fused_batch_norm_grad-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1263 EXPECT_EQ(output_transpose_node, nullptr);
1264
1265 auto* x_backprop = context.graph_view->GetNode("x_backprop");
1266 ASSERT_NE(x_backprop, nullptr);
1267 ASSERT_EQ(x_backprop->NumRegularFanins(), 1);
1268 VerifyRegularFaninMatch(x_backprop, 0, fbng_node->GetName(), 0);
1269
1270 auto* scale_backprop = context.graph_view->GetNode("scale_backprop");
1271 ASSERT_NE(scale_backprop, nullptr);
1272 ASSERT_EQ(scale_backprop->NumRegularFanins(), 1);
1273 VerifyRegularFaninMatch(scale_backprop, 0, fbng_node->GetName(), 1);
1274
1275 auto* offset_backprop = context.graph_view->GetNode("offset_backprop");
1276 ASSERT_NE(offset_backprop, nullptr);
1277 ASSERT_EQ(offset_backprop->NumRegularFanins(), 1);
1278 VerifyRegularFaninMatch(offset_backprop, 0, fbng_node->GetName(), 2);
1279
1280 auto* reserve_space_3 = context.graph_view->GetNode("reserve_space_3");
1281 ASSERT_NE(reserve_space_3, nullptr);
1282 ASSERT_EQ(reserve_space_3->NumRegularFanins(), 1);
1283 VerifyRegularFaninMatch(reserve_space_3, 0, fbng_node->GetName(), 3);
1284
1285 auto* reserve_space_4 = context.graph_view->GetNode("reserve_space_4");
1286 ASSERT_NE(reserve_space_4, nullptr);
1287 ASSERT_EQ(reserve_space_4->NumRegularFanins(), 1);
1288 VerifyRegularFaninMatch(reserve_space_4, 0, fbng_node->GetName(), 4);
1289 }
1290
TEST_F(TransposerTest,DefaultLayoutAgnosticOpTransposerIdentityTest)1291 TEST_F(TransposerTest, DefaultLayoutAgnosticOpTransposerIdentityTest) {
1292 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1293 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1294 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1295 GrapplerItem item;
1296 Scope scope = Scope::NewRootScope();
1297 auto conv2d = SimpleConv2D(&scope);
1298 auto identity = ops::Identity(
1299 scope.WithOpName("identity").WithDevice("/device:GPU:0"), conv2d);
1300 auto output = ops::Identity(scope.WithOpName("output"), identity);
1301 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1302 TransposeContext context;
1303 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1304 item, virtual_cluster_.get(), &context));
1305 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1306
1307 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1308 auto* c2d = context.graph_view->GetNode("conv2d");
1309 ASSERT_NE(c2d, nullptr);
1310 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1311
1312 DefaultLayoutAgnosticOpTransposer transposer;
1313 auto* i = context.graph_view->GetNode("identity");
1314 ASSERT_NE(i, nullptr);
1315 TF_ASSERT_OK(transposer.TransposeNode(&context, i));
1316
1317 // The expected optimized graph contains 2 extra sets of Transpose nodes.
1318 auto* input_transpose_node = context.graph_view->GetNode(
1319 "identity-0-TransposeNHWCToNCHW-LayoutOptimizer");
1320 ASSERT_NE(input_transpose_node, nullptr);
1321 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
1322 VerifyRegularFaninMatch(input_transpose_node, 0,
1323 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1324
1325 auto* i_node = context.graph_view->GetNode("identity");
1326 ASSERT_NE(i_node, nullptr);
1327 ASSERT_EQ(i_node->NumRegularFanins(), 1);
1328 VerifyRegularFaninMatch(i_node, 0, input_transpose_node->GetName(), 0);
1329
1330 auto* output_transpose_node = context.graph_view->GetNode(
1331 "identity-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1332 ASSERT_NE(output_transpose_node, nullptr);
1333 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1334 VerifyRegularFaninMatch(output_transpose_node, 0, i_node->GetName(), 0);
1335
1336 auto* output_node = context.graph_view->GetNode("output");
1337 ASSERT_NE(output_node, nullptr);
1338 ASSERT_EQ(output_node->NumRegularFanins(), 1);
1339 VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
1340 }
1341
TEST_F(TransposerTest,DefaultLayoutAgnosticOpTransposerIdentityBadInputTest)1342 TEST_F(TransposerTest, DefaultLayoutAgnosticOpTransposerIdentityBadInputTest) {
1343 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1344 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1345 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1346 GrapplerItem item;
1347 Scope scope = Scope::NewRootScope();
1348 auto conv2d = SimpleConv2D(&scope);
1349 auto sum = ops::Sum(scope.WithOpName("sum"), conv2d, {0, 1});
1350 auto identity = ops::Identity(
1351 scope.WithOpName("identity").WithDevice("/device:GPU:0"), sum);
1352 auto output = ops::Identity(scope.WithOpName("output"), identity);
1353 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1354 TransposeContext context;
1355 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1356 item, virtual_cluster_.get(), &context));
1357 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1358
1359 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1360 auto* c2d = context.graph_view->GetNode("conv2d");
1361 ASSERT_NE(c2d, nullptr);
1362 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1363
1364 DefaultLayoutAgnosticOpTransposer transposer;
1365 auto* i = context.graph_view->GetNode("identity");
1366 ASSERT_NE(i, nullptr);
1367 TF_ASSERT_OK(transposer.TransposeNode(&context, i));
1368
1369 // Optimization should not occur because input is not the right shape (needs
1370 // to be 4D).
1371 auto* input_transpose_node = context.graph_view->GetNode(
1372 "identity-0-TransposeNHWCToNCHW-LayoutOptimizer");
1373 EXPECT_EQ(input_transpose_node, nullptr);
1374
1375 auto* i_node = context.graph_view->GetNode("identity");
1376 ASSERT_NE(i_node, nullptr);
1377 ASSERT_EQ(i_node->NumRegularFanins(), 1);
1378 VerifyRegularFaninMatch(i_node, 0, "sum", 0);
1379
1380 auto* output_transpose_node = context.graph_view->GetNode(
1381 "identity-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1382 EXPECT_EQ(output_transpose_node, nullptr);
1383
1384 auto* output_node = context.graph_view->GetNode("output");
1385 ASSERT_NE(output_node, nullptr);
1386 ASSERT_EQ(output_node->NumRegularFanins(), 1);
1387 VerifyRegularFaninMatch(output_node, 0, i_node->GetName(), 0);
1388 }
1389
TEST_F(TransposerTest,AddNTransposerTest)1390 TEST_F(TransposerTest, AddNTransposerTest) {
1391 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1392 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1393 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1394 GrapplerItem item;
1395 TF_ASSERT_OK(CreateSimpleAddN(&item.graph));
1396 TransposeContext context;
1397 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1398 item, virtual_cluster_.get(), &context));
1399 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1400
1401 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1402 auto* conv2d = context.graph_view->GetNode("conv2d");
1403 ASSERT_NE(conv2d, nullptr);
1404 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, conv2d));
1405
1406 AddNTransposer addn_transposer;
1407 auto* an = context.graph_view->GetNode("add_n");
1408 ASSERT_NE(an, nullptr);
1409 TF_ASSERT_OK(addn_transposer.TransposeNode(&context, an));
1410
1411 // The expected optimized graph contains 5 extra sets of Transpose nodes.
1412 auto* input_transpose_node1 = context.graph_view->GetNode(
1413 "add_n-0-TransposeNHWCToNCHW-LayoutOptimizer");
1414 ASSERT_NE(input_transpose_node1, nullptr);
1415 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1416 VerifyRegularFaninMatch(input_transpose_node1, 0, "a", 0);
1417
1418 auto* input_transpose_node2 = context.graph_view->GetNode(
1419 "add_n-1-TransposeNHWCToNCHW-LayoutOptimizer");
1420 ASSERT_NE(input_transpose_node2, nullptr);
1421 ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1422 VerifyRegularFaninMatch(input_transpose_node2, 0, "b", 0);
1423
1424 auto* input_transpose_node3 = context.graph_view->GetNode(
1425 "add_n-2-TransposeNHWCToNCHW-LayoutOptimizer");
1426 ASSERT_NE(input_transpose_node3, nullptr);
1427 ASSERT_EQ(input_transpose_node3->NumRegularFanins(), 2);
1428 VerifyRegularFaninMatch(input_transpose_node3, 0, "c", 0);
1429
1430 auto* input_transpose_node4 = context.graph_view->GetNode(
1431 "add_n-3-TransposeNHWCToNCHW-LayoutOptimizer");
1432 ASSERT_NE(input_transpose_node4, nullptr);
1433 ASSERT_EQ(input_transpose_node4->NumRegularFanins(), 2);
1434 VerifyRegularFaninMatch(input_transpose_node4, 0,
1435 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1436
1437 auto* an_node = context.graph_view->GetNode("add_n");
1438 ASSERT_NE(an_node, nullptr);
1439 ASSERT_EQ(an_node->NumRegularFanins(), 4);
1440 VerifyRegularFaninMatch(an_node, 0, input_transpose_node1->GetName(), 0);
1441 VerifyRegularFaninMatch(an_node, 1, input_transpose_node2->GetName(), 0);
1442 VerifyRegularFaninMatch(an_node, 2, input_transpose_node3->GetName(), 0);
1443 VerifyRegularFaninMatch(an_node, 3, input_transpose_node4->GetName(), 0);
1444
1445 auto* output_transpose_node = context.graph_view->GetNode(
1446 "add_n-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1447 ASSERT_NE(output_transpose_node, nullptr);
1448 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1449 VerifyRegularFaninMatch(output_transpose_node, 0, an_node->GetName(), 0);
1450
1451 auto* output_node = context.graph_view->GetNode("output");
1452 ASSERT_NE(output_node, nullptr);
1453 ASSERT_EQ(output_node->NumRegularFanins(), 1);
1454 VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
1455 }
1456
TEST_F(TransposerTest,AddNTransposerNotAfterTransformTest)1457 TEST_F(TransposerTest, AddNTransposerNotAfterTransformTest) {
1458 GrapplerItem item;
1459 TF_ASSERT_OK(CreateSimpleAddN(&item.graph));
1460 TransposeContext context;
1461 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1462 item, virtual_cluster_.get(), &context));
1463 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1464
1465 AddNTransposer addn_transposer;
1466 auto* an = context.graph_view->GetNode("add_n");
1467 ASSERT_NE(an, nullptr);
1468 TF_ASSERT_OK(addn_transposer.TransposeNode(&context, an));
1469
1470 // Optimization should not occur because AddN does not follow a transform.
1471 auto* input_transpose_node1 = context.graph_view->GetNode(
1472 "add_n-0-TransposeNHWCToNCHW-LayoutOptimizer");
1473 EXPECT_EQ(input_transpose_node1, nullptr);
1474
1475 auto* input_transpose_node2 = context.graph_view->GetNode(
1476 "add_n-1-TransposeNHWCToNCHW-LayoutOptimizer");
1477 EXPECT_EQ(input_transpose_node2, nullptr);
1478
1479 auto* input_transpose_node3 = context.graph_view->GetNode(
1480 "add_n-2-TransposeNHWCToNCHW-LayoutOptimizer");
1481 EXPECT_EQ(input_transpose_node3, nullptr);
1482
1483 auto* input_transpose_node4 = context.graph_view->GetNode(
1484 "add_n-3-TransposeNHWCToNCHW-LayoutOptimizer");
1485 EXPECT_EQ(input_transpose_node4, nullptr);
1486
1487 auto* an_node = context.graph_view->GetNode("add_n");
1488 ASSERT_NE(an_node, nullptr);
1489 ASSERT_EQ(an_node->NumRegularFanins(), 4);
1490 VerifyRegularFaninMatch(an_node, 0, "a", 0);
1491 VerifyRegularFaninMatch(an_node, 1, "b", 0);
1492 VerifyRegularFaninMatch(an_node, 2, "c", 0);
1493 VerifyRegularFaninMatch(an_node, 3, "conv2d", 0);
1494
1495 auto* output_transpose_node = context.graph_view->GetNode(
1496 "add_n-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1497 EXPECT_EQ(output_transpose_node, nullptr);
1498
1499 auto* output_node = context.graph_view->GetNode("output");
1500 ASSERT_NE(output_node, nullptr);
1501 ASSERT_EQ(output_node->NumRegularFanins(), 1);
1502 VerifyRegularFaninMatch(output_node, 0, an_node->GetName(), 0);
1503 }
1504
TEST_F(TransposerTest,IdentityNTransposerTest)1505 TEST_F(TransposerTest, IdentityNTransposerTest) {
1506 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1507 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1508 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1509 GrapplerItem item;
1510 TF_ASSERT_OK(CreateSimpleIdentityN(&item.graph));
1511 TransposeContext context;
1512 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1513 item, virtual_cluster_.get(), &context));
1514 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1515
1516 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1517 auto* conv2d_1 = context.graph_view->GetNode("conv2d_1");
1518 ASSERT_NE(conv2d_1, nullptr);
1519 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, conv2d_1));
1520 auto* conv2d_2 = context.graph_view->GetNode("conv2d_2");
1521 ASSERT_NE(conv2d_2, nullptr);
1522 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, conv2d_2));
1523
1524 IdentityNTransposer identityn_transposer;
1525 auto* in = context.graph_view->GetNode("identity_n");
1526 ASSERT_NE(in, nullptr);
1527 TF_ASSERT_OK(identityn_transposer.TransposeNode(&context, in));
1528
1529 // The expected optimized graph contains 4 extra sets of Transpose nodes.
1530 auto* input_transpose_node1 = context.graph_view->GetNode(
1531 "identity_n-0-TransposeNHWCToNCHW-LayoutOptimizer");
1532 EXPECT_EQ(input_transpose_node1, nullptr);
1533
1534 auto* input_transpose_node2 = context.graph_view->GetNode(
1535 "identity_n-1-TransposeNHWCToNCHW-LayoutOptimizer");
1536 ASSERT_NE(input_transpose_node2, nullptr);
1537 ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1538 VerifyRegularFaninMatch(input_transpose_node2, 0,
1539 "conv2d_2-0-0-TransposeNCHWToNHWC-LayoutOptimizer",
1540 0);
1541
1542 auto* input_transpose_node3 = context.graph_view->GetNode(
1543 "identity_n-2-TransposeNHWCToNCHW-LayoutOptimizer");
1544 EXPECT_EQ(input_transpose_node3, nullptr);
1545
1546 auto* input_transpose_node4 = context.graph_view->GetNode(
1547 "identity_n-3-TransposeNHWCToNCHW-LayoutOptimizer");
1548 EXPECT_EQ(input_transpose_node4, nullptr);
1549
1550 auto* in_node = context.graph_view->GetNode("identity_n");
1551 ASSERT_NE(in_node, nullptr);
1552 ASSERT_EQ(in_node->NumRegularFanins(), 4);
1553 VerifyRegularFaninMatch(in_node, 0, "conv2d_1", 0);
1554 VerifyRegularFaninMatch(in_node, 1, input_transpose_node2->GetName(), 0);
1555 VerifyRegularFaninMatch(in_node, 2, "a", 0);
1556 VerifyRegularFaninMatch(in_node, 3, "b", 0);
1557
1558 auto* output_transpose_node1 = context.graph_view->GetNode(
1559 "identity_n-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1560 EXPECT_EQ(output_transpose_node1, nullptr);
1561
1562 auto* output_transpose_node2 = context.graph_view->GetNode(
1563 "identity_n-1-0-TransposeNCHWToNHWC-LayoutOptimizer");
1564 ASSERT_NE(output_transpose_node2, nullptr);
1565 ASSERT_EQ(output_transpose_node2->NumRegularFanins(), 2);
1566 VerifyRegularFaninMatch(output_transpose_node2, 0, in_node->GetName(), 1);
1567
1568 auto* output_transpose_node3 = context.graph_view->GetNode(
1569 "identity_n-2-0-TransposeNCHWToNHWC-LayoutOptimizer");
1570 EXPECT_EQ(output_transpose_node3, nullptr);
1571
1572 auto* output_transpose_node4 = context.graph_view->GetNode(
1573 "identity_n-3-0-TransposeNCHWToNHWC-LayoutOptimizer");
1574 EXPECT_EQ(output_transpose_node4, nullptr);
1575
1576 auto* conv2d_1_output_node = context.graph_view->GetNode("conv2d_1_output");
1577 ASSERT_NE(conv2d_1_output_node, nullptr);
1578 ASSERT_EQ(conv2d_1_output_node->NumRegularFanins(), 1);
1579 VerifyRegularFaninMatch(conv2d_1_output_node, 0, in_node->GetName(), 0);
1580
1581 auto* conv2d_2_output_node = context.graph_view->GetNode("conv2d_2_output");
1582 ASSERT_NE(conv2d_2_output_node, nullptr);
1583 ASSERT_EQ(conv2d_2_output_node->NumRegularFanins(), 1);
1584 VerifyRegularFaninMatch(conv2d_2_output_node, 0,
1585 output_transpose_node2->GetName(), 0);
1586
1587 auto* a_output_node = context.graph_view->GetNode("a_output");
1588 ASSERT_NE(a_output_node, nullptr);
1589 ASSERT_EQ(a_output_node->NumRegularFanins(), 1);
1590 VerifyRegularFaninMatch(a_output_node, 0, in_node->GetName(), 2);
1591
1592 auto* b_output_node = context.graph_view->GetNode("b_output");
1593 ASSERT_NE(b_output_node, nullptr);
1594 ASSERT_EQ(b_output_node->NumRegularFanins(), 1);
1595 VerifyRegularFaninMatch(b_output_node, 0, in_node->GetName(), 3);
1596 }
1597
TEST_F(TransposerTest,MergeTransposerTestMergeBothInputsConvertible)1598 TEST_F(TransposerTest, MergeTransposerTestMergeBothInputsConvertible) {
1599 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1600 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1601 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1602 GrapplerItem item;
1603 Scope scope = Scope::NewRootScope();
1604 auto conv2d = SimpleConv2D(&scope);
1605 Output i1 = ops::Identity(scope.WithOpName("i1"), conv2d);
1606 auto merge = ops::Merge(scope.WithOpName("merge").WithDevice("/device:GPU:0"),
1607 {conv2d, i1});
1608 auto i2 = ops::Identity(scope.WithOpName("i2"), merge.output);
1609 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1610 TransposeContext context;
1611 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1612 item, virtual_cluster_.get(), &context));
1613 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1614
1615 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1616 auto* c2d = context.graph_view->GetNode("conv2d");
1617 ASSERT_NE(c2d, nullptr);
1618 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1619
1620 MergeTransposer merge_transposer;
1621 auto* m = context.graph_view->GetNode("merge");
1622 ASSERT_NE(m, nullptr);
1623 TF_ASSERT_OK(merge_transposer.TransposeNode(&context, m));
1624
1625 // The expected optimized graph contains 3 extra sets of Transpose nodes.
1626 auto* input_transpose_node1 = context.graph_view->GetNode(
1627 "merge-0-TransposeNHWCToNCHW-LayoutOptimizer");
1628 ASSERT_NE(input_transpose_node1, nullptr);
1629 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1630 VerifyRegularFaninMatch(input_transpose_node1, 0,
1631 "conv2d-0-1-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1632
1633 auto* input_transpose_node2 = context.graph_view->GetNode(
1634 "merge-1-TransposeNHWCToNCHW-LayoutOptimizer");
1635 ASSERT_NE(input_transpose_node2, nullptr);
1636 ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1637 VerifyRegularFaninMatch(input_transpose_node2, 0, "i1", 0);
1638
1639 auto* m_node = context.graph_view->GetNode("merge");
1640 ASSERT_NE(m_node, nullptr);
1641 ASSERT_EQ(m_node->NumRegularFanins(), 2);
1642 VerifyRegularFaninMatch(m_node, 0, input_transpose_node1->GetName(), 0);
1643 VerifyRegularFaninMatch(m_node, 1, input_transpose_node2->GetName(), 0);
1644
1645 auto* output_transpose_node = context.graph_view->GetNode(
1646 "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1647 ASSERT_NE(output_transpose_node, nullptr);
1648 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1649 VerifyRegularFaninMatch(output_transpose_node, 0, m_node->GetName(), 0);
1650
1651 auto* output_node = context.graph_view->GetNode("i2");
1652 ASSERT_NE(output_node, nullptr);
1653 ASSERT_EQ(output_node->NumRegularFanins(), 1);
1654 VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
1655 }
1656
TEST_F(TransposerTest,MergeTransposerTestMergeOneInputNotConvertible)1657 TEST_F(TransposerTest, MergeTransposerTestMergeOneInputNotConvertible) {
1658 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1659 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1660 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1661 GrapplerItem item;
1662 Scope scope = Scope::NewRootScope();
1663 auto conv2d = SimpleConv2D(&scope);
1664 auto tensor_4d =
1665 ops::Const(scope.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
1666 auto merge = ops::Merge(scope.WithOpName("merge").WithDevice("/device:GPU:0"),
1667 {conv2d, tensor_4d});
1668 auto i2 = ops::Identity(scope.WithOpName("i2"), merge.output);
1669 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1670 TransposeContext context;
1671 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1672 item, virtual_cluster_.get(), &context));
1673 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1674
1675 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1676 auto* c2d = context.graph_view->GetNode("conv2d");
1677 ASSERT_NE(c2d, nullptr);
1678 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1679
1680 MergeTransposer merge_transposer;
1681 auto* m = context.graph_view->GetNode("merge");
1682 ASSERT_NE(m, nullptr);
1683 TF_ASSERT_OK(merge_transposer.TransposeNode(&context, m));
1684
1685 // Optimization should not occur because not every input is a transform or
1686 // after transform.
1687 auto* input_transpose_node1 = context.graph_view->GetNode(
1688 "merge-0-TransposeNHWCToNCHW-LayoutOptimizer");
1689 EXPECT_EQ(input_transpose_node1, nullptr);
1690
1691 auto* input_transpose_node2 = context.graph_view->GetNode(
1692 "merge-1-TransposeNHWCToNCHW-LayoutOptimizer");
1693 EXPECT_EQ(input_transpose_node2, nullptr);
1694
1695 auto* m_node = context.graph_view->GetNode("merge");
1696 ASSERT_NE(m_node, nullptr);
1697 ASSERT_EQ(m_node->NumRegularFanins(), 2);
1698 VerifyRegularFaninMatch(m_node, 0,
1699 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1700 VerifyRegularFaninMatch(m_node, 1, "tensor_4d", 0);
1701
1702 auto* output_transpose_node = context.graph_view->GetNode(
1703 "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1704 EXPECT_EQ(output_transpose_node, nullptr);
1705
1706 auto* output_node = context.graph_view->GetNode("i2");
1707 ASSERT_NE(output_node, nullptr);
1708 ASSERT_EQ(output_node->NumRegularFanins(), 1);
1709 VerifyRegularFaninMatch(output_node, 0, m_node->GetName(), 0);
1710 }
1711
TEST_F(TransposerTest,PadTransposerTest)1712 TEST_F(TransposerTest, PadTransposerTest) {
1713 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1714 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1715 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1716 GrapplerItem item;
1717 Scope scope = Scope::NewRootScope();
1718 auto conv2d = SimpleConv2D(&scope);
1719 auto c = ops::Const(scope.WithOpName("c"), {1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
1720 auto p =
1721 ops::Pad(scope.WithOpName("p").WithDevice("/device:GPU:0"), conv2d, c);
1722 auto o = ops::Identity(scope.WithOpName("o"), p);
1723 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1724 TransposeContext context;
1725 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1726 item, virtual_cluster_.get(), &context));
1727 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1728
1729 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1730 auto* c2d = context.graph_view->GetNode("conv2d");
1731 ASSERT_NE(c2d, nullptr);
1732 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1733
1734 PadTransposer pad_transposer;
1735 auto* pad = context.graph_view->GetNode("p");
1736 ASSERT_NE(pad, nullptr);
1737 TF_ASSERT_OK(pad_transposer.TransposeNode(&context, pad));
1738
1739 // The expected optimized graph contains 2 extra sets of Transpose nodes and 1
1740 // DataFormatVecPermute node.
1741 auto* input_transpose_node =
1742 context.graph_view->GetNode("p-0-TransposeNHWCToNCHW-LayoutOptimizer");
1743 ASSERT_NE(input_transpose_node, nullptr);
1744 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
1745 VerifyRegularFaninMatch(input_transpose_node, 0,
1746 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1747
1748 auto* padding_transpose_node = context.graph_view->GetNode(
1749 "p-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
1750 ASSERT_NE(padding_transpose_node, nullptr);
1751 ASSERT_EQ(padding_transpose_node->NumRegularFanins(), 1);
1752 VerifyRegularFaninMatch(padding_transpose_node, 0, "c", 0);
1753
1754 auto* pad_node = context.graph_view->GetNode("p");
1755 ASSERT_NE(pad_node, nullptr);
1756 ASSERT_EQ(pad_node->NumRegularFanins(), 2);
1757 VerifyRegularFaninMatch(pad_node, 0, input_transpose_node->GetName(), 0);
1758 VerifyRegularFaninMatch(pad_node, 1, padding_transpose_node->GetName(), 0);
1759
1760 auto* output_transpose_node =
1761 context.graph_view->GetNode("p-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1762 ASSERT_NE(output_transpose_node, nullptr);
1763 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1764 VerifyRegularFaninMatch(output_transpose_node, 0, pad_node->GetName(), 0);
1765
1766 auto* output_node = context.graph_view->GetNode("o");
1767 ASSERT_NE(output_node, nullptr);
1768 ASSERT_EQ(output_node->NumRegularFanins(), 1);
1769 VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0);
1770 }
1771
TEST_F(TransposerTest,SwitchTransposerTest)1772 TEST_F(TransposerTest, SwitchTransposerTest) {
1773 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1774 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1775 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1776 GrapplerItem item;
1777 Scope scope = Scope::NewRootScope();
1778 auto conv2d = SimpleConv2D(&scope);
1779 ops::Variable ctrl(scope.WithOpName("ctrl"), {}, DT_BOOL);
1780 auto sw = ops::Switch(scope.WithOpName("switch").WithDevice("/device:GPU:0"),
1781 conv2d, ctrl);
1782 auto i1 = ops::Identity(scope.WithOpName("i1"), sw.output_false);
1783 auto i2 = ops::Identity(scope.WithOpName("i2"), sw.output_true);
1784 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1785 TransposeContext context;
1786 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1787 item, virtual_cluster_.get(), &context));
1788 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1789
1790 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1791 auto* c2d = context.graph_view->GetNode("conv2d");
1792 ASSERT_NE(c2d, nullptr);
1793 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1794
1795 SwitchTransposer switch_transposer;
1796 auto* sw_node = context.graph_view->GetNode("switch");
1797 ASSERT_NE(sw_node, nullptr);
1798 TF_ASSERT_OK(switch_transposer.TransposeNode(&context, sw_node));
1799
1800 // The expected optimized graph contains 3 extra sets of Transpose nodes.
1801 auto* input_transpose_node = context.graph_view->GetNode(
1802 "switch-0-TransposeNHWCToNCHW-LayoutOptimizer");
1803 ASSERT_NE(input_transpose_node, nullptr);
1804 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
1805 VerifyRegularFaninMatch(input_transpose_node, 0,
1806 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1807
1808 auto* switch_node = context.graph_view->GetNode("switch");
1809 ASSERT_NE(switch_node, nullptr);
1810 ASSERT_EQ(switch_node->NumRegularFanins(), 2);
1811 VerifyRegularFaninMatch(switch_node, 0, input_transpose_node->GetName(), 0);
1812 VerifyRegularFaninMatch(switch_node, 1, "ctrl", 0);
1813
1814 auto* output_transpose_node1 = context.graph_view->GetNode(
1815 "switch-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1816 ASSERT_NE(output_transpose_node1, nullptr);
1817 ASSERT_EQ(output_transpose_node1->NumRegularFanins(), 2);
1818 VerifyRegularFaninMatch(output_transpose_node1, 0, switch_node->GetName(), 0);
1819
1820 auto* output_transpose_node2 = context.graph_view->GetNode(
1821 "switch-1-0-TransposeNCHWToNHWC-LayoutOptimizer");
1822 ASSERT_NE(output_transpose_node2, nullptr);
1823 ASSERT_EQ(output_transpose_node2->NumRegularFanins(), 2);
1824 VerifyRegularFaninMatch(output_transpose_node2, 0, switch_node->GetName(), 1);
1825
1826 auto* i1_node = context.graph_view->GetNode("i1");
1827 ASSERT_NE(i1_node, nullptr);
1828 ASSERT_EQ(i1_node->NumRegularFanins(), 1);
1829 VerifyRegularFaninMatch(i1_node, 0, output_transpose_node1->GetName(), 0);
1830
1831 auto* i2_node = context.graph_view->GetNode("i2");
1832 ASSERT_NE(i2_node, nullptr);
1833 ASSERT_EQ(i2_node->NumRegularFanins(), 1);
1834 VerifyRegularFaninMatch(i2_node, 0, output_transpose_node2->GetName(), 0);
1835 }
1836
TEST_F(TransposerTest,TernaryOpTransposerTest)1837 TEST_F(TransposerTest, TernaryOpTransposerTest) {
1838 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1839 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1840 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1841 GrapplerItem item;
1842 Scope scope = Scope::NewRootScope();
1843 auto conv2d = SimpleConv2D(&scope);
1844 auto a = ops::RandomUniform(scope.WithOpName("a"),
1845 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
1846 auto b = ops::RandomUniform(scope.WithOpName("b"),
1847 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
1848 auto beta_inc = ops::Betainc(
1849 scope.WithOpName("beta_inc").WithDevice("/device:GPU:0"), a, b, conv2d);
1850 auto z = ops::Identity(scope.WithOpName("z"), beta_inc);
1851 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1852 TransposeContext context;
1853 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1854 item, virtual_cluster_.get(), &context));
1855 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1856
1857 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1858 auto* c2d = context.graph_view->GetNode("conv2d");
1859 ASSERT_NE(c2d, nullptr);
1860 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1861
1862 TernaryOpTransposer ternary_op_transposer;
1863 auto* bi = context.graph_view->GetNode("beta_inc");
1864 ASSERT_NE(bi, nullptr);
1865 TF_ASSERT_OK(ternary_op_transposer.TransposeNode(&context, bi));
1866
1867 // The expected optimized graph contains 4 extra sets of Transpose nodes.
1868 auto* input_transpose_node1 = context.graph_view->GetNode(
1869 "beta_inc-0-TransposeNHWCToNCHW-LayoutOptimizer");
1870 ASSERT_NE(input_transpose_node1, nullptr);
1871 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1872 VerifyRegularFaninMatch(input_transpose_node1, 0, "a", 0);
1873
1874 auto* input_transpose_node2 = context.graph_view->GetNode(
1875 "beta_inc-1-TransposeNHWCToNCHW-LayoutOptimizer");
1876 ASSERT_NE(input_transpose_node2, nullptr);
1877 ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1878 VerifyRegularFaninMatch(input_transpose_node2, 0, "b", 0);
1879
1880 auto* input_transpose_node3 = context.graph_view->GetNode(
1881 "beta_inc-2-TransposeNHWCToNCHW-LayoutOptimizer");
1882 ASSERT_NE(input_transpose_node3, nullptr);
1883 ASSERT_EQ(input_transpose_node3->NumRegularFanins(), 2);
1884 VerifyRegularFaninMatch(input_transpose_node3, 0,
1885 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1886
1887 auto* bi_node = context.graph_view->GetNode("beta_inc");
1888 ASSERT_NE(bi_node, nullptr);
1889 ASSERT_EQ(bi_node->NumRegularFanins(), 3);
1890 VerifyRegularFaninMatch(bi_node, 0, input_transpose_node1->GetName(), 0);
1891 VerifyRegularFaninMatch(bi_node, 1, input_transpose_node2->GetName(), 0);
1892 VerifyRegularFaninMatch(bi_node, 2, input_transpose_node3->GetName(), 0);
1893
1894 auto* output_transpose_node = context.graph_view->GetNode(
1895 "beta_inc-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1896 ASSERT_NE(output_transpose_node, nullptr);
1897 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1898 VerifyRegularFaninMatch(output_transpose_node, 0, bi_node->GetName(), 0);
1899
1900 auto* z_output_node = context.graph_view->GetNode("z");
1901 ASSERT_NE(z_output_node, nullptr);
1902 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
1903 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
1904 0);
1905 }
1906
TEST_F(TransposerTest,UnaryGradTransposerTestTanhGrad)1907 TEST_F(TransposerTest, UnaryGradTransposerTestTanhGrad) {
1908 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1909 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1910 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1911 GrapplerItem item;
1912 Scope scope = Scope::NewRootScope();
1913 auto conv2d = SimpleConv2D(&scope);
1914 auto a = ops::RandomUniform(scope.WithOpName("a"),
1915 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
1916 auto tanh_grad_op = ops::internal::TanhGrad(
1917 scope.WithOpName("tanh_grad").WithDevice("/device:GPU:0"), conv2d, a);
1918 auto z = ops::Identity(scope.WithOpName("z"), tanh_grad_op);
1919 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1920 TransposeContext context;
1921 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1922 item, virtual_cluster_.get(), &context));
1923 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1924
1925 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1926 auto* c2d = context.graph_view->GetNode("conv2d");
1927 ASSERT_NE(c2d, nullptr);
1928 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1929
1930 UnaryGradTransposer unary_grad_transposer;
1931 auto* tanh_grad = context.graph_view->GetNode("tanh_grad");
1932 ASSERT_NE(tanh_grad, nullptr);
1933 TF_ASSERT_OK(unary_grad_transposer.TransposeNode(&context, tanh_grad));
1934
1935 // The expected optimized graph contains 4 extra sets of Transpose nodes.
1936 auto* input_transpose_node1 = context.graph_view->GetNode(
1937 "tanh_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
1938 ASSERT_NE(input_transpose_node1, nullptr);
1939 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
1940 VerifyRegularFaninMatch(input_transpose_node1, 0,
1941 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
1942
1943 auto* input_transpose_node2 = context.graph_view->GetNode(
1944 "tanh_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
1945 ASSERT_NE(input_transpose_node2, nullptr);
1946 ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
1947 VerifyRegularFaninMatch(input_transpose_node2, 0, "a", 0);
1948
1949 auto* tanh_grad_node = context.graph_view->GetNode("tanh_grad");
1950 ASSERT_NE(tanh_grad_node, nullptr);
1951 ASSERT_EQ(tanh_grad_node->NumRegularFanins(), 2);
1952 VerifyRegularFaninMatch(tanh_grad_node, 0, input_transpose_node1->GetName(),
1953 0);
1954 VerifyRegularFaninMatch(tanh_grad_node, 1, input_transpose_node2->GetName(),
1955 0);
1956
1957 auto* output_transpose_node = context.graph_view->GetNode(
1958 "tanh_grad-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1959 ASSERT_NE(output_transpose_node, nullptr);
1960 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
1961 VerifyRegularFaninMatch(output_transpose_node, 0, tanh_grad_node->GetName(),
1962 0);
1963
1964 auto* z_output_node = context.graph_view->GetNode("z");
1965 ASSERT_NE(z_output_node, nullptr);
1966 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
1967 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
1968 0);
1969 }
1970
TEST_F(TransposerTest,UnaryGradTransposerTestRelu6Grad)1971 TEST_F(TransposerTest, UnaryGradTransposerTestRelu6Grad) {
1972 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1973 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
1974 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1975 GrapplerItem item;
1976 Scope scope = Scope::NewRootScope();
1977 auto conv2d = SimpleConv2D(&scope);
1978 auto a = ops::RandomUniform(scope.WithOpName("a"),
1979 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
1980 auto relu6_grad_op = ops::internal::SigmoidGrad(
1981 scope.WithOpName("relu6_grad").WithDevice("/device:GPU:0"), conv2d, a);
1982 auto z = ops::Identity(scope.WithOpName("z"), relu6_grad_op);
1983 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
1984 TransposeContext context;
1985 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
1986 item, virtual_cluster_.get(), &context));
1987 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
1988
1989 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
1990 auto* c2d = context.graph_view->GetNode("conv2d");
1991 ASSERT_NE(c2d, nullptr);
1992 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
1993
1994 UnaryGradTransposer unary_grad_transposer;
1995 auto* relu6_grad = context.graph_view->GetNode("relu6_grad");
1996 ASSERT_NE(relu6_grad, nullptr);
1997 TF_ASSERT_OK(unary_grad_transposer.TransposeNode(&context, relu6_grad));
1998
1999 // The expected optimized graph contains 4 extra sets of Transpose nodes.
2000 auto* input_transpose_node1 = context.graph_view->GetNode(
2001 "relu6_grad-0-TransposeNHWCToNCHW-LayoutOptimizer");
2002 ASSERT_NE(input_transpose_node1, nullptr);
2003 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2004 VerifyRegularFaninMatch(input_transpose_node1, 0,
2005 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2006
2007 auto* input_transpose_node2 = context.graph_view->GetNode(
2008 "relu6_grad-1-TransposeNHWCToNCHW-LayoutOptimizer");
2009 ASSERT_NE(input_transpose_node2, nullptr);
2010 ASSERT_EQ(input_transpose_node2->NumRegularFanins(), 2);
2011 VerifyRegularFaninMatch(input_transpose_node2, 0, "a", 0);
2012
2013 auto* relu6_grad_node = context.graph_view->GetNode("relu6_grad");
2014 ASSERT_NE(relu6_grad_node, nullptr);
2015 ASSERT_EQ(relu6_grad_node->NumRegularFanins(), 2);
2016 VerifyRegularFaninMatch(relu6_grad_node, 0, input_transpose_node1->GetName(),
2017 0);
2018 VerifyRegularFaninMatch(relu6_grad_node, 1, input_transpose_node2->GetName(),
2019 0);
2020
2021 auto* output_transpose_node = context.graph_view->GetNode(
2022 "relu6_grad-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2023 ASSERT_NE(output_transpose_node, nullptr);
2024 ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2);
2025 VerifyRegularFaninMatch(output_transpose_node, 0, relu6_grad_node->GetName(),
2026 0);
2027
2028 auto* z_output_node = context.graph_view->GetNode("z");
2029 ASSERT_NE(z_output_node, nullptr);
2030 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2031 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2032 0);
2033 }
2034
TEST_F(TransposerTest,SqueezeTransposerTest)2035 TEST_F(TransposerTest, SqueezeTransposerTest) {
2036 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2037 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2038 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2039 GrapplerItem item;
2040 Scope scope = Scope::NewRootScope();
2041 auto input =
2042 ops::RandomUniform(scope.WithOpName("input"), {32, 1, 1, 8}, DT_FLOAT);
2043 auto filter =
2044 ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 16}, DT_FLOAT);
2045 auto conv2d = ops::Conv2D(
2046 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2047 {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2048
2049 auto squeeze_op = ops::Squeeze(
2050 scope.WithOpName("squeeze").WithDevice("/device:GPU:0"), conv2d);
2051 auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2052 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2053 TransposeContext context;
2054 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2055 item, virtual_cluster_.get(), &context));
2056 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2057
2058 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2059 auto* c2d = context.graph_view->GetNode("conv2d");
2060 ASSERT_NE(c2d, nullptr);
2061 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2062
2063 SqueezeTransposer squeeze_transposer;
2064 auto* squeeze = context.graph_view->GetNode("squeeze");
2065 ASSERT_NE(squeeze, nullptr);
2066 TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2067
2068 auto* input_transpose_node1 = context.graph_view->GetNode(
2069 "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2070 ASSERT_NE(input_transpose_node1, nullptr);
2071 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2072 VerifyRegularFaninMatch(input_transpose_node1, 0,
2073 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2074
2075 auto* squeeze_node = context.graph_view->GetNode("squeeze");
2076 ASSERT_NE(squeeze_node, nullptr);
2077 ASSERT_EQ(squeeze_node->NumRegularFanins(), 1);
2078 VerifyRegularFaninMatch(squeeze_node, 0, input_transpose_node1->GetName(), 0);
2079
2080 auto* output_transpose_node = context.graph_view->GetNode(
2081 "squeeze-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2082 EXPECT_EQ(output_transpose_node, nullptr);
2083
2084 auto* z_output_node = context.graph_view->GetNode("z");
2085 ASSERT_NE(z_output_node, nullptr);
2086 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2087 VerifyRegularFaninMatch(z_output_node, 0, squeeze_node->GetName(), 0);
2088 }
2089
TEST_F(TransposerTest,SqueezeTransposerTestUnsupportedInputShape)2090 TEST_F(TransposerTest, SqueezeTransposerTestUnsupportedInputShape) {
2091 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2092 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2093 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2094 GrapplerItem item;
2095 Scope scope = Scope::NewRootScope();
2096 auto input =
2097 ops::RandomUniform(scope.WithOpName("input"), {32, 5, 5, 8}, DT_FLOAT);
2098 auto filter =
2099 ops::RandomUniform(scope.WithOpName("filter"), {5, 5, 8, 16}, DT_FLOAT);
2100 auto conv2d = ops::Conv2D(
2101 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2102 {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2103
2104 auto squeeze_op = ops::Squeeze(
2105 scope.WithOpName("squeeze").WithDevice("/device:GPU:0"), conv2d);
2106 auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2107 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2108 TransposeContext context;
2109 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2110 item, virtual_cluster_.get(), &context));
2111 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2112
2113 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2114 auto* c2d = context.graph_view->GetNode("conv2d");
2115 ASSERT_NE(c2d, nullptr);
2116 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2117
2118 SqueezeTransposer squeeze_transposer;
2119 auto* squeeze = context.graph_view->GetNode("squeeze");
2120 ASSERT_NE(squeeze, nullptr);
2121 TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2122
2123 // Expect no changes to the input edge.
2124 auto* input_transpose_node1 = context.graph_view->GetNode(
2125 "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2126 EXPECT_EQ(input_transpose_node1, nullptr);
2127 }
2128
TEST_F(TransposerTest,SqueezeTransposerTestInvalidHWAxis)2129 TEST_F(TransposerTest, SqueezeTransposerTestInvalidHWAxis) {
2130 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2131 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2132 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2133 GrapplerItem item;
2134 Scope scope = Scope::NewRootScope();
2135 auto input =
2136 ops::RandomUniform(scope.WithOpName("input"), {32, 1, 1, 8}, DT_FLOAT);
2137 auto filter =
2138 ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 16}, DT_FLOAT);
2139 auto conv2d = ops::Conv2D(
2140 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2141 {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2142
2143 auto squeeze_op =
2144 ops::Squeeze(scope.WithOpName("squeeze").WithDevice("/device:GPU:0"),
2145 conv2d, ops::Squeeze::Attrs().Axis({1}));
2146 auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2147 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2148 TransposeContext context;
2149 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2150 item, virtual_cluster_.get(), &context));
2151 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2152
2153 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2154 auto* c2d = context.graph_view->GetNode("conv2d");
2155 ASSERT_NE(c2d, nullptr);
2156 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2157
2158 SqueezeTransposer squeeze_transposer;
2159 auto* squeeze = context.graph_view->GetNode("squeeze");
2160 ASSERT_NE(squeeze, nullptr);
2161 TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2162
2163 // Expect no changes to the input edge.
2164 auto* input_transpose_node1 = context.graph_view->GetNode(
2165 "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2166 EXPECT_EQ(input_transpose_node1, nullptr);
2167 }
2168
TEST_F(TransposerTest,SqueezeTransposerTestInvalidNHWAxis)2169 TEST_F(TransposerTest, SqueezeTransposerTestInvalidNHWAxis) {
2170 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2171 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2172 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2173 GrapplerItem item;
2174 Scope scope = Scope::NewRootScope();
2175 auto input =
2176 ops::RandomUniform(scope.WithOpName("input"), {32, 1, 1, 8}, DT_FLOAT);
2177 auto filter =
2178 ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 1}, DT_FLOAT);
2179 auto conv2d = ops::Conv2D(
2180 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2181 {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2182
2183 auto squeeze_op =
2184 ops::Squeeze(scope.WithOpName("squeeze").WithDevice("/device:GPU:0"),
2185 conv2d, ops::Squeeze::Attrs().Axis({1, 2, 3}));
2186 auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2187 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2188 TransposeContext context;
2189 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2190 item, virtual_cluster_.get(), &context));
2191 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2192
2193 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2194 auto* c2d = context.graph_view->GetNode("conv2d");
2195 ASSERT_NE(c2d, nullptr);
2196 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2197
2198 SqueezeTransposer squeeze_transposer;
2199 auto* squeeze = context.graph_view->GetNode("squeeze");
2200 ASSERT_NE(squeeze, nullptr);
2201 TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2202
2203 // Expect no changes to the input edge.
2204 auto* input_transpose_node1 = context.graph_view->GetNode(
2205 "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2206 EXPECT_EQ(input_transpose_node1, nullptr);
2207 }
2208
TEST_F(TransposerTest,SqueezeTransposerTestSqueezeDimsUpdated)2209 TEST_F(TransposerTest, SqueezeTransposerTestSqueezeDimsUpdated) {
2210 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2211 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2212 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2213 GrapplerItem item;
2214 Scope scope = Scope::NewRootScope();
2215 auto input =
2216 ops::RandomUniform(scope.WithOpName("input"), {1, 1, 1, 8}, DT_FLOAT);
2217 auto filter =
2218 ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 1}, DT_FLOAT);
2219 auto conv2d = ops::Conv2D(
2220 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2221 {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2222
2223 auto squeeze_op =
2224 ops::Squeeze(scope.WithOpName("squeeze").WithDevice("/device:GPU:0"),
2225 conv2d, ops::Squeeze::Attrs().Axis({1, 2}));
2226 auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2227 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2228 TransposeContext context;
2229 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2230 item, virtual_cluster_.get(), &context));
2231 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2232
2233 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2234 auto* c2d = context.graph_view->GetNode("conv2d");
2235 ASSERT_NE(c2d, nullptr);
2236 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2237
2238 SqueezeTransposer squeeze_transposer;
2239 auto* squeeze = context.graph_view->GetNode("squeeze");
2240 ASSERT_NE(squeeze, nullptr);
2241 TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2242
2243 auto* input_transpose_node1 = context.graph_view->GetNode(
2244 "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2245 ASSERT_NE(input_transpose_node1, nullptr);
2246 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2247 VerifyRegularFaninMatch(input_transpose_node1, 0,
2248 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2249
2250 auto* squeeze_node = context.graph_view->GetNode("squeeze");
2251 ASSERT_NE(squeeze_node, nullptr);
2252 ASSERT_EQ(squeeze_node->NumRegularFanins(), 1);
2253 VerifyRegularFaninMatch(squeeze_node, 0, input_transpose_node1->GetName(), 0);
2254 const auto* squeeze_dims_attr = squeeze_node->GetAttr("squeeze_dims");
2255 const auto& list = squeeze_dims_attr->list();
2256 ASSERT_EQ(list.i_size(), 2);
2257 EXPECT_EQ(list.i(0), 2);
2258 EXPECT_EQ(list.i(1), 3);
2259
2260 auto* output_transpose_node = context.graph_view->GetNode(
2261 "squeeze-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2262 EXPECT_EQ(output_transpose_node, nullptr);
2263
2264 auto* z_output_node = context.graph_view->GetNode("z");
2265 ASSERT_NE(z_output_node, nullptr);
2266 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2267 VerifyRegularFaninMatch(z_output_node, 0, squeeze_node->GetName(), 0);
2268 }
2269
2270 // Same as SqueezeTransposerTestSqueezeDimsUpdated but with squeeze dims
2271 // specified with negative values.
TEST_F(TransposerTest,SqueezeTransposerTestNegativeSqueezeDimsUpdated)2272 TEST_F(TransposerTest, SqueezeTransposerTestNegativeSqueezeDimsUpdated) {
2273 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2274 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2275 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2276 GrapplerItem item;
2277 Scope scope = Scope::NewRootScope();
2278 auto input =
2279 ops::RandomUniform(scope.WithOpName("input"), {1, 1, 1, 8}, DT_FLOAT);
2280 auto filter =
2281 ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 1}, DT_FLOAT);
2282 auto conv2d = ops::Conv2D(
2283 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2284 {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2285
2286 auto squeeze_op =
2287 ops::Squeeze(scope.WithOpName("squeeze").WithDevice("/device:GPU:0"),
2288 conv2d, ops::Squeeze::Attrs().Axis({-3, -2}));
2289 auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2290 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2291 TransposeContext context;
2292 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2293 item, virtual_cluster_.get(), &context));
2294 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2295
2296 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2297 auto* c2d = context.graph_view->GetNode("conv2d");
2298 ASSERT_NE(c2d, nullptr);
2299 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2300
2301 SqueezeTransposer squeeze_transposer;
2302 auto* squeeze = context.graph_view->GetNode("squeeze");
2303 ASSERT_NE(squeeze, nullptr);
2304 TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2305
2306 auto* input_transpose_node1 = context.graph_view->GetNode(
2307 "squeeze-0-TransposeNHWCToNCHW-LayoutOptimizer");
2308 ASSERT_NE(input_transpose_node1, nullptr);
2309 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2310 VerifyRegularFaninMatch(input_transpose_node1, 0,
2311 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2312
2313 auto* squeeze_node = context.graph_view->GetNode("squeeze");
2314 ASSERT_NE(squeeze_node, nullptr);
2315 ASSERT_EQ(squeeze_node->NumRegularFanins(), 1);
2316 VerifyRegularFaninMatch(squeeze_node, 0, input_transpose_node1->GetName(), 0);
2317 const auto* squeeze_dims_attr = squeeze_node->GetAttr("squeeze_dims");
2318 const auto& list = squeeze_dims_attr->list();
2319 ASSERT_EQ(list.i_size(), 2);
2320 EXPECT_EQ(list.i(0), 2);
2321 EXPECT_EQ(list.i(1), 3);
2322
2323 auto* output_transpose_node = context.graph_view->GetNode(
2324 "squeeze-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2325 EXPECT_EQ(output_transpose_node, nullptr);
2326
2327 auto* z_output_node = context.graph_view->GetNode("z");
2328 ASSERT_NE(z_output_node, nullptr);
2329 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2330 VerifyRegularFaninMatch(z_output_node, 0, squeeze_node->GetName(), 0);
2331 }
2332
2333 // Same as SqueezeTransposerTestSqueezeDimsUpdated but with the source and
2334 // destination formats swapped (as is used in some cases when the data type is
2335 // DT_HALF).
TEST_F(TransposerTest,SqueezeTransposerTestNCHWToNHWCSqueezeDimsUpdated)2336 TEST_F(TransposerTest, SqueezeTransposerTestNCHWToNHWCSqueezeDimsUpdated) {
2337 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2338 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2339 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2340 GrapplerItem item;
2341 Scope scope = Scope::NewRootScope();
2342 auto input =
2343 ops::RandomUniform(scope.WithOpName("input"), {1, 8, 1, 1}, DT_FLOAT);
2344 auto filter =
2345 ops::RandomUniform(scope.WithOpName("filter"), {1, 1, 8, 1}, DT_FLOAT);
2346 auto conv2d = ops::Conv2D(
2347 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2348 {1, 1, 1, 1}, "SAME", ops::Conv2D::DataFormat(kDstFormat));
2349
2350 auto squeeze_op =
2351 ops::Squeeze(scope.WithOpName("squeeze").WithDevice("/device:GPU:0"),
2352 conv2d, ops::Squeeze::Attrs().Axis({2, 3}));
2353 auto z = ops::Identity(scope.WithOpName("z"), squeeze_op);
2354 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2355 TransposeContext context;
2356 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2357 item, virtual_cluster_.get(), &context));
2358 context.AssignDeviceAndDataFormats(kGPU, kDstFormat, kSrcFormat);
2359
2360 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2361 auto* c2d = context.graph_view->GetNode("conv2d");
2362 ASSERT_NE(c2d, nullptr);
2363 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2364
2365 SqueezeTransposer squeeze_transposer;
2366 auto* squeeze = context.graph_view->GetNode("squeeze");
2367 ASSERT_NE(squeeze, nullptr);
2368 TF_ASSERT_OK(squeeze_transposer.TransposeNode(&context, squeeze));
2369
2370 auto* input_transpose_node1 = context.graph_view->GetNode(
2371 "squeeze-0-TransposeNCHWToNHWC-LayoutOptimizer");
2372 ASSERT_NE(input_transpose_node1, nullptr);
2373 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2374 VerifyRegularFaninMatch(input_transpose_node1, 0,
2375 "conv2d-0-0-TransposeNHWCToNCHW-LayoutOptimizer", 0);
2376
2377 auto* squeeze_node = context.graph_view->GetNode("squeeze");
2378 ASSERT_NE(squeeze_node, nullptr);
2379 ASSERT_EQ(squeeze_node->NumRegularFanins(), 1);
2380 VerifyRegularFaninMatch(squeeze_node, 0, input_transpose_node1->GetName(), 0);
2381 const auto* squeeze_dims_attr = squeeze_node->GetAttr("squeeze_dims");
2382 const auto& list = squeeze_dims_attr->list();
2383 ASSERT_EQ(list.i_size(), 2);
2384 EXPECT_EQ(list.i(0), 1);
2385 EXPECT_EQ(list.i(1), 2);
2386
2387 auto* output_transpose_node = context.graph_view->GetNode(
2388 "squeeze-0-0-TransposeNHWCToNCHW-LayoutOptimizer");
2389 EXPECT_EQ(output_transpose_node, nullptr);
2390
2391 auto* z_output_node = context.graph_view->GetNode("z");
2392 ASSERT_NE(z_output_node, nullptr);
2393 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2394 VerifyRegularFaninMatch(z_output_node, 0, squeeze_node->GetName(), 0);
2395 }
2396
TEST_F(TransposerTest,MaxPoolV2Transposer)2397 TEST_F(TransposerTest, MaxPoolV2Transposer) {
2398 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2399 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2400 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2401 GrapplerItem item;
2402 Scope scope = Scope::NewRootScope();
2403 auto input =
2404 ops::RandomUniform(scope.WithOpName("input"),
2405 {kBatchSize, kWidth, kHeight, kDepthIn}, DT_FLOAT);
2406 auto ksize = ops::Const(scope.WithOpName("ksize"), {1, kKernel, kKernel, 1});
2407 auto strides =
2408 ops::Const(scope.WithOpName("strides"), {1, kKernel, kKernel, 1});
2409 auto maxpool_op =
2410 ops::MaxPoolV2(scope.WithOpName("maxpoolv2").WithDevice("/device:GPU:0"),
2411 input, ksize, strides, "VALID");
2412 auto z = ops::Identity(scope.WithOpName("z"), maxpool_op);
2413 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2414 TransposeContext context;
2415 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2416 item, virtual_cluster_.get(), &context));
2417 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2418
2419 MaxPoolV2Transposer maxpool_transposer;
2420 auto* maxpool = context.graph_view->GetNode("maxpoolv2");
2421 ASSERT_NE(maxpool, nullptr);
2422 TF_ASSERT_OK(maxpool_transposer.TransposeNode(&context, maxpool));
2423
2424 auto* input_transpose_node1 = context.graph_view->GetNode(
2425 "maxpoolv2-0-TransposeNHWCToNCHW-LayoutOptimizer");
2426 ASSERT_NE(input_transpose_node1, nullptr);
2427 auto* input_transpose_node2 = context.graph_view->GetNode(
2428 "maxpoolv2-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
2429 ASSERT_NE(input_transpose_node2, nullptr);
2430 auto* input_transpose_node3 = context.graph_view->GetNode(
2431 "maxpoolv2-2-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
2432 ASSERT_NE(input_transpose_node3, nullptr);
2433
2434 auto* updated_maxpool = context.graph_view->GetNode("maxpoolv2");
2435 ASSERT_NE(updated_maxpool, nullptr);
2436 ASSERT_EQ(updated_maxpool->NumRegularFanins(), 3);
2437 VerifyRegularFaninMatch(updated_maxpool, 0, input_transpose_node1->GetName(),
2438 0);
2439 VerifyRegularFaninMatch(updated_maxpool, 1, input_transpose_node2->GetName(),
2440 0);
2441 VerifyRegularFaninMatch(updated_maxpool, 2, input_transpose_node3->GetName(),
2442 0);
2443
2444 auto* output_transpose_node = context.graph_view->GetNode(
2445 "maxpoolv2-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2446 ASSERT_NE(output_transpose_node, nullptr);
2447
2448 auto* z_output_node = context.graph_view->GetNode("z");
2449 ASSERT_NE(z_output_node, nullptr);
2450 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2451 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2452 0);
2453 }
2454
TEST_F(TransposerTest,MaxPoolGradV2Transposer)2455 TEST_F(TransposerTest, MaxPoolGradV2Transposer) {
2456 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2457 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2458 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2459 for (bool use_grad_grad : {false, true}) {
2460 GrapplerItem item;
2461 Scope scope = Scope::NewRootScope();
2462 auto orig_input =
2463 ops::RandomUniform(scope.WithOpName("orig_input"),
2464 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2465 auto orig_output =
2466 ops::RandomUniform(scope.WithOpName("orig_output"),
2467 {kBatchSize, use_grad_grad ? kOutHeight : kHeight,
2468 use_grad_grad ? kOutWidth : kWidth, kDepthIn},
2469 DT_FLOAT);
2470 auto grad =
2471 ops::RandomUniform(scope.WithOpName("grad_input"),
2472 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2473 auto ksize =
2474 ops::Const(scope.WithOpName("ksize"), {1, kKernel, kKernel, 1});
2475 auto strides =
2476 ops::Const(scope.WithOpName("strides"), {1, kKernel, kKernel, 1});
2477 Output maxpoolgrad_op;
2478 if (use_grad_grad) {
2479 maxpoolgrad_op = ops::MaxPoolGradGradV2(
2480 scope.WithOpName("maxpoolgradv2").WithDevice("/device:GPU:0"),
2481 orig_input, orig_output, grad, ksize, strides, "VALID");
2482 } else {
2483 maxpoolgrad_op = ops::MaxPoolGradV2(
2484 scope.WithOpName("maxpoolgradv2").WithDevice("/device:GPU:0"),
2485 orig_input, orig_output, grad, ksize, strides, "VALID");
2486 }
2487 auto z = ops::Identity(scope.WithOpName("z"), maxpoolgrad_op);
2488 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2489 TransposeContext context;
2490 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2491 item, virtual_cluster_.get(), &context));
2492 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2493
2494 MaxPoolGradV2Transposer maxpoolgrad_transposer;
2495 auto* maxpoolgrad = context.graph_view->GetNode("maxpoolgradv2");
2496 ASSERT_NE(maxpoolgrad, nullptr);
2497 TF_ASSERT_OK(maxpoolgrad_transposer.TransposeNode(&context, maxpoolgrad));
2498
2499 auto* orig_input_transpose_node = context.graph_view->GetNode(
2500 "maxpoolgradv2-0-TransposeNHWCToNCHW-LayoutOptimizer");
2501 ASSERT_NE(orig_input_transpose_node, nullptr);
2502 auto* orig_output_transpose_node = context.graph_view->GetNode(
2503 "maxpoolgradv2-1-TransposeNHWCToNCHW-LayoutOptimizer");
2504 ASSERT_NE(orig_output_transpose_node, nullptr);
2505 auto* grad_input_transpose_node = context.graph_view->GetNode(
2506 "maxpoolgradv2-2-TransposeNHWCToNCHW-LayoutOptimizer");
2507 ASSERT_NE(grad_input_transpose_node, nullptr);
2508 auto* size_node = context.graph_view->GetNode(
2509 "maxpoolgradv2-3-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
2510 ASSERT_NE(size_node, nullptr);
2511 auto* stride_node = context.graph_view->GetNode(
2512 "maxpoolgradv2-4-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
2513 ASSERT_NE(stride_node, nullptr);
2514
2515 auto* updated_maxpoolgrad = context.graph_view->GetNode("maxpoolgradv2");
2516 ASSERT_NE(updated_maxpoolgrad, nullptr);
2517 ASSERT_EQ(updated_maxpoolgrad->NumRegularFanins(), 5);
2518 VerifyRegularFaninMatch(updated_maxpoolgrad, 0,
2519 orig_input_transpose_node->GetName(), 0);
2520 VerifyRegularFaninMatch(updated_maxpoolgrad, 1,
2521 orig_output_transpose_node->GetName(), 0);
2522 VerifyRegularFaninMatch(updated_maxpoolgrad, 2,
2523 grad_input_transpose_node->GetName(), 0);
2524 VerifyRegularFaninMatch(updated_maxpoolgrad, 3, size_node->GetName(), 0);
2525 VerifyRegularFaninMatch(updated_maxpoolgrad, 4, stride_node->GetName(), 0);
2526
2527 auto* output_transpose_node = context.graph_view->GetNode(
2528 "maxpoolgradv2-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2529 ASSERT_NE(output_transpose_node, nullptr);
2530
2531 auto* z_output_node = context.graph_view->GetNode("z");
2532 ASSERT_NE(z_output_node, nullptr);
2533 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2534 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2535 0);
2536 }
2537 }
2538
TEST_F(TransposerTest,BinaryOpTransposerAdd)2539 TEST_F(TransposerTest, BinaryOpTransposerAdd) {
2540 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2541 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2542 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2543 GrapplerItem item;
2544 Scope scope = Scope::NewRootScope();
2545 auto input =
2546 ops::RandomUniform(scope.WithOpName("input"),
2547 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2548 auto filter =
2549 ops::RandomUniform(scope.WithOpName("filter"),
2550 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2551 auto conv2d = ops::Conv2D(
2552 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2553 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2554 auto a = ops::RandomUniform(scope.WithOpName("a"), {1}, DT_FLOAT);
2555 auto add =
2556 ops::Add(scope.WithOpName("Add").WithDevice("/device:GPU:0"), a, conv2d);
2557 auto z = ops::Identity(scope.WithOpName("z"), add);
2558 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2559 TransposeContext context;
2560 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2561 item, virtual_cluster_.get(), &context));
2562 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2563
2564 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2565 auto* c2d = context.graph_view->GetNode("conv2d");
2566 ASSERT_NE(c2d, nullptr);
2567 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2568
2569 auto* addop = context.graph_view->GetNode("Add");
2570 ASSERT_NE(addop, nullptr);
2571 BinaryOpTransposer binaryop_transposer;
2572 TF_ASSERT_OK(binaryop_transposer.TransposeNode(&context, addop));
2573
2574 auto* input_const_node =
2575 context.graph_view->GetNode("Add-0-ReshapeConst-LayoutOptimizer");
2576 ASSERT_NE(input_const_node, nullptr);
2577 EXPECT_EQ(input_const_node->NumRegularFanins(), 0);
2578
2579 auto* input_reshape_node =
2580 context.graph_view->GetNode("Add-0-ReshapeNHWCToNCHW-LayoutOptimizer");
2581 ASSERT_NE(input_reshape_node, nullptr);
2582 ASSERT_EQ(input_reshape_node->NumRegularFanins(), 2);
2583 VerifyRegularFaninMatch(input_reshape_node, 0, "a", 0);
2584 VerifyRegularFaninMatch(input_reshape_node, 1, input_const_node->GetName(),
2585 0);
2586
2587 auto* input_transpose_node =
2588 context.graph_view->GetNode("Add-1-TransposeNHWCToNCHW-LayoutOptimizer");
2589 ASSERT_NE(input_transpose_node, nullptr);
2590 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
2591 VerifyRegularFaninMatch(input_transpose_node, 0,
2592 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2593
2594 auto* updated_add = context.graph_view->GetNode("Add");
2595 ASSERT_NE(updated_add, nullptr);
2596 ASSERT_EQ(updated_add->NumRegularFanins(), 2);
2597 VerifyRegularFaninMatch(updated_add, 0, input_reshape_node->GetName(), 0);
2598 VerifyRegularFaninMatch(updated_add, 1, input_transpose_node->GetName(), 0);
2599
2600 auto* output_transpose_node = context.graph_view->GetNode(
2601 "Add-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2602 ASSERT_NE(output_transpose_node, nullptr);
2603
2604 auto* z_output_node = context.graph_view->GetNode("z");
2605 ASSERT_NE(z_output_node, nullptr);
2606 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2607 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2608 0);
2609 }
2610
TEST_F(TransposerTest,BinaryOpTransposerMul)2611 TEST_F(TransposerTest, BinaryOpTransposerMul) {
2612 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2613 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2614 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2615 GrapplerItem item;
2616 Scope scope = Scope::NewRootScope();
2617 auto input =
2618 ops::RandomUniform(scope.WithOpName("input"),
2619 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2620 auto filter =
2621 ops::RandomUniform(scope.WithOpName("filter"),
2622 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2623 auto conv2d = ops::Conv2D(
2624 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2625 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2626 auto a = ops::RandomUniform(scope.WithOpName("a"), {1}, DT_FLOAT);
2627 auto mul =
2628 ops::Mul(scope.WithOpName("Mul").WithDevice("/device:GPU:0"), conv2d, a);
2629 auto z = ops::Identity(scope.WithOpName("z"), mul);
2630 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2631 TransposeContext context;
2632 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2633 item, virtual_cluster_.get(), &context));
2634 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2635
2636 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2637 auto* c2d = context.graph_view->GetNode("conv2d");
2638 ASSERT_NE(c2d, nullptr);
2639 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2640
2641 auto* mulop = context.graph_view->GetNode("Mul");
2642 ASSERT_NE(mulop, nullptr);
2643 BinaryOpTransposer binaryop_transposer;
2644 TF_ASSERT_OK(binaryop_transposer.TransposeNode(&context, mulop));
2645
2646 auto* input_const_node =
2647 context.graph_view->GetNode("Mul-1-ReshapeConst-LayoutOptimizer");
2648 ASSERT_NE(input_const_node, nullptr);
2649 EXPECT_EQ(input_const_node->NumRegularFanins(), 0);
2650
2651 auto* input_reshape_node =
2652 context.graph_view->GetNode("Mul-1-ReshapeNHWCToNCHW-LayoutOptimizer");
2653 ASSERT_NE(input_reshape_node, nullptr);
2654 ASSERT_EQ(input_reshape_node->NumRegularFanins(), 2);
2655 VerifyRegularFaninMatch(input_reshape_node, 0, "a", 0);
2656 VerifyRegularFaninMatch(input_reshape_node, 1, input_const_node->GetName(),
2657 0);
2658
2659 auto* input_transpose_node =
2660 context.graph_view->GetNode("Mul-0-TransposeNHWCToNCHW-LayoutOptimizer");
2661 ASSERT_NE(input_transpose_node, nullptr);
2662 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
2663 VerifyRegularFaninMatch(input_transpose_node, 0,
2664 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2665
2666 auto* updated_mul = context.graph_view->GetNode("Mul");
2667 ASSERT_NE(updated_mul, nullptr);
2668 ASSERT_EQ(updated_mul->NumRegularFanins(), 2);
2669 VerifyRegularFaninMatch(updated_mul, 1, input_reshape_node->GetName(), 0);
2670 VerifyRegularFaninMatch(updated_mul, 0, input_transpose_node->GetName(), 0);
2671
2672 auto* output_transpose_node = context.graph_view->GetNode(
2673 "Mul-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2674 ASSERT_NE(output_transpose_node, nullptr);
2675
2676 auto* z_output_node = context.graph_view->GetNode("z");
2677 ASSERT_NE(z_output_node, nullptr);
2678 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2679 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2680 0);
2681 }
2682
TEST_F(TransposerTest,BinaryOpTransposerPolygamma)2683 TEST_F(TransposerTest, BinaryOpTransposerPolygamma) {
2684 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2685 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2686 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2687 GrapplerItem item;
2688 Scope scope = Scope::NewRootScope();
2689 auto input =
2690 ops::RandomUniform(scope.WithOpName("input"),
2691 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2692 auto filter =
2693 ops::RandomUniform(scope.WithOpName("filter"),
2694 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2695 auto conv2d = ops::Conv2D(
2696 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2697 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2698 auto a = ops::RandomUniform(scope.WithOpName("a"),
2699 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
2700
2701 auto polygamma = ops::Polygamma(
2702 scope.WithOpName("polygamma").WithDevice("/device:GPU:0"), conv2d, a);
2703 auto z = ops::Identity(scope.WithOpName("z"), polygamma);
2704 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2705 TransposeContext context;
2706 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2707 item, virtual_cluster_.get(), &context));
2708 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2709
2710 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2711 auto* c2d = context.graph_view->GetNode("conv2d");
2712 ASSERT_NE(c2d, nullptr);
2713 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2714
2715 BinaryOpTransposer binaryop_transposer;
2716 auto* polygamma_op = context.graph_view->GetNode("polygamma");
2717 ASSERT_NE(polygamma_op, nullptr);
2718 TF_ASSERT_OK(binaryop_transposer.TransposeNode(&context, polygamma_op));
2719
2720 auto* input_transpose_node1 = context.graph_view->GetNode(
2721 "polygamma-0-TransposeNHWCToNCHW-LayoutOptimizer");
2722 ASSERT_NE(input_transpose_node1, nullptr);
2723 ASSERT_EQ(input_transpose_node1->NumRegularFanins(), 2);
2724 VerifyRegularFaninMatch(input_transpose_node1, 0,
2725 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2726
2727 auto* updated_polygamma = context.graph_view->GetNode("polygamma");
2728 ASSERT_NE(updated_polygamma, nullptr);
2729 ASSERT_EQ(updated_polygamma->NumRegularFanins(), 2);
2730 VerifyRegularFaninMatch(updated_polygamma, 0,
2731 input_transpose_node1->GetName(), 0);
2732
2733 auto* output_transpose_node = context.graph_view->GetNode(
2734 "polygamma-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2735 ASSERT_NE(output_transpose_node, nullptr);
2736
2737 auto* z_output_node = context.graph_view->GetNode("z");
2738 ASSERT_NE(z_output_node, nullptr);
2739 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2740 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2741 0);
2742 }
2743
CreateConcatV1Op(const Scope & scope,const InputList & tensors,const Input & concat_axis,Output * output)2744 bool CreateConcatV1Op(const Scope& scope, const InputList& tensors,
2745 const Input& concat_axis, Output* output) {
2746 if (!scope.ok()) {
2747 return false;
2748 }
2749 auto values = ops::AsNodeOutList(scope, tensors);
2750 if (!scope.ok()) {
2751 return false;
2752 }
2753 auto axis = ops::AsNodeOut(scope, concat_axis);
2754 if (!scope.ok()) {
2755 return false;
2756 }
2757 Node* ret;
2758 const auto unique_name = scope.GetUniqueNameForOp("Concat");
2759 auto builder = NodeBuilder(unique_name, "Concat").Input(axis).Input(values);
2760 scope.UpdateBuilder(&builder);
2761 scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
2762 if (!scope.ok()) {
2763 return false;
2764 }
2765 scope.UpdateStatus(scope.DoShapeInference(ret));
2766 *output = Output(ret, 0);
2767 return true;
2768 }
2769
TEST_F(TransposerTest,ConcatOpTransposerConcat)2770 TEST_F(TransposerTest, ConcatOpTransposerConcat) {
2771 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2772 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2773 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2774 GrapplerItem item;
2775 Scope scope = Scope::NewRootScope();
2776 Output input_1 = ops::RandomUniform(scope.WithOpName("input_1"),
2777 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
2778 Output input_2 = ops::RandomUniform(scope.WithOpName("input_2"),
2779 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
2780 auto input =
2781 ops::RandomUniform(scope.WithOpName("input"),
2782 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2783 auto filter =
2784 ops::RandomUniform(scope.WithOpName("filter"),
2785 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2786 Output conv2d = ops::Conv2D(
2787 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2788 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2789 auto axis = ops::Const(scope.WithOpName("axis"), 2, {});
2790 Output concat_op;
2791 ASSERT_TRUE(
2792 CreateConcatV1Op(scope.WithOpName("concat").WithDevice("/device:GPU:0"),
2793 {input_1, input_2, conv2d}, axis, &concat_op));
2794 auto z = ops::Identity(scope.WithOpName("z"), concat_op);
2795 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2796
2797 TransposeContext context;
2798 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2799 item, virtual_cluster_.get(), &context));
2800 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2801
2802 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2803 auto* c2d = context.graph_view->GetNode("conv2d");
2804 ASSERT_NE(c2d, nullptr);
2805 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2806
2807 ConcatOpTransposer concat_transposer;
2808 auto* concat = context.graph_view->GetNode("concat");
2809 ASSERT_NE(concat, nullptr);
2810 TF_ASSERT_OK(concat_transposer.TransposeNode(&context, concat));
2811
2812 auto* conv2d_transpose_node = context.graph_view->GetNode(
2813 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2814 ASSERT_NE(conv2d_transpose_node, nullptr);
2815 auto* conv2d_concat_input_node = context.graph_view->GetNode(
2816 "concat-3-TransposeNHWCToNCHW-LayoutOptimizer");
2817 ASSERT_NE(conv2d_concat_input_node, nullptr);
2818 ASSERT_EQ(conv2d_concat_input_node->NumRegularFanins(), 2);
2819 VerifyRegularFaninMatch(conv2d_concat_input_node, 0,
2820 conv2d_transpose_node->GetName(), 0);
2821
2822 auto* axis_dim_node = context.graph_view->GetNode(
2823 "concat-0-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
2824 ASSERT_NE(axis_dim_node, nullptr);
2825
2826 auto* updated_concat = context.graph_view->GetNode("concat");
2827 ASSERT_NE(updated_concat, nullptr);
2828 ASSERT_EQ(updated_concat->NumRegularFanins(), 4);
2829 VerifyRegularFaninMatch(updated_concat, 0, axis_dim_node->GetName(), 0);
2830 VerifyRegularFaninMatch(updated_concat, 1,
2831 "concat-1-TransposeNHWCToNCHW-LayoutOptimizer", 0);
2832 VerifyRegularFaninMatch(updated_concat, 2,
2833 "concat-2-TransposeNHWCToNCHW-LayoutOptimizer", 0);
2834 VerifyRegularFaninMatch(updated_concat, 3,
2835 conv2d_concat_input_node->GetName(), 0);
2836
2837 auto* output_transpose_node = context.graph_view->GetNode(
2838 "concat-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2839 ASSERT_NE(output_transpose_node, nullptr);
2840
2841 auto* z_output_node = context.graph_view->GetNode("z");
2842 ASSERT_NE(z_output_node, nullptr);
2843 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2844 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2845 0);
2846 }
2847
TEST_F(TransposerTest,ConcatOpTransposerConcatV2)2848 TEST_F(TransposerTest, ConcatOpTransposerConcatV2) {
2849 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2850 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2851 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2852 GrapplerItem item;
2853 Scope scope = Scope::NewRootScope();
2854 Output input_1 = ops::RandomUniform(scope.WithOpName("input_1"),
2855 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
2856 Output input_2 = ops::RandomUniform(scope.WithOpName("input_2"),
2857 {kBatchSize, 5, 3, kDepthOut}, DT_FLOAT);
2858 auto input =
2859 ops::RandomUniform(scope.WithOpName("input"),
2860 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2861 auto filter =
2862 ops::RandomUniform(scope.WithOpName("filter"),
2863 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2864 Output conv2d = ops::Conv2D(
2865 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2866 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2867 auto axis = ops::Const(scope.WithOpName("axis"), 2, {});
2868 auto concat_op =
2869 ops::Concat(scope.WithOpName("concat").WithDevice("/device:GPU:0"),
2870 {input_1, input_2, conv2d}, axis);
2871 auto z = ops::Identity(scope.WithOpName("z"), concat_op);
2872 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2873
2874 TransposeContext context;
2875 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2876 item, virtual_cluster_.get(), &context));
2877 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2878
2879 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2880 auto* c2d = context.graph_view->GetNode("conv2d");
2881 ASSERT_NE(c2d, nullptr);
2882 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2883
2884 ConcatOpTransposer concat_transposer;
2885 auto* concat = context.graph_view->GetNode("concat");
2886 ASSERT_NE(concat, nullptr);
2887 TF_ASSERT_OK(concat_transposer.TransposeNode(&context, concat));
2888
2889 auto* conv2d_transpose_node = context.graph_view->GetNode(
2890 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2891 ASSERT_NE(conv2d_transpose_node, nullptr);
2892 auto* conv2d_concat_input_node = context.graph_view->GetNode(
2893 "concat-2-TransposeNHWCToNCHW-LayoutOptimizer");
2894 ASSERT_NE(conv2d_concat_input_node, nullptr);
2895 ASSERT_EQ(conv2d_concat_input_node->NumRegularFanins(), 2);
2896 VerifyRegularFaninMatch(conv2d_concat_input_node, 0,
2897 conv2d_transpose_node->GetName(), 0);
2898
2899 auto* axis_dim_node = context.graph_view->GetNode(
2900 "concat-3-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
2901 ASSERT_NE(axis_dim_node, nullptr);
2902
2903 auto* updated_concat = context.graph_view->GetNode("concat");
2904 ASSERT_NE(updated_concat, nullptr);
2905 ASSERT_EQ(updated_concat->NumRegularFanins(), 4);
2906 VerifyRegularFaninMatch(updated_concat, 0,
2907 "concat-0-TransposeNHWCToNCHW-LayoutOptimizer", 0);
2908 VerifyRegularFaninMatch(updated_concat, 1,
2909 "concat-1-TransposeNHWCToNCHW-LayoutOptimizer", 0);
2910 VerifyRegularFaninMatch(updated_concat, 2,
2911 conv2d_concat_input_node->GetName(), 0);
2912 VerifyRegularFaninMatch(updated_concat, 3, axis_dim_node->GetName(), 0);
2913
2914 auto* output_transpose_node = context.graph_view->GetNode(
2915 "concat-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2916 ASSERT_NE(output_transpose_node, nullptr);
2917
2918 auto* z_output_node = context.graph_view->GetNode("z");
2919 ASSERT_NE(z_output_node, nullptr);
2920 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2921 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2922 0);
2923 }
2924
TEST_F(TransposerTest,ReverseV2Transposer)2925 TEST_F(TransposerTest, ReverseV2Transposer) {
2926 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2927 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2928 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2929 GrapplerItem item;
2930 Scope scope = Scope::NewRootScope();
2931
2932 auto input =
2933 ops::RandomUniform(scope.WithOpName("input"),
2934 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
2935 auto filter =
2936 ops::RandomUniform(scope.WithOpName("filter"),
2937 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
2938 Output conv2d = ops::Conv2D(
2939 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
2940 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
2941 auto axis = ops::Const(scope.WithOpName("axis"), {0, 3}, {2});
2942 auto reverse_op = ops::Reverse(
2943 scope.WithOpName("reverse_v2").WithDevice("/device:GPU:0"), conv2d, axis);
2944 auto z = ops::Identity(scope.WithOpName("z"), reverse_op);
2945 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
2946
2947 TransposeContext context;
2948 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
2949 item, virtual_cluster_.get(), &context));
2950 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
2951
2952 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
2953 auto* c2d = context.graph_view->GetNode("conv2d");
2954 ASSERT_NE(c2d, nullptr);
2955 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
2956
2957 ReverseV2Transposer reverse_v2_transposer;
2958 auto* reverse_v2 = context.graph_view->GetNode("reverse_v2");
2959 ASSERT_NE(reverse_v2, nullptr);
2960 TF_ASSERT_OK(reverse_v2_transposer.TransposeNode(&context, reverse_v2));
2961
2962 auto* input_transpose_node = context.graph_view->GetNode(
2963 "reverse_v2-0-TransposeNHWCToNCHW-LayoutOptimizer");
2964 ASSERT_NE(input_transpose_node, nullptr);
2965 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
2966 VerifyRegularFaninMatch(input_transpose_node, 0,
2967 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
2968
2969 auto* axis_node = context.graph_view->GetNode(
2970 "reverse_v2-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
2971 ASSERT_NE(axis_node, nullptr);
2972 ASSERT_EQ(axis_node->NumRegularFanins(), 1);
2973 VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
2974
2975 auto* updated_reverse_v2 = context.graph_view->GetNode("reverse_v2");
2976 ASSERT_NE(updated_reverse_v2, nullptr);
2977 ASSERT_EQ(updated_reverse_v2->NumRegularFanins(), 2);
2978 VerifyRegularFaninMatch(updated_reverse_v2, 0,
2979 input_transpose_node->GetName(), 0);
2980 VerifyRegularFaninMatch(updated_reverse_v2, 1, axis_node->GetName(), 0);
2981
2982 auto* output_transpose_node = context.graph_view->GetNode(
2983 "reverse_v2-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
2984 ASSERT_NE(output_transpose_node, nullptr);
2985
2986 auto* z_output_node = context.graph_view->GetNode("z");
2987 ASSERT_NE(z_output_node, nullptr);
2988 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
2989 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
2990 0);
2991 }
2992
TEST_F(TransposerTest,TileTransposer)2993 TEST_F(TransposerTest, TileTransposer) {
2994 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2995 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
2996 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
2997 GrapplerItem item;
2998 Scope scope = Scope::NewRootScope();
2999
3000 auto input =
3001 ops::RandomUniform(scope.WithOpName("input"),
3002 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3003 auto filter =
3004 ops::RandomUniform(scope.WithOpName("filter"),
3005 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3006 Output conv2d = ops::Conv2D(
3007 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3008 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3009 auto multiple = ops::Const(scope.WithOpName("multiple"), {1, 1, 2, 3}, {4});
3010 auto tile_op = ops::Tile(scope.WithOpName("tile").WithDevice("/device:GPU:0"),
3011 conv2d, multiple);
3012 auto z = ops::Identity(scope.WithOpName("z"), tile_op);
3013 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3014
3015 TransposeContext context;
3016 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3017 item, virtual_cluster_.get(), &context));
3018 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3019
3020 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3021 auto* c2d = context.graph_view->GetNode("conv2d");
3022 ASSERT_NE(c2d, nullptr);
3023 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3024
3025 TileTransposer tile_transposer;
3026 auto* tile = context.graph_view->GetNode("tile");
3027 ASSERT_NE(tile, nullptr);
3028 TF_ASSERT_OK(tile_transposer.TransposeNode(&context, tile));
3029
3030 auto* input_transpose_node =
3031 context.graph_view->GetNode("tile-0-TransposeNHWCToNCHW-LayoutOptimizer");
3032 ASSERT_NE(input_transpose_node, nullptr);
3033 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
3034 VerifyRegularFaninMatch(input_transpose_node, 0,
3035 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3036
3037 auto* multiple_node = context.graph_view->GetNode(
3038 "tile-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3039 ASSERT_NE(multiple_node, nullptr);
3040 ASSERT_EQ(multiple_node->NumRegularFanins(), 1);
3041 VerifyRegularFaninMatch(multiple_node, 0, "multiple", 0);
3042
3043 auto* updated_tile = context.graph_view->GetNode("tile");
3044 ASSERT_NE(updated_tile, nullptr);
3045 ASSERT_EQ(updated_tile->NumRegularFanins(), 2);
3046 VerifyRegularFaninMatch(updated_tile, 0, input_transpose_node->GetName(), 0);
3047 VerifyRegularFaninMatch(updated_tile, 1, multiple_node->GetName(), 0);
3048
3049 auto* output_transpose_node = context.graph_view->GetNode(
3050 "tile-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3051 ASSERT_NE(output_transpose_node, nullptr);
3052
3053 auto* z_output_node = context.graph_view->GetNode("z");
3054 ASSERT_NE(z_output_node, nullptr);
3055 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3056 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
3057 0);
3058 }
3059
TEST_F(TransposerTest,ShapeTransposer)3060 TEST_F(TransposerTest, ShapeTransposer) {
3061 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3062 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3063 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3064 GrapplerItem item;
3065 Scope scope = Scope::NewRootScope();
3066 auto input =
3067 ops::RandomUniform(scope.WithOpName("input"),
3068 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3069 auto filter =
3070 ops::RandomUniform(scope.WithOpName("filter"),
3071 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3072 Output conv2d = ops::Conv2D(
3073 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3074 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3075 auto shape =
3076 ops::Shape(scope.WithOpName("shape").WithDevice("/device:GPU:0"), conv2d);
3077 auto z = ops::Identity(scope.WithOpName("z"), shape);
3078
3079 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3080
3081 TransposeContext context;
3082 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3083 item, virtual_cluster_.get(), &context));
3084 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3085
3086 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3087 auto* c2d = context.graph_view->GetNode("conv2d");
3088 ASSERT_NE(c2d, nullptr);
3089 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3090
3091 ShapeTransposer shape_transposer;
3092 auto* shape_node = context.graph_view->GetNode("shape");
3093 ASSERT_NE(shape_node, nullptr);
3094 TF_ASSERT_OK(shape_transposer.TransposeNode(&context, shape_node));
3095
3096 auto* conv2d_transpose_node = context.graph_view->GetNode(
3097 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3098 ASSERT_NE(conv2d_transpose_node, nullptr);
3099
3100 auto* shape_input_node = context.graph_view->GetNode(
3101 "shape-0-TransposeNHWCToNCHW-LayoutOptimizer");
3102 ASSERT_NE(shape_input_node, nullptr);
3103 ASSERT_EQ(shape_input_node->NumRegularFanins(), 2);
3104 VerifyRegularFaninMatch(shape_input_node, 0, conv2d_transpose_node->GetName(),
3105 0);
3106
3107 auto* output_vec_perm_node = context.graph_view->GetNode(
3108 "shape-0-0-DataFormatVecPermuteNCHWToNHWC-LayoutOptimizer");
3109 ASSERT_NE(output_vec_perm_node, nullptr);
3110
3111 auto* z_output_node = context.graph_view->GetNode("z");
3112 ASSERT_NE(z_output_node, nullptr);
3113 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3114 VerifyRegularFaninMatch(z_output_node, 0, output_vec_perm_node->GetName(), 0);
3115 }
3116
TEST_F(TransposerTest,ShapeNTransposer)3117 TEST_F(TransposerTest, ShapeNTransposer) {
3118 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3119 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3120 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3121 GrapplerItem item;
3122 Scope scope = Scope::NewRootScope();
3123 auto input =
3124 ops::RandomUniform(scope.WithOpName("input"),
3125 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3126 auto filter =
3127 ops::RandomUniform(scope.WithOpName("filter"),
3128 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3129 Output conv2d_1 = ops::Conv2D(
3130 scope.WithOpName("conv2d_1").WithDevice("/device:GPU:0"), input, filter,
3131 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3132 Output conv2d_2 = ops::Conv2D(
3133 scope.WithOpName("conv2d_2").WithDevice("/device:GPU:0"), input, filter,
3134 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3135 Output conv2d_3 = ops::Conv2D(
3136 scope.WithOpName("conv2d_3").WithDevice("/device:GPU:0"), input, filter,
3137 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3138 auto shape =
3139 ops::ShapeN(scope.WithOpName("shape").WithDevice("/device:GPU:0"),
3140 {conv2d_1, conv2d_2, conv2d_3});
3141 auto z_1 = ops::Identity(scope.WithOpName("z_1"), shape.output[0]);
3142 auto z_2 = ops::Identity(scope.WithOpName("z_2"), shape.output[1]);
3143 auto z_3 = ops::Identity(scope.WithOpName("z_3"), shape.output[2]);
3144
3145 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3146
3147 TransposeContext context;
3148 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3149 item, virtual_cluster_.get(), &context));
3150 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3151
3152 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3153 auto* c2d_1 = context.graph_view->GetNode("conv2d_1");
3154 ASSERT_NE(c2d_1, nullptr);
3155 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d_1));
3156 auto* c2d_2 = context.graph_view->GetNode("conv2d_2");
3157 ASSERT_NE(c2d_2, nullptr);
3158 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d_2));
3159
3160 ShapeNTransposer shape_transposer;
3161 auto* shape_node = context.graph_view->GetNode("shape");
3162 ASSERT_NE(shape_node, nullptr);
3163 TF_ASSERT_OK(shape_transposer.TransposeNode(&context, shape_node));
3164
3165 auto* conv2d_1_transpose_node = context.graph_view->GetNode(
3166 "conv2d_1-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3167 ASSERT_NE(conv2d_1_transpose_node, nullptr);
3168 auto* conv2d_2_transpose_node = context.graph_view->GetNode(
3169 "conv2d_2-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3170 ASSERT_NE(conv2d_2_transpose_node, nullptr);
3171
3172 auto* shape_input_1_node = context.graph_view->GetNode(
3173 "shape-0-TransposeNHWCToNCHW-LayoutOptimizer");
3174 ASSERT_NE(shape_input_1_node, nullptr);
3175 ASSERT_EQ(shape_input_1_node->NumRegularFanins(), 2);
3176 VerifyRegularFaninMatch(shape_input_1_node, 0,
3177 conv2d_1_transpose_node->GetName(), 0);
3178
3179 auto* shape_input_2_node = context.graph_view->GetNode(
3180 "shape-1-TransposeNHWCToNCHW-LayoutOptimizer");
3181 ASSERT_NE(shape_input_2_node, nullptr);
3182 ASSERT_EQ(shape_input_2_node->NumRegularFanins(), 2);
3183 VerifyRegularFaninMatch(shape_input_2_node, 0,
3184 conv2d_2_transpose_node->GetName(), 0);
3185
3186 auto* updated_shape_node = context.graph_view->GetNode("shape");
3187 ASSERT_NE(updated_shape_node, nullptr);
3188 ASSERT_EQ(updated_shape_node->NumRegularFanins(), 3);
3189 VerifyRegularFaninMatch(updated_shape_node, 0, shape_input_1_node->GetName(),
3190 0);
3191 VerifyRegularFaninMatch(updated_shape_node, 1, shape_input_2_node->GetName(),
3192 0);
3193 VerifyRegularFaninMatch(updated_shape_node, 2, "conv2d_3", 0);
3194
3195 auto* output_vec_perm_node_1 = context.graph_view->GetNode(
3196 "shape-0-0-DataFormatVecPermuteNCHWToNHWC-LayoutOptimizer");
3197 ASSERT_NE(output_vec_perm_node_1, nullptr);
3198 auto* output_vec_perm_node_2 = context.graph_view->GetNode(
3199 "shape-1-0-DataFormatVecPermuteNCHWToNHWC-LayoutOptimizer");
3200 ASSERT_NE(output_vec_perm_node_2, nullptr);
3201
3202 auto* z_output_node_1 = context.graph_view->GetNode("z_1");
3203 ASSERT_NE(z_output_node_1, nullptr);
3204 ASSERT_EQ(z_output_node_1->NumRegularFanins(), 1);
3205 VerifyRegularFaninMatch(z_output_node_1, 0, output_vec_perm_node_1->GetName(),
3206 0);
3207
3208 auto* z_output_node_2 = context.graph_view->GetNode("z_2");
3209 ASSERT_NE(z_output_node_2, nullptr);
3210 ASSERT_EQ(z_output_node_2->NumRegularFanins(), 1);
3211 VerifyRegularFaninMatch(z_output_node_2, 0, output_vec_perm_node_2->GetName(),
3212 0);
3213
3214 auto* z_output_node_3 = context.graph_view->GetNode("z_3");
3215 ASSERT_NE(z_output_node_3, nullptr);
3216 ASSERT_EQ(z_output_node_3->NumRegularFanins(), 1);
3217 VerifyRegularFaninMatch(z_output_node_3, 0, updated_shape_node->GetName(), 2);
3218 }
3219
TEST_F(TransposerTest,FillOpTransposer)3220 TEST_F(TransposerTest, FillOpTransposer) {
3221 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3222 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3223 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3224 GrapplerItem item;
3225 Scope scope = Scope::NewRootScope();
3226 auto input =
3227 ops::RandomUniform(scope.WithOpName("input"),
3228 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3229 auto filter =
3230 ops::RandomUniform(scope.WithOpName("filter"),
3231 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3232 Output conv2d = ops::Conv2D(
3233 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3234 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3235 auto shape = ops::Shape(scope.WithOpName("conv2d"), conv2d);
3236 auto value = ops::Const(scope.WithOpName("value"), 0, {});
3237 auto fill = ops::Fill(scope.WithOpName("fill").WithDevice("/device:GPU:0"),
3238 shape, value);
3239 auto z = ops::Identity(scope.WithOpName("z"), fill);
3240 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3241
3242 TransposeContext context;
3243 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3244 item, virtual_cluster_.get(), &context));
3245 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3246
3247 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3248 auto* c2d = context.graph_view->GetNode("conv2d");
3249 ASSERT_NE(c2d, nullptr);
3250 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3251
3252 FillOpTransposer fill_op_transposer;
3253 auto* fill_node = context.graph_view->GetNode("fill");
3254 ASSERT_NE(fill_node, nullptr);
3255 TF_ASSERT_OK(fill_op_transposer.TransposeNode(&context, fill_node));
3256
3257 auto* input_node = context.graph_view->GetNode(
3258 "fill-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3259 ASSERT_NE(input_node, nullptr);
3260
3261 auto* updated_fill_node = context.graph_view->GetNode("fill");
3262 ASSERT_NE(updated_fill_node, nullptr);
3263 ASSERT_EQ(updated_fill_node->NumRegularFanins(), 2);
3264 VerifyRegularFaninMatch(updated_fill_node, 0, input_node->GetName(), 0);
3265 VerifyRegularFaninMatch(updated_fill_node, 1, "value", 0);
3266
3267 auto* output_node = context.graph_view->GetNode(
3268 "fill-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3269 ASSERT_NE(output_node, nullptr);
3270 ASSERT_EQ(output_node->NumRegularFanins(), 2);
3271 VerifyRegularFaninMatch(output_node, 0, updated_fill_node->GetName(), 0);
3272
3273 auto* z_node = context.graph_view->GetNode("z");
3274 ASSERT_NE(z_node, nullptr);
3275 ASSERT_EQ(z_node->NumRegularFanins(), 1);
3276 VerifyRegularFaninMatch(z_node, 0, output_node->GetName(), 0);
3277 }
3278
TEST_F(TransposerTest,SliceTransposer)3279 TEST_F(TransposerTest, SliceTransposer) {
3280 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3281 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3282 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3283 GrapplerItem item;
3284 Scope scope = Scope::NewRootScope();
3285
3286 auto input =
3287 ops::RandomUniform(scope.WithOpName("input"),
3288 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3289 auto filter =
3290 ops::RandomUniform(scope.WithOpName("filter"),
3291 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3292 Output conv2d = ops::Conv2D(
3293 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3294 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3295 auto begin = ops::Const(scope.WithOpName("begin"), {0, 0, 2, 1}, {4});
3296 auto size = ops::Const(scope.WithOpName("size"), {1, 1, 2, 3}, {4});
3297 auto slice_op =
3298 ops::Slice(scope.WithOpName("slice").WithDevice("/device:GPU:0"), conv2d,
3299 begin, size);
3300 auto z = ops::Identity(scope.WithOpName("z"), slice_op);
3301 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3302
3303 TransposeContext context;
3304 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3305 item, virtual_cluster_.get(), &context));
3306 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3307
3308 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3309 auto* c2d = context.graph_view->GetNode("conv2d");
3310 ASSERT_NE(c2d, nullptr);
3311 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3312
3313 SliceTransposer slice_transposer;
3314 auto* slice = context.graph_view->GetNode("slice");
3315 ASSERT_NE(slice, nullptr);
3316 TF_ASSERT_OK(slice_transposer.TransposeNode(&context, slice));
3317
3318 auto* input_transpose_node = context.graph_view->GetNode(
3319 "slice-0-TransposeNHWCToNCHW-LayoutOptimizer");
3320 ASSERT_NE(input_transpose_node, nullptr);
3321 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
3322 VerifyRegularFaninMatch(input_transpose_node, 0,
3323 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3324
3325 auto* begin_node = context.graph_view->GetNode(
3326 "slice-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3327 ASSERT_NE(begin_node, nullptr);
3328 ASSERT_EQ(begin_node->NumRegularFanins(), 1);
3329 VerifyRegularFaninMatch(begin_node, 0, "begin", 0);
3330
3331 auto* size_node = context.graph_view->GetNode(
3332 "slice-2-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3333 ASSERT_NE(size_node, nullptr);
3334 ASSERT_EQ(size_node->NumRegularFanins(), 1);
3335 VerifyRegularFaninMatch(size_node, 0, "size", 0);
3336
3337 auto* updated_slice_node = context.graph_view->GetNode("slice");
3338 ASSERT_NE(updated_slice_node, nullptr);
3339 ASSERT_EQ(updated_slice_node->NumRegularFanins(), 3);
3340 VerifyRegularFaninMatch(updated_slice_node, 0,
3341 input_transpose_node->GetName(), 0);
3342 VerifyRegularFaninMatch(updated_slice_node, 1, begin_node->GetName(), 0);
3343 VerifyRegularFaninMatch(updated_slice_node, 2, size_node->GetName(), 0);
3344
3345 auto* output_transpose_node = context.graph_view->GetNode(
3346 "slice-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3347 ASSERT_NE(output_transpose_node, nullptr);
3348
3349 auto* z_output_node = context.graph_view->GetNode("z");
3350 ASSERT_NE(z_output_node, nullptr);
3351 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3352 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
3353 0);
3354 }
3355
TEST_F(TransposerTest,SplitTransposer)3356 TEST_F(TransposerTest, SplitTransposer) {
3357 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3358 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3359 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3360 GrapplerItem item;
3361 Scope scope = Scope::NewRootScope();
3362
3363 auto input =
3364 ops::RandomUniform(scope.WithOpName("input"),
3365 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3366 auto filter =
3367 ops::RandomUniform(scope.WithOpName("filter"),
3368 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3369 Output conv2d = ops::Conv2D(
3370 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3371 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3372 auto axis = ops::Const(scope.WithOpName("axis"), 2, {});
3373 auto split_op = ops::Split(
3374 scope.WithOpName("split").WithDevice("/device:GPU:0"), axis, conv2d, 3);
3375 auto z_1 = ops::Identity(scope.WithOpName("z_1"), split_op.output[0]);
3376 auto z_2 = ops::Identity(scope.WithOpName("z_2"), split_op.output[1]);
3377 auto z_3 = ops::Identity(scope.WithOpName("z_3"), split_op.output[2]);
3378 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3379
3380 TransposeContext context;
3381 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3382 item, virtual_cluster_.get(), &context));
3383 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3384
3385 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3386 auto* c2d = context.graph_view->GetNode("conv2d");
3387 ASSERT_NE(c2d, nullptr);
3388 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3389
3390 SplitTransposer split_transposer;
3391 auto* split = context.graph_view->GetNode("split");
3392 ASSERT_NE(split, nullptr);
3393 TF_ASSERT_OK(split_transposer.TransposeNode(&context, split));
3394
3395 auto* input_transpose_node = context.graph_view->GetNode(
3396 "split-1-TransposeNHWCToNCHW-LayoutOptimizer");
3397 ASSERT_NE(input_transpose_node, nullptr);
3398 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
3399 VerifyRegularFaninMatch(input_transpose_node, 0,
3400 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3401
3402 auto* axis_node = context.graph_view->GetNode(
3403 "split-0-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
3404 ASSERT_NE(axis_node, nullptr);
3405 ASSERT_EQ(axis_node->NumRegularFanins(), 1);
3406 VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
3407
3408 auto* updated_split_node = context.graph_view->GetNode("split");
3409 ASSERT_NE(updated_split_node, nullptr);
3410 ASSERT_EQ(updated_split_node->NumRegularFanins(), 2);
3411 VerifyRegularFaninMatch(updated_split_node, 0, axis_node->GetName(), 0);
3412 VerifyRegularFaninMatch(updated_split_node, 1,
3413 input_transpose_node->GetName(), 0);
3414
3415 auto* output_transpose_node_1 = context.graph_view->GetNode(
3416 "split-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3417 ASSERT_NE(output_transpose_node_1, nullptr);
3418 auto* output_transpose_node_2 = context.graph_view->GetNode(
3419 "split-1-0-TransposeNCHWToNHWC-LayoutOptimizer");
3420 ASSERT_NE(output_transpose_node_2, nullptr);
3421 auto* output_transpose_node_3 = context.graph_view->GetNode(
3422 "split-2-0-TransposeNCHWToNHWC-LayoutOptimizer");
3423 ASSERT_NE(output_transpose_node_3, nullptr);
3424
3425 auto* z_output_node_1 = context.graph_view->GetNode("z_1");
3426 ASSERT_NE(z_output_node_1, nullptr);
3427 ASSERT_EQ(z_output_node_1->NumRegularFanins(), 1);
3428 VerifyRegularFaninMatch(z_output_node_1, 0,
3429 output_transpose_node_1->GetName(), 0);
3430 auto* z_output_node_2 = context.graph_view->GetNode("z_2");
3431 ASSERT_NE(z_output_node_2, nullptr);
3432 ASSERT_EQ(z_output_node_2->NumRegularFanins(), 1);
3433 VerifyRegularFaninMatch(z_output_node_2, 0,
3434 output_transpose_node_2->GetName(), 0);
3435 auto* z_output_node_3 = context.graph_view->GetNode("z_3");
3436 ASSERT_NE(z_output_node_3, nullptr);
3437 ASSERT_EQ(z_output_node_3->NumRegularFanins(), 1);
3438 VerifyRegularFaninMatch(z_output_node_3, 0,
3439 output_transpose_node_3->GetName(), 0);
3440 }
3441
TEST_F(TransposerTest,SplitVTransposer)3442 TEST_F(TransposerTest, SplitVTransposer) {
3443 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3444 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3445 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3446 GrapplerItem item;
3447 Scope scope = Scope::NewRootScope();
3448
3449 auto input =
3450 ops::RandomUniform(scope.WithOpName("input"),
3451 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3452 auto filter =
3453 ops::RandomUniform(scope.WithOpName("filter"),
3454 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3455 Output conv2d = ops::Conv2D(
3456 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3457 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3458 auto axis = ops::Const(scope.WithOpName("axis"), 1, {});
3459 auto size_splits =
3460 ops::Const(scope.WithOpName("size_splits"), {2, 2, 1}, {3});
3461 auto splitv_op =
3462 ops::SplitV(scope.WithOpName("splitv").WithDevice("/device:GPU:0"),
3463 conv2d, size_splits, axis, 3);
3464 auto z_1 = ops::Identity(scope.WithOpName("z_1"), splitv_op.output[0]);
3465 auto z_2 = ops::Identity(scope.WithOpName("z_2"), splitv_op.output[1]);
3466 auto z_3 = ops::Identity(scope.WithOpName("z_3"), splitv_op.output[2]);
3467 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3468
3469 TransposeContext context;
3470 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3471 item, virtual_cluster_.get(), &context));
3472 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3473
3474 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3475 auto* c2d = context.graph_view->GetNode("conv2d");
3476 ASSERT_NE(c2d, nullptr);
3477 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3478
3479 SplitVTransposer splitv_transposer;
3480 auto* splitv = context.graph_view->GetNode("splitv");
3481 ASSERT_NE(splitv, nullptr);
3482 TF_ASSERT_OK(splitv_transposer.TransposeNode(&context, splitv));
3483
3484 auto* input_transpose_node = context.graph_view->GetNode(
3485 "splitv-0-TransposeNHWCToNCHW-LayoutOptimizer");
3486 ASSERT_NE(input_transpose_node, nullptr);
3487 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
3488 VerifyRegularFaninMatch(input_transpose_node, 0,
3489 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3490
3491 auto* axis_node = context.graph_view->GetNode(
3492 "splitv-2-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
3493 ASSERT_NE(axis_node, nullptr);
3494 ASSERT_EQ(axis_node->NumRegularFanins(), 1);
3495 VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
3496
3497 auto* updated_splitv_node = context.graph_view->GetNode("splitv");
3498 ASSERT_NE(updated_splitv_node, nullptr);
3499 ASSERT_EQ(updated_splitv_node->NumRegularFanins(), 3);
3500 VerifyRegularFaninMatch(updated_splitv_node, 0,
3501 input_transpose_node->GetName(), 0);
3502 VerifyRegularFaninMatch(updated_splitv_node, 1, "size_splits", 0);
3503 VerifyRegularFaninMatch(updated_splitv_node, 2, axis_node->GetName(), 0);
3504
3505 auto* output_transpose_node_1 = context.graph_view->GetNode(
3506 "splitv-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3507 ASSERT_NE(output_transpose_node_1, nullptr);
3508 auto* output_transpose_node_2 = context.graph_view->GetNode(
3509 "splitv-1-0-TransposeNCHWToNHWC-LayoutOptimizer");
3510 ASSERT_NE(output_transpose_node_2, nullptr);
3511 auto* output_transpose_node_3 = context.graph_view->GetNode(
3512 "splitv-2-0-TransposeNCHWToNHWC-LayoutOptimizer");
3513 ASSERT_NE(output_transpose_node_3, nullptr);
3514
3515 auto* z_output_node_1 = context.graph_view->GetNode("z_1");
3516 ASSERT_NE(z_output_node_1, nullptr);
3517 ASSERT_EQ(z_output_node_1->NumRegularFanins(), 1);
3518 VerifyRegularFaninMatch(z_output_node_1, 0,
3519 output_transpose_node_1->GetName(), 0);
3520 auto* z_output_node_2 = context.graph_view->GetNode("z_2");
3521 ASSERT_NE(z_output_node_2, nullptr);
3522 ASSERT_EQ(z_output_node_2->NumRegularFanins(), 1);
3523 VerifyRegularFaninMatch(z_output_node_2, 0,
3524 output_transpose_node_2->GetName(), 0);
3525 auto* z_output_node_3 = context.graph_view->GetNode("z_3");
3526 ASSERT_NE(z_output_node_3, nullptr);
3527 ASSERT_EQ(z_output_node_3->NumRegularFanins(), 1);
3528 VerifyRegularFaninMatch(z_output_node_3, 0,
3529 output_transpose_node_3->GetName(), 0);
3530 }
3531
TEST_F(TransposerTest,StridedSliceTransposer)3532 TEST_F(TransposerTest, StridedSliceTransposer) {
3533 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3534 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3535 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3536 GrapplerItem item;
3537 Scope scope = Scope::NewRootScope();
3538
3539 auto input =
3540 ops::RandomUniform(scope.WithOpName("input"),
3541 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3542 auto filter =
3543 ops::RandomUniform(scope.WithOpName("filter"),
3544 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3545 Output conv2d = ops::Conv2D(
3546 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3547 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3548
3549 auto attrs = ops::StridedSlice::Attrs().BeginMask(0xB).EndMask(0x7);
3550
3551 auto begin = ops::Const(scope.WithOpName("begin"), {2, 0, 2, 1}, {4});
3552 auto end = ops::Const(scope.WithOpName("end"), {34, 4, 3, 1}, {4});
3553 auto strides = ops::Const(scope.WithOpName("strides"), {7, 2, 1, 1}, {4});
3554
3555 auto strided_slice_op = ops::StridedSlice(
3556 scope.WithOpName("stridedslice").WithDevice("/device:GPU:0"), conv2d,
3557 begin, end, strides, attrs);
3558 auto z = ops::Identity(scope.WithOpName("z"), strided_slice_op);
3559 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3560
3561 TransposeContext context;
3562 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3563 item, virtual_cluster_.get(), &context));
3564 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3565
3566 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3567 auto* c2d = context.graph_view->GetNode("conv2d");
3568 ASSERT_NE(c2d, nullptr);
3569 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3570
3571 StridedSliceTransposer stridedslice_transposer;
3572 auto* stridedslice = context.graph_view->GetNode("stridedslice");
3573 ASSERT_NE(stridedslice, nullptr);
3574 TF_ASSERT_OK(stridedslice_transposer.TransposeNode(&context, stridedslice));
3575
3576 auto* input_transpose_node = context.graph_view->GetNode(
3577 "stridedslice-0-TransposeNHWCToNCHW-LayoutOptimizer");
3578 ASSERT_NE(input_transpose_node, nullptr);
3579 ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2);
3580 VerifyRegularFaninMatch(input_transpose_node, 0,
3581 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3582
3583 auto* begin_node = context.graph_view->GetNode(
3584 "stridedslice-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3585 ASSERT_NE(begin_node, nullptr);
3586 auto* end_node = context.graph_view->GetNode(
3587 "stridedslice-2-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3588 ASSERT_NE(end_node, nullptr);
3589 auto* strides_node = context.graph_view->GetNode(
3590 "stridedslice-3-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3591 ASSERT_NE(strides_node, nullptr);
3592
3593 auto* updated_stridedslice_node = context.graph_view->GetNode("stridedslice");
3594 ASSERT_NE(updated_stridedslice_node, nullptr);
3595 ASSERT_EQ(updated_stridedslice_node->NumRegularFanins(), 4);
3596 VerifyRegularFaninMatch(updated_stridedslice_node, 0,
3597 input_transpose_node->GetName(), 0);
3598 VerifyRegularFaninMatch(updated_stridedslice_node, 1, begin_node->GetName(),
3599 0);
3600 VerifyRegularFaninMatch(updated_stridedslice_node, 2, end_node->GetName(), 0);
3601 VerifyRegularFaninMatch(updated_stridedslice_node, 3, strides_node->GetName(),
3602 0);
3603 const auto* begin_mask_attr =
3604 updated_stridedslice_node->GetAttr("begin_mask");
3605 ASSERT_NE(begin_mask_attr, nullptr);
3606 EXPECT_EQ(begin_mask_attr->i(), 0x7);
3607 const auto* end_mask_attr = updated_stridedslice_node->GetAttr("end_mask");
3608 ASSERT_NE(end_mask_attr, nullptr);
3609 EXPECT_EQ(end_mask_attr->i(), 0xD);
3610
3611 auto* output_transpose_node = context.graph_view->GetNode(
3612 "stridedslice-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3613 ASSERT_NE(output_transpose_node, nullptr);
3614
3615 auto* z_output_node = context.graph_view->GetNode("z");
3616 ASSERT_NE(z_output_node, nullptr);
3617 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3618 VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
3619 0);
3620 }
3621
TEST_F(TransposerTest,StridedSliceTransposerEllipsisMaskPresent)3622 TEST_F(TransposerTest, StridedSliceTransposerEllipsisMaskPresent) {
3623 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3624 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3625 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3626 GrapplerItem item;
3627 Scope scope = Scope::NewRootScope();
3628
3629 auto input =
3630 ops::RandomUniform(scope.WithOpName("input"),
3631 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3632 auto filter =
3633 ops::RandomUniform(scope.WithOpName("filter"),
3634 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3635 Output conv2d = ops::Conv2D(
3636 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3637 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3638
3639 auto attrs =
3640 ops::StridedSlice::Attrs().BeginMask(0xB).EndMask(0x7).EllipsisMask(0x2);
3641
3642 auto begin = ops::Const(scope.WithOpName("begin"), {2, 0, 2, 1}, {4});
3643 auto end = ops::Const(scope.WithOpName("end"), {34, 4, 3, 1}, {4});
3644 auto strides = ops::Const(scope.WithOpName("strides"), {7, 2, 1, 1}, {4});
3645
3646 auto strided_slice_op = ops::StridedSlice(
3647 scope.WithOpName("stridedslice").WithDevice("/device:GPU:0"), conv2d,
3648 begin, end, strides, attrs);
3649 auto z = ops::Identity(scope.WithOpName("z"), strided_slice_op);
3650 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3651
3652 TransposeContext context;
3653 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3654 item, virtual_cluster_.get(), &context));
3655 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3656
3657 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3658 auto* c2d = context.graph_view->GetNode("conv2d");
3659 ASSERT_NE(c2d, nullptr);
3660 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3661
3662 StridedSliceTransposer stridedslice_transposer;
3663 auto* stridedslice = context.graph_view->GetNode("stridedslice");
3664 ASSERT_NE(stridedslice, nullptr);
3665 TF_ASSERT_OK(stridedslice_transposer.TransposeNode(&context, stridedslice));
3666
3667 // Expect StridedSlice Node to remain unchanged because of the ellipsis mask.
3668 auto* updated_stridedslice_node = context.graph_view->GetNode("stridedslice");
3669 ASSERT_NE(updated_stridedslice_node, nullptr);
3670 ASSERT_EQ(updated_stridedslice_node->NumRegularFanins(), 4);
3671 VerifyRegularFaninMatch(updated_stridedslice_node, 0,
3672 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3673 VerifyRegularFaninMatch(updated_stridedslice_node, 1, "begin", 0);
3674 VerifyRegularFaninMatch(updated_stridedslice_node, 2, "end", 0);
3675 VerifyRegularFaninMatch(updated_stridedslice_node, 3, "strides", 0);
3676
3677 auto* z_output_node = context.graph_view->GetNode("z");
3678 ASSERT_NE(z_output_node, nullptr);
3679 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3680 VerifyRegularFaninMatch(z_output_node, 0,
3681 updated_stridedslice_node->GetName(), 0);
3682 }
3683
TEST_F(TransposerTest,StridedSliceTransposerConstFaninBadRank)3684 TEST_F(TransposerTest, StridedSliceTransposerConstFaninBadRank) {
3685 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3686 GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
3687 #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
3688 GrapplerItem item;
3689 Scope scope = Scope::NewRootScope();
3690
3691 auto input =
3692 ops::RandomUniform(scope.WithOpName("input"),
3693 {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
3694 auto filter =
3695 ops::RandomUniform(scope.WithOpName("filter"),
3696 {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
3697 Output conv2d = ops::Conv2D(
3698 scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
3699 {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
3700
3701 auto attrs = ops::StridedSlice::Attrs().BeginMask(0xB).EndMask(0x7);
3702
3703 auto begin = ops::Const(scope.WithOpName("begin"), {2, 0, 2}, {3});
3704 auto end = ops::Const(scope.WithOpName("end"), {34, 4, 3}, {3});
3705 auto strides = ops::Const(scope.WithOpName("strides"), {7, 2, 1}, {3});
3706
3707 auto strided_slice_op = ops::StridedSlice(
3708 scope.WithOpName("stridedslice").WithDevice("/device:GPU:0"), conv2d,
3709 begin, end, strides, attrs);
3710 auto z = ops::Identity(scope.WithOpName("z"), strided_slice_op);
3711 TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
3712
3713 TransposeContext context;
3714 TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
3715 item, virtual_cluster_.get(), &context));
3716 context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
3717
3718 DefaultLayoutSensitiveOpTransposer conv2d_transposer;
3719 auto* c2d = context.graph_view->GetNode("conv2d");
3720 ASSERT_NE(c2d, nullptr);
3721 TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
3722
3723 StridedSliceTransposer stridedslice_transposer;
3724 auto* stridedslice = context.graph_view->GetNode("stridedslice");
3725 ASSERT_NE(stridedslice, nullptr);
3726 TF_ASSERT_OK(stridedslice_transposer.TransposeNode(&context, stridedslice));
3727
3728 auto* input_transpose_node = context.graph_view->GetNode(
3729 "stridedslice-0-TransposeNHWCToNCHW-LayoutOptimizer");
3730 ASSERT_EQ(input_transpose_node, nullptr);
3731
3732 auto* begin_node = context.graph_view->GetNode(
3733 "stridedslice-1-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3734 ASSERT_EQ(begin_node, nullptr);
3735 auto* end_node = context.graph_view->GetNode(
3736 "stridedslice-2-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3737 ASSERT_EQ(end_node, nullptr);
3738 auto* strides_node = context.graph_view->GetNode(
3739 "stridedslice-3-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
3740 ASSERT_EQ(strides_node, nullptr);
3741
3742 auto* updated_stridedslice_node = context.graph_view->GetNode("stridedslice");
3743 ASSERT_NE(updated_stridedslice_node, nullptr);
3744 ASSERT_EQ(updated_stridedslice_node->NumRegularFanins(), 4);
3745 VerifyRegularFaninMatch(updated_stridedslice_node, 0,
3746 "conv2d-0-0-TransposeNCHWToNHWC-LayoutOptimizer", 0);
3747 VerifyRegularFaninMatch(updated_stridedslice_node, 1, "begin", 0);
3748 VerifyRegularFaninMatch(updated_stridedslice_node, 2, "end", 0);
3749 VerifyRegularFaninMatch(updated_stridedslice_node, 3, "strides", 0);
3750 const auto* begin_mask_attr =
3751 updated_stridedslice_node->GetAttr("begin_mask");
3752 ASSERT_NE(begin_mask_attr, nullptr);
3753 EXPECT_EQ(begin_mask_attr->i(), 0xB);
3754 const auto* end_mask_attr = updated_stridedslice_node->GetAttr("end_mask");
3755 ASSERT_NE(end_mask_attr, nullptr);
3756 EXPECT_EQ(end_mask_attr->i(), 0x7);
3757
3758 auto* output_transpose_node = context.graph_view->GetNode(
3759 "stridedslice-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
3760 ASSERT_EQ(output_transpose_node, nullptr);
3761
3762 auto* z_output_node = context.graph_view->GetNode("z");
3763 ASSERT_NE(z_output_node, nullptr);
3764 ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
3765 VerifyRegularFaninMatch(z_output_node, 0,
3766 updated_stridedslice_node->GetName(), 0);
3767 }
3768
TEST_F(TransposerTest,ReduceTransposerKeepDims)3769 TEST_F(TransposerTest, ReduceTransposerKeepDims) {
3770 ReduceTransposerKeepDims<int32>();
3771 ReduceTransposerKeepDims<int64_t>();
3772 }
3773
TEST_F(TransposerTest,ReduceTransposerValidAxisNode)3774 TEST_F(TransposerTest, ReduceTransposerValidAxisNode) {
3775 ReduceTransposerValidAxisNode<int32>();
3776 ReduceTransposerValidAxisNode<int64_t>();
3777 }
3778
TEST(PermutationTest,PermutesVector)3779 TEST(PermutationTest, PermutesVector) {
3780 std::vector<int64_t> input{32, 16, 8, 4};
3781 std::vector<int64_t> expected{4, 8, 16, 32};
3782 TF_ASSERT_OK(PermuteSingle("test", {3, 2, 1, 0}, &input));
3783 ASSERT_EQ(input.size(), 4);
3784 for (int i = 0; i < input.size(); ++i) {
3785 EXPECT_EQ(input[i], expected[i]);
3786 }
3787 }
3788
TEST(PermutationTest,PermutesRepeatedField)3789 TEST(PermutationTest, PermutesRepeatedField) {
3790 TensorShapeProto input_shape = MakeTensorShapeFromDimensions({1, 2, 3, 4});
3791 TensorShapeProto expected_shape = MakeTensorShapeFromDimensions({1, 4, 2, 3});
3792
3793 TF_ASSERT_OK(PermuteSingle("test", {0, 3, 1, 2}, input_shape.mutable_dim()));
3794 EXPECT_EQ(input_shape.DebugString(), expected_shape.DebugString());
3795 }
3796
TEST(PermutationTest,PermutesDoubleRepeatedField)3797 TEST(PermutationTest, PermutesDoubleRepeatedField) {
3798 {
3799 // NHWC -> NCHW
3800 TensorShapeProto input =
3801 MakeTensorShapeFromDimensions({1, 2, 3, 4, 5, 6, 7, 8});
3802 TensorShapeProto expected =
3803 MakeTensorShapeFromDimensions({1, 2, 7, 8, 3, 4, 5, 6});
3804
3805 TF_ASSERT_OK(PermuteDouble("test", {0, 3, 1, 2}, input.mutable_dim()));
3806 EXPECT_EQ(input.DebugString(), expected.DebugString());
3807 }
3808 {
3809 // NCHW -> NHWC
3810 TensorShapeProto input =
3811 MakeTensorShapeFromDimensions({1, 2, 3, 4, 5, 6, 7, 8});
3812 TensorShapeProto expected =
3813 MakeTensorShapeFromDimensions({1, 2, 5, 6, 7, 8, 3, 4});
3814 TF_ASSERT_OK(PermuteDouble("test", {0, 2, 3, 1}, input.mutable_dim()));
3815 EXPECT_EQ(input.DebugString(), expected.DebugString());
3816 }
3817 }
3818
TEST(PermutationTest,PermutesDataFormat)3819 TEST(PermutationTest, PermutesDataFormat) {
3820 string input = "NHWC";
3821 string expected = "NCHW";
3822 TF_ASSERT_OK(PermuteSingle("test", {0, 3, 1, 2}, &input));
3823 EXPECT_EQ(input, expected);
3824 }
3825
TEST(PermutationTest,PermutesString)3826 TEST(PermutationTest, PermutesString) {
3827 string input = "ABCD";
3828 string expected = "ACBD";
3829 TF_ASSERT_OK(PermuteSingle("test", {0, 2, 1, 3}, &input));
3830 EXPECT_EQ(input, expected);
3831 }
3832
TEST(PermutationTest,GetNHWCToNCHWPermutation)3833 TEST(PermutationTest, GetNHWCToNCHWPermutation) {
3834 string src_format = "NHWC";
3835 absl::flat_hash_map<char, int> src_dim_indices =
3836 GetDimensionIndices(src_format);
3837 EXPECT_EQ(src_dim_indices.size(), 4);
3838 EXPECT_EQ(src_dim_indices['N'], 0);
3839 EXPECT_EQ(src_dim_indices['H'], 1);
3840 EXPECT_EQ(src_dim_indices['W'], 2);
3841 EXPECT_EQ(src_dim_indices['C'], 3);
3842 string dst_format = "NCHW";
3843 std::vector<int> permutation = GetPermutation(src_dim_indices, dst_format);
3844 ASSERT_EQ(permutation.size(), 4);
3845 EXPECT_EQ(permutation[0], 0);
3846 EXPECT_EQ(permutation[1], 3);
3847 EXPECT_EQ(permutation[2], 1);
3848 EXPECT_EQ(permutation[3], 2);
3849 }
3850
TEST(PermutationTest,GetNCHWToNHWCPermutation)3851 TEST(PermutationTest, GetNCHWToNHWCPermutation) {
3852 string src_format = "NCHW";
3853 absl::flat_hash_map<char, int> src_dim_indices =
3854 GetDimensionIndices(src_format);
3855 EXPECT_EQ(src_dim_indices.size(), 4);
3856 EXPECT_EQ(src_dim_indices['N'], 0);
3857 EXPECT_EQ(src_dim_indices['C'], 1);
3858 EXPECT_EQ(src_dim_indices['H'], 2);
3859 EXPECT_EQ(src_dim_indices['W'], 3);
3860 string dst_format = "NHWC";
3861 std::vector<int> permutation = GetPermutation(src_dim_indices, dst_format);
3862 ASSERT_EQ(permutation.size(), 4);
3863 EXPECT_EQ(permutation[0], 0);
3864 EXPECT_EQ(permutation[1], 2);
3865 EXPECT_EQ(permutation[2], 3);
3866 EXPECT_EQ(permutation[3], 1);
3867 }
3868
3869 // TODO(yanzha): Add frame related tests.
3870 } // namespace
3871 } // namespace grappler
3872 } // namespace tensorflow
3873