1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
17
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/util.h"
27
28 #include "tensorflow/compiler/xla/test_helpers.h"
29
30 namespace xla {
31 namespace cpu {
32
33 using ::testing::ElementsAre;
34
35 class ConvCanonicalizationTest : public HloTestBase {
36 public:
ConvCanonicalizationTest()37 ConvCanonicalizationTest() {
38 for (int i = 0; i < 2; ++i) {
39 auto dim = conv_window_.add_dimensions();
40 dim->set_size(kWindowSize);
41 dim->set_stride(1);
42 dim->set_padding_low(0);
43 dim->set_padding_high(0);
44 dim->set_window_dilation(1);
45 dim->set_base_dilation(1);
46 }
47 }
48
49 protected:
50 Window conv_window_;
51
52 static constexpr int kBatchSize = 50;
53 static constexpr int kInputSize = 28;
54 static constexpr int kWindowSize = 5;
55 static constexpr int kInputFeatureCount = 32;
56 static constexpr int kOutputFeatureCount = 64;
57 };
58
TEST_F(ConvCanonicalizationTest,NonCanonicalToCanonical)59 TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
60 auto builder = HloComputation::Builder(TestName());
61 // The input dimensions are in CNHW order.
62 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
63 LiteralUtil::CreateR4FromArray4D(Array4D<float>(
64 kInputFeatureCount, kBatchSize, kInputSize, kInputSize))));
65 // The kernel dimensions are in OIHW order.
66 auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
67 LiteralUtil::CreateR4FromArray4D(Array4D<float>(
68 kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize))));
69
70 ConvolutionDimensionNumbers dnums;
71 dnums.set_input_batch_dimension(1);
72 dnums.set_output_batch_dimension(1);
73 dnums.add_input_spatial_dimensions(2);
74 dnums.add_output_spatial_dimensions(2);
75 dnums.add_input_spatial_dimensions(3);
76 dnums.add_output_spatial_dimensions(3);
77 dnums.set_input_feature_dimension(0);
78 dnums.set_output_feature_dimension(0);
79 dnums.add_kernel_spatial_dimensions(2);
80 dnums.add_kernel_spatial_dimensions(3);
81 dnums.set_kernel_input_feature_dimension(1);
82 dnums.set_kernel_output_feature_dimension(0);
83 auto output_size = kInputSize - kWindowSize + 1;
84 builder.AddInstruction(HloInstruction::CreateConvolve(
85 ShapeUtil::MakeShape(
86 F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}),
87 input, kernel, /*feature_group_count=*/1, /*batch_group_count=*/1,
88 conv_window_, dnums, DefaultPrecisionConfig(2)));
89
90 auto module = CreateNewVerifiedModule();
91 HloComputation* entry_computation =
92 module->AddEntryComputation(builder.Build());
93
94 cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features(
95 [](int64_t shape_size) {
96 return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
97 });
98 ConvCanonicalization conv_canonicalization(&target_machine_features);
99 EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
100
101 const HloInstruction* output_reshape = entry_computation->root_instruction();
102 EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
103 const HloInstruction* canonical_conv = output_reshape->operand(0);
104 EXPECT_EQ(HloOpcode::kConvolution, canonical_conv->opcode());
105 const HloInstruction* input_reshape = canonical_conv->operand(0);
106 EXPECT_EQ(HloOpcode::kTranspose, input_reshape->opcode());
107 const HloInstruction* kernel_reshape = canonical_conv->operand(1);
108 EXPECT_EQ(HloOpcode::kTranspose, kernel_reshape->opcode());
109
110 // The input is in CNHW order. input_reshape should produce
111 // NHWC for the convolution to hit the Eigen fast path.
112 EXPECT_THAT(input_reshape->dimensions(), ElementsAre(1, 2, 3, 0));
113 // The kernel is in OIHW order. kernel_reshape should produce
114 // HWIO for the convolution to hit the Eigen fast path.
115 EXPECT_THAT(kernel_reshape->dimensions(), ElementsAre(2, 3, 1, 0));
116 // The output of the canonical convolution is in NHWC order (the same as
117 // input_reshape's order). output_reshape should restore that order to the
118 // order of the computation root (CNHW).
119 EXPECT_THAT(output_reshape->dimensions(), ElementsAre(3, 0, 1, 2));
120 }
121
TEST_F(ConvCanonicalizationTest,CanonicalStaysTheSame)122 TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
123 auto builder = HloComputation::Builder(TestName());
124 // The input dimensions are in NHWC order.
125 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
126 LiteralUtil::CreateR4FromArray4D(Array4D<float>(
127 kBatchSize, kInputSize, kInputSize, kInputFeatureCount))));
128 // The kernel dimensions are in HWIO order.
129 auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
130 LiteralUtil::CreateR4FromArray4D(Array4D<float>(
131 kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount))));
132
133 ConvolutionDimensionNumbers dnums;
134 dnums.set_input_batch_dimension(0);
135 dnums.set_output_batch_dimension(0);
136 dnums.add_input_spatial_dimensions(1);
137 dnums.add_output_spatial_dimensions(1);
138 dnums.add_input_spatial_dimensions(2);
139 dnums.add_output_spatial_dimensions(2);
140 dnums.set_input_feature_dimension(3);
141 dnums.set_output_feature_dimension(3);
142 dnums.add_kernel_spatial_dimensions(0);
143 dnums.add_kernel_spatial_dimensions(1);
144 dnums.set_kernel_input_feature_dimension(2);
145 dnums.set_kernel_output_feature_dimension(3);
146 auto output_size = kInputSize - kWindowSize + 1;
147 builder.AddInstruction(HloInstruction::CreateConvolve(
148 ShapeUtil::MakeShape(
149 F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}),
150 input, kernel, /*feature_group_count=*/1, /*batch_group_count=*/1,
151 conv_window_, dnums, DefaultPrecisionConfig(2)));
152
153 auto module = CreateNewVerifiedModule();
154 module->AddEntryComputation(builder.Build());
155
156 cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features(
157 [](int64_t shape_size) {
158 return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
159 });
160 ConvCanonicalization conv_canonicalization(&target_machine_features);
161 EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
162 }
163
164 } // namespace cpu
165 } // namespace xla
166