xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/attribute_exporter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
17 
18 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
19 #include "tensorflow/compiler/xla/types.h"
20 #include "tensorflow/compiler/xla/util.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/stream_executor/dnn.h"
23 
24 namespace xla {
25 
ConvertConvDimensionNumbers(mlir::mhlo::ConvDimensionNumbersAttr input)26 ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
27     mlir::mhlo::ConvDimensionNumbersAttr input) {
28   ConvolutionDimensionNumbers output;
29 
30   output.set_input_batch_dimension(input.getInputBatchDimension());
31   output.set_input_feature_dimension(input.getInputFeatureDimension());
32   for (auto v : input.getInputSpatialDimensions()) {
33     output.add_input_spatial_dimensions(v);
34   }
35 
36   output.set_kernel_input_feature_dimension(
37       input.getKernelInputFeatureDimension());
38   output.set_kernel_output_feature_dimension(
39       input.getKernelOutputFeatureDimension());
40 
41   for (auto v : input.getKernelSpatialDimensions()) {
42     output.add_kernel_spatial_dimensions(v);
43   }
44 
45   output.set_output_batch_dimension(input.getOutputBatchDimension());
46   output.set_output_feature_dimension(input.getOutputFeatureDimension());
47 
48   for (auto v : input.getOutputSpatialDimensions()) {
49     output.add_output_spatial_dimensions(v);
50   }
51 
52   return output;
53 }
54 
ConvertConvActivationMode(mlir::lmhlo_gpu::Activation activation)55 StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
56     mlir::lmhlo_gpu::Activation activation) {
57   switch (activation) {
58     case mlir::lmhlo_gpu::Activation::None:
59       return stream_executor::dnn::kNone;
60     case mlir::lmhlo_gpu::Activation::Sigmoid:
61       return stream_executor::dnn::kSigmoid;
62     case mlir::lmhlo_gpu::Activation::Tanh:
63       return stream_executor::dnn::kTanh;
64     case mlir::lmhlo_gpu::Activation::Relu:
65       return stream_executor::dnn::kRelu;
66     case mlir::lmhlo_gpu::Activation::Relu6:
67       return stream_executor::dnn::kRelu6;
68     case mlir::lmhlo_gpu::Activation::ReluX:
69       return stream_executor::dnn::kReluX;
70     case mlir::lmhlo_gpu::Activation::BandPass:
71       return stream_executor::dnn::kBandPass;
72     default:
73       return InternalError("Unexpected activation");
74   }
75 }
76 
77 // Convert replica group from MLIR encoding to HLO.
78 // See HloFunctionImporter::ConvertReplicaGroups for the MLIR encoding.
ConvertReplicaGroups(mlir::DenseIntElementsAttr input)79 StatusOr<std::vector<ReplicaGroup>> ConvertReplicaGroups(
80     mlir::DenseIntElementsAttr input) {
81   mlir::RankedTensorType type =
82       input.getType().dyn_cast<mlir::RankedTensorType>();
83   if (!type || type.getRank() != 2 ||
84       !type.getElementType().isInteger(/*width=*/64)) {
85     return InternalError("Execpted replica group to be a rank 2 tensor of i64");
86   }
87   // rank 0 is num_groups, rank 1 is group size.
88   auto replica_group_values_it = input.getValues<uint64_t>().begin();
89   std::vector<ReplicaGroup> replica_groups(type.getDimSize(0));
90   for (ReplicaGroup& group : replica_groups) {
91     for (int64_t element_idx = 0; element_idx < type.getDimSize(1);
92          ++element_idx, ++replica_group_values_it) {
93       // For replica group attribute, -1 indicates padding added by
94       // HloFunctionImporter::ConvertReplicaGroups. This should always be at the
95       // end and can be dropped when converting back to XLA HLO ReplicaGroups.
96       if (*replica_group_values_it != -1) {
97         group.add_replica_ids(*replica_group_values_it);
98       }
99     }
100   }
101   return replica_groups;
102 }
103 
104 // Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
105 // and source-target pairs are defined in HLO.
ConvertNx2Attribute(llvm::Optional<mlir::DenseIntElementsAttr> optional_attr)106 StatusOr<std::vector<std::pair<int64_t, int64_t>>> ConvertNx2Attribute(
107     llvm::Optional<mlir::DenseIntElementsAttr> optional_attr) {
108   if (!optional_attr.has_value())
109     return std::vector<std::pair<int64_t, int64_t>>{};
110   mlir::DenseIntElementsAttr attr = *optional_attr;
111   auto type = attr.getType().dyn_cast<mlir::RankedTensorType>();
112   if (!type || type.getRank() != 2 || type.getShape()[1] != 2)
113     return InternalError("expected Nx2 attribute to be a tensor of shape Nx2");
114   auto it = attr.getValues<int64_t>().begin();
115   std::vector<std::pair<int64_t, int64_t>> out(attr.getNumElements() / 2);
116   for (auto& item : out) {
117     int64_t first = *it;
118     ++it;
119     int64_t second = *it;
120     ++it;
121     item = {first, second};
122   }
123   return out;
124 }
125 
ConvertFftType(llvm::StringRef type_string)126 StatusOr<FftType> ConvertFftType(llvm::StringRef type_string) {
127   llvm::Optional<mlir::mhlo::FftType> type =
128       mlir::mhlo::symbolizeEnum<mlir::mhlo::FftType>(type_string);
129   if (!type) return InvalidArgument("Unknown FFT type %s", type_string.str());
130 
131   switch (*type) {
132     case mlir::mhlo::FftType::FFT:
133       return xla::FftType::FFT;
134     case mlir::mhlo::FftType::IFFT:
135       return xla::FftType::IFFT;
136     case mlir::mhlo::FftType::RFFT:
137       return xla::FftType::RFFT;
138     case mlir::mhlo::FftType::IRFFT:
139       return xla::FftType::IRFFT;
140     default:
141       return InvalidArgument("Unknown FFT type enum #%d", *type);
142   }
143 }
144 
ConvertTranspose(llvm::StringRef transpose_string)145 StatusOr<TriangularSolveOptions::Transpose> ConvertTranspose(
146     llvm::StringRef transpose_string) {
147   llvm::Optional<mlir::mhlo::Transpose> transpose =
148       mlir::mhlo::symbolizeTranspose(transpose_string);
149   if (!transpose)
150     return InvalidArgument("Unknown transpose type %s", transpose_string.str());
151 
152   switch (*transpose) {
153     case mlir::mhlo::Transpose::NO_TRANSPOSE:
154       return TriangularSolveOptions::NO_TRANSPOSE;
155     case mlir::mhlo::Transpose::TRANSPOSE:
156       return TriangularSolveOptions::TRANSPOSE;
157     case mlir::mhlo::Transpose::ADJOINT:
158       return TriangularSolveOptions::ADJOINT;
159     case mlir::mhlo::Transpose::TRANSPOSE_INVALID:
160       return TriangularSolveOptions::TRANSPOSE_INVALID;
161     default:
162       return InvalidArgument("Unknown transpose enum value #%d", *transpose);
163   }
164 }
165 
ConvertCustomCallApiVersion(mlir::mhlo::CustomCallApiVersion api_version)166 StatusOr<xla::CustomCallApiVersion> ConvertCustomCallApiVersion(
167     mlir::mhlo::CustomCallApiVersion api_version) {
168   switch (api_version) {
169     case mlir::mhlo::CustomCallApiVersion::API_VERSION_UNSPECIFIED:
170       return xla::CustomCallApiVersion::API_VERSION_UNSPECIFIED;
171     case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL:
172       return xla::CustomCallApiVersion::API_VERSION_ORIGINAL;
173     case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
174       return xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING;
175     case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED:
176       return xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED;
177     default:
178       return InvalidArgument("Unknown CustomCallApiVersion enum value #%d",
179                              api_version);
180   }
181 }
182 
183 }  // namespace xla
184