xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
17 
18 #include "tensorflow/compiler/xla/layout_util.h"
19 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/window_util.h"
23 
24 namespace xla {
25 namespace cpu {
26 
GetMinimumAlignmentForArray(const Shape & shape,const TargetMachineFeatures & target_machine_features)27 int64_t GetMinimumAlignmentForArray(
28     const Shape& shape, const TargetMachineFeatures& target_machine_features) {
29   CHECK(LayoutUtil::IsDenseArray(shape));
30 
31   // We don't require a layout to be set on `shape`.  This only works on CPU
32   // because we don't pad our tensors or otherwise have complicated data tiling
33   // schemes.
34 
35   int64_t allocation_size_bytes =
36       ShapeUtil::ElementsIn(shape) *
37       ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
38   return target_machine_features.minimum_alignment_for_allocation(
39       allocation_size_bytes);
40 }
41 
PotentiallyImplementedAsEigenConvolution(const HloInstruction & convolution,const TargetMachineFeatures & target_machine_features)42 bool PotentiallyImplementedAsEigenConvolution(
43     const HloInstruction& convolution,
44     const TargetMachineFeatures& target_machine_features) {
45   // The following conditions are necessary (but not sufficient) for
46   // implementing `convolution` with Eigen convolution:
47   // - the input and kernel have a non-zero number of elements.
48   // - the input is in NHWC order.
49   // - the kernel is in HWIO order.
50   //
51   // To be sufficient, certain layout constraints need to be satisfied as well.
52   const Shape& input_shape = convolution.operand(0)->shape();
53   const Shape& kernel_shape = convolution.operand(1)->shape();
54   const Shape& output_shape = convolution.shape();
55 
56   auto is_aligned = [&](const Shape& shape) {
57     return GetMinimumAlignmentForArray(shape, target_machine_features) >=
58            TargetMachineFeatures::kEigenExpectedTensorAlignment;
59   };
60 
61   if (!is_aligned(input_shape) || !is_aligned(kernel_shape) ||
62       !is_aligned(output_shape)) {
63     return false;
64   }
65 
66   if (ShapeUtil::IsZeroElementArray(input_shape) ||
67       ShapeUtil::IsZeroElementArray(kernel_shape)) {
68     return false;
69   }
70   // Make sure input and kernel has the same data type.
71   CHECK(
72       ShapeUtil::SameElementTypeIgnoringFpPrecision(input_shape, kernel_shape));
73   // TODO(b/65408531): Explore using Eigen dot for complex64 type.
74   PrimitiveType primitive_type = input_shape.element_type();
75   if (primitive_type != F16 && primitive_type != F32) {
76     return false;
77   }
78   if (window_util::HasWindowReversal(convolution.window())) {
79     return false;
80   }
81 
82   const ConvolutionDimensionNumbers& dnums =
83       convolution.convolution_dimension_numbers();
84   // Only 1D through 3D convolutions are supported at the moment.
85   const int64_t num_spatial_dims = dnums.output_spatial_dimensions_size();
86   if (num_spatial_dims < 1 || num_spatial_dims > 3) {
87     return false;
88   }
89 
90   for (int64_t i = 0; i < num_spatial_dims; ++i) {
91     if (dnums.input_spatial_dimensions(i) != i + 1) {
92       return false;
93     }
94     if (dnums.kernel_spatial_dimensions(i) != i) {
95       return false;
96     }
97     if (dnums.output_spatial_dimensions(i) != i + 1) {
98       return false;
99     }
100   }
101 
102   return dnums.input_batch_dimension() == 0 &&
103          dnums.input_feature_dimension() == input_shape.dimensions_size() - 1 &&
104          dnums.output_batch_dimension() == 0 &&
105          dnums.output_feature_dimension() ==
106              output_shape.dimensions_size() - 1 &&
107          dnums.kernel_input_feature_dimension() ==
108              kernel_shape.dimensions_size() - 2 &&
109          dnums.kernel_output_feature_dimension() ==
110              kernel_shape.dimensions_size() - 1;
111 }
112 
113 }  // namespace cpu
114 }  // namespace xla
115