1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
17
18 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
19 #include "tensorflow/compiler/xla/types.h"
20 #include "tensorflow/compiler/xla/util.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/stream_executor/dnn.h"
23
24 namespace xla {
25
ConvertConvDimensionNumbers(mlir::mhlo::ConvDimensionNumbersAttr input)26 ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
27 mlir::mhlo::ConvDimensionNumbersAttr input) {
28 ConvolutionDimensionNumbers output;
29
30 output.set_input_batch_dimension(input.getInputBatchDimension());
31 output.set_input_feature_dimension(input.getInputFeatureDimension());
32 for (auto v : input.getInputSpatialDimensions()) {
33 output.add_input_spatial_dimensions(v);
34 }
35
36 output.set_kernel_input_feature_dimension(
37 input.getKernelInputFeatureDimension());
38 output.set_kernel_output_feature_dimension(
39 input.getKernelOutputFeatureDimension());
40
41 for (auto v : input.getKernelSpatialDimensions()) {
42 output.add_kernel_spatial_dimensions(v);
43 }
44
45 output.set_output_batch_dimension(input.getOutputBatchDimension());
46 output.set_output_feature_dimension(input.getOutputFeatureDimension());
47
48 for (auto v : input.getOutputSpatialDimensions()) {
49 output.add_output_spatial_dimensions(v);
50 }
51
52 return output;
53 }
54
ConvertConvActivationMode(mlir::lmhlo_gpu::Activation activation)55 StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
56 mlir::lmhlo_gpu::Activation activation) {
57 switch (activation) {
58 case mlir::lmhlo_gpu::Activation::None:
59 return stream_executor::dnn::kNone;
60 case mlir::lmhlo_gpu::Activation::Sigmoid:
61 return stream_executor::dnn::kSigmoid;
62 case mlir::lmhlo_gpu::Activation::Tanh:
63 return stream_executor::dnn::kTanh;
64 case mlir::lmhlo_gpu::Activation::Relu:
65 return stream_executor::dnn::kRelu;
66 case mlir::lmhlo_gpu::Activation::Relu6:
67 return stream_executor::dnn::kRelu6;
68 case mlir::lmhlo_gpu::Activation::ReluX:
69 return stream_executor::dnn::kReluX;
70 case mlir::lmhlo_gpu::Activation::BandPass:
71 return stream_executor::dnn::kBandPass;
72 default:
73 return InternalError("Unexpected activation");
74 }
75 }
76
77 // Convert replica group from MLIR encoding to HLO.
78 // See HloFunctionImporter::ConvertReplicaGroups for the MLIR encoding.
ConvertReplicaGroups(mlir::DenseIntElementsAttr input)79 StatusOr<std::vector<ReplicaGroup>> ConvertReplicaGroups(
80 mlir::DenseIntElementsAttr input) {
81 mlir::RankedTensorType type =
82 input.getType().dyn_cast<mlir::RankedTensorType>();
83 if (!type || type.getRank() != 2 ||
84 !type.getElementType().isInteger(/*width=*/64)) {
85 return InternalError("Execpted replica group to be a rank 2 tensor of i64");
86 }
87 // rank 0 is num_groups, rank 1 is group size.
88 auto replica_group_values_it = input.getValues<uint64_t>().begin();
89 std::vector<ReplicaGroup> replica_groups(type.getDimSize(0));
90 for (ReplicaGroup& group : replica_groups) {
91 for (int64_t element_idx = 0; element_idx < type.getDimSize(1);
92 ++element_idx, ++replica_group_values_it) {
93 // For replica group attribute, -1 indicates padding added by
94 // HloFunctionImporter::ConvertReplicaGroups. This should always be at the
95 // end and can be dropped when converting back to XLA HLO ReplicaGroups.
96 if (*replica_group_values_it != -1) {
97 group.add_replica_ids(*replica_group_values_it);
98 }
99 }
100 }
101 return replica_groups;
102 }
103
104 // Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
105 // and source-target pairs are defined in HLO.
ConvertNx2Attribute(llvm::Optional<mlir::DenseIntElementsAttr> optional_attr)106 StatusOr<std::vector<std::pair<int64_t, int64_t>>> ConvertNx2Attribute(
107 llvm::Optional<mlir::DenseIntElementsAttr> optional_attr) {
108 if (!optional_attr.has_value())
109 return std::vector<std::pair<int64_t, int64_t>>{};
110 mlir::DenseIntElementsAttr attr = *optional_attr;
111 auto type = attr.getType().dyn_cast<mlir::RankedTensorType>();
112 if (!type || type.getRank() != 2 || type.getShape()[1] != 2)
113 return InternalError("expected Nx2 attribute to be a tensor of shape Nx2");
114 auto it = attr.getValues<int64_t>().begin();
115 std::vector<std::pair<int64_t, int64_t>> out(attr.getNumElements() / 2);
116 for (auto& item : out) {
117 int64_t first = *it;
118 ++it;
119 int64_t second = *it;
120 ++it;
121 item = {first, second};
122 }
123 return out;
124 }
125
ConvertFftType(llvm::StringRef type_string)126 StatusOr<FftType> ConvertFftType(llvm::StringRef type_string) {
127 llvm::Optional<mlir::mhlo::FftType> type =
128 mlir::mhlo::symbolizeEnum<mlir::mhlo::FftType>(type_string);
129 if (!type) return InvalidArgument("Unknown FFT type %s", type_string.str());
130
131 switch (*type) {
132 case mlir::mhlo::FftType::FFT:
133 return xla::FftType::FFT;
134 case mlir::mhlo::FftType::IFFT:
135 return xla::FftType::IFFT;
136 case mlir::mhlo::FftType::RFFT:
137 return xla::FftType::RFFT;
138 case mlir::mhlo::FftType::IRFFT:
139 return xla::FftType::IRFFT;
140 default:
141 return InvalidArgument("Unknown FFT type enum #%d", *type);
142 }
143 }
144
ConvertTranspose(llvm::StringRef transpose_string)145 StatusOr<TriangularSolveOptions::Transpose> ConvertTranspose(
146 llvm::StringRef transpose_string) {
147 llvm::Optional<mlir::mhlo::Transpose> transpose =
148 mlir::mhlo::symbolizeTranspose(transpose_string);
149 if (!transpose)
150 return InvalidArgument("Unknown transpose type %s", transpose_string.str());
151
152 switch (*transpose) {
153 case mlir::mhlo::Transpose::NO_TRANSPOSE:
154 return TriangularSolveOptions::NO_TRANSPOSE;
155 case mlir::mhlo::Transpose::TRANSPOSE:
156 return TriangularSolveOptions::TRANSPOSE;
157 case mlir::mhlo::Transpose::ADJOINT:
158 return TriangularSolveOptions::ADJOINT;
159 case mlir::mhlo::Transpose::TRANSPOSE_INVALID:
160 return TriangularSolveOptions::TRANSPOSE_INVALID;
161 default:
162 return InvalidArgument("Unknown transpose enum value #%d", *transpose);
163 }
164 }
165
ConvertCustomCallApiVersion(mlir::mhlo::CustomCallApiVersion api_version)166 StatusOr<xla::CustomCallApiVersion> ConvertCustomCallApiVersion(
167 mlir::mhlo::CustomCallApiVersion api_version) {
168 switch (api_version) {
169 case mlir::mhlo::CustomCallApiVersion::API_VERSION_UNSPECIFIED:
170 return xla::CustomCallApiVersion::API_VERSION_UNSPECIFIED;
171 case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL:
172 return xla::CustomCallApiVersion::API_VERSION_ORIGINAL;
173 case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
174 return xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING;
175 case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED:
176 return xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED;
177 default:
178 return InvalidArgument("Unknown CustomCallApiVersion enum value #%d",
179 api_version);
180 }
181 }
182
183 } // namespace xla
184