xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/hlo_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines helpers useful when creating or manipulating lhlo/hlo.
17 
18 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
19 
20 #include "mlir/IR/AffineMap.h"  // from @llvm-project
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
26 #include "tensorflow/core/platform/bfloat16.h"
27 #include "tensorflow/core/platform/logging.h"
28 
29 namespace xla {
30 namespace {
31 
32 using mlir::AffineMap;
33 using mlir::Builder;
34 using mlir::DenseElementsAttr;
35 using mlir::ShapedType;
36 using xla::LiteralBase;
37 using xla::StatusOr;
38 
39 template <typename CppType>
CreateDenseAttrFromLiteral(const ShapedType & type,const LiteralBase & literal)40 ::mlir::DenseElementsAttr CreateDenseAttrFromLiteral(
41     const ShapedType& type, const LiteralBase& literal) {
42   auto data_span = literal.data<CppType>();
43   return ::mlir::DenseElementsAttr::get(
44       type, llvm::makeArrayRef(data_span.data(), data_span.size()));
45 }
46 
GetPermutationIfAvailable(const Shape & shape,mlir::Builder builder)47 StatusOr<AffineMap> GetPermutationIfAvailable(const Shape& shape,
48                                               mlir::Builder builder) {
49   // N.B. IsMonotonicWithDim0Major ignores tiling, and I can't change it because
50   // some XLA code relies on it treating tiled layouts as equivalent to untiled
51   // layouts, so the check to rule out tiling has to come /before/ the
52   // early-return branch, or we'd miss tiled monotonic layouts.
53   if (!shape.layout().tiles().empty()) {
54     return tensorflow::errors::Internal("Tiled layouts are not yet supported");
55   }
56   if (!shape.has_layout() ||
57       LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
58     return AffineMap();
59   }
60   if (!shape.is_static()) {
61     return tensorflow::errors::Internal(
62         "Permutations for dynamic shapes are not yet supported");
63   }
64   int64_t accumulated_stride = 1;
65   llvm::SmallVector<int64_t, 4> strides(shape.rank(), 1);
66   for (int64_t dim : LayoutUtil::MinorToMajor(shape)) {
67     strides[dim] = accumulated_stride;
68     accumulated_stride *= shape.dimensions(dim);
69   }
70   if (accumulated_stride == 0) {
71     return AffineMap();
72   }
73   return makeStridedLinearLayoutMap(strides, /*offset=*/0,
74                                     builder.getContext());
75 }
76 
77 template <typename T>
CopyDenseElementsBy(mlir::DenseElementsAttr data,std::vector<uint8_t> * output)78 void CopyDenseElementsBy(mlir::DenseElementsAttr data,
79                          std::vector<uint8_t>* output) {
80   output->resize(data.getNumElements() * sizeof(T));
81   int i = 0;
82   for (T element : data.getValues<T>()) {
83     std::memcpy(&(*output)[i], &element, sizeof(T));
84     i += sizeof(T);
85   }
86 }
87 
88 }  // namespace
89 
ConvertTensorShapeToMemRefType(const Shape & shape,mlir::Builder builder)90 StatusOr<mlir::MemRefType> ConvertTensorShapeToMemRefType(
91     const Shape& shape, mlir::Builder builder) {
92   auto element_type_or =
93       ConvertPrimitiveTypeToMLIRType(shape.element_type(), builder);
94   if (!element_type_or.ok()) return element_type_or.status();
95 
96   using mlir::MemRefType;
97   auto dimensions = shape.dimensions();
98   llvm::SmallVector<int64_t, 4> array(dimensions.begin(), dimensions.end());
99   auto permutation_or = GetPermutationIfAvailable(shape, builder);
100   if (!permutation_or.ok()) return permutation_or.status();
101   return MemRefType::get(array, element_type_or.ValueOrDie(),
102                          permutation_or.ValueOrDie());
103 }
104 
CreateDenseElementsAttrFromLiteral(const LiteralBase & literal,Builder builder)105 StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
106     const LiteralBase& literal, Builder builder) {
107   TF_ASSIGN_OR_RETURN(auto type,
108                       ConvertTensorShapeToType<mlir::RankedTensorType>(
109                           literal.shape(), builder));
110 
111   // TODO(hinsu): Support remaining XLA primitive types.
112   auto element_type = literal.shape().element_type();
113   switch (element_type) {
114     case PrimitiveType::PRED:
115       return CreateDenseAttrFromLiteral<bool>(type, literal);
116     case PrimitiveType::F16:
117       return CreateDenseAttrFromLiteral<half>(type, literal);
118     case PrimitiveType::BF16:
119       return CreateDenseAttrFromLiteral<bfloat16>(type, literal);
120     case PrimitiveType::F32:
121       return CreateDenseAttrFromLiteral<float>(type, literal);
122     case PrimitiveType::F64:
123       return CreateDenseAttrFromLiteral<double>(type, literal);
124     case PrimitiveType::S8:
125       return CreateDenseAttrFromLiteral<int8_t>(type, literal);
126     case PrimitiveType::S16:
127       return CreateDenseAttrFromLiteral<int16_t>(type, literal);
128     case PrimitiveType::S32:
129       return CreateDenseAttrFromLiteral<int32_t>(type, literal);
130     case PrimitiveType::S64:
131       return CreateDenseAttrFromLiteral<int64_t>(type, literal);
132     case PrimitiveType::U8:
133       return CreateDenseAttrFromLiteral<uint8_t>(type, literal);
134     case PrimitiveType::U16:
135       return CreateDenseAttrFromLiteral<uint16_t>(type, literal);
136     case PrimitiveType::U32:
137       return CreateDenseAttrFromLiteral<uint32_t>(type, literal);
138     case PrimitiveType::U64:
139       return CreateDenseAttrFromLiteral<uint64_t>(type, literal);
140     case PrimitiveType::C64:
141       return CreateDenseAttrFromLiteral<complex64>(type, literal);
142     case PrimitiveType::C128:
143       return CreateDenseAttrFromLiteral<complex128>(type, literal);
144     default:
145       return tensorflow::errors::Internal(
146           absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
147   }
148 }
149 
CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,std::vector<uint8_t> * output)150 Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
151                                         std::vector<uint8_t>* output) {
152   mlir::Type element_type = data.getType().getElementType();
153 
154   // TODO(hinsu): Support remaining XLA primitive types.
155   if (element_type.isInteger(1)) {
156     CopyDenseElementsBy<bool>(data, output);
157     return ::tensorflow::OkStatus();
158   }
159   if (element_type.isInteger(8)) {
160     CopyDenseElementsBy<uint8_t>(data, output);
161     return ::tensorflow::OkStatus();
162   }
163   if (element_type.isInteger(16)) {
164     CopyDenseElementsBy<uint16_t>(data, output);
165     return ::tensorflow::OkStatus();
166   }
167   if (element_type.isInteger(32)) {
168     CopyDenseElementsBy<uint32_t>(data, output);
169     return ::tensorflow::OkStatus();
170   }
171   if (element_type.isInteger(64)) {
172     CopyDenseElementsBy<uint64_t>(data, output);
173     return ::tensorflow::OkStatus();
174   }
175   if (element_type.isBF16()) {
176     CopyDenseElementsBy<bfloat16>(data, output);
177     return ::tensorflow::OkStatus();
178   }
179   if (element_type.isF16()) {
180     CopyDenseElementsBy<half>(data, output);
181     return ::tensorflow::OkStatus();
182   }
183   if (element_type.isF32()) {
184     CopyDenseElementsBy<float>(data, output);
185     return ::tensorflow::OkStatus();
186   }
187   if (element_type.isF64()) {
188     CopyDenseElementsBy<double>(data, output);
189     return ::tensorflow::OkStatus();
190   }
191   if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
192     if (complex_type.getElementType().isF32()) {
193       CopyDenseElementsBy<complex64>(data, output);
194       return ::tensorflow::OkStatus();
195     }
196     if (complex_type.getElementType().isF64()) {
197       CopyDenseElementsBy<complex128>(data, output);
198       return ::tensorflow::OkStatus();
199     }
200   }
201   return tensorflow::errors::Internal(
202       "Unsupported type in CopyDenseElementsDataToXlaFormat");
203 }
204 
GetElementTypeBytes(mlir::Type type)205 StatusOr<int> GetElementTypeBytes(mlir::Type type) {
206   if (type.isInteger(1)) {
207     return 1;
208   }
209   if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
210     TF_ASSIGN_OR_RETURN(int bytes,
211                         GetElementTypeBytes(complex_type.getElementType()));
212     return bytes * 2;
213   }
214   int width = type.getIntOrFloatBitWidth();
215   TF_RET_CHECK(width % 8 == 0);
216   return width / 8;
217 }
218 
CreateDenseIntElementsAttrFromVector(const llvm::ArrayRef<int64_t> vector,mlir::Builder builder,llvm::ArrayRef<int64_t> shape)219 mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
220     const llvm::ArrayRef<int64_t> vector, mlir::Builder builder,
221     llvm::ArrayRef<int64_t> shape) {
222   return mlir::DenseIntElementsAttr::get(
223       mlir::RankedTensorType::get(shape.empty() ? vector.size() : shape,
224                                   builder.getIntegerType(64)),
225       vector);
226 }
227 
ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,mlir::Builder builder)228 StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
229                                                     mlir::Builder builder) {
230   switch (element_type) {
231     case PrimitiveType::PRED:
232       return builder.getI1Type();
233     case PrimitiveType::F16:
234       return builder.getF16Type();
235     case PrimitiveType::BF16:
236       return builder.getBF16Type();
237     case PrimitiveType::F32:
238       return builder.getF32Type();
239     case PrimitiveType::F64:
240       return builder.getF64Type();
241     case PrimitiveType::S8:
242       return builder.getIntegerType(8);
243     case PrimitiveType::S16:
244       return builder.getIntegerType(16);
245     case PrimitiveType::S32:
246       return builder.getIntegerType(32);
247     case PrimitiveType::S64:
248       return builder.getIntegerType(64);
249     case PrimitiveType::U8:
250       return builder.getIntegerType(8, /*isSigned=*/false);
251     case PrimitiveType::U16:
252       return builder.getIntegerType(16, /*isSigned=*/false);
253     case PrimitiveType::U32:
254       return builder.getIntegerType(32, /*isSigned=*/false);
255     case PrimitiveType::U64:
256       return builder.getIntegerType(64, /*isSigned=*/false);
257     case PrimitiveType::C64:
258       return mlir::ComplexType::get(builder.getF32Type());
259     case PrimitiveType::C128:
260       return mlir::ComplexType::get(builder.getF64Type());
261     // TODO(b/130356985): Support unsigned primitive types.
262     default:
263       return tensorflow::errors::Internal(
264           absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
265   }
266 }
267 
CreateGatherDimensionNumbers(const GatherDimensionNumbers & input,mlir::Builder builder)268 mlir::mhlo::GatherDimensionNumbersAttr CreateGatherDimensionNumbers(
269     const GatherDimensionNumbers& input, mlir::Builder builder) {
270   auto get_i64_array = [](absl::Span<const int64_t> container) {
271     return llvm::ArrayRef<int64_t>{container.data(), container.size()};
272   };
273   return mlir::mhlo::GatherDimensionNumbersAttr::get(
274       builder.getContext(), get_i64_array(input.offset_dims()),
275       get_i64_array(input.collapsed_slice_dims()),
276       get_i64_array(input.start_index_map()), input.index_vector_dim());
277 }
278 
MhloToHloOpcode(mlir::Operation * op)279 StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) {
280   using mlir::isa;
281 
282   if (isa<mlir::mhlo::ConstantOp, mlir::lmhlo::ConstantOp>(op)) {
283     return xla::HloOpcode::kConstant;
284   } else if (isa<mlir::mhlo::IotaOp, mlir::lmhlo::IotaOp>(op)) {
285     return xla::HloOpcode::kIota;
286   } else if (isa<mlir::mhlo::ConvertOp, mlir::lmhlo::ConvertOp>(op)) {
287     return xla::HloOpcode::kConvert;
288   } else if (isa<mlir::mhlo::AddOp, mlir::lmhlo::AddOp>(op)) {
289     return xla::HloOpcode::kAdd;
290   } else if (isa<mlir::mhlo::Atan2Op, mlir::lmhlo::Atan2Op>(op)) {
291     return xla::HloOpcode::kAtan2;
292   } else if (isa<mlir::mhlo::DivOp, mlir::lmhlo::DivOp>(op)) {
293     return xla::HloOpcode::kDivide;
294   } else if (isa<mlir::mhlo::MaxOp, mlir::lmhlo::MaxOp>(op)) {
295     return xla::HloOpcode::kMaximum;
296   } else if (isa<mlir::mhlo::MinOp, mlir::lmhlo::MinOp>(op)) {
297     return xla::HloOpcode::kMinimum;
298   } else if (isa<mlir::mhlo::MulOp, mlir::lmhlo::MulOp>(op)) {
299     return xla::HloOpcode::kMultiply;
300   } else if (isa<mlir::mhlo::PowOp, mlir::lmhlo::PowOp>(op)) {
301     return xla::HloOpcode::kPower;
302   } else if (isa<mlir::mhlo::RemOp, mlir::lmhlo::RemOp>(op)) {
303     return xla::HloOpcode::kRemainder;
304   } else if (isa<mlir::mhlo::ShiftLeftOp, mlir::lmhlo::ShiftLeftOp>(op)) {
305     return xla::HloOpcode::kShiftLeft;
306   } else if (isa<mlir::mhlo::ShiftRightArithmeticOp,
307                  mlir::lmhlo::ShiftRightArithmeticOp>(op)) {
308     return xla::HloOpcode::kShiftRightArithmetic;
309   } else if (isa<mlir::mhlo::ShiftRightLogicalOp,
310                  mlir::lmhlo::ShiftRightLogicalOp>(op)) {
311     return xla::HloOpcode::kShiftRightLogical;
312   } else if (isa<mlir::mhlo::SubtractOp, mlir::lmhlo::SubtractOp>(op)) {
313     return xla::HloOpcode::kSubtract;
314   } else if (isa<mlir::mhlo::XorOp, mlir::lmhlo::XorOp>(op)) {
315     return xla::HloOpcode::kXor;
316   } else if (isa<mlir::mhlo::InfeedOp, mlir::lmhlo::InfeedOp>(op)) {
317     return xla::HloOpcode::kInfeed;
318   } else if (isa<mlir::mhlo::OutfeedOp, mlir::lmhlo::OutfeedOp>(op)) {
319     return xla::HloOpcode::kOutfeed;
320   } else if (isa<mlir::mhlo::SendOp>(op)) {
321     return xla::HloOpcode::kSend;
322   } else if (isa<mlir::mhlo::RecvOp>(op)) {
323     return xla::HloOpcode::kRecv;
324   } else if (isa<mlir::mhlo::ReplicaIdOp, mlir::lmhlo::ReplicaIdOp>(op)) {
325     return xla::HloOpcode::kReplicaId;
326   } else if (isa<mlir::mhlo::AfterAllOp>(op)) {
327     return xla::HloOpcode::kAfterAll;
328   } else if (isa<mlir::mhlo::AllReduceOp, mlir::lmhlo::AllReduceOp>(op)) {
329     return xla::HloOpcode::kAllReduce;
330   } else if (isa<mlir::mhlo::AllToAllOp>(op)) {
331     return xla::HloOpcode::kAllToAll;
332   } else if (isa<mlir::mhlo::TupleOp>(op)) {
333     return xla::HloOpcode::kTuple;
334   } else if (isa<mlir::mhlo::BatchNormGradOp, mlir::lmhlo::BatchNormGradOp>(
335                  op)) {
336     return xla::HloOpcode::kBatchNormGrad;
337   } else if (isa<mlir::mhlo::BatchNormInferenceOp,
338                  mlir::lmhlo::BatchNormInferenceOp>(op)) {
339     return xla::HloOpcode::kBatchNormInference;
340   } else if (isa<mlir::mhlo::BatchNormTrainingOp,
341                  mlir::lmhlo::BatchNormTrainingOp>(op)) {
342     return xla::HloOpcode::kBatchNormTraining;
343   } else if (isa<mlir::mhlo::BitcastConvertOp, mlir::lmhlo::BitcastConvertOp>(
344                  op)) {
345     return xla::HloOpcode::kBitcastConvert;
346   } else if (isa<mlir::mhlo::BroadcastOp, mlir::lmhlo::BroadcastOp>(op)) {
347     return xla::HloOpcode::kBroadcast;
348   } else if (isa<mlir::mhlo::CholeskyOp, mlir::lmhlo::CholeskyOp>(op)) {
349     return xla::HloOpcode::kCholesky;
350   } else if (isa<mlir::mhlo::ClampOp, mlir::lmhlo::ClampOp>(op)) {
351     return xla::HloOpcode::kClamp;
352   } else if (isa<mlir::mhlo::ConcatenateOp, mlir::lmhlo::ConcatenateOp>(op)) {
353     return xla::HloOpcode::kConcatenate;
354   } else if (isa<mlir::mhlo::ConvolutionOp, mlir::lmhlo::ConvolutionOp>(op)) {
355     return xla::HloOpcode::kConvolution;
356   } else if (isa<mlir::mhlo::SortOp, mlir::lmhlo::SortOp>(op)) {
357     return xla::HloOpcode::kSort;
358   } else if (isa<mlir::mhlo::RngBitGeneratorOp>(op)) {
359     return xla::HloOpcode::kRngBitGenerator;
360   } else if (isa<mlir::mhlo::XlaRngGetAndUpdateStateOp>(op)) {
361     return xla::HloOpcode::kRngGetAndUpdateState;
362   } else if (isa<mlir::mhlo::FusionOp, mlir::lmhlo::FusionOp>(op)) {
363     return xla::HloOpcode::kFusion;
364   } else if (isa<mlir::mhlo::BitcastOp>(op)) {
365     return xla::HloOpcode::kBitcast;
366   } else if (isa<mlir::mhlo::AbsOp, mlir::lmhlo::AbsOp>(op)) {
367     return xla::HloOpcode::kAbs;
368   } else if (isa<mlir::mhlo::CbrtOp, mlir::lmhlo::CbrtOp>(op)) {
369     return xla::HloOpcode::kCbrt;
370   } else if (isa<mlir::mhlo::CeilOp, mlir::lmhlo::CeilOp>(op)) {
371     return xla::HloOpcode::kCeil;
372   } else if (isa<mlir::mhlo::ClzOp, mlir::lmhlo::ClzOp>(op)) {
373     return xla::HloOpcode::kClz;
374   } else if (isa<mlir::mhlo::CosineOp, mlir::lmhlo::CosineOp>(op)) {
375     return xla::HloOpcode::kCos;
376   } else if (isa<mlir::mhlo::ExpOp, mlir::lmhlo::ExpOp>(op)) {
377     return xla::HloOpcode::kExp;
378   } else if (isa<mlir::mhlo::Expm1Op, mlir::lmhlo::Expm1Op>(op)) {
379     return xla::HloOpcode::kExpm1;
380   } else if (isa<mlir::mhlo::FloorOp, mlir::lmhlo::FloorOp>(op)) {
381     return xla::HloOpcode::kFloor;
382   } else if (isa<mlir::mhlo::ImagOp, mlir::lmhlo::ImagOp>(op)) {
383     return xla::HloOpcode::kImag;
384   } else if (isa<mlir::mhlo::IsFiniteOp, mlir::lmhlo::IsFiniteOp>(op)) {
385     return xla::HloOpcode::kIsFinite;
386   } else if (isa<mlir::mhlo::LogOp, mlir::lmhlo::LogOp>(op)) {
387     return xla::HloOpcode::kLog;
388   } else if (isa<mlir::mhlo::Log1pOp, mlir::lmhlo::Log1pOp>(op)) {
389     return xla::HloOpcode::kLog1p;
390   } else if (isa<mlir::mhlo::LogisticOp>(op)) {
391     return xla::HloOpcode::kLogistic;
392   } else if (isa<mlir::mhlo::NotOp, mlir::lmhlo::NotOp>(op)) {
393     return xla::HloOpcode::kNot;
394   } else if (isa<mlir::mhlo::NegOp, mlir::lmhlo::NegOp>(op)) {
395     return xla::HloOpcode::kNegate;
396   } else if (isa<mlir::mhlo::PopulationCountOp, mlir::lmhlo::PopulationCountOp>(
397                  op)) {
398     return xla::HloOpcode::kPopulationCount;
399   } else if (isa<mlir::mhlo::RealOp, mlir::lmhlo::RealOp>(op)) {
400     return xla::HloOpcode::kReal;
401   } else if (isa<mlir::mhlo::RoundOp, mlir::lmhlo::RoundOp>(op)) {
402     return xla::HloOpcode::kRoundNearestAfz;
403   } else if (isa<mlir::mhlo::RoundNearestEvenOp,
404                  mlir::lmhlo::RoundNearestEvenOp>(op)) {
405     return xla::HloOpcode::kRoundNearestEven;
406   } else if (isa<mlir::mhlo::RsqrtOp, mlir::lmhlo::RsqrtOp>(op)) {
407     return xla::HloOpcode::kRsqrt;
408   } else if (isa<mlir::mhlo::SignOp, mlir::lmhlo::SignOp>(op)) {
409     return xla::HloOpcode::kSign;
410   } else if (isa<mlir::mhlo::SineOp, mlir::lmhlo::SineOp>(op)) {
411     return xla::HloOpcode::kSin;
412   } else if (isa<mlir::mhlo::SqrtOp, mlir::lmhlo::SqrtOp>(op)) {
413     return xla::HloOpcode::kSqrt;
414   } else if (isa<mlir::mhlo::TanhOp, mlir::lmhlo::TanhOp>(op)) {
415     return xla::HloOpcode::kTanh;
416   } else if (isa<mlir::mhlo::ComplexOp, mlir::lmhlo::ComplexOp>(op)) {
417     return xla::HloOpcode::kComplex;
418   } else if (isa<mlir::mhlo::AndOp, mlir::lmhlo::AndOp>(op)) {
419     return xla::HloOpcode::kAnd;
420   } else if (isa<mlir::mhlo::OrOp, mlir::lmhlo::OrOp>(op)) {
421     return xla::HloOpcode::kOr;
422   } else if (isa<mlir::mhlo::WhileOp, mlir::lmhlo::WhileOp>(op)) {
423     return xla::HloOpcode::kWhile;
424   } else if (isa<mlir::mhlo::ReduceOp, mlir::lmhlo::ReduceOp>(op)) {
425     return xla::HloOpcode::kReduce;
426   } else if (isa<mlir::mhlo::GetTupleElementOp>(op)) {
427     return xla::HloOpcode::kGetTupleElement;
428   } else if (isa<mlir::mhlo::CompareOp, mlir::lmhlo::CompareOp>(op)) {
429     return xla::HloOpcode::kCompare;
430   } else if (isa<mlir::mhlo::SliceOp, mlir::lmhlo::SliceOp>(op)) {
431     return xla::HloOpcode::kSlice;
432   } else if (isa<mlir::mhlo::DynamicSliceOp, mlir::lmhlo::DynamicSliceOp>(op)) {
433     return xla::HloOpcode::kDynamicSlice;
434   } else if (isa<mlir::mhlo::DynamicUpdateSliceOp,
435                  mlir::lmhlo::DynamicUpdateSliceOp>(op)) {
436     return xla::HloOpcode::kDynamicUpdateSlice;
437   } else if (isa<mlir::mhlo::CollectivePermuteOp,
438                  mlir::lmhlo::CollectivePermuteOp>(op)) {
439     return xla::HloOpcode::kCollectivePermute;
440   } else if (isa<mlir::mhlo::CopyOp, mlir::lmhlo::CopyOp>(op)) {
441     return xla::HloOpcode::kCopy;
442   } else if (isa<mlir::mhlo::CustomCallOp, mlir::lmhlo::CustomCallOp>(op)) {
443     return xla::HloOpcode::kCustomCall;
444   } else if (isa<mlir::mhlo::DotOp, mlir::lmhlo::DotOp>(op)) {
445     return xla::HloOpcode::kDot;
446   } else if (isa<mlir::mhlo::FftOp, mlir::lmhlo::FftOp>(op)) {
447     return xla::HloOpcode::kFft;
448   } else if (isa<mlir::mhlo::GatherOp, mlir::lmhlo::GatherOp>(op)) {
449     return xla::HloOpcode::kGather;
450   } else if (isa<mlir::mhlo::GetDimensionSizeOp>(op)) {
451     return xla::HloOpcode::kGetDimensionSize;
452   } else if (isa<mlir::mhlo::MapOp, mlir::lmhlo::MapOp>(op)) {
453     return xla::HloOpcode::kMap;
454   } else if (isa<mlir::mhlo::ReshapeOp, mlir::lmhlo::ReshapeOp>(op)) {
455     return xla::HloOpcode::kReshape;
456   } else if (isa<mlir::mhlo::DynamicReshapeOp>(op)) {
457     return xla::HloOpcode::kDynamicReshape;
458   } else if (isa<mlir::mhlo::ScatterOp, mlir::lmhlo::ScatterOp>(op)) {
459     return xla::HloOpcode::kScatter;
460   } else if (isa<mlir::mhlo::SelectOp, mlir::lmhlo::SelectOp>(op)) {
461     return xla::HloOpcode::kSelect;
462   } else if (isa<mlir::mhlo::SelectAndScatterOp,
463                  mlir::lmhlo::SelectAndScatterOp>(op)) {
464     return xla::HloOpcode::kSelectAndScatter;
465   } else if (isa<mlir::mhlo::SetDimensionSizeOp>(op)) {
466     return xla::HloOpcode::kSetDimensionSize;
467   } else if (isa<mlir::mhlo::ReverseOp, mlir::lmhlo::ReverseOp>(op)) {
468     return xla::HloOpcode::kReverse;
469   } else if (isa<mlir::mhlo::PadOp, mlir::lmhlo::PadOp>(op)) {
470     return xla::HloOpcode::kPad;
471   } else if (isa<mlir::mhlo::TransposeOp, mlir::lmhlo::TransposeOp>(op)) {
472     return xla::HloOpcode::kTranspose;
473   } else if (isa<mlir::mhlo::TriangularSolveOp, mlir::lmhlo::TriangularSolveOp>(
474                  op)) {
475     return xla::HloOpcode::kTriangularSolve;
476   } else if (isa<mlir::mhlo::ReduceWindowOp, mlir::lmhlo::ReduceWindowOp>(op)) {
477     return xla::HloOpcode::kReduceWindow;
478   } else if (isa<mlir::mhlo::ReducePrecisionOp, mlir::lmhlo::ReducePrecisionOp>(
479                  op)) {
480     return xla::HloOpcode::kReducePrecision;
481   } else if (isa<mlir::mhlo::DotGeneralOp>(op)) {
482     return xla::HloOpcode::kDot;
483   } else if (isa<mlir::mhlo::BroadcastInDimOp, mlir::lmhlo::BroadcastInDimOp>(
484                  op)) {
485     return xla::HloOpcode::kBroadcast;
486   } else {
487     std::string s;
488     {
489       llvm::raw_string_ostream os(s);
490       op->print(os);
491     }
492     return tensorflow::errors::Unimplemented(
493         "Unimplemented MHLO -> HloOpcode: ", s);
494   }
495 }
496 
497 }  // namespace xla
498