xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/util.h"
27 
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 
30 namespace xla {
31 namespace cpu {
32 
33 using ::testing::ElementsAre;
34 
35 class ConvCanonicalizationTest : public HloTestBase {
36  public:
ConvCanonicalizationTest()37   ConvCanonicalizationTest() {
38     for (int i = 0; i < 2; ++i) {
39       auto dim = conv_window_.add_dimensions();
40       dim->set_size(kWindowSize);
41       dim->set_stride(1);
42       dim->set_padding_low(0);
43       dim->set_padding_high(0);
44       dim->set_window_dilation(1);
45       dim->set_base_dilation(1);
46     }
47   }
48 
49  protected:
50   Window conv_window_;
51 
52   static constexpr int kBatchSize = 50;
53   static constexpr int kInputSize = 28;
54   static constexpr int kWindowSize = 5;
55   static constexpr int kInputFeatureCount = 32;
56   static constexpr int kOutputFeatureCount = 64;
57 };
58 
TEST_F(ConvCanonicalizationTest,NonCanonicalToCanonical)59 TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
60   auto builder = HloComputation::Builder(TestName());
61   // The input dimensions are in CNHW order.
62   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
63       LiteralUtil::CreateR4FromArray4D(Array4D<float>(
64           kInputFeatureCount, kBatchSize, kInputSize, kInputSize))));
65   // The kernel dimensions are in OIHW order.
66   auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
67       LiteralUtil::CreateR4FromArray4D(Array4D<float>(
68           kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize))));
69 
70   ConvolutionDimensionNumbers dnums;
71   dnums.set_input_batch_dimension(1);
72   dnums.set_output_batch_dimension(1);
73   dnums.add_input_spatial_dimensions(2);
74   dnums.add_output_spatial_dimensions(2);
75   dnums.add_input_spatial_dimensions(3);
76   dnums.add_output_spatial_dimensions(3);
77   dnums.set_input_feature_dimension(0);
78   dnums.set_output_feature_dimension(0);
79   dnums.add_kernel_spatial_dimensions(2);
80   dnums.add_kernel_spatial_dimensions(3);
81   dnums.set_kernel_input_feature_dimension(1);
82   dnums.set_kernel_output_feature_dimension(0);
83   auto output_size = kInputSize - kWindowSize + 1;
84   builder.AddInstruction(HloInstruction::CreateConvolve(
85       ShapeUtil::MakeShape(
86           F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}),
87       input, kernel, /*feature_group_count=*/1, /*batch_group_count=*/1,
88       conv_window_, dnums, DefaultPrecisionConfig(2)));
89 
90   auto module = CreateNewVerifiedModule();
91   HloComputation* entry_computation =
92       module->AddEntryComputation(builder.Build());
93 
94   cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features(
95       [](int64_t shape_size) {
96         return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
97       });
98   ConvCanonicalization conv_canonicalization(&target_machine_features);
99   EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
100 
101   const HloInstruction* output_reshape = entry_computation->root_instruction();
102   EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
103   const HloInstruction* canonical_conv = output_reshape->operand(0);
104   EXPECT_EQ(HloOpcode::kConvolution, canonical_conv->opcode());
105   const HloInstruction* input_reshape = canonical_conv->operand(0);
106   EXPECT_EQ(HloOpcode::kTranspose, input_reshape->opcode());
107   const HloInstruction* kernel_reshape = canonical_conv->operand(1);
108   EXPECT_EQ(HloOpcode::kTranspose, kernel_reshape->opcode());
109 
110   // The input is in CNHW order. input_reshape should produce
111   // NHWC for the convolution to hit the Eigen fast path.
112   EXPECT_THAT(input_reshape->dimensions(), ElementsAre(1, 2, 3, 0));
113   // The kernel is in OIHW order. kernel_reshape should produce
114   // HWIO for the convolution to hit the Eigen fast path.
115   EXPECT_THAT(kernel_reshape->dimensions(), ElementsAre(2, 3, 1, 0));
116   // The output of the canonical convolution is in NHWC order (the same as
117   // input_reshape's order). output_reshape should restore that order to the
118   // order of the computation root (CNHW).
119   EXPECT_THAT(output_reshape->dimensions(), ElementsAre(3, 0, 1, 2));
120 }
121 
TEST_F(ConvCanonicalizationTest,CanonicalStaysTheSame)122 TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
123   auto builder = HloComputation::Builder(TestName());
124   // The input dimensions are in NHWC order.
125   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
126       LiteralUtil::CreateR4FromArray4D(Array4D<float>(
127           kBatchSize, kInputSize, kInputSize, kInputFeatureCount))));
128   // The kernel dimensions are in HWIO order.
129   auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
130       LiteralUtil::CreateR4FromArray4D(Array4D<float>(
131           kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount))));
132 
133   ConvolutionDimensionNumbers dnums;
134   dnums.set_input_batch_dimension(0);
135   dnums.set_output_batch_dimension(0);
136   dnums.add_input_spatial_dimensions(1);
137   dnums.add_output_spatial_dimensions(1);
138   dnums.add_input_spatial_dimensions(2);
139   dnums.add_output_spatial_dimensions(2);
140   dnums.set_input_feature_dimension(3);
141   dnums.set_output_feature_dimension(3);
142   dnums.add_kernel_spatial_dimensions(0);
143   dnums.add_kernel_spatial_dimensions(1);
144   dnums.set_kernel_input_feature_dimension(2);
145   dnums.set_kernel_output_feature_dimension(3);
146   auto output_size = kInputSize - kWindowSize + 1;
147   builder.AddInstruction(HloInstruction::CreateConvolve(
148       ShapeUtil::MakeShape(
149           F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}),
150       input, kernel, /*feature_group_count=*/1, /*batch_group_count=*/1,
151       conv_window_, dnums, DefaultPrecisionConfig(2)));
152 
153   auto module = CreateNewVerifiedModule();
154   module->AddEntryComputation(builder.Build());
155 
156   cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features(
157       [](int64_t shape_size) {
158         return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
159       });
160   ConvCanonicalization conv_canonicalization(&target_machine_features);
161   EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
162 }
163 
164 }  // namespace cpu
165 }  // namespace xla
166