1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines helpers useful when creating or manipulating lhlo/hlo.
17
18 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
19
20 #include "mlir/IR/AffineMap.h" // from @llvm-project
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
26 #include "tensorflow/core/platform/bfloat16.h"
27 #include "tensorflow/core/platform/logging.h"
28
29 namespace xla {
30 namespace {
31
32 using mlir::AffineMap;
33 using mlir::Builder;
34 using mlir::DenseElementsAttr;
35 using mlir::ShapedType;
36 using xla::LiteralBase;
37 using xla::StatusOr;
38
39 template <typename CppType>
CreateDenseAttrFromLiteral(const ShapedType & type,const LiteralBase & literal)40 ::mlir::DenseElementsAttr CreateDenseAttrFromLiteral(
41 const ShapedType& type, const LiteralBase& literal) {
42 auto data_span = literal.data<CppType>();
43 return ::mlir::DenseElementsAttr::get(
44 type, llvm::makeArrayRef(data_span.data(), data_span.size()));
45 }
46
GetPermutationIfAvailable(const Shape & shape,mlir::Builder builder)47 StatusOr<AffineMap> GetPermutationIfAvailable(const Shape& shape,
48 mlir::Builder builder) {
49 // N.B. IsMonotonicWithDim0Major ignores tiling, and I can't change it because
50 // some XLA code relies on it treating tiled layouts as equivalent to untiled
51 // layouts, so the check to rule out tiling has to come /before/ the
52 // early-return branch, or we'd miss tiled monotonic layouts.
53 if (!shape.layout().tiles().empty()) {
54 return tensorflow::errors::Internal("Tiled layouts are not yet supported");
55 }
56 if (!shape.has_layout() ||
57 LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
58 return AffineMap();
59 }
60 if (!shape.is_static()) {
61 return tensorflow::errors::Internal(
62 "Permutations for dynamic shapes are not yet supported");
63 }
64 int64_t accumulated_stride = 1;
65 llvm::SmallVector<int64_t, 4> strides(shape.rank(), 1);
66 for (int64_t dim : LayoutUtil::MinorToMajor(shape)) {
67 strides[dim] = accumulated_stride;
68 accumulated_stride *= shape.dimensions(dim);
69 }
70 if (accumulated_stride == 0) {
71 return AffineMap();
72 }
73 return makeStridedLinearLayoutMap(strides, /*offset=*/0,
74 builder.getContext());
75 }
76
77 template <typename T>
CopyDenseElementsBy(mlir::DenseElementsAttr data,std::vector<uint8_t> * output)78 void CopyDenseElementsBy(mlir::DenseElementsAttr data,
79 std::vector<uint8_t>* output) {
80 output->resize(data.getNumElements() * sizeof(T));
81 int i = 0;
82 for (T element : data.getValues<T>()) {
83 std::memcpy(&(*output)[i], &element, sizeof(T));
84 i += sizeof(T);
85 }
86 }
87
88 } // namespace
89
ConvertTensorShapeToMemRefType(const Shape & shape,mlir::Builder builder)90 StatusOr<mlir::MemRefType> ConvertTensorShapeToMemRefType(
91 const Shape& shape, mlir::Builder builder) {
92 auto element_type_or =
93 ConvertPrimitiveTypeToMLIRType(shape.element_type(), builder);
94 if (!element_type_or.ok()) return element_type_or.status();
95
96 using mlir::MemRefType;
97 auto dimensions = shape.dimensions();
98 llvm::SmallVector<int64_t, 4> array(dimensions.begin(), dimensions.end());
99 auto permutation_or = GetPermutationIfAvailable(shape, builder);
100 if (!permutation_or.ok()) return permutation_or.status();
101 return MemRefType::get(array, element_type_or.ValueOrDie(),
102 permutation_or.ValueOrDie());
103 }
104
CreateDenseElementsAttrFromLiteral(const LiteralBase & literal,Builder builder)105 StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
106 const LiteralBase& literal, Builder builder) {
107 TF_ASSIGN_OR_RETURN(auto type,
108 ConvertTensorShapeToType<mlir::RankedTensorType>(
109 literal.shape(), builder));
110
111 // TODO(hinsu): Support remaining XLA primitive types.
112 auto element_type = literal.shape().element_type();
113 switch (element_type) {
114 case PrimitiveType::PRED:
115 return CreateDenseAttrFromLiteral<bool>(type, literal);
116 case PrimitiveType::F16:
117 return CreateDenseAttrFromLiteral<half>(type, literal);
118 case PrimitiveType::BF16:
119 return CreateDenseAttrFromLiteral<bfloat16>(type, literal);
120 case PrimitiveType::F32:
121 return CreateDenseAttrFromLiteral<float>(type, literal);
122 case PrimitiveType::F64:
123 return CreateDenseAttrFromLiteral<double>(type, literal);
124 case PrimitiveType::S8:
125 return CreateDenseAttrFromLiteral<int8_t>(type, literal);
126 case PrimitiveType::S16:
127 return CreateDenseAttrFromLiteral<int16_t>(type, literal);
128 case PrimitiveType::S32:
129 return CreateDenseAttrFromLiteral<int32_t>(type, literal);
130 case PrimitiveType::S64:
131 return CreateDenseAttrFromLiteral<int64_t>(type, literal);
132 case PrimitiveType::U8:
133 return CreateDenseAttrFromLiteral<uint8_t>(type, literal);
134 case PrimitiveType::U16:
135 return CreateDenseAttrFromLiteral<uint16_t>(type, literal);
136 case PrimitiveType::U32:
137 return CreateDenseAttrFromLiteral<uint32_t>(type, literal);
138 case PrimitiveType::U64:
139 return CreateDenseAttrFromLiteral<uint64_t>(type, literal);
140 case PrimitiveType::C64:
141 return CreateDenseAttrFromLiteral<complex64>(type, literal);
142 case PrimitiveType::C128:
143 return CreateDenseAttrFromLiteral<complex128>(type, literal);
144 default:
145 return tensorflow::errors::Internal(
146 absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
147 }
148 }
149
CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,std::vector<uint8_t> * output)150 Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
151 std::vector<uint8_t>* output) {
152 mlir::Type element_type = data.getType().getElementType();
153
154 // TODO(hinsu): Support remaining XLA primitive types.
155 if (element_type.isInteger(1)) {
156 CopyDenseElementsBy<bool>(data, output);
157 return ::tensorflow::OkStatus();
158 }
159 if (element_type.isInteger(8)) {
160 CopyDenseElementsBy<uint8_t>(data, output);
161 return ::tensorflow::OkStatus();
162 }
163 if (element_type.isInteger(16)) {
164 CopyDenseElementsBy<uint16_t>(data, output);
165 return ::tensorflow::OkStatus();
166 }
167 if (element_type.isInteger(32)) {
168 CopyDenseElementsBy<uint32_t>(data, output);
169 return ::tensorflow::OkStatus();
170 }
171 if (element_type.isInteger(64)) {
172 CopyDenseElementsBy<uint64_t>(data, output);
173 return ::tensorflow::OkStatus();
174 }
175 if (element_type.isBF16()) {
176 CopyDenseElementsBy<bfloat16>(data, output);
177 return ::tensorflow::OkStatus();
178 }
179 if (element_type.isF16()) {
180 CopyDenseElementsBy<half>(data, output);
181 return ::tensorflow::OkStatus();
182 }
183 if (element_type.isF32()) {
184 CopyDenseElementsBy<float>(data, output);
185 return ::tensorflow::OkStatus();
186 }
187 if (element_type.isF64()) {
188 CopyDenseElementsBy<double>(data, output);
189 return ::tensorflow::OkStatus();
190 }
191 if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
192 if (complex_type.getElementType().isF32()) {
193 CopyDenseElementsBy<complex64>(data, output);
194 return ::tensorflow::OkStatus();
195 }
196 if (complex_type.getElementType().isF64()) {
197 CopyDenseElementsBy<complex128>(data, output);
198 return ::tensorflow::OkStatus();
199 }
200 }
201 return tensorflow::errors::Internal(
202 "Unsupported type in CopyDenseElementsDataToXlaFormat");
203 }
204
GetElementTypeBytes(mlir::Type type)205 StatusOr<int> GetElementTypeBytes(mlir::Type type) {
206 if (type.isInteger(1)) {
207 return 1;
208 }
209 if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
210 TF_ASSIGN_OR_RETURN(int bytes,
211 GetElementTypeBytes(complex_type.getElementType()));
212 return bytes * 2;
213 }
214 int width = type.getIntOrFloatBitWidth();
215 TF_RET_CHECK(width % 8 == 0);
216 return width / 8;
217 }
218
CreateDenseIntElementsAttrFromVector(const llvm::ArrayRef<int64_t> vector,mlir::Builder builder,llvm::ArrayRef<int64_t> shape)219 mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
220 const llvm::ArrayRef<int64_t> vector, mlir::Builder builder,
221 llvm::ArrayRef<int64_t> shape) {
222 return mlir::DenseIntElementsAttr::get(
223 mlir::RankedTensorType::get(shape.empty() ? vector.size() : shape,
224 builder.getIntegerType(64)),
225 vector);
226 }
227
ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,mlir::Builder builder)228 StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
229 mlir::Builder builder) {
230 switch (element_type) {
231 case PrimitiveType::PRED:
232 return builder.getI1Type();
233 case PrimitiveType::F16:
234 return builder.getF16Type();
235 case PrimitiveType::BF16:
236 return builder.getBF16Type();
237 case PrimitiveType::F32:
238 return builder.getF32Type();
239 case PrimitiveType::F64:
240 return builder.getF64Type();
241 case PrimitiveType::S8:
242 return builder.getIntegerType(8);
243 case PrimitiveType::S16:
244 return builder.getIntegerType(16);
245 case PrimitiveType::S32:
246 return builder.getIntegerType(32);
247 case PrimitiveType::S64:
248 return builder.getIntegerType(64);
249 case PrimitiveType::U8:
250 return builder.getIntegerType(8, /*isSigned=*/false);
251 case PrimitiveType::U16:
252 return builder.getIntegerType(16, /*isSigned=*/false);
253 case PrimitiveType::U32:
254 return builder.getIntegerType(32, /*isSigned=*/false);
255 case PrimitiveType::U64:
256 return builder.getIntegerType(64, /*isSigned=*/false);
257 case PrimitiveType::C64:
258 return mlir::ComplexType::get(builder.getF32Type());
259 case PrimitiveType::C128:
260 return mlir::ComplexType::get(builder.getF64Type());
261 // TODO(b/130356985): Support unsigned primitive types.
262 default:
263 return tensorflow::errors::Internal(
264 absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
265 }
266 }
267
CreateGatherDimensionNumbers(const GatherDimensionNumbers & input,mlir::Builder builder)268 mlir::mhlo::GatherDimensionNumbersAttr CreateGatherDimensionNumbers(
269 const GatherDimensionNumbers& input, mlir::Builder builder) {
270 auto get_i64_array = [](absl::Span<const int64_t> container) {
271 return llvm::ArrayRef<int64_t>{container.data(), container.size()};
272 };
273 return mlir::mhlo::GatherDimensionNumbersAttr::get(
274 builder.getContext(), get_i64_array(input.offset_dims()),
275 get_i64_array(input.collapsed_slice_dims()),
276 get_i64_array(input.start_index_map()), input.index_vector_dim());
277 }
278
MhloToHloOpcode(mlir::Operation * op)279 StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) {
280 using mlir::isa;
281
282 if (isa<mlir::mhlo::ConstantOp, mlir::lmhlo::ConstantOp>(op)) {
283 return xla::HloOpcode::kConstant;
284 } else if (isa<mlir::mhlo::IotaOp, mlir::lmhlo::IotaOp>(op)) {
285 return xla::HloOpcode::kIota;
286 } else if (isa<mlir::mhlo::ConvertOp, mlir::lmhlo::ConvertOp>(op)) {
287 return xla::HloOpcode::kConvert;
288 } else if (isa<mlir::mhlo::AddOp, mlir::lmhlo::AddOp>(op)) {
289 return xla::HloOpcode::kAdd;
290 } else if (isa<mlir::mhlo::Atan2Op, mlir::lmhlo::Atan2Op>(op)) {
291 return xla::HloOpcode::kAtan2;
292 } else if (isa<mlir::mhlo::DivOp, mlir::lmhlo::DivOp>(op)) {
293 return xla::HloOpcode::kDivide;
294 } else if (isa<mlir::mhlo::MaxOp, mlir::lmhlo::MaxOp>(op)) {
295 return xla::HloOpcode::kMaximum;
296 } else if (isa<mlir::mhlo::MinOp, mlir::lmhlo::MinOp>(op)) {
297 return xla::HloOpcode::kMinimum;
298 } else if (isa<mlir::mhlo::MulOp, mlir::lmhlo::MulOp>(op)) {
299 return xla::HloOpcode::kMultiply;
300 } else if (isa<mlir::mhlo::PowOp, mlir::lmhlo::PowOp>(op)) {
301 return xla::HloOpcode::kPower;
302 } else if (isa<mlir::mhlo::RemOp, mlir::lmhlo::RemOp>(op)) {
303 return xla::HloOpcode::kRemainder;
304 } else if (isa<mlir::mhlo::ShiftLeftOp, mlir::lmhlo::ShiftLeftOp>(op)) {
305 return xla::HloOpcode::kShiftLeft;
306 } else if (isa<mlir::mhlo::ShiftRightArithmeticOp,
307 mlir::lmhlo::ShiftRightArithmeticOp>(op)) {
308 return xla::HloOpcode::kShiftRightArithmetic;
309 } else if (isa<mlir::mhlo::ShiftRightLogicalOp,
310 mlir::lmhlo::ShiftRightLogicalOp>(op)) {
311 return xla::HloOpcode::kShiftRightLogical;
312 } else if (isa<mlir::mhlo::SubtractOp, mlir::lmhlo::SubtractOp>(op)) {
313 return xla::HloOpcode::kSubtract;
314 } else if (isa<mlir::mhlo::XorOp, mlir::lmhlo::XorOp>(op)) {
315 return xla::HloOpcode::kXor;
316 } else if (isa<mlir::mhlo::InfeedOp, mlir::lmhlo::InfeedOp>(op)) {
317 return xla::HloOpcode::kInfeed;
318 } else if (isa<mlir::mhlo::OutfeedOp, mlir::lmhlo::OutfeedOp>(op)) {
319 return xla::HloOpcode::kOutfeed;
320 } else if (isa<mlir::mhlo::SendOp>(op)) {
321 return xla::HloOpcode::kSend;
322 } else if (isa<mlir::mhlo::RecvOp>(op)) {
323 return xla::HloOpcode::kRecv;
324 } else if (isa<mlir::mhlo::ReplicaIdOp, mlir::lmhlo::ReplicaIdOp>(op)) {
325 return xla::HloOpcode::kReplicaId;
326 } else if (isa<mlir::mhlo::AfterAllOp>(op)) {
327 return xla::HloOpcode::kAfterAll;
328 } else if (isa<mlir::mhlo::AllReduceOp, mlir::lmhlo::AllReduceOp>(op)) {
329 return xla::HloOpcode::kAllReduce;
330 } else if (isa<mlir::mhlo::AllToAllOp>(op)) {
331 return xla::HloOpcode::kAllToAll;
332 } else if (isa<mlir::mhlo::TupleOp>(op)) {
333 return xla::HloOpcode::kTuple;
334 } else if (isa<mlir::mhlo::BatchNormGradOp, mlir::lmhlo::BatchNormGradOp>(
335 op)) {
336 return xla::HloOpcode::kBatchNormGrad;
337 } else if (isa<mlir::mhlo::BatchNormInferenceOp,
338 mlir::lmhlo::BatchNormInferenceOp>(op)) {
339 return xla::HloOpcode::kBatchNormInference;
340 } else if (isa<mlir::mhlo::BatchNormTrainingOp,
341 mlir::lmhlo::BatchNormTrainingOp>(op)) {
342 return xla::HloOpcode::kBatchNormTraining;
343 } else if (isa<mlir::mhlo::BitcastConvertOp, mlir::lmhlo::BitcastConvertOp>(
344 op)) {
345 return xla::HloOpcode::kBitcastConvert;
346 } else if (isa<mlir::mhlo::BroadcastOp, mlir::lmhlo::BroadcastOp>(op)) {
347 return xla::HloOpcode::kBroadcast;
348 } else if (isa<mlir::mhlo::CholeskyOp, mlir::lmhlo::CholeskyOp>(op)) {
349 return xla::HloOpcode::kCholesky;
350 } else if (isa<mlir::mhlo::ClampOp, mlir::lmhlo::ClampOp>(op)) {
351 return xla::HloOpcode::kClamp;
352 } else if (isa<mlir::mhlo::ConcatenateOp, mlir::lmhlo::ConcatenateOp>(op)) {
353 return xla::HloOpcode::kConcatenate;
354 } else if (isa<mlir::mhlo::ConvolutionOp, mlir::lmhlo::ConvolutionOp>(op)) {
355 return xla::HloOpcode::kConvolution;
356 } else if (isa<mlir::mhlo::SortOp, mlir::lmhlo::SortOp>(op)) {
357 return xla::HloOpcode::kSort;
358 } else if (isa<mlir::mhlo::RngBitGeneratorOp>(op)) {
359 return xla::HloOpcode::kRngBitGenerator;
360 } else if (isa<mlir::mhlo::XlaRngGetAndUpdateStateOp>(op)) {
361 return xla::HloOpcode::kRngGetAndUpdateState;
362 } else if (isa<mlir::mhlo::FusionOp, mlir::lmhlo::FusionOp>(op)) {
363 return xla::HloOpcode::kFusion;
364 } else if (isa<mlir::mhlo::BitcastOp>(op)) {
365 return xla::HloOpcode::kBitcast;
366 } else if (isa<mlir::mhlo::AbsOp, mlir::lmhlo::AbsOp>(op)) {
367 return xla::HloOpcode::kAbs;
368 } else if (isa<mlir::mhlo::CbrtOp, mlir::lmhlo::CbrtOp>(op)) {
369 return xla::HloOpcode::kCbrt;
370 } else if (isa<mlir::mhlo::CeilOp, mlir::lmhlo::CeilOp>(op)) {
371 return xla::HloOpcode::kCeil;
372 } else if (isa<mlir::mhlo::ClzOp, mlir::lmhlo::ClzOp>(op)) {
373 return xla::HloOpcode::kClz;
374 } else if (isa<mlir::mhlo::CosineOp, mlir::lmhlo::CosineOp>(op)) {
375 return xla::HloOpcode::kCos;
376 } else if (isa<mlir::mhlo::ExpOp, mlir::lmhlo::ExpOp>(op)) {
377 return xla::HloOpcode::kExp;
378 } else if (isa<mlir::mhlo::Expm1Op, mlir::lmhlo::Expm1Op>(op)) {
379 return xla::HloOpcode::kExpm1;
380 } else if (isa<mlir::mhlo::FloorOp, mlir::lmhlo::FloorOp>(op)) {
381 return xla::HloOpcode::kFloor;
382 } else if (isa<mlir::mhlo::ImagOp, mlir::lmhlo::ImagOp>(op)) {
383 return xla::HloOpcode::kImag;
384 } else if (isa<mlir::mhlo::IsFiniteOp, mlir::lmhlo::IsFiniteOp>(op)) {
385 return xla::HloOpcode::kIsFinite;
386 } else if (isa<mlir::mhlo::LogOp, mlir::lmhlo::LogOp>(op)) {
387 return xla::HloOpcode::kLog;
388 } else if (isa<mlir::mhlo::Log1pOp, mlir::lmhlo::Log1pOp>(op)) {
389 return xla::HloOpcode::kLog1p;
390 } else if (isa<mlir::mhlo::LogisticOp>(op)) {
391 return xla::HloOpcode::kLogistic;
392 } else if (isa<mlir::mhlo::NotOp, mlir::lmhlo::NotOp>(op)) {
393 return xla::HloOpcode::kNot;
394 } else if (isa<mlir::mhlo::NegOp, mlir::lmhlo::NegOp>(op)) {
395 return xla::HloOpcode::kNegate;
396 } else if (isa<mlir::mhlo::PopulationCountOp, mlir::lmhlo::PopulationCountOp>(
397 op)) {
398 return xla::HloOpcode::kPopulationCount;
399 } else if (isa<mlir::mhlo::RealOp, mlir::lmhlo::RealOp>(op)) {
400 return xla::HloOpcode::kReal;
401 } else if (isa<mlir::mhlo::RoundOp, mlir::lmhlo::RoundOp>(op)) {
402 return xla::HloOpcode::kRoundNearestAfz;
403 } else if (isa<mlir::mhlo::RoundNearestEvenOp,
404 mlir::lmhlo::RoundNearestEvenOp>(op)) {
405 return xla::HloOpcode::kRoundNearestEven;
406 } else if (isa<mlir::mhlo::RsqrtOp, mlir::lmhlo::RsqrtOp>(op)) {
407 return xla::HloOpcode::kRsqrt;
408 } else if (isa<mlir::mhlo::SignOp, mlir::lmhlo::SignOp>(op)) {
409 return xla::HloOpcode::kSign;
410 } else if (isa<mlir::mhlo::SineOp, mlir::lmhlo::SineOp>(op)) {
411 return xla::HloOpcode::kSin;
412 } else if (isa<mlir::mhlo::SqrtOp, mlir::lmhlo::SqrtOp>(op)) {
413 return xla::HloOpcode::kSqrt;
414 } else if (isa<mlir::mhlo::TanhOp, mlir::lmhlo::TanhOp>(op)) {
415 return xla::HloOpcode::kTanh;
416 } else if (isa<mlir::mhlo::ComplexOp, mlir::lmhlo::ComplexOp>(op)) {
417 return xla::HloOpcode::kComplex;
418 } else if (isa<mlir::mhlo::AndOp, mlir::lmhlo::AndOp>(op)) {
419 return xla::HloOpcode::kAnd;
420 } else if (isa<mlir::mhlo::OrOp, mlir::lmhlo::OrOp>(op)) {
421 return xla::HloOpcode::kOr;
422 } else if (isa<mlir::mhlo::WhileOp, mlir::lmhlo::WhileOp>(op)) {
423 return xla::HloOpcode::kWhile;
424 } else if (isa<mlir::mhlo::ReduceOp, mlir::lmhlo::ReduceOp>(op)) {
425 return xla::HloOpcode::kReduce;
426 } else if (isa<mlir::mhlo::GetTupleElementOp>(op)) {
427 return xla::HloOpcode::kGetTupleElement;
428 } else if (isa<mlir::mhlo::CompareOp, mlir::lmhlo::CompareOp>(op)) {
429 return xla::HloOpcode::kCompare;
430 } else if (isa<mlir::mhlo::SliceOp, mlir::lmhlo::SliceOp>(op)) {
431 return xla::HloOpcode::kSlice;
432 } else if (isa<mlir::mhlo::DynamicSliceOp, mlir::lmhlo::DynamicSliceOp>(op)) {
433 return xla::HloOpcode::kDynamicSlice;
434 } else if (isa<mlir::mhlo::DynamicUpdateSliceOp,
435 mlir::lmhlo::DynamicUpdateSliceOp>(op)) {
436 return xla::HloOpcode::kDynamicUpdateSlice;
437 } else if (isa<mlir::mhlo::CollectivePermuteOp,
438 mlir::lmhlo::CollectivePermuteOp>(op)) {
439 return xla::HloOpcode::kCollectivePermute;
440 } else if (isa<mlir::mhlo::CopyOp, mlir::lmhlo::CopyOp>(op)) {
441 return xla::HloOpcode::kCopy;
442 } else if (isa<mlir::mhlo::CustomCallOp, mlir::lmhlo::CustomCallOp>(op)) {
443 return xla::HloOpcode::kCustomCall;
444 } else if (isa<mlir::mhlo::DotOp, mlir::lmhlo::DotOp>(op)) {
445 return xla::HloOpcode::kDot;
446 } else if (isa<mlir::mhlo::FftOp, mlir::lmhlo::FftOp>(op)) {
447 return xla::HloOpcode::kFft;
448 } else if (isa<mlir::mhlo::GatherOp, mlir::lmhlo::GatherOp>(op)) {
449 return xla::HloOpcode::kGather;
450 } else if (isa<mlir::mhlo::GetDimensionSizeOp>(op)) {
451 return xla::HloOpcode::kGetDimensionSize;
452 } else if (isa<mlir::mhlo::MapOp, mlir::lmhlo::MapOp>(op)) {
453 return xla::HloOpcode::kMap;
454 } else if (isa<mlir::mhlo::ReshapeOp, mlir::lmhlo::ReshapeOp>(op)) {
455 return xla::HloOpcode::kReshape;
456 } else if (isa<mlir::mhlo::DynamicReshapeOp>(op)) {
457 return xla::HloOpcode::kDynamicReshape;
458 } else if (isa<mlir::mhlo::ScatterOp, mlir::lmhlo::ScatterOp>(op)) {
459 return xla::HloOpcode::kScatter;
460 } else if (isa<mlir::mhlo::SelectOp, mlir::lmhlo::SelectOp>(op)) {
461 return xla::HloOpcode::kSelect;
462 } else if (isa<mlir::mhlo::SelectAndScatterOp,
463 mlir::lmhlo::SelectAndScatterOp>(op)) {
464 return xla::HloOpcode::kSelectAndScatter;
465 } else if (isa<mlir::mhlo::SetDimensionSizeOp>(op)) {
466 return xla::HloOpcode::kSetDimensionSize;
467 } else if (isa<mlir::mhlo::ReverseOp, mlir::lmhlo::ReverseOp>(op)) {
468 return xla::HloOpcode::kReverse;
469 } else if (isa<mlir::mhlo::PadOp, mlir::lmhlo::PadOp>(op)) {
470 return xla::HloOpcode::kPad;
471 } else if (isa<mlir::mhlo::TransposeOp, mlir::lmhlo::TransposeOp>(op)) {
472 return xla::HloOpcode::kTranspose;
473 } else if (isa<mlir::mhlo::TriangularSolveOp, mlir::lmhlo::TriangularSolveOp>(
474 op)) {
475 return xla::HloOpcode::kTriangularSolve;
476 } else if (isa<mlir::mhlo::ReduceWindowOp, mlir::lmhlo::ReduceWindowOp>(op)) {
477 return xla::HloOpcode::kReduceWindow;
478 } else if (isa<mlir::mhlo::ReducePrecisionOp, mlir::lmhlo::ReducePrecisionOp>(
479 op)) {
480 return xla::HloOpcode::kReducePrecision;
481 } else if (isa<mlir::mhlo::DotGeneralOp>(op)) {
482 return xla::HloOpcode::kDot;
483 } else if (isa<mlir::mhlo::BroadcastInDimOp, mlir::lmhlo::BroadcastInDimOp>(
484 op)) {
485 return xla::HloOpcode::kBroadcast;
486 } else {
487 std::string s;
488 {
489 llvm::raw_string_ostream os(s);
490 op->print(os);
491 }
492 return tensorflow::errors::Unimplemented(
493 "Unimplemented MHLO -> HloOpcode: ", s);
494 }
495 }
496
497 } // namespace xla
498