1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
17
18 #include <sys/types.h>
19
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25
26 namespace xla {
27
ConvertPrecisionConfig(const PrecisionConfig * config,mlir::Builder * builder)28 mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
29 mlir::Builder* builder) {
30 if (!config) return {};
31
32 // TODO(b/129709049) The HLO text format elides this in the all DEFAULT
33 // case and the parser sticks it in. Maybe we should too.
34 llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
35
36 for (auto prec : config->operand_precision()) {
37 operand_precision_attrs.push_back(mlir::mhlo::PrecisionAttr::get(
38 builder->getContext(),
39 mlir::mhlo::symbolizePrecision(PrecisionConfig_Precision_Name(prec))
40 .getValue()));
41 }
42 return builder->getArrayAttr(operand_precision_attrs);
43 }
44
45 // Converts the gather dimensions to attributes.
ConvertGatherDimensionNumbers(const xla::GatherDimensionNumbers & dnums,mlir::Builder * builder)46 mlir::mhlo::GatherDimensionNumbersAttr ConvertGatherDimensionNumbers(
47 const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) {
48 std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
49 dnums.offset_dims().end());
50 std::vector<int64_t> collapsed_slice_dims(
51 dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
52 std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
53 dnums.start_index_map().end());
54 return mlir::mhlo::GatherDimensionNumbersAttr::get(
55 builder->getContext(), offset_dims, collapsed_slice_dims, start_index_map,
56 dnums.index_vector_dim());
57 }
58
ConvertScatterDimensionNumbers(const xla::ScatterDimensionNumbers & dnums,mlir::Builder * builder)59 mlir::mhlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers(
60 const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) {
61 std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
62 dnums.update_window_dims().end());
63 std::vector<int64_t> inserted_window_dims(
64 dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
65 std::vector<int64_t> scatter_dims_to_operand_dims(
66 dnums.scatter_dims_to_operand_dims().begin(),
67 dnums.scatter_dims_to_operand_dims().end());
68 return mlir::mhlo::ScatterDimensionNumbersAttr::get(
69 builder->getContext(), update_window_dims, inserted_window_dims,
70 scatter_dims_to_operand_dims, dnums.index_vector_dim());
71 }
72
ConvertDotDimensionNumbers(const DotDimensionNumbers & dnums,mlir::Builder * builder)73 mlir::mhlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers(
74 const DotDimensionNumbers& dnums, mlir::Builder* builder) {
75 auto arrayref = [](absl::Span<const int64_t> array) {
76 return llvm::ArrayRef<int64_t>{array.data(), array.size()};
77 };
78 return mlir::mhlo::DotDimensionNumbersAttr::get(
79 builder->getContext(), arrayref(dnums.lhs_batch_dimensions()),
80 arrayref(dnums.rhs_batch_dimensions()),
81 arrayref(dnums.lhs_contracting_dimensions()),
82 arrayref(dnums.rhs_contracting_dimensions()));
83 }
84
ConvertConvDimensionNumbers(const xla::ConvolutionDimensionNumbers & dnums,mlir::Builder * builder)85 mlir::mhlo::ConvDimensionNumbersAttr ConvertConvDimensionNumbers(
86 const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) {
87 auto arrayref = [](absl::Span<const int64_t> array) {
88 return llvm::ArrayRef<int64_t>{array.data(), array.size()};
89 };
90 llvm::SmallVector<int64_t, 4> input_spatial_dims(
91 dnums.input_spatial_dimensions().begin(),
92 dnums.input_spatial_dimensions().end());
93 llvm::SmallVector<int64_t, 4> kernel_spatial_dims(
94 dnums.kernel_spatial_dimensions().begin(),
95 dnums.kernel_spatial_dimensions().end());
96 llvm::SmallVector<int64_t, 4> output_spatial_dims(
97 dnums.output_spatial_dimensions().begin(),
98 dnums.output_spatial_dimensions().end());
99 return mlir::mhlo::ConvDimensionNumbersAttr::get(
100 builder->getContext(), dnums.input_batch_dimension(),
101 dnums.input_feature_dimension(),
102 arrayref(dnums.input_spatial_dimensions()),
103 dnums.kernel_input_feature_dimension(),
104 dnums.kernel_output_feature_dimension(),
105 arrayref(dnums.kernel_spatial_dimensions()),
106 dnums.output_batch_dimension(), dnums.output_feature_dimension(),
107 arrayref(dnums.output_spatial_dimensions()));
108 }
109
ConvertFftType(FftType type)110 StatusOr<mlir::mhlo::FftType> ConvertFftType(FftType type) {
111 switch (type) {
112 case FftType::FFT:
113 return mlir::mhlo::FftType::FFT;
114 case FftType::IFFT:
115 return mlir::mhlo::FftType::IFFT;
116 case FftType::RFFT:
117 return mlir::mhlo::FftType::RFFT;
118 case FftType::IRFFT:
119 return mlir::mhlo::FftType::IRFFT;
120 default:
121 return InvalidArgument("Unknown FFT type enum value #%d", type);
122 }
123 }
124
ConvertTranspose(xla::TriangularSolveOptions_Transpose transpose)125 StatusOr<mlir::mhlo::Transpose> ConvertTranspose(
126 xla::TriangularSolveOptions_Transpose transpose) {
127 switch (transpose) {
128 case TriangularSolveOptions::NO_TRANSPOSE:
129 return mlir::mhlo::Transpose::NO_TRANSPOSE;
130 case TriangularSolveOptions::TRANSPOSE:
131 return mlir::mhlo::Transpose::TRANSPOSE;
132 case TriangularSolveOptions::ADJOINT:
133 return mlir::mhlo::Transpose::ADJOINT;
134 case TriangularSolveOptions::TRANSPOSE_INVALID:
135 return mlir::mhlo::Transpose::TRANSPOSE_INVALID;
136 default:
137 return InvalidArgument("Unknown transpose enum value #%d", transpose);
138 }
139 }
140
ConvertCustomCallApiVersion(xla::CustomCallApiVersion api_version)141 StatusOr<mlir::mhlo::CustomCallApiVersion> ConvertCustomCallApiVersion(
142 xla::CustomCallApiVersion api_version) {
143 switch (api_version) {
144 case xla::CustomCallApiVersion::API_VERSION_UNSPECIFIED:
145 return mlir::mhlo::CustomCallApiVersion::API_VERSION_UNSPECIFIED;
146 case xla::CustomCallApiVersion::API_VERSION_ORIGINAL:
147 return mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL;
148 case xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
149 return mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING;
150 case xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED:
151 return mlir::mhlo::CustomCallApiVersion::
152 API_VERSION_STATUS_RETURNING_UNIFIED;
153 default:
154 return InvalidArgument("Unknown CustomCallApiVersion enum value #%d (%s)",
155 api_version,
156 xla::CustomCallApiVersion_Name(api_version));
157 }
158 }
159
ExtractLayoutsFromShapes(const absl::Span<const Shape> shapes_with_layouts,mlir::Builder * builder)160 StatusOr<mlir::ArrayAttr> ExtractLayoutsFromShapes(
161 const absl::Span<const Shape> shapes_with_layouts, mlir::Builder* builder) {
162 std::vector<mlir::Attribute> layouts;
163 for (auto& shape_and_layout : shapes_with_layouts) {
164 if (shape_and_layout.IsTuple())
165 return tensorflow::errors::Unimplemented(
166 "Layout support for nested tuples is not implemented.");
167 // XLA can have invalid layout for certain values (such as token types).
168 // These are imported as empty layout in MHLO.
169 if (!shape_and_layout.IsArray()) {
170 layouts.push_back(builder->getIndexTensorAttr({}));
171 continue;
172 }
173
174 // Only a subset of layout specification in XLA is supported in MHLO
175 // currently. The layout has to be dense, and only specify the order of
176 // dimensions. Sparse, tiled layout or non-default memory space fields
177 // cannot be expressed in MHLO layout yet.
178 if (!xla::LayoutUtil::IsDenseArray(shape_and_layout)) {
179 return tensorflow::errors::Unimplemented(
180 "Only dense arrays are supported.");
181 }
182
183 const xla::Layout& xla_layout = shape_and_layout.layout();
184 if (!xla_layout.tiles().empty())
185 return tensorflow::errors::Unimplemented(
186 "Tiled layout is not supported yet");
187 if (xla_layout.memory_space() != xla::Layout::kDefaultMemorySpace)
188 return tensorflow::errors::Unimplemented(
189 "Layout support for non-default memory space is not yet implemented");
190
191 llvm::SmallVector<int64_t> layout;
192 for (int64_t dim_index : xla_layout.minor_to_major())
193 layout.push_back(dim_index);
194 layouts.push_back(builder->getIndexTensorAttr(layout));
195 }
196 return builder->getArrayAttr(layouts);
197 }
198
ExtractLayoutsFromTuple(const Shape shape,mlir::Builder * builder)199 StatusOr<mlir::ArrayAttr> ExtractLayoutsFromTuple(const Shape shape,
200 mlir::Builder* builder) {
201 if (!shape.IsTuple()) return InvalidArgument("Expected shape to be Tuple");
202 return ExtractLayoutsFromShapes(shape.tuple_shapes(), builder);
203 }
204
205 } // namespace xla
206