xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/hlo_function_importer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
17 
18 #include <unordered_map>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/types/optional.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Location.h"  // from @llvm-project
32 #include "mlir/IR/Region.h"  // from @llvm-project
33 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
34 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
35 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
36 #include "tensorflow/compiler/xla/comparison_util.h"
37 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
38 #include "tensorflow/compiler/xla/protobuf_util.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_computation.h"
41 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
42 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
43 #include "tensorflow/compiler/xla/service/hlo_module.h"
44 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
45 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/xla_data.pb.h"
48 #include "tensorflow/core/platform/statusor.h"
49 
50 using llvm::APInt;
51 using llvm::makeArrayRef;
52 using mlir::DenseIntElementsAttr;
53 using mlir::NamedAttribute;
54 using mlir::Operation;
55 using mlir::RankedTensorType;
56 using mlir::Type;
57 using mlir::Value;
58 using mlir::func::FuncOp;
59 
60 namespace xla {
61 
62 namespace {
63 
64 constexpr char kShardingAttr[] = "mhlo.sharding";
65 
66 // Note: This sanitization function causes an irreversible many-to-one mapping
67 // and any solution to mitigate this would cause issues with the reverse
68 // direction. Longterm solution is to add a function attribute to maintain the
69 // original HLO naming.
SanitizeFunctionName(llvm::StringRef name)70 std::string SanitizeFunctionName(llvm::StringRef name) {
71   std::string output(name);
72   llvm::for_each(output, [](char& x) { x = x == '-' ? '_' : x; });
73   return output;
74 }
75 
76 // Returns whether the instruction is a default dot operation.
DotIsDefault(const HloInstruction * instruction)77 bool DotIsDefault(const HloInstruction* instruction) {
78   const auto& operands = instruction->operands();
79   // eg. vector[3] dot matrix[3, 2] => [2] not default dot
80   if (operands[0]->shape().rank() < operands[1]->shape().rank()) {
81     return false;
82   }
83   auto dnums = instruction->dot_dimension_numbers();
84   DotDimensionNumbers default_dimension_numbers;
85   default_dimension_numbers.add_lhs_contracting_dimensions(
86       instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1);
87   default_dimension_numbers.add_rhs_contracting_dimensions(0);
88   return xla::protobuf_util::ProtobufEquals(dnums, default_dimension_numbers);
89 }
90 
91 // Returns an MLIR Location generated from HLO Instruction. Uses instruction
92 // metadata if present or instruction name.
GenerateInstructionLocation(const HloInstruction * instruction,mlir::OpBuilder * func_builder)93 mlir::Location GenerateInstructionLocation(const HloInstruction* instruction,
94                                            mlir::OpBuilder* func_builder) {
95   const std::string& op_name = instruction->metadata().op_name();
96   if (op_name.empty()) {
97     return mlir::NameLoc::get(func_builder->getStringAttr(instruction->name()));
98   }
99 
100   mlir::Location op_name_loc =
101       mlir::NameLoc::get(func_builder->getStringAttr(op_name));
102   const std::string& source_file = instruction->metadata().source_file();
103   if (source_file.empty()) {
104     return op_name_loc;
105   }
106 
107   return func_builder->getFusedLoc(
108       {op_name_loc,
109        mlir::FileLineColLoc::get(func_builder->getContext(), source_file,
110                                  instruction->metadata().source_line(), 0)});
111 }
112 
113 // Clean up the GetTupleElementOp, created during the flattening of
114 // tuple arguments and return values, if eligible for folding. Removal of
115 // get-tuple-element can transitively make the defining TupleOp dead to be
116 // removed subsequently.
CleanUpTupleOps(mlir::Block * block,mlir::OpBuilder * builder)117 void CleanUpTupleOps(mlir::Block* block, mlir::OpBuilder* builder) {
118   bool changed = true;
119   llvm::SmallVector<Value> folded_results;
120 
121   while (changed) {
122     changed = false;
123     for (Operation& op : llvm::make_early_inc_range(block->getOperations())) {
124       if (llvm::isa<mlir::mhlo::GetTupleElementOp>(op)) {
125         folded_results.clear();
126         if (failed(builder->tryFold(&op, folded_results))) continue;
127         op.replaceAllUsesWith(folded_results);
128         op.erase();
129         changed = true;
130       } else if (llvm::isa<mlir::mhlo::TupleOp>(op) &&
131                  mlir::isOpTriviallyDead(&op)) {
132         op.erase();
133         changed = true;
134       }
135     }
136   }
137 }
138 
139 }  // namespace
140 
ReplaceBlockArgumentsWithImplicitOperands(mlir::Operation * op,llvm::ArrayRef<mlir::Value> implicit_operands)141 void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands(
142     mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands) {
143   assert((mlir::dyn_cast<mlir::mhlo::IfOp>(*op) ||
144           mlir::dyn_cast<mlir::mhlo::CaseOp>(*op)) &&
145          "Unexpected mlir op in "
146          "HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands!");
147 
148   int implicit_operand_index = 0;
149   for (auto& region : op->getRegions()) {
150     for (auto arg : region.getArguments()) {
151       assert(implicit_operand_index < implicit_operands.size());
152       arg.replaceAllUsesWith(implicit_operands[implicit_operand_index++]);
153     }
154     region.front().eraseArguments(
155         llvm::to_vector(llvm::seq<unsigned>(0, region.getNumArguments())));
156   }
157 }
158 
CreateTupleFromOpResults(mlir::OpBuilder * func_builder,mlir::Location loc,mlir::Operation * op,mlir::Type type)159 mlir::Operation* HloFunctionImporter::CreateTupleFromOpResults(
160     mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Operation* op,
161     mlir::Type type) {
162   if (!type.isa<mlir::TupleType>()) return op;
163 
164   llvm::SmallVector<Value> flattened_results = op->getResults();
165   llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
166   auto result =
167       CreateTupleValue(func_builder, loc, flattened_results_ref, type);
168   auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
169   assert(defining_tuple_op && "builder didn't return the right type");
170   auto tupleOp = defining_tuple_op.getOperation();
171   return tupleOp;
172 }
173 
IsNestedTupleInData(Type type)174 static bool IsNestedTupleInData(Type type) {
175   auto tuple_type = type.dyn_cast<mlir::TupleType>();
176   if (!tuple_type) return false;
177 
178   assert(tuple_type.getType(1).isa<mlir::mhlo::TokenType>() &&
179          "Infeed: Non token type");
180   auto data_type = tuple_type.getType(0);
181 
182   auto data_tuple_type = data_type.dyn_cast<mlir::TupleType>();
183   if (!data_tuple_type) return false;
184 
185   for (auto child_type : data_tuple_type.getTypes()) {
186     if (child_type.isa<mlir::TupleType>()) return true;
187   }
188 
189   return false;
190 }
191 
FlattenTupleType(Type type,llvm::SmallVectorImpl<Type> & flattened_types)192 void HloFunctionImporter::FlattenTupleType(
193     Type type, llvm::SmallVectorImpl<Type>& flattened_types) {
194   auto tuple_type = type.dyn_cast<mlir::TupleType>();
195   if (!tuple_type) {
196     flattened_types.push_back(type);
197     return;
198   }
199 
200   for (auto child_type : tuple_type.getTypes()) {
201     FlattenTupleType(child_type, flattened_types);
202   }
203 }
204 
FlattenTupleValue(mlir::OpBuilder * func_builder,mlir::Location loc,Value value,llvm::SmallVectorImpl<Value> & flattened_values)205 void HloFunctionImporter::FlattenTupleValue(
206     mlir::OpBuilder* func_builder, mlir::Location loc, Value value,
207     llvm::SmallVectorImpl<Value>& flattened_values) {
208   auto tuple_type = value.getType().dyn_cast<mlir::TupleType>();
209   if (!tuple_type) {
210     flattened_values.push_back(value);
211     return;
212   }
213 
214   int flattenIdx = 0;
215   for (auto child_type : tuple_type.getTypes()) {
216     auto sub_value = func_builder->create<mlir::mhlo::GetTupleElementOp>(
217         loc, child_type, value, func_builder->getI32IntegerAttr(flattenIdx++));
218     FlattenTupleValue(func_builder, loc, sub_value, flattened_values);
219   }
220 }
221 
CreateTupleValue(mlir::OpBuilder * func_builder,mlir::Location loc,llvm::MutableArrayRef<Value> & flatten_values,Type type)222 Value HloFunctionImporter::CreateTupleValue(
223     mlir::OpBuilder* func_builder, mlir::Location loc,
224     llvm::MutableArrayRef<Value>& flatten_values, Type type) {
225   auto tuple_type = type.dyn_cast<mlir::TupleType>();
226   if (!tuple_type) {
227     assert(!flatten_values.empty());
228     auto retval = flatten_values.front();
229     flatten_values = flatten_values.drop_front();
230     return retval;
231   }
232 
233   llvm::SmallVector<mlir::Value> flatten_sub_values;
234   for (auto child_type : tuple_type.getTypes())
235     flatten_sub_values.push_back(
236         CreateTupleValue(func_builder, loc, flatten_values, child_type));
237 
238   return func_builder->create<mlir::mhlo::TupleOp>(loc, flatten_sub_values)
239       .getResult();
240 }
241 
ImportAsFunc(const HloComputation & computation,mlir::ModuleOp module,std::unordered_map<const HloComputation *,FuncOp> * function_map,mlir::Builder * builder,bool is_main)242 Status HloFunctionImporter::ImportAsFunc(
243     const HloComputation& computation, mlir::ModuleOp module,
244     std::unordered_map<const HloComputation*, FuncOp>* function_map,
245     mlir::Builder* builder, bool is_main) {
246   HloFunctionImporter importer(module, function_map, builder);
247   return importer.ImportAsFunc(computation, is_main).status();
248 }
249 
ImportAsRegion(const xla::HloComputation & computation,mlir::Region * region,mlir::Builder * builder,bool flatten_region_arg_tuple)250 Status HloFunctionImporter::ImportAsRegion(
251     const xla::HloComputation& computation, mlir::Region* region,
252     mlir::Builder* builder, bool flatten_region_arg_tuple) {
253   HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
254                                builder);
255   return importer.ImportAsRegion(computation, region, flatten_region_arg_tuple);
256 }
257 
ImportAsFunc(const HloComputation & computation,bool is_main)258 StatusOr<FuncOp> HloFunctionImporter::ImportAsFunc(
259     const HloComputation& computation, bool is_main) {
260   std::string computation_name =
261       is_main ? "main" : SanitizeFunctionName(computation.name());
262 
263   FuncOp* imported(nullptr);
264   if (function_map_) {
265     imported = &((*function_map_)[&computation]);
266     if (*imported) {
267       return *imported;
268     }
269   } else {
270     TF_RET_CHECK(!module_.lookupSymbol<FuncOp>(computation_name))
271         << "Attempting to redeclare an existing function named "
272         << computation.name();
273   }
274   llvm::SmallVector<Type, 4> args, rets;
275   TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
276   TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
277   auto func_type = mlir::FunctionType::get(context_, args, rets);
278 
279   // Construct the MLIR function and map arguments.
280   llvm::ArrayRef<mlir::NamedAttribute> attrs;
281   auto function = FuncOp::create(mlir::UnknownLoc::get(context_),
282                                  computation_name, func_type, attrs);
283   auto visibility = computation_name == "main" ? FuncOp::Visibility::Public
284                                                : FuncOp::Visibility::Private;
285   function.setVisibility(visibility);
286 
287   for (auto& entry : llvm::enumerate(computation.parameter_instructions())) {
288     HloInstruction* parameter = entry.value();
289     if (parameter->has_sharding()) {
290       function.setArgAttr(
291           entry.index(), kShardingAttr,
292           builder_->getStringAttr(
293               parameter->sharding().ToProto().SerializeAsString()));
294     }
295   }
296   if (computation.root_instruction()->has_sharding()) {
297     auto result = computation.root_instruction();
298     if (function.getNumResults() != 1) {
299       return tensorflow::errors::Internal(absl::StrCat(
300           "Expected only a single result but got ", function.getNumResults()));
301     }
302     function.setResultAttr(
303         0, kShardingAttr,
304         builder_->getStringAttr(
305             result->sharding().ToProto().SerializeAsString()));
306   }
307 
308   module_.push_back(function);
309 
310   // Add to the map right away for function calls if map is set.
311   if (imported) {
312     *imported = function;
313   }
314 
315   mlir::Block* block = function.addEntryBlock();
316   TF_RETURN_IF_ERROR(ImportInstructions(computation, block,
317                                         /*flatten_region_arg_tuple=*/false));
318 
319   return function;
320 }
321 
ImportAsRegion(const HloComputation & computation,mlir::Region * region,bool flatten_region_arg_tuple)322 tensorflow::Status HloFunctionImporter::ImportAsRegion(
323     const HloComputation& computation, mlir::Region* region,
324     bool flatten_region_arg_tuple) {
325   auto loc = region->getLoc();
326   // TODO(hinsu): Store computation name as an attribute for round-trip.
327   auto* block = new mlir::Block;
328   region->push_back(block);
329 
330   llvm::SmallVector<Type, 4> args;
331   TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
332 
333   // Flatten the tuple-typed arguments.
334   if (flatten_region_arg_tuple) {
335     for (auto arg : args) {
336       llvm::SmallVector<Type> flattened_arg_types;
337       FlattenTupleType(arg, flattened_arg_types);
338       block->addArguments(
339           flattened_arg_types,
340           mlir::SmallVector<mlir::Location>(flattened_arg_types.size(), loc));
341     }
342   } else {
343     block->addArguments(args,
344                         mlir::SmallVector<mlir::Location>(args.size(), loc));
345   }
346 
347   return ImportInstructions(computation, block, flatten_region_arg_tuple);
348 }
349 
ImportInstructionsImpl(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)350 StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
351     const xla::HloComputation& computation,
352     const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
353   // Setup the input parameters.
354   const int num_parameters = computation.num_parameters();
355 
356   for (int i = 0; i < num_parameters; i++) {
357     auto hlo_parameter = computation.parameter_instruction(i);
358     instruction_value_map_[hlo_parameter] = arguments[i];
359   }
360 
361   for (auto instruction : computation.MakeInstructionPostOrder()) {
362     TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
363     TF_ASSIGN_OR_RETURN(
364         auto new_operation,
365         ImportInstructionWithLayout(instruction, operands, builder));
366     if (new_operation) {
367       instruction_value_map_[instruction] = new_operation->getResult(0);
368     }
369   }
370 
371   // Setup the return type (HLO only supports a single return value).
372   return GetMlirValue(computation.root_instruction());
373 }
374 
ImportInstructions(const HloComputation & computation,mlir::Block * block,bool flatten_region_arg_tuple)375 Status HloFunctionImporter::ImportInstructions(
376     const HloComputation& computation, mlir::Block* block,
377     bool flatten_region_arg_tuple) {
378   llvm::SmallVector<Value, 4> arguments(block->args_begin(), block->args_end());
379   mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
380 
381   // TODO(suderman): Add location tracking details.
382   mlir::Location loc = builder.getUnknownLoc();
383 
384   Value result;
385   if (!llvm::isa<FuncOp>(block->getParentOp()) && flatten_region_arg_tuple) {
386     // 'effective_arguments' stores the mhlo value corresponding to each
387     // computation parameter. The value could be a BlockArgument, if the
388     // corresponding computation parameter is non-tuple typed, or a TupleOp,
389     // otherwise.
390     llvm::SmallVector<Value> effective_arguments;
391 
392     llvm::SmallVector<Type> computation_arg_types;
393     TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(),
394                                     &computation_arg_types));
395     int flatten_idx = 0;
396     for (Type computation_arg_type : computation_arg_types) {
397       auto orig_tuple_arg_type =
398           computation_arg_type.dyn_cast<mlir::TupleType>();
399 
400       // If the computation-parameter type is non-tuple, no action is needed.
401       if (!orig_tuple_arg_type) {
402         effective_arguments.push_back(arguments[flatten_idx]);
403         flatten_idx++;
404         continue;
405       }
406 
407       // For each tuple-typed computation parameter, create a mhlo::TupleOp
408       // value in the region body, using the already flattened values in
409       // 'arguments'. For example: With computation parameters: [tuple<T1>,
410       // tuple<T2, T4>] We have, 'arguments' = [T1 arg1, T2 arg2, T3 arg3] and
411       // we need to create two tuples tuples, one using arg1, and the other
412       // using arg2 and arg3.
413       llvm::SmallVector<Type> flattened_arg_type;
414       FlattenTupleType(orig_tuple_arg_type, flattened_arg_type);
415 
416       llvm::MutableArrayRef<Value> sub_args(
417           arguments.begin() + flatten_idx,
418           arguments.begin() + flatten_idx + flattened_arg_type.size());
419 
420       auto tupleVal =
421           CreateTupleValue(&builder, loc, sub_args, orig_tuple_arg_type);
422       effective_arguments.push_back(tupleVal);
423 
424       flatten_idx += flattened_arg_type.size();
425     }
426 
427     TF_ASSIGN_OR_RETURN(
428         result,
429         ImportInstructionsImpl(computation, effective_arguments, &builder));
430   } else {
431     TF_ASSIGN_OR_RETURN(
432         result, ImportInstructionsImpl(computation, arguments, &builder));
433   }
434 
435   // Create terminator op depending on the parent op of this region.
436   if (llvm::isa<FuncOp>(block->getParentOp())) {
437     builder.create<mlir::func::ReturnOp>(loc, result);
438   } else {
439     if (flatten_region_arg_tuple) {
440       // Flatten tuples in results of this region.
441       llvm::SmallVector<Value> flattened_return_operands;
442       FlattenTupleValue(&builder, loc, result, flattened_return_operands);
443       builder.create<mlir::mhlo::ReturnOp>(loc, flattened_return_operands);
444     } else {
445       builder.create<mlir::mhlo::ReturnOp>(loc, result);
446     }
447   }
448 
449   CleanUpTupleOps(block, &builder);
450 
451   return ::tensorflow::OkStatus();
452 }
453 
ImportInstructions(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)454 StatusOr<Value> HloFunctionImporter::ImportInstructions(
455     const xla::HloComputation& computation,
456     const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
457   mlir::Block* block = builder->getBlock();
458   if (block == nullptr)
459     return InvalidArgument(
460         "ImportInstructions requires a valid block in the builder");
461 
462   HloFunctionImporter importer(
463       block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
464   return importer.ImportInstructionsImpl(computation, arguments, builder);
465 }
466 
ImportInstruction(const xla::HloInstruction * instr,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * builder,DynamicShapeHandlingMode mode)467 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
468     const xla::HloInstruction* instr,
469     const llvm::SmallVectorImpl<mlir::Value>& operands,
470     mlir::OpBuilder* builder, DynamicShapeHandlingMode mode) {
471   mlir::Block* block = builder->getBlock();
472   if (block == nullptr)
473     return InvalidArgument(
474         "ImportInstructions requires a valid block in the builder");
475 
476   HloFunctionImporter importer(
477       block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
478 
479   return importer.ImportInstructionWithLayout(instr, operands, builder, mode);
480 }
481 
ImportInstructionImpl(const HloInstruction * instruction,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * func_builder,DynamicShapeHandlingMode mode)482 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
483     const HloInstruction* instruction,
484     const llvm::SmallVectorImpl<mlir::Value>& operands,
485     mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) {
486   const Shape& instruction_shape = instruction->shape();
487   const Shape& shape = mode == DynamicShapeHandlingMode::kConvertToStatic
488                            ? xla::ShapeUtil::MakeStaticShape(instruction_shape)
489                            : instruction_shape;
490   TF_ASSIGN_OR_RETURN(auto result_type,
491                       ConvertShapeToType<RankedTensorType>(shape, *builder_));
492   mlir::Location loc = GenerateInstructionLocation(instruction, func_builder);
493 
494   llvm::SmallVector<NamedAttribute, 10> attributes;
495   if (instruction->has_sharding()) {
496     attributes.push_back(builder_->getNamedAttr(
497         kShardingAttr,
498         builder_->getStringAttr(
499             instruction->sharding().ToProto().SerializeAsString())));
500   }
501 
502   switch (instruction->opcode()) {
503     case HloOpcode::kParameter: {
504       return nullptr;
505     }
506     case HloOpcode::kConstant: {
507       const Literal& literal = instruction->literal();
508       auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
509       if (!attr.ok()) return attr.status();
510       mlir::Operation* new_operation =
511           func_builder->create<mlir::mhlo::ConstantOp>(loc, attr.ValueOrDie());
512       for (auto attr : attributes) {
513         new_operation->setAttr(attr.getName(), attr.getValue());
514       }
515       return new_operation;
516     }
517     case HloOpcode::kIota: {
518       return func_builder
519           ->create<mlir::mhlo::IotaOp>(
520               loc, result_type,
521               func_builder->getI64IntegerAttr(
522                   Cast<HloIotaInstruction>(instruction)->iota_dimension()))
523           .getOperation();
524     }
525     case HloOpcode::kBroadcast: {
526       // Note that the HLO broadcast is more powerful than the XLA broadcast
527       // op. BroadcastInDim offers a superset of the HLO op's functionality.
528       attributes.push_back(
529           builder_->getNamedAttr("broadcast_dimensions",
530                                  ConvertDimensions(instruction->dimensions())));
531       return func_builder
532           ->create<mlir::mhlo::BroadcastInDimOp>(loc, result_type, operands,
533                                                  attributes)
534           .getOperation();
535     }
536 
537     case HloOpcode::kBatchNormGrad:
538     case HloOpcode::kBatchNormInference:
539     case HloOpcode::kBatchNormTraining:
540       attributes.push_back(builder_->getNamedAttr(
541           "epsilon", builder_->getF32FloatAttr(instruction->epsilon())));
542       attributes.push_back(builder_->getNamedAttr(
543           "feature_index",
544           builder_->getI64IntegerAttr(instruction->feature_index())));
545       if (instruction->opcode() == HloOpcode::kBatchNormGrad) {
546         // Flatten the return type if they are tuple-typed.
547         llvm::SmallVector<Type> flattened_ret_types;
548         FlattenTupleType(result_type, flattened_ret_types);
549 
550         auto op = func_builder
551                       ->create<mlir::mhlo::BatchNormGradOp>(
552                           loc, flattened_ret_types, operands, attributes)
553                       .getOperation();
554 
555         return CreateTupleFromOpResults(func_builder, loc, op, result_type);
556       } else if (instruction->opcode() == HloOpcode::kBatchNormInference) {
557         return func_builder
558             ->create<mlir::mhlo::BatchNormInferenceOp>(loc, result_type,
559                                                        operands, attributes)
560             .getOperation();
561       } else {
562         assert(instruction->opcode() == HloOpcode::kBatchNormTraining);
563 
564         // Flatten the return type if they are tuple-typed.
565         llvm::SmallVector<Type> flattened_ret_types;
566         FlattenTupleType(result_type, flattened_ret_types);
567 
568         auto op = func_builder
569                       ->create<mlir::mhlo::BatchNormTrainingOp>(
570                           loc, flattened_ret_types, operands, attributes)
571                       .getOperation();
572 
573         return CreateTupleFromOpResults(func_builder, loc, op, result_type);
574       }
575 
576     case HloOpcode::kDot: {
577       attributes.push_back(builder_->getNamedAttr(
578           "precision_config",
579           ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
580 
581       // Consider consolidating DotOps together.
582       if (DotIsDefault(instruction)) {
583         return func_builder
584             ->create<mlir::mhlo::DotOp>(loc, result_type, operands, attributes)
585             .getOperation();
586       }
587 
588       attributes.push_back(builder_->getNamedAttr(
589           "dot_dimension_numbers",
590           ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
591                                      builder_)));
592       return func_builder
593           ->create<mlir::mhlo::DotGeneralOp>(loc, result_type, operands,
594                                              attributes)
595           .getOperation();
596     }
597     case HloOpcode::kCall: {
598       TF_ASSIGN_OR_RETURN(
599           FuncOp function,
600           ImportAsFunc(*instruction->to_apply(), /*is_main=*/false));
601       mlir::Operation* new_operation =
602           func_builder->create<mlir::func::CallOp>(loc, function, operands);
603       return new_operation;
604     }
605     case HloOpcode::kCollectivePermute: {
606       attributes.push_back(ConvertSourceTargetPairs(
607           instruction->source_target_pairs(), builder_));
608       return func_builder
609           ->create<mlir::mhlo::CollectivePermuteOp>(loc, result_type, operands,
610                                                     attributes)
611           .getOperation();
612     }
613     case HloOpcode::kCustomCall: {
614       auto custom_call = Cast<HloCustomCallInstruction>(instruction);
615       const auto& called_computations = custom_call->called_computations();
616       if (!called_computations.empty()) {
617         llvm::SmallVector<mlir::Attribute> callees;
618         callees.reserve(called_computations.size());
619         for (HloComputation* callee : called_computations) {
620           TF_ASSIGN_OR_RETURN(FuncOp function, ImportAsFunc(*callee,
621                                                             /*is_main=*/false));
622           callees.push_back(mlir::FlatSymbolRefAttr::get(builder_->getContext(),
623                                                          function.getName()));
624         }
625         attributes.push_back(builder_->getNamedAttr(
626             "called_computations",
627             mlir::ArrayAttr::get(builder_->getContext(), callees)));
628       }
629       if (custom_call->layout_constrained()) {
630         TF_ASSIGN_OR_RETURN(
631             mlir::ArrayAttr operand_layouts,
632             ExtractLayoutsFromShapes(custom_call->operand_shapes_with_layout(),
633                                      builder_));
634         attributes.push_back(
635             builder_->getNamedAttr("operand_layouts", operand_layouts));
636         mlir::ArrayAttr result_layouts;
637         if (custom_call->shape().IsTuple()) {
638           TF_ASSIGN_OR_RETURN(
639               result_layouts,
640               ExtractLayoutsFromTuple(custom_call->shape(), builder_));
641         } else {
642           TF_ASSIGN_OR_RETURN(
643               result_layouts,
644               ExtractLayoutsFromShapes({custom_call->shape()}, builder_));
645         }
646         attributes.push_back(
647             builder_->getNamedAttr("result_layouts", result_layouts));
648       }
649 
650       TF_ASSIGN_OR_RETURN(
651           auto mlir_api_version,
652           ConvertCustomCallApiVersion(custom_call->api_version()));
653       attributes.push_back(builder_->getNamedAttr(
654           "call_target_name",
655           builder_->getStringAttr(custom_call->custom_call_target())));
656       attributes.push_back(builder_->getNamedAttr(
657           "has_side_effect",
658           builder_->getBoolAttr(custom_call->custom_call_has_side_effect())));
659       attributes.push_back(builder_->getNamedAttr(
660           "backend_config",
661           builder_->getStringAttr(custom_call->raw_backend_config_string())));
662       attributes.push_back(builder_->getNamedAttr(
663           "api_version", mlir::mhlo::CustomCallApiVersionAttr::get(
664                              builder_->getContext(), mlir_api_version)));
665       return func_builder
666           ->create<mlir::mhlo::CustomCallOp>(loc, result_type, operands,
667                                              attributes)
668           .getOperation();
669     }
670     case HloOpcode::kCompare: {
671       auto compare = Cast<HloCompareInstruction>(instruction);
672       attributes.push_back(ConvertComparisonDirection(compare->direction()));
673       auto default_type = Comparison::DefaultComparisonType(
674           compare->operand(0)->shape().element_type());
675       if (compare->type() != default_type)
676         attributes.push_back(ConvertComparisonType(compare->type()));
677       return func_builder
678           ->create<mlir::mhlo::CompareOp>(loc, result_type, operands,
679                                           attributes)
680           .getOperation();
681     }
682     case HloOpcode::kCholesky: {
683       attributes.push_back(builder_->getNamedAttr(
684           "lower",
685           builder_->getBoolAttr(instruction->cholesky_options().lower())));
686       return func_builder
687           ->create<mlir::mhlo::CholeskyOp>(loc, result_type, operands,
688                                            attributes)
689           .getOperation();
690     }
691     case HloOpcode::kGather: {
692       auto gather_instruction = Cast<HloGatherInstruction>(instruction);
693       attributes.push_back(builder_->getNamedAttr(
694           "dimension_numbers",
695           ConvertGatherDimensionNumbers(
696               gather_instruction->gather_dimension_numbers(), builder_)));
697 
698       std::vector<int64_t> slice_sizes(
699           gather_instruction->gather_slice_sizes().begin(),
700           gather_instruction->gather_slice_sizes().end());
701       attributes.push_back(
702           builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
703       attributes.push_back(builder_->getNamedAttr(
704           "indices_are_sorted",
705           builder_->getBoolAttr(gather_instruction->indices_are_sorted())));
706 
707       return func_builder
708           ->create<mlir::mhlo::GatherOp>(loc, result_type, operands, attributes)
709           .getOperation();
710     }
711     case HloOpcode::kDynamicSlice: {
712       std::vector<int64_t> slice_sizes(
713           instruction->dynamic_slice_sizes().begin(),
714           instruction->dynamic_slice_sizes().end());
715       return func_builder
716           ->create<mlir::mhlo::DynamicSliceOp>(
717               loc, result_type, operands[0],
718               makeArrayRef(operands).drop_front(), Convert(slice_sizes))
719           .getOperation();
720     }
721     case HloOpcode::kDynamicUpdateSlice: {
722       return func_builder
723           ->create<mlir::mhlo::DynamicUpdateSliceOp>(
724               loc, result_type, operands[0], operands[1],
725               llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
726           .getOperation();
727     }
728     case HloOpcode::kInfeed: {
729       if (IsNestedTupleInData(result_type)) {
730         llvm_unreachable(
731             "Importing xla::kInfeed with nested tuple shape not supported");
732       }
733 
734       attributes.push_back(builder_->getNamedAttr(
735           "infeed_config",
736           mlir::StringAttr::get(builder_->getContext(),
737                                 instruction->infeed_config())));
738 
739       llvm::SmallVector<mlir::Attribute> flattened_attr;
740       TF_RETURN_IF_ERROR(
741           ConvertShapeToMlirLayout(instruction->shape(), flattened_attr));
742       attributes.push_back(builder_->getNamedAttr(
743           "layout", builder_->getArrayAttr(makeArrayRef(flattened_attr))));
744 
745       // Flatten the return-type if they are tuple-typed.
746       llvm::SmallVector<Type> flattened_ret_types;
747       FlattenTupleType(result_type, flattened_ret_types);
748 
749       auto op = func_builder->create<mlir::mhlo::InfeedOp>(
750           loc, flattened_ret_types, operands, attributes);
751 
752       return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
753                                       result_type);
754     }
755     case HloOpcode::kOutfeed: {
756       attributes.push_back(builder_->getNamedAttr(
757           "outfeed_config",
758           mlir::StringAttr::get(builder_->getContext(),
759                                 instruction->outfeed_config())));
760 
761       assert(operands.size() == 2 && "Expected 2 operands for HLO Infeed");
762 
763       // In case operands[0] is a tuple, flatten it.
764       llvm::SmallVector<Value> flattened_operands;
765       FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
766       flattened_operands.push_back(operands[1]);
767 
768       auto op = func_builder->create<mlir::mhlo::OutfeedOp>(
769           loc, result_type, flattened_operands, attributes);
770 
771       return op.getOperation();
772     }
773     case HloOpcode::kPad: {
774       const auto& padding_config = instruction->padding_config();
775       llvm::SmallVector<int64_t, 4> edge_padding_low;
776       llvm::SmallVector<int64_t, 4> edge_padding_high;
777       llvm::SmallVector<int64_t, 4> interior_padding;
778       edge_padding_low.reserve(padding_config.dimensions_size());
779       edge_padding_high.reserve(padding_config.dimensions_size());
780       interior_padding.reserve(padding_config.dimensions_size());
781 
782       for (const auto& dimension : padding_config.dimensions()) {
783         edge_padding_low.push_back(dimension.edge_padding_low());
784         edge_padding_high.push_back(dimension.edge_padding_high());
785         interior_padding.push_back(dimension.interior_padding());
786       }
787 
788       return func_builder
789           ->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
790                                       operands[1], Convert(edge_padding_low),
791                                       Convert(edge_padding_high),
792                                       Convert(interior_padding))
793           .getOperation();
794     }
795     case HloOpcode::kScatter: {
796       auto scatter = Cast<HloScatterInstruction>(instruction);
797       attributes.push_back(builder_->getNamedAttr(
798           "scatter_dimension_numbers",
799           ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
800                                          builder_)));
801       attributes.push_back(builder_->getNamedAttr(
802           "indices_are_sorted",
803           builder_->getBoolAttr(scatter->indices_are_sorted())));
804       attributes.push_back(builder_->getNamedAttr(
805           "unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
806 
807       llvm::SmallVector<Type> flattened_types;
808       FlattenTupleType(result_type, flattened_types);
809 
810       auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
811           loc, flattened_types, operands, attributes);
812       TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
813                                         &scatter_op.update_computation(),
814                                         /*flatten_region_arg_tuple=*/true));
815       TF_ASSIGN_OR_RETURN(auto result_type,
816                           ConvertShapeToType<RankedTensorType>(
817                               instruction->shape(), *builder_));
818       return CreateTupleFromOpResults(func_builder, loc,
819                                       scatter_op.getOperation(), result_type);
820     }
821     case HloOpcode::kSelectAndScatter: {
822       auto select_scatter = Cast<HloSelectAndScatterInstruction>(instruction);
823       llvm::SmallVector<int64_t, 4> window_strides, window_dimensions;
824       llvm::SmallVector<int64_t, 8> padding;
825       for (const auto& dim : select_scatter->window().dimensions()) {
826         window_strides.push_back(dim.stride());
827         window_dimensions.push_back(dim.size());
828         padding.push_back(dim.padding_low());
829         padding.push_back(dim.padding_high());
830       }
831       attributes.push_back(
832           builder_->getNamedAttr("window_strides", Convert(window_strides)));
833       attributes.push_back(builder_->getNamedAttr("window_dimensions",
834                                                   Convert(window_dimensions)));
835       attributes.push_back(ConvertPadding(padding));
836       auto select_scatter_op =
837           func_builder->create<mlir::mhlo::SelectAndScatterOp>(
838               loc, result_type, operands, attributes);
839       TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
840                                         &select_scatter_op.select(),
841                                         /*flatten_region_arg_tuple=*/true));
842       TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
843                                         &select_scatter_op.scatter(),
844                                         /*flatten_region_arg_tuple=*/true));
845       return select_scatter_op.getOperation();
846     }
847     case HloOpcode::kSetDimensionSize: {
848       attributes.push_back(builder_->getNamedAttr(
849           "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
850       return func_builder
851           ->create<mlir::mhlo::SetDimensionSizeOp>(loc, result_type, operands,
852                                                    attributes)
853           .getOperation();
854     }
855     case HloOpcode::kSlice: {
856       return func_builder
857           ->create<mlir::mhlo::SliceOp>(
858               loc, result_type, operands[0],
859               ConvertDimensions(instruction->slice_starts()),
860               ConvertDimensions(instruction->slice_limits()),
861               ConvertDimensions(instruction->slice_strides()))
862           .getOperation();
863     }
864     case HloOpcode::kSort: {
865       auto sort_instruction = Cast<HloSortInstruction>(instruction);
866 
867       llvm::SmallVector<Type, 4> return_types = {result_type};
868       if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
869         return_types = llvm::to_vector<6>(tuple_ty.getTypes());
870       }
871 
872       auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
873           loc, return_types, operands,
874           builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
875           builder_->getBoolAttr(sort_instruction->is_stable()));
876       TF_RETURN_IF_ERROR(ImportAsRegion(*sort_instruction->to_apply(),
877                                         &sort_op.comparator(),
878                                         /*flatten_region_arg_tuple=*/true));
879 
880       // Check if the output needs to be tupled.
881       if (return_types.size() == 1 && return_types.front() == result_type) {
882         return sort_op.getOperation();
883       }
884 
885       return func_builder
886           ->create<mlir::mhlo::TupleOp>(loc, result_type, sort_op.getResults())
887           .getOperation();
888     }
889     case HloOpcode::kConditional: {
890       llvm::SmallVector<Type, 4> rets;
891 
892       // Flatten the tuple-typed operands.
893       llvm::SmallVector<Value> flattened_operands;
894       for (auto& operand : operands)
895         FlattenTupleValue(func_builder, loc, operand, flattened_operands);
896 
897       // If/Case Op has a single operand; we collect the other operands to
898       // replace the corresponding block arguments.
899       llvm::ArrayRef<Value> implicit_operands(flattened_operands.begin() + 1,
900                                               flattened_operands.end());
901 
902       mlir::Type pred_or_index_type =
903           operands[0].getType().cast<mlir::TensorType>().getElementType();
904       // It is a predicated conditional if first argument is a boolean and
905       // should be mapped to If op.
906       if (pred_or_index_type.isInteger(1)) {
907         TF_RETURN_IF_ERROR(GetMlirTypes(
908             {instruction->true_computation()->root_instruction()}, &rets));
909 
910         // Flatten the return-type.
911         llvm::SmallVector<Type> flattened_ret_types;
912         assert(rets.size() == 1);
913         FlattenTupleType(rets[0], flattened_ret_types);
914 
915         auto op = func_builder->create<mlir::mhlo::IfOp>(
916             loc, flattened_ret_types, flattened_operands[0], attributes);
917         TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
918                                           &op.true_branch(),
919                                           /*flatten_region_arg_tuple=*/true));
920         TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
921                                           &op.false_branch(),
922                                           /*flatten_region_arg_tuple=*/true));
923 
924         // Replace the uses of block-arguments of the IfOp with the
925         // implicit_operands.
926         ReplaceBlockArgumentsWithImplicitOperands(op.getOperation(),
927                                                   implicit_operands);
928 
929         return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
930                                         rets[0]);
931       }
932 
933       // Otherwise, it is a indexed conditional and should be mapped to Case
934       // op.
935       TF_RETURN_IF_ERROR(GetMlirTypes(
936           {instruction->branch_computation(0)->root_instruction()}, &rets));
937 
938       // Flatten the return-type.
939       llvm::SmallVector<Type> flattened_ret_types;
940       assert(rets.size() == 1);
941       FlattenTupleType(rets[0], flattened_ret_types);
942 
943       int num_branches = instruction->branch_count();
944       auto op = func_builder->create<mlir::mhlo::CaseOp>(
945           loc, flattened_ret_types, flattened_operands[0], attributes,
946           num_branches);
947       for (const auto& index_and_computation :
948            llvm::enumerate(instruction->branch_computations())) {
949         auto index = index_and_computation.index();
950         HloComputation* computation = index_and_computation.value();
951         TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index],
952                                           /*flatten_region_arg_tuple=*/true));
953       }
954 
955       // Replace the uses of block-arguments of the CaseOp with the
956       // implicit_operands.
957       ReplaceBlockArgumentsWithImplicitOperands(op.getOperation(),
958                                                 implicit_operands);
959 
960       return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
961                                       rets[0]);
962     }
963     case HloOpcode::kConcatenate: {
964       // TODO(b/132057942): Support taking an uint64_t instead of an
965       // IntegerAttr for concatenate dimension.
966       return func_builder
967           ->create<mlir::mhlo::ConcatenateOp>(
968               loc, result_type, operands,
969               builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
970           .getOperation();
971     }
972     case HloOpcode::kAllGather: {
973       auto all_gather = Cast<HloAllGatherInstruction>(instruction);
974       attributes.push_back(builder_->getNamedAttr(
975           "all_gather_dim",
976           builder_->getI64IntegerAttr(all_gather->all_gather_dimension())));
977       attributes.push_back(
978           ConvertReplicaGroups(all_gather->replica_groups(), builder_));
979       if (all_gather->channel_id().has_value())
980         attributes.push_back(
981             ConvertChannelHandle(all_gather->channel_id().value()));
982       return func_builder
983           ->create<mlir::mhlo::AllGatherOp>(loc, result_type, operands,
984                                             attributes)
985           .getOperation();
986     }
987     case HloOpcode::kAllReduce: {
988       auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
989       attributes.push_back(
990           ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
991       if (all_reduce->channel_id().has_value())
992         attributes.push_back(
993             ConvertChannelHandle(all_reduce->channel_id().value()));
994       if (all_reduce->use_global_device_ids())
995         attributes.push_back(ConvertUseGlobalDeviceIds());
996       auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
997           loc, result_type, operands, attributes);
998       TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
999                                         &all_reduce_op.computation()));
1000       return all_reduce_op.getOperation();
1001     }
1002     case HloOpcode::kAllToAll: {
1003       // TODO(b/207152612): all-to-all HLO can either have pre-split operands
1004       // (and returns a tuple) or a single operand that is split across
1005       // `split_dimension` into the number of replicas in a group. Only the
1006       // latter case (array all-to-all) is supported in importer right now and
1007       // the former (tuple all-to-all) is not supported yet.
1008       auto all_to_all = Cast<HloAllToAllInstruction>(instruction);
1009       if (all_to_all->shape().IsTuple())
1010         return tensorflow::errors::Unimplemented(
1011             "Importing tuple all-to-all HLO is not supported yet");
1012 
1013       // Check invariants of array all-to-all. This is a sanity check and is
1014       // verified by the HLO verifier.
1015       if (!all_to_all->split_dimension().has_value() || operands.size() != 1 ||
1016           all_to_all->replica_groups().empty())
1017         return tensorflow::errors::InvalidArgument(
1018             "Array all-to-all should have a split dimension, one operand and "
1019             "non-empty replica groups");
1020 
1021       auto replica_groups_attr =
1022           ConvertReplicaGroups(all_to_all->replica_groups(), builder_)
1023               .getValue()
1024               .cast<DenseIntElementsAttr>();
1025       uint64_t split_dim = all_to_all->split_dimension().value();
1026       uint64_t concat_dim = split_dim;
1027       uint64_t split_count = all_to_all->replica_groups()[0].replica_ids_size();
1028 
1029       return func_builder
1030           ->create<mlir::mhlo::AllToAllOp>(loc, result_type, operands[0],
1031                                            split_dim, concat_dim, split_count,
1032                                            replica_groups_attr)
1033           .getOperation();
1034     }
1035     case HloOpcode::kReduce: {
1036       // Operands in the first half are reduction inputs and the remaining
1037       // operands are corresponding initial values.
1038       size_t num_inputs = operands.size() / 2;
1039       llvm::SmallVector<Type, 4> return_types = {result_type};
1040       if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
1041         return_types = llvm::to_vector<6>(tuple_ty.getTypes());
1042       }
1043 
1044       auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
1045           loc, return_types,
1046           llvm::makeArrayRef(operands).take_front(num_inputs),
1047           llvm::makeArrayRef(operands).drop_front(num_inputs),
1048           ConvertDimensions(instruction->dimensions()));
1049       TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
1050                                         &reduce.body(),
1051                                         /*flatten_region_arg_tuple=*/true));
1052 
1053       // Check if the output needs to be tupled.
1054       if (return_types.size() == 1 && return_types.front() == result_type) {
1055         return reduce.getOperation();
1056       }
1057 
1058       return func_builder
1059           ->create<mlir::mhlo::TupleOp>(loc, result_type, reduce.getResults())
1060           .getOperation();
1061     }
1062     case HloOpcode::kReverse: {
1063       return func_builder
1064           ->create<mlir::mhlo::ReverseOp>(
1065               loc, result_type, operands[0],
1066               ConvertDimensions(instruction->dimensions()))
1067           .getOperation();
1068     }
1069     case HloOpcode::kRng: {
1070       auto shape = func_builder->create<mlir::mhlo::ConstantOp>(
1071           loc, Convert(result_type.cast<RankedTensorType>().getShape()));
1072       switch (instruction->random_distribution()) {
1073         case xla::RNG_UNIFORM:
1074           return func_builder
1075               ->create<mlir::mhlo::RngOp>(
1076                   loc, result_type, operands[0], operands[1], shape,
1077                   ::mlir::mhlo::RngDistribution::UNIFORM)
1078               .getOperation();
1079 
1080         case xla::RNG_NORMAL:
1081           return func_builder
1082               ->create<mlir::mhlo::RngOp>(loc, result_type, operands[0],
1083                                           operands[1], shape,
1084                                           ::mlir::mhlo::RngDistribution::NORMAL)
1085               .getOperation();
1086 
1087         default:
1088           return tensorflow::errors::InvalidArgument(absl::StrCat(
1089               "Unsupported distribution: ",
1090               RandomDistributionToString(instruction->random_distribution())));
1091       }
1092     }
1093     case HloOpcode::kRngBitGenerator: {
1094       auto rng_op = Cast<HloRngBitGeneratorInstruction>(instruction);
1095 
1096       // Flatten the return type if they are tuple-typed.
1097       llvm::SmallVector<Type> flattened_ret_types;
1098       FlattenTupleType(result_type, flattened_ret_types);
1099 
1100       auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get(
1101           builder_->getContext(),
1102           *mlir::mhlo::symbolizeRngAlgorithm(rng_op->algorithm()));
1103       auto op = func_builder->create<mlir::mhlo::RngBitGeneratorOp>(
1104           loc, flattened_ret_types, algorithm_attr, operands[0]);
1105 
1106       return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
1107                                       result_type);
1108     }
1109     case HloOpcode::kRngGetAndUpdateState: {
1110       return func_builder
1111           ->create<mlir::mhlo::XlaRngGetAndUpdateStateOp>(
1112               loc, result_type,
1113               func_builder->getI64IntegerAttr(
1114                   Cast<HloRngGetAndUpdateStateInstruction>(instruction)
1115                       ->delta()))
1116           .getOperation();
1117     }
1118     case HloOpcode::kWhile: {
1119       llvm::SmallVector<Value> flattened_operands;
1120       llvm::SmallVector<Type> flattened_operand_types;
1121       FlattenTupleType(operands[0].getType(), flattened_operand_types);
1122       FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
1123 
1124       auto op = func_builder->create<mlir::mhlo::WhileOp>(
1125           loc, flattened_operand_types, flattened_operands);
1126 
1127       TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->while_condition(),
1128                                         &op.cond(),
1129                                         /*flatten_region_arg_tuple=*/true));
1130       TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->while_body(), &op.body(),
1131                                         /*flatten_region_arg_tuple=*/true));
1132       return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
1133                                       operands[0].getType());
1134     }
1135     case HloOpcode::kGetTupleElement: {
1136       attributes.push_back(builder_->getNamedAttr(
1137           "index", builder_->getIntegerAttr(builder_->getIntegerType(32),
1138                                             instruction->tuple_index())));
1139       return func_builder
1140           ->create<mlir::mhlo::GetTupleElementOp>(loc, result_type, operands,
1141                                                   attributes)
1142           .getOperation();
1143     };
1144     case HloOpcode::kGetDimensionSize: {
1145       attributes.push_back(builder_->getNamedAttr(
1146           "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
1147       return func_builder
1148           ->create<mlir::mhlo::GetDimensionSizeOp>(loc, result_type, operands,
1149                                                    attributes)
1150           .getOperation();
1151     };
1152     case HloOpcode::kTranspose: {
1153       attributes.push_back(builder_->getNamedAttr(
1154           "permutation", ConvertDimensions(instruction->dimensions())));
1155       return func_builder
1156           ->create<mlir::mhlo::TransposeOp>(loc, result_type, operands,
1157                                             attributes)
1158           .getOperation();
1159     }
1160     case HloOpcode::kTriangularSolve: {
1161       attributes.push_back(builder_->getNamedAttr(
1162           "left_side",
1163           builder_->getBoolAttr(
1164               instruction->triangular_solve_options().left_side())));
1165       attributes.push_back(builder_->getNamedAttr(
1166           "lower", builder_->getBoolAttr(
1167                        instruction->triangular_solve_options().lower())));
1168       attributes.push_back(builder_->getNamedAttr(
1169           "unit_diagonal",
1170           builder_->getBoolAttr(
1171               instruction->triangular_solve_options().unit_diagonal())));
1172       auto transpose_a = mlir::mhlo::TransposeAttr::get(
1173           builder_->getContext(),
1174           mlir::mhlo::symbolizeTranspose(
1175               TriangularSolveOptions::Transpose_Name(
1176                   instruction->triangular_solve_options().transpose_a()))
1177               .getValue());
1178 
1179       attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
1180       return func_builder
1181           ->create<mlir::mhlo::TriangularSolveOp>(loc, result_type, operands,
1182                                                   attributes)
1183           .getOperation();
1184     }
1185     case HloOpcode::kReduceScatter: {
1186       auto reduce_scatter = Cast<HloReduceScatterInstruction>(instruction);
1187       attributes.push_back(builder_->getNamedAttr(
1188           "scatter_dimension",
1189           builder_->getI64IntegerAttr(reduce_scatter->scatter_dimension())));
1190       attributes.push_back(
1191           ConvertReplicaGroups(reduce_scatter->replica_groups(), builder_));
1192       if (reduce_scatter->channel_id().has_value())
1193         attributes.push_back(
1194             ConvertChannelHandle(reduce_scatter->channel_id().value()));
1195       auto reduce_scatter_op =
1196           func_builder->create<mlir::mhlo::ReduceScatterOp>(
1197               loc, result_type, operands, attributes);
1198       TF_RETURN_IF_ERROR(ImportAsRegion(*reduce_scatter->to_apply(),
1199                                         &reduce_scatter_op.computation(),
1200                                         /*flatten_region_arg_tuple=*/true));
1201 
1202       return reduce_scatter_op.getOperation();
1203     }
1204     case HloOpcode::kReduceWindow: {
1205       llvm::SmallVector<Type, 4> return_types = {result_type};
1206       if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
1207         return_types = llvm::to_vector<6>(tuple_ty.getTypes());
1208       }
1209       llvm::SmallVector<int64_t, 4> sizes, strides, base_dilations,
1210           win_dilations;
1211       llvm::SmallVector<int64_t, 8> padding;
1212       for (const auto& dim : instruction->window().dimensions()) {
1213         sizes.push_back(dim.size());
1214         strides.push_back(dim.stride());
1215         base_dilations.push_back(dim.base_dilation());
1216         win_dilations.push_back(dim.window_dilation());
1217         padding.push_back(dim.padding_low());
1218         padding.push_back(dim.padding_high());
1219       }
1220       attributes.push_back(builder_->getNamedAttr("window_dimensions",
1221                                                   ConvertDimensions(sizes)));
1222       attributes.push_back(
1223           builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
1224       attributes.push_back(builder_->getNamedAttr(
1225           "base_dilations", ConvertDimensions(base_dilations)));
1226       attributes.push_back(builder_->getNamedAttr(
1227           "window_dilations", ConvertDimensions(win_dilations)));
1228       attributes.push_back(ConvertPadding(padding));
1229       auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
1230           loc, return_types, operands, attributes);
1231       TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
1232                                         &reduce.body(),
1233                                         /*flatten_region_arg_tuple=*/true));
1234 
1235       // Check if the output needs to be tupled.
1236       if (return_types.size() == 1 && return_types.front() == result_type) {
1237         return reduce.getOperation();
1238       }
1239 
1240       return func_builder
1241           ->create<mlir::mhlo::TupleOp>(loc, result_type, reduce.getResults())
1242           .getOperation();
1243     }
1244     case HloOpcode::kMap: {
1245       auto op = func_builder->create<mlir::mhlo::MapOp>(
1246           loc, result_type, operands,
1247           ConvertDimensions(instruction->dimensions()));
1248       TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
1249                                         &op.computation(),
1250                                         /*flatten_region_arg_tuple=*/true));
1251       return op.getOperation();
1252     }
1253     case HloOpcode::kConvolution: {
1254       llvm::SmallVector<int64_t, 4> strides, lhs_dilations, rhs_dilations;
1255       llvm::SmallVector<bool, 4> reversals;
1256       llvm::SmallVector<int64_t, 8> paddings;
1257       for (const auto& dim : instruction->window().dimensions()) {
1258         strides.push_back(dim.stride());
1259         lhs_dilations.push_back(dim.base_dilation());
1260         rhs_dilations.push_back(dim.window_dilation());
1261         paddings.push_back(dim.padding_low());
1262         paddings.push_back(dim.padding_high());
1263         reversals.push_back(dim.window_reversal());
1264       }
1265 
1266       attributes.push_back(
1267           builder_->getNamedAttr("window_strides", Convert(strides)));
1268       attributes.push_back(ConvertPadding(paddings));
1269       attributes.push_back(
1270           builder_->getNamedAttr("lhs_dilation", Convert(lhs_dilations)));
1271       attributes.push_back(
1272           builder_->getNamedAttr("rhs_dilation", Convert(rhs_dilations)));
1273       attributes.push_back(
1274           builder_->getNamedAttr("window_reversal", Convert(reversals)));
1275       attributes.push_back(builder_->getNamedAttr(
1276           "dimension_numbers",
1277           ConvertConvDimensionNumbers(
1278               instruction->convolution_dimension_numbers(), builder_)));
1279       attributes.push_back(builder_->getNamedAttr(
1280           "feature_group_count",
1281           builder_->getI64IntegerAttr(instruction->feature_group_count())));
1282       attributes.push_back(builder_->getNamedAttr(
1283           "batch_group_count",
1284           builder_->getI64IntegerAttr(instruction->batch_group_count())));
1285       attributes.push_back(builder_->getNamedAttr(
1286           "precision_config",
1287           ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
1288 
1289       return func_builder
1290           ->create<mlir::mhlo::ConvolutionOp>(loc, result_type, operands,
1291                                               attributes)
1292           .getOperation();
1293     }
1294 
1295     case HloOpcode::kFft: {
1296       auto fft_type = mlir::mhlo::FftTypeAttr::get(
1297           builder_->getContext(),
1298           mlir::mhlo::symbolizeFftType(FftType_Name(instruction->fft_type()))
1299               .getValue());
1300 
1301       std::vector<int64_t> fft_length(instruction->fft_length().begin(),
1302                                       instruction->fft_length().end());
1303 
1304       attributes.push_back(builder_->getNamedAttr("fft_type", fft_type));
1305       attributes.push_back(
1306           builder_->getNamedAttr("fft_length", Convert(fft_length)));
1307       return func_builder
1308           ->create<mlir::mhlo::FftOp>(loc, result_type, operands, attributes)
1309           .getOperation();
1310     }
1311 
1312     case HloOpcode::kAdd: {
1313       // HLO add ops on PRED elements are actually boolean or, but MHLO dialect
1314       // AddOps on i1 are just addition with overflow; so, we have to implement
1315       // the special behavior of HLO add ops on PRED here by creating an
1316       // arith::OrIOp instead.
1317       if (instruction->shape().element_type() == PRED) {
1318         return func_builder
1319             ->create<mlir::mhlo::OrOp>(loc, result_type, operands, attributes)
1320             .getOperation();
1321       } else {
1322         return func_builder
1323             ->create<mlir::mhlo::AddOp>(loc, result_type, operands, attributes)
1324             .getOperation();
1325       }
1326     }
1327     case HloOpcode::kAfterAll: {
1328       // HLO AfterAll ops without any token input are used to just create a
1329       // token. MHLO has a special op CreateToken for this case.
1330       if (instruction->operands().empty()) {
1331         return func_builder
1332             ->create<mlir::mhlo::CreateTokenOp>(loc, result_type, operands,
1333                                                 attributes)
1334             .getOperation();
1335       } else {
1336         return func_builder
1337             ->create<mlir::mhlo::AfterAllOp>(loc, result_type, operands,
1338                                              attributes)
1339             .getOperation();
1340       }
1341     }
1342 
1343     case HloOpcode::kConvert: {
1344       // Convert to boolean is special, it requires a comparison to 0 instead of
1345       // a truncation to i1, otherwise it is a 1-1 translation.
1346       auto ranked_type = result_type.dyn_cast<mlir::RankedTensorType>();
1347       mlir::IntegerType integer_type =
1348           (ranked_type)
1349               ? ranked_type.getElementType().dyn_cast<mlir::IntegerType>()
1350               : nullptr;
1351       if (!integer_type || integer_type.getWidth() != 1) {
1352         // Simple case: 1-1 mapping.
1353         return {func_builder->create<mlir::mhlo::ConvertOp>(
1354             loc, result_type, operands, attributes)};
1355       }
1356 
1357       // Return type is boolean, let's use `operand != 0` instead of Convert.
1358       xla::Shape input_shape = instruction->operand(0)->shape();
1359       TF_ASSIGN_OR_RETURN(mlir::Type type,
1360                           ConvertTensorShapeToType<mlir::RankedTensorType>(
1361                               input_shape, *func_builder));
1362       auto zero = func_builder->create<mlir::mhlo::ConstantOp>(
1363           loc, func_builder->getZeroAttr(type));
1364       return {func_builder->create<mlir::mhlo::CompareOp>(
1365           loc, operands[0], zero, mlir::mhlo::ComparisonDirection::NE)};
1366     }
1367     case HloOpcode::kOptimizationBarrier: {
1368       llvm::SmallVector<Value> flattened_operands;
1369       llvm::SmallVector<Type> flattened_operand_types;
1370       FlattenTupleType(operands[0].getType(), flattened_operand_types);
1371       FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
1372 
1373       auto op = func_builder->create<mlir::mhlo::OptimizationBarrierOp>(
1374           loc, flattened_operand_types, flattened_operands);
1375 
1376       return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
1377                                       operands[0].getType());
1378     }
1379     case HloOpcode::kDomain: {
1380       auto domain_kind = mlir::mhlo::symbolizeDomainKind(
1381           instruction->user_side_metadata().Kind());
1382       if (!domain_kind || *domain_kind != mlir::mhlo::DomainKind::sharding) {
1383         return tensorflow::errors::InvalidArgument(
1384             "Invalid domain kind in hlo -> mhlo import. Only 'sharding' is "
1385             "supported");
1386       }
1387       attributes.push_back(builder_->getNamedAttr(
1388           "kind", mlir::mhlo::DomainKindAttr::get(func_builder->getContext(),
1389                                                   *domain_kind)));
1390 
1391       // In XLA, DomainMetadata is open-world, but in the proto, it is hardcoded
1392       // to be ShardingMetadata. Thankfully, the only other implementation of
1393       // DomainMetadata is OpName, which is generally used for debugging and
1394       // never for compiling production models.
1395       //
1396       // Since this is hardcoded as such in the proto, we must follow suit.
1397       // TODO(b/208783683): The one improvement we can make on this is to move
1398       // from the a serialized proto representation to a parsable string
1399       auto exit_metadata = ShardingMetadata::ToShardingMetadata(
1400           &instruction->operand_side_metadata());
1401       auto entry_metadata = ShardingMetadata::ToShardingMetadata(
1402           &instruction->user_side_metadata());
1403       attributes.push_back(builder_->getNamedAttr(
1404           "exit_metadata",
1405           builder_->getStringAttr(
1406               (*exit_metadata)->sharding()->ToProto().SerializeAsString())));
1407       attributes.push_back(builder_->getNamedAttr(
1408           "entry_metadata",
1409           builder_->getStringAttr(
1410               (*entry_metadata)->sharding()->ToProto().SerializeAsString())));
1411 
1412       return func_builder
1413           ->create<mlir::mhlo::DomainOp>(loc, result_type, operands, attributes)
1414           .getOperation();
1415     }
1416 
1417 #define NO_ATTRIBUTE_CASE(hlo_op_code, mlir_op)                               \
1418   case HloOpcode::hlo_op_code: {                                              \
1419     return func_builder                                                       \
1420         ->create<mlir::mhlo::mlir_op>(loc, result_type, operands, attributes) \
1421         .getOperation();                                                      \
1422   }
1423 
1424       // broadcast dimensions are never added here because they don't exist as
1425       // part of the HLO instruction. They are only a convenience in the XLA
1426       // builder API.
1427       NO_ATTRIBUTE_CASE(kAbs, AbsOp);
1428       NO_ATTRIBUTE_CASE(kAddDependency, AddDependencyOp);
1429       NO_ATTRIBUTE_CASE(kAnd, AndOp);
1430       NO_ATTRIBUTE_CASE(kAtan2, Atan2Op);
1431       NO_ATTRIBUTE_CASE(kBitcastConvert, BitcastConvertOp);
1432       NO_ATTRIBUTE_CASE(kCbrt, CbrtOp);
1433       NO_ATTRIBUTE_CASE(kClz, ClzOp);
1434       NO_ATTRIBUTE_CASE(kCeil, CeilOp);
1435       NO_ATTRIBUTE_CASE(kClamp, ClampOp);
1436       NO_ATTRIBUTE_CASE(kComplex, ComplexOp);
1437       NO_ATTRIBUTE_CASE(kCos, CosineOp);
1438       NO_ATTRIBUTE_CASE(kDivide, DivOp);
1439       NO_ATTRIBUTE_CASE(kExp, ExpOp);
1440       NO_ATTRIBUTE_CASE(kExpm1, Expm1Op);
1441       NO_ATTRIBUTE_CASE(kFloor, FloorOp);
1442       NO_ATTRIBUTE_CASE(kIsFinite, IsFiniteOp);
1443       NO_ATTRIBUTE_CASE(kImag, ImagOp);
1444       NO_ATTRIBUTE_CASE(kLog, LogOp);
1445       NO_ATTRIBUTE_CASE(kLog1p, Log1pOp);
1446       NO_ATTRIBUTE_CASE(kMaximum, MaxOp);
1447       NO_ATTRIBUTE_CASE(kMinimum, MinOp);
1448       NO_ATTRIBUTE_CASE(kMultiply, MulOp);
1449       NO_ATTRIBUTE_CASE(kNegate, NegOp);
1450       NO_ATTRIBUTE_CASE(kNot, NotOp);
1451       NO_ATTRIBUTE_CASE(kOr, OrOp);
1452       NO_ATTRIBUTE_CASE(kPartitionId, PartitionIdOp);
1453       NO_ATTRIBUTE_CASE(kPopulationCount, PopulationCountOp);
1454       NO_ATTRIBUTE_CASE(kPower, PowOp);
1455       NO_ATTRIBUTE_CASE(kReal, RealOp);
1456       NO_ATTRIBUTE_CASE(kRemainder, RemOp);
1457       NO_ATTRIBUTE_CASE(kReplicaId, ReplicaIdOp);
1458       NO_ATTRIBUTE_CASE(kLogistic, LogisticOp);
1459       // The dimensions attribute is not present on the HLO Reshape
1460       // instruction. If dimensions are non-default, the XLA builder
1461       // implements it as a separate transpose.
1462       NO_ATTRIBUTE_CASE(kReshape, ReshapeOp);
1463       NO_ATTRIBUTE_CASE(kRoundNearestAfz, RoundOp);
1464       NO_ATTRIBUTE_CASE(kRoundNearestEven, RoundNearestEvenOp);
1465       NO_ATTRIBUTE_CASE(kRsqrt, RsqrtOp);
1466       NO_ATTRIBUTE_CASE(kSelect, SelectOp);
1467       NO_ATTRIBUTE_CASE(kShiftLeft, ShiftLeftOp);
1468       NO_ATTRIBUTE_CASE(kShiftRightArithmetic, ShiftRightArithmeticOp);
1469       NO_ATTRIBUTE_CASE(kShiftRightLogical, ShiftRightLogicalOp);
1470       NO_ATTRIBUTE_CASE(kSign, SignOp);
1471       NO_ATTRIBUTE_CASE(kSin, SineOp);
1472       NO_ATTRIBUTE_CASE(kSqrt, SqrtOp);
1473       NO_ATTRIBUTE_CASE(kSubtract, SubtractOp);
1474       NO_ATTRIBUTE_CASE(kTanh, TanhOp);
1475       NO_ATTRIBUTE_CASE(kTuple, TupleOp);
1476       NO_ATTRIBUTE_CASE(kXor, XorOp);
1477       // TODO(b/129422361) Copy needs special handling because it is not
1478       // defined in tensorflow/compiler/xla/client/xla_builder.h. See
1479       // operation semantics in
1480       // g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
1481       NO_ATTRIBUTE_CASE(kCopy, CopyOp);
1482 
1483 #undef NO_ATTRIBUTE_CASE
1484 
1485     case HloOpcode::kFusion: {
1486       // Flatten the tuple-typed operands.
1487       llvm::SmallVector<Value> flattened_operands;
1488       for (auto& operand : operands)
1489         FlattenTupleValue(func_builder, loc, operand, flattened_operands);
1490 
1491       // Flatten the return type if they are tuple-typed.
1492       llvm::SmallVector<Type> flattened_ret_types;
1493       FlattenTupleType(result_type, flattened_ret_types);
1494 
1495       auto fusion_kind = mlir::mhlo::symbolizeFusionKind(
1496           xla::ToString(instruction->fusion_kind()));
1497       auto fusion = func_builder->create<mlir::mhlo::FusionOp>(
1498           loc, flattened_ret_types, flattened_operands,
1499           mlir::mhlo::FusionKindAttr::get(func_builder->getContext(),
1500                                           fusion_kind.getValue()));
1501       TF_RETURN_IF_ERROR(ImportAsRegion(
1502           *instruction->fused_instructions_computation(),
1503           &fusion.fused_computation(), /*flatten_region_arg_tuple=*/true));
1504 
1505       return CreateTupleFromOpResults(func_builder, loc, fusion.getOperation(),
1506                                       result_type);
1507     }
1508     case HloOpcode::kBitcast: {
1509       auto bitcast = func_builder->create<mlir::mhlo::BitcastOp>(
1510           loc, result_type, operands, attributes);
1511       // Store the source and result layout as attributes. Although the MHLO
1512       // Bitcast operates on tensors, these layouts are relevant as they define
1513       // the mapping between the elements of the source and result.
1514       SetLayoutForMlir(bitcast, instruction->shape(), "result_layout");
1515       SetLayoutForMlir(bitcast, instruction->operand(0)->shape(),
1516                        "source_layout");
1517       return bitcast.getOperation();
1518     }
1519     case HloOpcode::kReducePrecision: {
1520       auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
1521           loc, result_type, operands[0], attributes);
1522       op.exponent_bitsAttr(func_builder->getIntegerAttr(
1523           func_builder->getI32Type(), instruction->exponent_bits()));
1524       op.mantissa_bitsAttr(func_builder->getIntegerAttr(
1525           func_builder->getI32Type(), instruction->mantissa_bits()));
1526       return op.getOperation();
1527     }
1528     default: {
1529       mlir::OperationState result(loc, "mhlo.unknown");
1530       result.addOperands(operands);
1531       result.addTypes(result_type);
1532       for (auto attr : attributes) {
1533         result.attributes.push_back(attr);
1534       }
1535 
1536       return func_builder->create(result);
1537     }
1538   }
1539 }
1540 
SetXlaShape(mlir::Operation * op,const Shape & shape)1541 void SetXlaShape(mlir::Operation* op, const Shape& shape) {
1542   op->setAttr("xla_shape",
1543               mlir::Builder(op->getContext())
1544                   .getStringAttr(shape.ToString(/*print_layout=*/true)));
1545 }
1546 
ImportInstructionWithLayout(const HloInstruction * instruction,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * func_builder,DynamicShapeHandlingMode mode)1547 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionWithLayout(
1548     const HloInstruction* instruction,
1549     const llvm::SmallVectorImpl<mlir::Value>& operands,
1550     mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) {
1551   TF_ASSIGN_OR_RETURN(
1552       mlir::Operation * op,
1553       ImportInstructionImpl(instruction, operands, func_builder, mode));
1554   if (op == nullptr) return op;
1555 
1556   // See MlirToHloConversionOptions for more about layouts.
1557   //
1558   // Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
1559   // in physical minor-to-major order.
1560   if (instruction->shape().IsArray()) {
1561     if (instruction->shape().has_layout() &&
1562         !instruction->shape().layout().minor_to_major().empty() &&
1563         instruction->shape().layout() !=
1564             LayoutUtil::MakeDescendingLayout(
1565                 instruction->shape().dimensions().size())) {
1566       SetXlaShape(op, instruction->shape());
1567     }
1568   } else {
1569     SetXlaShape(op, instruction->shape());
1570   }
1571   return op;
1572 }
1573 
GetOperands(const HloInstruction * instruction)1574 StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
1575     const HloInstruction* instruction) {
1576   llvm::SmallVector<mlir::Value, 4> operands;
1577   for (const auto& operand : instruction->operands()) {
1578     auto input_it = instruction_value_map_.find(operand);
1579     if (input_it == instruction_value_map_.end()) {
1580       return tensorflow::errors::Internal(
1581           absl::StrCat("Could not find input value: ", operand->name(),
1582                        " for instruction ", instruction->name()));
1583     }
1584     operands.push_back(input_it->second);
1585   }
1586   return operands;
1587 }
1588 
GetMlirTypes(const std::vector<HloInstruction * > & instructions,llvm::SmallVectorImpl<mlir::Type> * types)1589 tensorflow::Status HloFunctionImporter::GetMlirTypes(
1590     const std::vector<HloInstruction*>& instructions,
1591     llvm::SmallVectorImpl<mlir::Type>* types) {
1592   for (auto instruction : instructions) {
1593     TF_ASSIGN_OR_RETURN(auto ret_type, ConvertShapeToType<RankedTensorType>(
1594                                            instruction->shape(), *builder_));
1595     types->push_back(ret_type);
1596   }
1597   return ::tensorflow::OkStatus();
1598 }
1599 
GetMlirValue(const HloInstruction * instruction)1600 StatusOr<Value> HloFunctionImporter::GetMlirValue(
1601     const HloInstruction* instruction) {
1602   auto lookup = instruction_value_map_.find(instruction);
1603   if (lookup != instruction_value_map_.end()) {
1604     return lookup->second;
1605   }
1606 
1607   return tensorflow::errors::Internal(absl::StrCat(
1608       "Unable to find value for input: ", instruction->ToString()));
1609 }
1610 
ConvertComparisonDirection(ComparisonDirection direction)1611 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
1612     ComparisonDirection direction) {
1613   return builder_->getNamedAttr(
1614       "comparison_direction",
1615       mlir::mhlo::ComparisonDirectionAttr::get(
1616           builder_->getContext(), mlir::mhlo::symbolizeComparisonDirection(
1617                                       ComparisonDirectionToString(direction))
1618                                       .getValue()));
1619 }
1620 
ConvertComparisonType(Comparison::Type type)1621 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonType(
1622     Comparison::Type type) {
1623   return builder_->getNamedAttr(
1624       "compare_type",
1625       mlir::mhlo::ComparisonTypeAttr::get(
1626           builder_->getContext(),
1627           mlir::mhlo::symbolizeComparisonType(ComparisonTypeToString(type))
1628               .getValue()));
1629 }
1630 
ConvertDimensions(absl::Span<const int64_t> op_dimensions)1631 mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(
1632     absl::Span<const int64_t> op_dimensions) {
1633   llvm::SmallVector<APInt, 8> dimensions;
1634   dimensions.reserve(op_dimensions.size());
1635   for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value));
1636 
1637   return DenseIntElementsAttr::get(
1638       RankedTensorType::get(dimensions.size(), builder_->getIntegerType(64)),
1639       dimensions);
1640 }
1641 
Convert(llvm::ArrayRef<int64_t> elements)1642 mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
1643     llvm::ArrayRef<int64_t> elements) {
1644   return DenseIntElementsAttr::get(
1645       RankedTensorType::get(elements.size(), builder_->getIntegerType(64)),
1646       elements);
1647 }
1648 
Convert(llvm::ArrayRef<bool> elements)1649 mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
1650     llvm::ArrayRef<bool> elements) {
1651   return DenseIntElementsAttr::get(
1652       RankedTensorType::get(elements.size(), builder_->getI1Type()), elements);
1653 }
1654 
ConvertPadding(llvm::ArrayRef<int64_t> padding)1655 mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
1656     llvm::ArrayRef<int64_t> padding) {
1657   auto ty =
1658       mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
1659                                   builder_->getIntegerType(64));
1660   auto attr = DenseIntElementsAttr::get(ty, padding);
1661   return builder_->getNamedAttr("padding", attr);
1662 }
1663 
ConvertSourceTargetPairs(const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,mlir::Builder * builder)1664 mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
1665     const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
1666     mlir::Builder* builder) {
1667   std::vector<int64_t> attr(source_target_pairs.size() * 2);
1668   for (const auto& p : llvm::enumerate(source_target_pairs)) {
1669     attr[2 * p.index()] = p.value().first;
1670     attr[2 * p.index() + 1] = p.value().second;
1671   }
1672   auto type = mlir::RankedTensorType::get(
1673       {static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
1674   return builder->getNamedAttr("source_target_pairs",
1675                                DenseIntElementsAttr::get(type, attr));
1676 }
1677 
ConvertReplicaGroups(absl::Span<const ReplicaGroup> replica_groups,mlir::Builder * builder)1678 mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
1679     absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder) {
1680   const int64_t num_groups = replica_groups.size();
1681   // Replica groups in HLO can be non-uniform in size, for example:
1682   // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
1683   // tensor, pad the smaller sized replica groups with -1.
1684   const int64_t group_size = absl::c_accumulate(
1685       replica_groups, int64_t(0), [](int64_t current, const ReplicaGroup& g) {
1686         return std::max<int64_t>(current, g.replica_ids_size());
1687       });
1688   // Initialize all elements to -1 to support non-uniform replica groups.
1689   std::vector<int64_t> attr(num_groups * group_size, -1);
1690   for (int i = 0; i < num_groups; ++i) {
1691     int index = i * group_size;
1692     for (const int64_t& id : replica_groups[i].replica_ids())
1693       attr[index++] = id;
1694   }
1695   auto type = mlir::RankedTensorType::get({num_groups, group_size},
1696                                           builder->getIntegerType(64));
1697   return builder->getNamedAttr("replica_groups",
1698                                DenseIntElementsAttr::get(type, attr));
1699 }
1700 
ConvertChannelHandle(std::optional<int64_t> channel_id)1701 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
1702     std::optional<int64_t> channel_id) {
1703   xla::ChannelHandle channel_handle;
1704   if (channel_id) channel_handle.set_handle(*channel_id);
1705   return ConvertChannelHandle(channel_handle);
1706 }
1707 
ConvertChannelHandle(const xla::ChannelHandle & channel)1708 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
1709     const xla::ChannelHandle& channel) {
1710   return builder_->getNamedAttr(
1711       "channel_handle", mlir::mhlo::ChannelHandleAttr::get(
1712                             context_, channel.handle(), channel.type()));
1713 }
1714 
ConvertUseGlobalDeviceIds()1715 mlir::NamedAttribute HloFunctionImporter::ConvertUseGlobalDeviceIds() {
1716   return builder_->getNamedAttr("use_global_device_ids",
1717                                 builder_->getUnitAttr());
1718 }
1719 
SetLayoutForMlir(mlir::Operation * op,const Shape & shape,llvm::StringRef attr_name)1720 void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
1721                                            const Shape& shape,
1722                                            llvm::StringRef attr_name) {
1723   llvm::SmallVector<int64_t, 4> minor_to_major(
1724       shape.layout().minor_to_major().begin(),
1725       shape.layout().minor_to_major().end());
1726   op->setAttr(
1727       attr_name,
1728       mlir::Builder(op->getContext()).getIndexTensorAttr(minor_to_major));
1729 }
1730 
ConvertShapeToMlirLayout(const xla::Shape & shape,llvm::SmallVectorImpl<mlir::Attribute> & flattened_attr)1731 Status HloFunctionImporter::ConvertShapeToMlirLayout(
1732     const xla::Shape& shape,
1733     llvm::SmallVectorImpl<mlir::Attribute>& flattened_attr) {
1734   if (shape.IsToken()) {
1735     return ::tensorflow::OkStatus();
1736   }
1737   if (shape.IsTuple()) {
1738     std::vector<mlir::Attribute> tuple_layouts;
1739     for (int i = 0; i < shape.tuple_shapes_size(); i++) {
1740       TF_RETURN_IF_ERROR(
1741           ConvertShapeToMlirLayout(shape.tuple_shapes(i), flattened_attr));
1742     }
1743     return ::tensorflow::OkStatus();
1744   }
1745   if (shape.IsArray()) {
1746     const xla::Layout l = shape.layout();
1747     std::vector<mlir::Attribute> minor_to_major;
1748     for (int64_t i : l.minor_to_major()) {
1749       minor_to_major.push_back(builder_->getI64IntegerAttr(i));
1750     }
1751     llvm::ArrayRef<mlir::Attribute> array_ref(minor_to_major);
1752     flattened_attr.push_back(builder_->getArrayAttr(array_ref));
1753     return ::tensorflow::OkStatus();
1754   }
1755   return tensorflow::errors::Internal("Couldn't convert layout.");
1756 }
1757 
1758 }  // namespace xla
1759