xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/attribute_importer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
17 
18 #include <sys/types.h>
19 
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 
26 namespace xla {
27 
ConvertPrecisionConfig(const PrecisionConfig * config,mlir::Builder * builder)28 mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
29                                        mlir::Builder* builder) {
30   if (!config) return {};
31 
32   // TODO(b/129709049) The HLO text format elides this in the all DEFAULT
33   // case and the parser sticks it in. Maybe we should too.
34   llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
35 
36   for (auto prec : config->operand_precision()) {
37     operand_precision_attrs.push_back(mlir::mhlo::PrecisionAttr::get(
38         builder->getContext(),
39         mlir::mhlo::symbolizePrecision(PrecisionConfig_Precision_Name(prec))
40             .getValue()));
41   }
42   return builder->getArrayAttr(operand_precision_attrs);
43 }
44 
45 // Converts the gather dimensions to attributes.
ConvertGatherDimensionNumbers(const xla::GatherDimensionNumbers & dnums,mlir::Builder * builder)46 mlir::mhlo::GatherDimensionNumbersAttr ConvertGatherDimensionNumbers(
47     const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) {
48   std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
49                                    dnums.offset_dims().end());
50   std::vector<int64_t> collapsed_slice_dims(
51       dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
52   std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
53                                        dnums.start_index_map().end());
54   return mlir::mhlo::GatherDimensionNumbersAttr::get(
55       builder->getContext(), offset_dims, collapsed_slice_dims, start_index_map,
56       dnums.index_vector_dim());
57 }
58 
ConvertScatterDimensionNumbers(const xla::ScatterDimensionNumbers & dnums,mlir::Builder * builder)59 mlir::mhlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers(
60     const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) {
61   std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
62                                           dnums.update_window_dims().end());
63   std::vector<int64_t> inserted_window_dims(
64       dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
65   std::vector<int64_t> scatter_dims_to_operand_dims(
66       dnums.scatter_dims_to_operand_dims().begin(),
67       dnums.scatter_dims_to_operand_dims().end());
68   return mlir::mhlo::ScatterDimensionNumbersAttr::get(
69       builder->getContext(), update_window_dims, inserted_window_dims,
70       scatter_dims_to_operand_dims, dnums.index_vector_dim());
71 }
72 
ConvertDotDimensionNumbers(const DotDimensionNumbers & dnums,mlir::Builder * builder)73 mlir::mhlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers(
74     const DotDimensionNumbers& dnums, mlir::Builder* builder) {
75   auto arrayref = [](absl::Span<const int64_t> array) {
76     return llvm::ArrayRef<int64_t>{array.data(), array.size()};
77   };
78   return mlir::mhlo::DotDimensionNumbersAttr::get(
79       builder->getContext(), arrayref(dnums.lhs_batch_dimensions()),
80       arrayref(dnums.rhs_batch_dimensions()),
81       arrayref(dnums.lhs_contracting_dimensions()),
82       arrayref(dnums.rhs_contracting_dimensions()));
83 }
84 
ConvertConvDimensionNumbers(const xla::ConvolutionDimensionNumbers & dnums,mlir::Builder * builder)85 mlir::mhlo::ConvDimensionNumbersAttr ConvertConvDimensionNumbers(
86     const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) {
87   auto arrayref = [](absl::Span<const int64_t> array) {
88     return llvm::ArrayRef<int64_t>{array.data(), array.size()};
89   };
90   llvm::SmallVector<int64_t, 4> input_spatial_dims(
91       dnums.input_spatial_dimensions().begin(),
92       dnums.input_spatial_dimensions().end());
93   llvm::SmallVector<int64_t, 4> kernel_spatial_dims(
94       dnums.kernel_spatial_dimensions().begin(),
95       dnums.kernel_spatial_dimensions().end());
96   llvm::SmallVector<int64_t, 4> output_spatial_dims(
97       dnums.output_spatial_dimensions().begin(),
98       dnums.output_spatial_dimensions().end());
99   return mlir::mhlo::ConvDimensionNumbersAttr::get(
100       builder->getContext(), dnums.input_batch_dimension(),
101       dnums.input_feature_dimension(),
102       arrayref(dnums.input_spatial_dimensions()),
103       dnums.kernel_input_feature_dimension(),
104       dnums.kernel_output_feature_dimension(),
105       arrayref(dnums.kernel_spatial_dimensions()),
106       dnums.output_batch_dimension(), dnums.output_feature_dimension(),
107       arrayref(dnums.output_spatial_dimensions()));
108 }
109 
ConvertFftType(FftType type)110 StatusOr<mlir::mhlo::FftType> ConvertFftType(FftType type) {
111   switch (type) {
112     case FftType::FFT:
113       return mlir::mhlo::FftType::FFT;
114     case FftType::IFFT:
115       return mlir::mhlo::FftType::IFFT;
116     case FftType::RFFT:
117       return mlir::mhlo::FftType::RFFT;
118     case FftType::IRFFT:
119       return mlir::mhlo::FftType::IRFFT;
120     default:
121       return InvalidArgument("Unknown FFT type enum value #%d", type);
122   }
123 }
124 
ConvertTranspose(xla::TriangularSolveOptions_Transpose transpose)125 StatusOr<mlir::mhlo::Transpose> ConvertTranspose(
126     xla::TriangularSolveOptions_Transpose transpose) {
127   switch (transpose) {
128     case TriangularSolveOptions::NO_TRANSPOSE:
129       return mlir::mhlo::Transpose::NO_TRANSPOSE;
130     case TriangularSolveOptions::TRANSPOSE:
131       return mlir::mhlo::Transpose::TRANSPOSE;
132     case TriangularSolveOptions::ADJOINT:
133       return mlir::mhlo::Transpose::ADJOINT;
134     case TriangularSolveOptions::TRANSPOSE_INVALID:
135       return mlir::mhlo::Transpose::TRANSPOSE_INVALID;
136     default:
137       return InvalidArgument("Unknown transpose enum value #%d", transpose);
138   }
139 }
140 
ConvertCustomCallApiVersion(xla::CustomCallApiVersion api_version)141 StatusOr<mlir::mhlo::CustomCallApiVersion> ConvertCustomCallApiVersion(
142     xla::CustomCallApiVersion api_version) {
143   switch (api_version) {
144     case xla::CustomCallApiVersion::API_VERSION_UNSPECIFIED:
145       return mlir::mhlo::CustomCallApiVersion::API_VERSION_UNSPECIFIED;
146     case xla::CustomCallApiVersion::API_VERSION_ORIGINAL:
147       return mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL;
148     case xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
149       return mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING;
150     case xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED:
151       return mlir::mhlo::CustomCallApiVersion::
152           API_VERSION_STATUS_RETURNING_UNIFIED;
153     default:
154       return InvalidArgument("Unknown CustomCallApiVersion enum value #%d (%s)",
155                              api_version,
156                              xla::CustomCallApiVersion_Name(api_version));
157   }
158 }
159 
ExtractLayoutsFromShapes(const absl::Span<const Shape> shapes_with_layouts,mlir::Builder * builder)160 StatusOr<mlir::ArrayAttr> ExtractLayoutsFromShapes(
161     const absl::Span<const Shape> shapes_with_layouts, mlir::Builder* builder) {
162   std::vector<mlir::Attribute> layouts;
163   for (auto& shape_and_layout : shapes_with_layouts) {
164     if (shape_and_layout.IsTuple())
165       return tensorflow::errors::Unimplemented(
166           "Layout support for nested tuples is not implemented.");
167     // XLA can have invalid layout for certain values (such as token types).
168     // These are imported as empty layout in MHLO.
169     if (!shape_and_layout.IsArray()) {
170       layouts.push_back(builder->getIndexTensorAttr({}));
171       continue;
172     }
173 
174     // Only a subset of layout specification in XLA is supported in MHLO
175     // currently. The layout has to be dense, and only specify the order of
176     // dimensions. Sparse, tiled layout or non-default memory space fields
177     // cannot be expressed in MHLO layout yet.
178     if (!xla::LayoutUtil::IsDenseArray(shape_and_layout)) {
179       return tensorflow::errors::Unimplemented(
180           "Only dense arrays are supported.");
181     }
182 
183     const xla::Layout& xla_layout = shape_and_layout.layout();
184     if (!xla_layout.tiles().empty())
185       return tensorflow::errors::Unimplemented(
186           "Tiled layout is not supported yet");
187     if (xla_layout.memory_space() != xla::Layout::kDefaultMemorySpace)
188       return tensorflow::errors::Unimplemented(
189           "Layout support for non-default memory space is not yet implemented");
190 
191     llvm::SmallVector<int64_t> layout;
192     for (int64_t dim_index : xla_layout.minor_to_major())
193       layout.push_back(dim_index);
194     layouts.push_back(builder->getIndexTensorAttr(layout));
195   }
196   return builder->getArrayAttr(layouts);
197 }
198 
ExtractLayoutsFromTuple(const Shape shape,mlir::Builder * builder)199 StatusOr<mlir::ArrayAttr> ExtractLayoutsFromTuple(const Shape shape,
200                                                   mlir::Builder* builder) {
201   if (!shape.IsTuple()) return InvalidArgument("Expected shape to be Tuple");
202   return ExtractLayoutsFromShapes(shape.tuple_shapes(), builder);
203 }
204 
205 }  // namespace xla
206